"""
Methods for creating and querying a nearest neighbors model.
"""

import graphlab.connect as _mt
import graphlab as _graphlab
from graphlab.toolkits.model import Model
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.deps import pandas as _pandas, HAS_PANDAS as _HAS_PANDAS


def create(dataset, label, features=None, distance='auto', method='auto',
    verbose=True, **kwargs):
    """
    Create a nearest neighbor model, which can be searched efficiently and
    quickly for the nearest neighbors of a query observation. If the `method`
    argument is specified as 'auto', the type of model is chosen automatically
    based on the type of data in 'dataset'.

    Parameters
    ----------
    dataset : SFrame
        Reference data. This SFrame must have a column with labels for each row.
        
    label : string
        Name of the SFrame column with row labels.

    features : list[string], optional
        Name of the columns with features to use in computing distances between
        observations and the query points. 'None' (the default) indicates that
        all columns except the label should be used as features. Each column can
        be one of the following types:
        
        - *Numeric*: values of numeric type integer or float.

        - *Array*: list of numeric (integer or float) values. Each list element
          is treated as a separate variable in the model.

        - *Dictionary*: key-value pairs with numeric (integer or float) values.
          Each key indicates a separate variable in the model.

        Columns of type *list* are not supported. Convert them to array columns
        if all entries in the list are of numeric types.
          
    distance : {'auto', 'euclidean', 'manhattan', 'jaccard', 'cosine'}, optional
        Function that measures the distances between two observations. Please
        see the references for detailed descriptions of the distances. Note that
        for sparse vectors, euclidean, manhattan, and cosine distance assume
        missing keys have value 0.0. For the distances below, suppose :math:`u`
        and :math:`v` are observations with :math:`d` variables each.

        - *auto* (default): the model chooses a reasonable distance based on the
          data types in 'dataset'.

        - *euclidean*:

            .. math:: D(u, v)
                = \\sqrt{\sum_i^d (u_i - v_i)^2}

        - *manhattan*:

            .. math:: D(u, v)
                = \\sum_i^d |u_i - v_i|

        - *jaccard*: works only with variables in a dictionary feature, where
          the keys are treated as a set and the values are ignored. Suppose
          :math:`S` and :math:`T` are the sets of keys from two observations'
          dictionaries.

            .. math:: D(S, T)
                = 1 - \\frac{|S \cap T|}{|S \cup T|}

        - *cosine*: works only with the 'brute-force' method because it is not a
          true metric. Please see `Wikipedia
          <http://en.wikipedia.org/wiki/Cosine_similarity>`_ for more detail.

            .. math:: D(u, v)
                = 1 - \\frac{\sum_i^d u_i v_i}
                {\sqrt{\sum_i^d u_i^2}
                \sqrt{\sum_i^d v_i^2}}

    method : {'auto', 'ball-tree', 'brute-force'}, optional
        Method for computing nearest neighbors. The options are:
        
        - *auto* (default): the method is chosen automatically, based on the
          type of data and the distance. If the distance is 'manhattan' or
          'euclidean' and the features are numeric or vectors of numeric values,
          then the 'ball-tree' method is used. Otherwise, the 'brute-force'
          method is used.

        - *ball-tree*: use a tree structure to find the k-closest neighbors to
          each query point. The ball tree model is slower to construct than the
          brute force model, but queries are faster than linear time. This
          method is not implemented for jaccard distance and is not applicable
          for cosine distance. See `Liu, et al (2004)
          <http://papers.nips.cc/paper/2666-an-investigation-of-p
          ractical-approximat e-nearest-neighbor-algorithms>`_ for
          implementation details.

        - *brute-force*: compute the distance from a query point to all
          reference observations. There is no computation time for model
          creation with the brute force method (although the reference data is
          held in the model, but each query takes linear time.

    verbose: bool, optional
        If True, print progress updates and model details.

    **kwargs : optional
        Options for the distance function and query method.

        - *leaf_size*: for the ball tree method, the number of points in each
          leaf of the tree. The default is to use the max of 1,000 and n/(2^11),
          which ensures a maximum tree depth of 12. The default leaf size is
          indicated by a "0" in the
          :func:`~graphlab.nearest_neighbors.NearestNeighborsModel.get_default_options`
          method.

    Returns
    -------
    out : NearestNeighborsModel
        A structure for efficiently computing the nearest neighbors in 'dataset'
        of new query points.

    Notes
    -----
    - If the features should be weighted equally in the distance calculations
      but are measured on different scales, it is important to standardize the
      features. One way to do this is to subtract the mean of each column and
      divide by the standard deviation.

    Examples
    --------
    *Training*

    Given an :class:`~graphlab.SFrame` ``sf`` with a list of columns
    [``feature_1`` ... ``feature_K``] denoting features and a label column
    ``row_label``, create a
    :class:`~graphlab.nearest_neighbors.NearestNeighborsModel`:

    >>> import graphlab as gl
    >>> model = gl.nearest_neighbors.create(sf, 'row_label')

    By default, all columns in the training data except the label are used as
    features. It's also possible to select a subset of columns in the SFrame:

    >>> model = gl.nearest_neighbors.create(sf, 'label', ['feature_1', 'feature_2'])

    The distance function and model type are chosen automatically based on the
    type of data in 'sf', unless otherwise specified. For datasets with a large
    number of rows and up to about 100 variables, the ball tree method often
    leads to much faster queries.

    >>> model = gl.nearest_neighbors.create(sf, 'label',
                                            ['feature_1', 'feature_2'],
                                            distance='manhattan',
                                            method='ball-tree')

    If the ball tree is used, it's important to choose an approriate value for
    the 'leaf_size' parameter, which controls how many observations are stored
    in each leaf of the ball tree. By default, this is set so that the tree is
    no more than 12 levels deep, but larger or smaller values may lead to
    quicker queries depending on the shape and dimension of the data.

    >>> model = gl.nearest_neighbors.create(sf, 'label',
                                            ['feature_1', 'feature_2'],
                                            distance='manhattan',
                                            method='ball-tree',
                                            leaf_size=2000)

    *Querying*

    Queries should be passed in the form of an SFrame that includes the same
    columns used to train the model (other columns may be present, but are
    ignored). The same SFrame may be used as queries to get all-point nearest
    neighbors.

    >>> knn = model.query(sf, 'row_label', features=['feature_1', 'feature_2'], k=5)


    References
    ----------
    - `Wikipedia - nearest neighbor
      search <http://en.wikipedia.org/wiki/Nearest_neighbor_search>`_

    - `Wikipedia - ball tree <http://en.wikipedia.org/wiki/Ball_tree>`_

    - Ball tree implementation: Liu, T., et al. (2004) `An Investigation of
      Practical Approximate Nearest Neighbor Algorithms
      <http://papers.nips.cc/paper/2666-an-investigation-of-p
      ractical-approximat e-nearest-neighbor-algorithms>`_. Advances in Neural
      Information Processing Systems pp. 825-832.

    - `Wikipedia - Jaccard distance
      <http://en.wikipedia.org/wiki/Jaccard_index>`_

    - `Wikipedia - Cosine distance
      <http://en.wikipedia.org/wiki/Cosine_similarity>`_
    """

    ## Validate input
    if not (isinstance(dataset, _SFrame) or 
                    (_HAS_PANDAS and isinstance(_pandas.DataFrame))):
        raise TypeError("Input 'dataset' modulest be an SFrame.")

    if type(dataset) != _SFrame:
        dataset = _SFrame(dataset)

    if dataset.num_rows() == 0 or dataset.num_cols() == 0:
        raise ValueError("Input 'dataset' has no data.")

    if features is not None:
        if not hasattr(features, '__iter__'):
            raise TypeError("Input 'features' must be an iterable type.")

        if not all([isinstance(x, str) for x in features]):
            raise TypeError("Input 'features' must contain only strings.")

    if not isinstance(label, str):
        raise TypeError("Input 'label' must be a string type.")

    if not label in dataset.column_names():
        raise ValueError("Input 'label' must be the name of a column in the " +\
                         "reference SFrame 'dataset'.")


    ## Extract the features and labels
    if features is None:
        features = dataset.column_names()
        features.remove(label)

    sf_features = dataset.select_columns(features)
    sf_label = dataset.select_columns([label])


    ## Get feature types and decide which method to use
    num_variables = sum([len(x) if hasattr(x, '__iter__') else 1
                        for x in sf_features[0].itervalues()])  # assume the number of list elements does not change

    if method == 'auto':
        if ((distance=='euclidean' or distance=='manhattan' or distance=='auto')
                and not dict in sf_features.column_types()
                and num_variables <= 100):
            _method = 'ball-tree'
        else:
            _method = 'brute-force'
    else:
        _method = method


    if _method == 'ball-tree' and distance == 'cosine':
        raise TypeError("The ball tree method does not work with cosine " +\
                        "distance. Please try 'euclidean' or 'manhattan' " +\
                        "distance instead.")

    if _method == 'ball-tree' and distance == 'jaccard':
        raise TypeError("The ball tree method does not currently work with " +\
                        "jaccard distance. Please try 'euclidean' or 'manhattan' " +\
                        "distance instead.")


    ## Pick the right model name for the method
    if _method == 'ball-tree':
        model_name = 'nearest_neighbors_ball_tree'
        _mt._get_metric_tracker().track('toolkit.nearest_neighbors_balltree.create')

    elif _method == 'brute-force':
        model_name = 'nearest_neighbors_brute_force'
        _mt._get_metric_tracker().track('toolkit.nearest_neighbors_brute.create')

    else:
        raise ValueError("Method must be 'brute-force', 'ball-tree', or 'auto'")


    ## Clean the method options and create the options dictionary
    if len(kwargs) > 0:
        _method_options = {k.lower(): v for k, v in kwargs.items()}
    else:
        _method_options = {}

    opts = {}
    opts.update(_method_options)
    opts.update(
        {'model_name': model_name,
        'label': sf_label,
        'features': sf_features,
        'distance': distance})

    ## Construct the nearest neighbors model
    print "Starting model construction..."
    result = _graphlab.toolkits.main.run('nearest_neighbors_train', opts,
                                         verbose)
    model_proxy = result['model']
    model = NearestNeighborsModel(model_proxy)
    model.summary()
    print

    return model


