"""
Methods for creating and querying a nearest neighbors model.
"""

import graphlab.connect as _mt
import graphlab as _graphlab
from graphlab.toolkits.model import Model
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.deps import pandas as _pandas, HAS_PANDAS as _HAS_PANDAS


def create(dataset, label, features=None, distance='auto', method='auto',
    verbose=True, **kwargs):
    """
    Create a nearest neighbor model, which can be searched efficiently and
    quickly for the nearest neighbors of a query observation. If the `method`
    argument is specified as `auto`, the type of model is chosen automatically
    based on the type of data in `dataset`.

    Parameters
    ----------
    dataset : SFrame
        Reference data. This SFrame must have a column with labels for each row.

    label : string
        Name of the SFrame column with row labels.

    features : list[string], optional
        Name of the columns with features to use in computing distances between
        observations and the query points. 'None' (the default) indicates that
        all columns except the label should be used as features. Each column can
        be one of the following types:

        - *Numeric*: values of numeric type integer or float.

        - *Array*: list of numeric (integer or float) values. Each list element
          is treated as a separate variable in the model.

        - *Dictionary*: key-value pairs with numeric (integer or float) values.
          Each key indicates a separate variable in the model.

        Columns of type *list* are not supported. Convert them to array columns
        if all entries in the list are of numeric types.

    distance : {'auto', 'euclidean', 'manhattan', 'jaccard', 'cosine'}, optional

        Function that measures the distances between two observations. Please
        see the notes and references for detailed descriptions of the distances.
        Note that for sparse vectors, euclidean, manhattan, and cosine distance
        assume missing keys have value 0.0.

        - *auto* (default): the model chooses a reasonable distance based on the
          data types in 'dataset'.

        - *euclidean*

        - *manhattan*

        - *jaccard*: works only with variables in a dictionary feature, where
          the keys are treated as a set and the values are ignored. Suppose
          :math:`S` and :math:`T` are the sets of keys from two observations'
          dictionaries.

        - *cosine*: works only with the 'brute-force' method because it is not a
          true metric. Please see `Wikipedia
          <http://en.wikipedia.org/wiki/Cosine_similarity>`_ for more detail.

    method : {'auto', 'ball-tree', 'brute-force'}, optional
        Method for computing nearest neighbors. The options are:

        - *auto* (default): the method is chosen automatically, based on the
          type of data and the distance. If the distance is 'manhattan' or
          'euclidean' and the features are numeric or vectors of numeric values,
          then the 'ball-tree' method is used. Otherwise, the 'brute-force'
          method is used.

        - *ball-tree*: use a tree structure to find the k-closest neighbors to
          each query point. The ball tree model is slower to construct than the
          brute force model, but queries are faster than linear time. This
          method is not implemented for jaccard distance and is not applicable
          for cosine distance. See `Liu, et al (2004)
          <http://papers.nips.cc/paper/2666-an-investigation-of-p
          ractical-approximat e-nearest-neighbor-algorithms>`_ for
          implementation details.

        - *brute-force*: compute the distance from a query point to all
          reference observations. There is no computation time for model
          creation with the brute force method (although the reference data is
          held in the model, but each query takes linear time.

    verbose: bool, optional
        If True, print progress updates and model details.

    **kwargs : optional
        Options for the distance function and query method.

        - *leaf_size*: for the ball tree method, the number of points in each
          leaf of the tree. The default is to use the max of 1,000 and n/(2^11),
          which ensures a maximum tree depth of 12. The default leaf size is
          indicated by a "0" in the
          :func:`~graphlab.nearest_neighbors.NearestNeighborsModel.get_default_options`
          method.

    Returns
    -------
    out : NearestNeighborsModel
        A structure for efficiently computing the nearest neighbors in 'dataset'
        of new query points.

    See Also
    --------
    NearestNeighborsModel.query

    Notes
    -----
    - If the features should be weighted equally in the distance calculations
      but are measured on different scales, it is important to standardize the
      features. One way to do this is to subtract the mean of each column and
      divide by the standard deviation.

    - Distance definitions. Suppose :math:`u` and :math:`v` are observations
      with :math:`d` variables each.

        - `euclidean`
            
            .. math:: D(u, v) = \\sqrt{\sum_i^d (u_i - v_i)^2}

        - `manhattan`
            .. math:: D(u, v) = \\sum_i^d |u_i - v_i|

        - `jaccard`
            .. math:: D(S, T) = 1 - \\frac{|S \cap T|}{|S \cup T|}

        - `cosine`
            .. math::

                D(u, v) = 1 - \\frac{\sum_i^d u_i v_i}
                {\sqrt{\sum_i^d u_i^2}\sqrt{\sum_i^d v_i^2}}

    References
    ----------
    - `Wikipedia - nearest neighbor
      search <http://en.wikipedia.org/wiki/Nearest_neighbor_search>`_

    - `Wikipedia - ball tree <http://en.wikipedia.org/wiki/Ball_tree>`_

    - Ball tree implementation: Liu, T., et al. (2004) `An Investigation of
      Practical Approximate Nearest Neighbor Algorithms
      <http://papers.nips.cc/paper/2666-an-investigation-of-p
      ractical-approximat e-nearest-neighbor-algorithms>`_. Advances in Neural
      Information Processing Systems pp. 825-832.

    - `Wikipedia - Jaccard distance
      <http://en.wikipedia.org/wiki/Jaccard_index>`_

    - `Wikipedia - Cosine distance
      <http://en.wikipedia.org/wiki/Cosine_similarity>`_

    Examples
    --------
    Construct a nearest neighbors model with automatically determined method and
    distance:

    >>> sf = graphlab.SFrame({'label': [str(x) for x in range(3)],
                              'feature1': [0.98, 0.62, 0.11],
                              'feature2': [0.69, 0.58, 0.36]})
    >>> model = graphlab.nearest_neighbors.create(sf, 'label')

    For datasets with a large number of rows and up to about 100 variables, the
    ball tree method often leads to much faster queries.

    >>> model = graphlab.nearest_neighbors.create(sf, 'label',
                                                  distance='manhattan',
                                                  method='ball-tree')
    """

    ## Validate input
    if not (isinstance(dataset, _SFrame) or
                    (_HAS_PANDAS and isinstance(dataset, _pandas.DataFrame))):
        raise TypeError("Input 'dataset' must be an SFrame or pandas.DataFrame")

    if type(dataset) != _SFrame:
        dataset = _SFrame(dataset)

    if dataset.num_rows() == 0 or dataset.num_cols() == 0:
        raise ValueError("Input 'dataset' has no data.")

    if features is not None:
        if not hasattr(features, '__iter__'):
            raise TypeError("Input 'features' must be an iterable type.")

        if not all([isinstance(x, str) for x in features]):
            raise TypeError("Input 'features' must contain only strings.")

    if not isinstance(label, str):
        raise TypeError("Input 'label' must be a string type.")

    if not label in dataset.column_names():
        raise ValueError("Input 'label' must be the name of a column in the " +\
                         "reference SFrame 'dataset'.")


    ## Extract the features and labels
    if features is None:
        features = dataset.column_names()
        features.remove(label)

    sf_features = dataset.select_columns(features)
    sf_label = dataset.select_columns([label])


    ## Get feature types and decide which method to use
    num_variables = sum([len(x) if hasattr(x, '__iter__') else 1
                        for x in sf_features[0].itervalues()])  # assume the number of list elements does not change

    if method == 'auto':
        if ((distance=='euclidean' or distance=='manhattan' or distance=='auto')
                and not dict in sf_features.column_types()
                and num_variables <= 100):
            _method = 'ball-tree'
        else:
            _method = 'brute-force'
    else:
        _method = method


    if _method == 'ball-tree' and distance == 'cosine':
        raise TypeError("The ball tree method does not work with cosine " +\
                        "distance. Please try 'euclidean' or 'manhattan' " +\
                        "distance instead.")

    if _method == 'ball-tree' and distance == 'jaccard':
        raise TypeError("The ball tree method does not currently work with " +\
                        "jaccard distance. Please try 'euclidean' or 'manhattan' " +\
                        "distance instead.")


    ## Pick the right model name for the method
    if _method == 'ball-tree':
        model_name = 'nearest_neighbors_ball_tree'
        _mt._get_metric_tracker().track('toolkit.nearest_neighbors_balltree.create')

    elif _method == 'brute-force':
        model_name = 'nearest_neighbors_brute_force'
        _mt._get_metric_tracker().track('toolkit.nearest_neighbors_brute.create')

    else:
        raise ValueError("Method must be 'brute-force', 'ball-tree', or 'auto'")


    ## Clean the method options and create the options dictionary
    if len(kwargs) > 0:
        _method_options = {k.lower(): v for k, v in kwargs.items()}
    else:
        _method_options = {}

    opts = {}
    opts.update(_method_options)
    opts.update(
        {'model_name': model_name,
        'label': sf_label,
        'features': sf_features,
        'distance': distance})

    ## Construct the nearest neighbors model
    print "Starting model construction..."
    result = _graphlab.toolkits.main.run('nearest_neighbors_train', opts,
                                         verbose)
    model_proxy = result['model']
    model = NearestNeighborsModel(model_proxy)
    model.summary()
    print

    return model


class NearestNeighborsModel(Model):
    """
    The NearestNeighborsModel represents rows of an SFrame in a structure that
    is used to quickly and efficiently find the nearest neighbors of a query
    point. Use the 'create' method in this module to construct a
    NearestNeighborsModel instance.
    """

    def __init__(self, model_proxy):
        """___init__(self)"""
        self.__proxy__ = model_proxy
        self.__name__ = 'nearest_neighbors'

    def _get_wrapper(self):
        def model_wrapper(model_proxy):
            return NearestNeighborsModel(model_proxy)
        return model_wrapper

    def __str__(self):
        """
        Return a string description of the model to the ``print`` method.

        Returns
        -------
        out : string
            A description of the NearestNeighborsModel.
        """
        return self.__repr__()

    def __repr__(self):
        """
        Print a string description of the model when the model name is entered
        in the terminal.
        """

        width = 24
        key_str = "{:<{}}: {}"

        model_fields = [
            ("Distance", 'distance'),
            ("Method", 'method'),
            ("Examples", 'num_examples'),
            ("Features", 'num_features'),
            ("Variables", 'num_variables'),
            ("Training time (sec)", 'train_time')]

        ball_tree_fields = [
            ("Tree depth", 'tree_depth'),
            ("Leaf size", 'leaf_size')]

        out = []
        out.append(key_str.format("Class", width, self.__class__.__name__))

        for k, v in model_fields:
            value = self.get(v)
            if isinstance(value, float):
                try:
                    value = round(value, 4)
                except:
                    pass
            out.append(key_str.format(k, width, value))

        if self.get('method') == 'ball tree':
            for k, v in ball_tree_fields:
                value = self.get(v)
                if isinstance(value, float):
                    try:
                        value = round(value, 4)
                    except:
                        pass
                out.append(key_str.format(k, width, value))

        return '\n'.join(out)

    def summary(self):
        """
        Display a summary of the NearestNeighborsModel.

        Examples
        --------
        >>> sf = graphlab.SFrame({'label': [str(x) for x in range(3)],
                                  'feature1': [0.98, 0.62, 0.11],
                                  'feature2': [0.69, 0.58, 0.36]})
        >>> model = graphlab.nearest_neighbors.create(sf, 'label')
        >>> model.summary()
                                    Model summary
        --------------------------------------------------------
        Class                   : NearestNeighborsModel
        Distance                : euclidean
        Method                  : ball tree
        Examples                : 3
        Features                : 2
        Variables               : 2
        Training time (sec)     : 0.0232
        Tree depth              : 1
        Leaf size               : 1000
        """

        _mt._get_metric_tracker().track('toolkit.nearest_neighbors.summary')

        print ""
        print "                    Model summary                       "
        print "--------------------------------------------------------"
        print self.__repr__()

    def list_fields(self):
        """
        List the fields stored in the model, including data, model, and training
        options. Each field can be queried with the ``get`` method.

        Returns
        -------
        out : list
            List of fields queryable with the ``get`` method.

        See Also
        --------
        get

        Examples
        --------
         >>> sf = graphlab.SFrame({'label': [str(x) for x in range(3)],
                                  'feature1': [0.98, 0.62, 0.11],
                                  'feature2': [0.69, 0.58, 0.36]})
        >>> model = graphlab.nearest_neighbors.create(sf, 'label')
        >>> model.list_fields()
        ['distance',
         'features',
         'label',
         'leaf_size',
         'method',
         'num_examples',
         'num_features',
         'num_variables',
         'train_time',
         'tree_depth']
        """

        _mt._get_metric_tracker().track('toolkit.nearest_neighbors.list_fields')

        opts = {'model': self.__proxy__, 'model_name': self.__name__}
        response = _graphlab.toolkits.main.run('nearest_neighbors_list_keys',
                                               opts)

        return sorted(response.keys())

    def get(self, field):
        """
        Return the value of a given field. The list of all queryable fields is
        detailed below, and can be obtained with the
        :func:`~graphlab.nearest_neighbors.NearestNeighborsModel.list_fields`
        method.

        +-----------------------+----------------------------------------------+
        |      Field            | Description                                  |
        +=======================+==============================================+
        | distance              | Measure of dissimilarity between two points  |
        +-----------------------+----------------------------------------------+
        | features              | Feature column names                         |
        +-----------------------+----------------------------------------------+
        | label                 | Label column names                           |
        +-----------------------+----------------------------------------------+
        | leaf_size             | Max size of leaf nodes (ball tree only)      |
        +-----------------------+----------------------------------------------+
        | method                | Method of organizing reference data          |
        +-----------------------+----------------------------------------------+
        | num_examples          | Number of reference data observations        |
        +-----------------------+----------------------------------------------+
        | num_features          | Number of features for distance computation  |
        +-----------------------+----------------------------------------------+
        | num_variables         | Number of variables for distance computation |
        +-----------------------+----------------------------------------------+
        | train_time            | Time to create the reference structure       |
        +-----------------------+----------------------------------------------+
        | tree_depth            | Number of levels in the tree (ball tree only)|
        +-----------------------+----------------------------------------------+

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out
            Value of the requested field.

        See Also
        --------
        list_fields

        Examples
        --------
        >>> sf = graphlab.SFrame({'label': [str(x) for x in range(3)],
                                  'feature1': [0.98, 0.62, 0.11],
                                  'feature2': [0.69, 0.58, 0.36]})
        >>> model = graphlab.nearest_neighbors.create(sf, 'label')
        >>> model.get('num_features')
        2

        >>> model['train_time']
        0.023223
        """

        _mt._get_metric_tracker().track('toolkit.nearest_neighbors.get')

        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'field': field}
        response = _graphlab.toolkits.main.run('nearest_neighbors_get_value',
                                               opts)

        return response['value']

    def get_default_options(self):
        """
        Return a dictionary with the default options for the
        NearestNeighborsModel.

        Returns
        -------
        out : dict
            Dictionary with the default options.

        See Also
        --------
        get_current_options, list_fields, get

        Examples
        --------
        >>> sf = graphlab.SFrame({'label': [str(x) for x in range(3)],
                                  'feature1': [0.98, 0.62, 0.11],
                                  'feature2': [0.69, 0.58, 0.36]})
        >>> model = graphlab.nearest_neighbors.create(sf, 'label')
        >>> model.get_default_options()
        {'distance': 'auto', 'leaf_size': 0}
        """

        _mt._get_metric_tracker().track(
            'toolkit.nearest_neighbors.get_default_options')

        opts = {'model': self.__proxy__, 'model_name': self.__name__}

        return _graphlab.toolkits.main.run(
            'nearest_neighbors_get_default_options', opts)

    def get_current_options(self):
        """
        Return a dictionary with the options used to define and create the
        current NearestNeighborModel instance.

        Returns
        -------
        out : dict
            Dictionary of options used to train the current instance of the
            NearestNeighborsModel.

        See Also
        --------
        get_default_options, list_fields, get

        Examples
        --------
        >>> sf = graphlab.SFrame({'label': [str(x) for x in range(3)],
                                  'feature1': [0.98, 0.62, 0.11],
                                  'feature2': [0.69, 0.58, 0.36]})
        >>> model = graphlab.nearest_neighbors.create(sf, 'label')
        >>> model.get_current_options()
        {'distance': 'euclidean', 'leaf_size': 1000}
        """

        _mt._get_metric_tracker().track(
            'toolkit.nearest_neighbors.get_current_options')

        opts = {'model': self.__proxy__, 'model_name': self.__name__}

        return _graphlab.toolkits.main.run(
            'nearest_neighbors_get_current_options', opts)

    def training_stats(self):
        """
        Return a dictionary of statistics collected during creation of the
        model. These statistics are also available with the ``get`` method and
        are described in more detail in that method's documentation.

        Returns
        -------
        out : dict
            Dictionary of statistics compiled during creation of the
            NearestNeighborsModel.

        See Also
        --------
        summary

        Examples
        --------
        >>> sf = graphlab.SFrame({'label': [str(x) for x in range(3)],
                                  'feature1': [0.98, 0.62, 0.11],
                                  'feature2': [0.69, 0.58, 0.36]})
        >>> model = graphlab.nearest_neighbors.create(sf, 'label')
        >>> model.training_stats()
        {'features': 'feature1, feature2',
         'label': 'label',
         'leaf_size': 1000,
         'num_examples': 3,
         'num_features': 2,
         'num_variables': 2,
         'train_time': 0.023223,
         'tree_depth': 1}
        """

        _mt._get_metric_tracker().track(
            'toolkit.nearest_neighbors.training_stats')

        opts = {'model': self.__proxy__, 'model_name': self.__name__}
        return _graphlab.toolkits.main.run("nearest_neighbors_training_stats",
                opts)

    def query(self, dataset, label, features=None, k=5, verbose=True):
        """
        Retrieve the nearest neighbors from the reference set for each element
        of the query set. The query SFrame must include columns with the same
        names as the label and feature columns used to create the
        NearestNeighborsModel.

        Parameters
        ----------
        dataset : SFrame
            Query data. This SFrame must have a column with labels for each row.
            If the features for each observation are numeric, they may be in
            separate columns of 'dataset' or a single column with lists of
            values. The features may also be in the form of a column of sparse
            vectors (i.e. dictionaries), with string keys and numeric values.

        label : string
            Name of the SFrame column with row labels.

        features : list[string], optional
            Name of column(s) with features to use in computing distances
            between observations and the query points. These *must* match the
            features used for constructing the reference model.

        k : int, optional
            Number of nearest neighbors to return from the reference set for
            each query observation.

        verbose: bool, optional
            If True, print progress updates and model details.

        Returns
        -------
        out : SFrame
            An SFrame with the k-nearest neighbors of each query observation.
            The result contains four columns: the first is the label of the
            query observation, the second is the label of the nearby reference
            observation, the third is the distance between the query and
            reference observations, and the fourth is the rank of the reference
            observation among the query's k-nearest neighbors.

        Examples
        --------
        First construct a toy SFrame and create a nearest neighbors model:

        >>> sf = graphlab.SFrame({'label': [str(x) for x in range(3)],
                                  'feature1': [0.98, 0.62, 0.11],
                                  'feature2': [0.69, 0.58, 0.36]})
        >>> model = graphlab.nearest_neighbors.create(sf, 'label')

        A new SFrame contains query observations with same schema as the
        reference SFrame. This SFrame is passed to the ``query`` method.

        >>> queries = graphlab.SFrame({'label': [str(x) for x in range(3)],
                                      'feature1': [0.05, 0.61, 0.99],
                                      'feature2': [0.06, 0.97, 0.86]})
        >>> model.query(queries, 'label', k=2)
        +-------------+-----------------+----------------+------+
        | query_label | reference_label |    distance    | rank |
        +-------------+-----------------+----------------+------+
        |      0      |        2        | 0.305941170816 |  1   |
        |      0      |        1        | 0.771556867638 |  2   |
        |      1      |        1        | 0.390128184063 |  1   |
        |      1      |        0        | 0.464004310325 |  2   |
        |      2      |        0        | 0.170293863659 |  1   |
        |      2      |        1        | 0.464004310325 |  2   |
        +-------------+-----------------+----------------+------+
        """

        _mt._get_metric_tracker().track(
            'toolkit.nearest_neighbors.query')


        ## Validate input and construct the query observations
        if not isinstance(dataset, _SFrame):
            raise TypeError("Input 'dataset' must be an SFrame with query " + \
                            " observations.")

        if dataset.num_rows() == 0 or dataset.num_cols() == 0:
            raise ValueError("Input 'dataset' has no data.")

        if features is not None:
            if not hasattr(features, '__iter__'):
                raise TypeError("Input 'features' must be an iterable type.")

            if not all([isinstance(x, str) for x in features]):
                raise TypeError("Input 'features' must contain only strings.")

        if not isinstance(label, str):
            raise TypeError("Input 'label' must be a string type.")

        if not label in dataset.column_names():
            raise ValueError("Input 'label' must be the name of a column in " +\
                             "the query SFrame 'dataset'.")

        if not isinstance(k, (int, float)) or k <= 0:
            raise ValueError("Input 'k' must be a single number greater than 0.")


        ## Extract the features and labels
        if features is None:
            features = dataset.column_names()
            features.remove(label)

        sf_features = dataset.select_columns(features)
        sf_label = dataset.select_columns([label])

        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'features': sf_features,
                'label': sf_label,
                'k': k}

        print "Starting model querying..."
        result = _graphlab.toolkits.main.run('nearest_neighbors_query', opts,
                                             verbose)
        return _SFrame(None, _proxy=result['neighbors'])