class NearestNeighborsModel(Model):
    """
    The NearestNeighborsModel represents rows of an SFrame in a structure that
    is used to quickly and efficiently find the nearest neighbors of a query
    point. Use the 'create' method in this module to construct a
    NearestNeighborsModel instance.
    """

    def __init__(self, model_proxy):
        """___init__(self)"""
        self.__proxy__ = model_proxy
        self.__name__ = 'nearest_neighbors'

    def _get_wrapper(self):
        def model_wrapper(model_proxy):
            return NearestNeighborsModel(model_proxy)
        return model_wrapper

    def __str__(self):
        """
        Return a string description of the model to the ``print`` method.

        Returns
        -------
        out : string
            A description of the NearestNeighborsModel.
        """
        return self.__repr__()

    def __repr__(self):
        """
        Print a string description of the model when the model name is entered
        in the terminal.
        """

        width = 24
        key_str = "{:<{}}: {}"

        model_fields = [
            ("Distance", 'distance'),
            ("Method", 'method'),
            ("Examples", 'num_examples'),
            ("Features", 'num_features'),
            ("Variables", 'num_variables'),
            ("Training time (sec)", 'train_time')]

        ball_tree_fields = [
            ("Tree depth", 'tree_depth'),
            ("Leaf size", 'leaf_size')]

        out = []
        out.append(key_str.format("Class", width, self.__class__.__name__))

        for k, v in model_fields:
            value = self.get(v)
            if isinstance(value, float):
                try:
                    value = round(value, 4)
                except:
                    pass
            out.append(key_str.format(k, width, value))

        if self.get('method') == 'ball tree':
            for k, v in ball_tree_fields:
                value = self.get(v)
                if isinstance(value, float):
                    try:
                        value = round(value, 4)
                    except:
                        pass
                out.append(key_str.format(k, width, value))

        return '\n'.join(out)

    def summary(self):
        """
        Display a summary of the NearestNeighborsModel.
        """

        _mt._get_metric_tracker().track('toolkit.nearest_neighbors.summary')

        print ""
        print "                    Model summary                       "
        print "--------------------------------------------------------"        
        print self.__repr__()

    def list_fields(self):
        """
        List the fields stored in the model, including data, model, and training
        options. Each field can be queried with the ``get`` method.

        Returns
        -------
        out : list
            List of fields queryable with the ``get`` method.
        """

        _mt._get_metric_tracker().track('toolkit.nearest_neighbors.list_fields')

        opts = {'model': self.__proxy__, 'model_name': self.__name__}
        response = _graphlab.toolkits.main.run('nearest_neighbors_list_keys',
                                               opts)

        return sorted(response.keys())

    def get(self, field):
        """
        Return the value of a given field. The list of all queryable fields is
        detailed below, and can be obtained with the
        :func:`~graphlab.nearest_neighbors.NearestNeighborsModel.list_fields`
        method.

        +-----------------------+----------------------------------------------+
        |      Field            | Description                                  |
        +=======================+==============================================+
        | train_time           | Time to create the reference structure       |
        +-----------------------+----------------------------------------------+
        | features              | Feature column names                         |
        +-----------------------+----------------------------------------------+
        | label                 | Label column names                           |
        +-----------------------+----------------------------------------------+
        | num_examples          | Number of reference data observations        |
        +-----------------------+----------------------------------------------+
        | num_features          | Number of features for distance computation  |
        +-----------------------+----------------------------------------------+
        | num_variables         | Number of variables for distance computation |
        +-----------------------+----------------------------------------------+

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out
            Value of the requested field.
        """

        _mt._get_metric_tracker().track('toolkit.nearest_neighbors.get')

        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'field': field}
        response = _graphlab.toolkits.main.run('nearest_neighbors_get_value',
                                               opts)

        return response['value']

    def get_default_options(self):
        """
        Return a dictionary with the default options for the
        NearestNeighborsModel.

        Returns
        -------
        out : dict
            Dictionary with the default options.
        """

        _mt._get_metric_tracker().track(
            'toolkit.nearest_neighbors.get_default_options')

        opts = {'model': self.__proxy__, 'model_name': self.__name__}

        return _graphlab.toolkits.main.run(
            'nearest_neighbors_get_default_options', opts)

    def get_current_options(self):
        """
        Return a dictionary with the options used to define and create the
        current NearestNeighborModel instance.

        Returns
        -------
        out : dict
            Dictionary of options used to train the current instance of the
            NearestNeighborsModel.
        """

        _mt._get_metric_tracker().track(
            'toolkit.nearest_neighbors.get_current_options')

        opts = {'model': self.__proxy__, 'model_name': self.__name__}

        return _graphlab.toolkits.main.run(
            'nearest_neighbors_get_current_options', opts)

    def training_stats(self):
        """
        Return a dictionary of statistics collected during creation of the
        model. These statistics are also available with the ``get`` method and
        are described in more detail in that method's documentation.

        Returns
        -------
        out : dict
            Dictionary of statistics compiled during creation of the
            NearestNeighborsModel.
        """

        _mt._get_metric_tracker().track(
            'toolkit.nearest_neighbors.training_stats')

        opts = {'model': self.__proxy__, 'model_name': self.__name__}
        return _graphlab.toolkits.main.run("nearest_neighbors_training_stats",
                opts)

    def query(self, dataset, label, features=None, k=5, verbose=True):
        """
        Retrieve the nearest neighbors from the reference set for each element
        of the query set. The query SFrame must include columns with the same
        names as the label and feature columns used to create the
        NearestNeighborsModel.

        Parameters
        ----------
        dataset : SFrame
            Query data. This SFrame must have a column with labels for each row.
            If the features for each observation are numeric, they may be in
            separate columns of 'dataset' or a single column with lists of
            values. The features may also be in the form of a column of sparse
            vectors (i.e. dictionaries), with string keys and numeric values.

        label : string
            Name of the SFrame column with row labels.

        features : list[string], optional
            Name of column(s) with features to use in computing distances
            between observations and the query points. These *must* match the
            features used for constructing the reference model.

        k : int, optional
            Number of nearest neighbors to return from the reference set for
            each query observation.

        verbose: bool, optional
            If True, print progress updates and model details.

        Returns
        -------
        out : SFrame
            An SFrame with the k-nearest neighbors of each query observation.
            The result contains four columns: the first is the label of the
            query observation, the second is the label of the nearby reference
            observation, the third is the distance between the query and
            reference observations, and the fourth is the rank of the reference
            observation among the query's k-nearest neighbors.
        """

        _mt._get_metric_tracker().track(
            'toolkit.nearest_neighbors.query')


        ## Validate input and construct the query observations
        if not isinstance(dataset, _SFrame):
            raise TypeError("Input 'dataset' must be an SFrame with query " + \
                            " observations.")

        if dataset.num_rows() == 0 or dataset.num_cols() == 0:
            raise ValueError("Input 'dataset' has no data.")

        if features is not None:
            if not hasattr(features, '__iter__'):
                raise TypeError("Input 'features' must be an iterable type.")

            if not all([isinstance(x, str) for x in features]):
                raise TypeError("Input 'features' must contain only strings.")

        if not isinstance(label, str):
            raise TypeError("Input 'label' must be a string type.")

        if not label in dataset.column_names():
            raise ValueError("Input 'label' must be the name of a column in " +\
                             "the query SFrame 'dataset'.")

        if not isinstance(k, (int, float)) or k <= 0:
            raise ValueError("Input 'k' must be a single number greater than 0.") 


        ## Extract the features and labels
        if features is None:
            features = dataset.column_names()
            features.remove(label)

        sf_features = dataset.select_columns(features)
        sf_label = dataset.select_columns([label])

        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'features': sf_features,
                'label': sf_label,
                'k': k}

        print "Starting model querying..."
        result = _graphlab.toolkits.main.run('nearest_neighbors_query', opts,
                                             verbose)
        return _SFrame(None, _proxy=result['neighbors'])
