"""
This package contains the Gradient Boosted Trees model class and the create function.
"""

import graphlab as _graphlab
import graphlab.connect as _mt
import graphlab.toolkits.evaluation as evaluation
import json
import time
from graphlab.toolkits.model import Model
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.data_structures.sarray import SArray as _SArray
from graphlab.deps import HAS_PANDAS as _HAS_PANDAS, pandas as _pandas
from graphlab.toolkits.main import ToolkitError

DEFAULT_TRAINING_PARAM = {
    'max_depth': 6,
    'step_size': 0.3,
    'min_loss_reduction': 0,
    'min_child_weight': 0.1,
    'subsample': 1}

_BOOSTED_TREES_MODEL_PARAMS_KEYS = ['num_trees', 'step_size', 'max_depth', 'num_iterations', 'min_child_weight', 'min_loss_reduction', 'subsample']
_BOOSTED_TREE_TRAINING_PARAMS_KEYS = ['objective', 'training_time', 'training_error', 'validation_error', 'evaluation_metric']

_BOOSTED_TREE_TRAINING_DATA_PARAMS_KEYS = ['target', 'features', 'num_features', 'num_training_examples', 'num_validation_examples']

DEFAULT_HYPER_PARAMETER_RANGE = {
    'params': {
        'max_depth': [6, 8, 10],
        'step_size': 0.3,
        'min_loss_reduction': [0, 1, 10],
        'min_child_weight': 0.1,
        'subsample': 1},
    'num_iterations': [10, 50, 100]
}


class BoostedTreesModel(Model):
    """
    The gradient boosted trees model can be used for
    regression and classification tasks.

    The prediction is based on a collection of base learners,
    `regression trees or classification trees
    <http://en.wikipedia.org/wiki/Decision_tree_learning>`_ depending on the
    objective, and combines them through a technique called `gradient boosting
    <http://en.wikipedia.org/wiki/Gradient_boosting>`_.

    Different from linear models, e.g. linear regression or logistic regression,
    the gradient boost trees model is able to model non-linear interactions
    between the features and the target using decision trees as the subroutine.
    It is good for handling numerical features and categorical features with
    tens of categories but is less suitable for highly sparse feautres such as
    text data.

    An instance of this model can be created using
    :py:func:`~graphlab.tree_ensembles.boosted_trees.create`. Do NOT construct
    the model directly.

    See Also
    --------
    create
    """
    def __init__(self, model_proxy, _params, _finalized=False):
        """__init__(self)"""
        self.__proxy__ = model_proxy
        self._model_params = [(k, _params.get(k)) for k in _BOOSTED_TREES_MODEL_PARAMS_KEYS]
        self._training_params = [(k, _params.get(k)) for k in _BOOSTED_TREE_TRAINING_PARAMS_KEYS]
        self._training_data_params = [(k, _params.get(k)) for k in _BOOSTED_TREE_TRAINING_DATA_PARAMS_KEYS]
        if (_finalized):
            self.__params__ = _params

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        width = 24
        key_str = "{:<{}}: {}"

        # exlude these information from __repr__
        keys_to_exclude = set(['features'])

        ret = []
        ret.append(key_str.format("Class", width, self.__class__.__name__))

        # all_params = self._training_params + self._training_data_params + self._model_params
        for params in [self._training_params, self._training_data_params, self._model_params]:
            for k, v in params:
                if k in keys_to_exclude:
                    continue
                if isinstance(v, float):
                    try:
                        v = round(v, 4)
                    except:
                        pass
                ret.append(key_str.format(k, width, v))
            ret.append("")
        return '\n'.join(ret)

    def _get_wrapper(self):
        def model_wrapper(model_proxy):
            return BoostedTreesModel(model_proxy, self.__params__, _finalized=True)
        return model_wrapper

    def summary(self):
        """
        Prints a summary of the model

        Examples
        --------
        >>> m.summary()
        """
        _mt._get_metric_tracker().track('toolkit.tree_ensembles.boosted_trees.summary')
        print self.__str__()

    def list_fields(self):
        """
        Return the fields in the BoostedTreesModel.

        Returns
        -------
        out : list[str]
            A list of fields that can be queried using the
            ``get`` or ``m[key]`` method.

        Examples
        --------
        >>> m.list_fields()
        """
        _mt._get_metric_tracker().track('toolkit.tree_ensembles.boosted_trees.list_fields')
        return self.__proxy__.list_fields() + self.__params__.keys()

    def get(self, field):
        """
        Get the value of a given field. The list of all queryable fields is
        detailed below, and can be obtained programmatically using the
        :func:`~graphlab.boosted_trees.list_fields` method.

        +-------------------------+-------------------------------------------------------------+
        | Field                   | Description                                                 |
        +=========================+=============================================================+
        | target                  | Name of the target column                                   |
        +-------------------------+-------------------------------------------------------------+
        | features                | Names of the feature columns                                |
        +-------------------------+-------------------------------------------------------------+
        | objective               | Learning Objective                                          |
        +-------------------------+-------------------------------------------------------------+
        | num_features            | Number of features in the model                             |
        +-------------------------+-------------------------------------------------------------+
        | num_training_examples   | Number of training examples                                 |
        +-------------------------+-------------------------------------------------------------+
        | num_validation_examples | Number of validation examples                               |
        +-------------------------+-------------------------------------------------------------+
        | step_size               | Step_size used for combining the weight of individual trees |
        +-------------------------+-------------------------------------------------------------+
        | max_depth               | The maximum depth of individual trees                       |
        +-------------------------+-------------------------------------------------------------+
        | num_iterations          | Number of iterations, equals to the number of trees         |
        +-------------------------+-------------------------------------------------------------+
        | min_child_weight        | Minimun weight required on the leave nodes                  |
        +-------------------------+-------------------------------------------------------------+
        | min_loss_reduction      | Minimun loss reduction required for splitting a node        |
        +-------------------------+-------------------------------------------------------------+
        | subsample               | Percentage of the samples for training each individual tree |
        +-------------------------+-------------------------------------------------------------+
        | objective               | The learning objective -- regression or classification      |
        +-------------------------+-------------------------------------------------------------+
        | training_time           | Time spent on training the model in seconds                 |
        +-------------------------+-------------------------------------------------------------+
        | training_error          | Error on training data                                      |
        +-------------------------+-------------------------------------------------------------+
        | valiation_error         | Error on validation data                                    |
        +-------------------------+-------------------------------------------------------------+
        | evaluation_metric       | The evaluation metric                                       |
        +-------------------------+-------------------------------------------------------------+
        | trees                   | A list of string representing each tree                     |
        +-------------------------+-------------------------------------------------------------+
        | history                 | A list of string for the training history                   |
        +-------------------------+-------------------------------------------------------------+

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out : [various]
            The current value of the requested field.

        See Also
        --------
        list_fields

        Examples
        --------
        >>> m.get('training_error')
        """
        if field not in self.list_fields():
            raise KeyError('Key \"%s\" not in model. Available keys are %s.' % (field, ', '.join(self.keys())))

        if field in self.__params__:
            return self.__params__[field]
        else:
            return self.__proxy__.get(field)

    def show_tree(self, tree_id, vlabel_hover=False):
        """
        Plot the tree as an SGraph in canvas.

        Parameters
        ----------
        tree_id : int
            The id of the tree to show. Starting from 0 to num_iterations-1.

        vlabel_hover : bool, optional
            If True, hide the label on vertex, and only show the label
            with mouse hovering.

        Examples
        --------
        >>> m.show_tree(0)
        """
        tree_json = self.__proxy__.get('json_trees')[tree_id]
        # Create a sgraph from the json string of the ith tree
        g = _SGraphFromJsonTree(tree_json)

        # Make the label for each vertex based on its type
        def get_vlabel(record):
            if record['type'] == 'leaf':
                return str(record['value'])
            elif record['type'] == 'indicator':
                return str(record['name'])
            else:
                return str(record['name']) + '<' + str(record['value'])
        g.vertices['__repr__'] = g.vertices.apply(get_vlabel, str)

        fuchsia = [0.69, 0., 0.48]
        blue = [0.039, 0.55, 0.77]
        orange = [1.0, 0.33, 0.]
        green = [0.52, 0.74, 0.]

        # Make the color for leaf nodes
        def get_leaf_vcolor(record, is_classification):
            '''assign color to leaf vertex, negative_leaf->blue, positive_leaf->orange'''
            if is_classification:
                return blue if record['value'] < 0 else orange
            else:
                return orange
        leaf_vertices = g.vertices[g.vertices['type'] == 'leaf']
        is_classification = self['objective'] == 'classification'
        root_vertex_color = {0: fuchsia}
        leaf_vertex_color = dict([(x['__id'], get_leaf_vcolor(x, is_classification)) for x in leaf_vertices])
        highlight = {}
        highlight.update(root_vertex_color)
        highlight.update(leaf_vertex_color)

        # Hack: we want the canvas to show tree_{i} instead of g, so here is how we do it.
        graph_name = 'tree_' + str(tree_id)
        locals().update({graph_name: g})
        del g
        locals()[graph_name].show(vlabel='__repr__', elabel='value', vlabel_hover=vlabel_hover, highlight=highlight, arrows=True)

    def evaluate(self, dataset, metric='auto', threshold=0.5):
        """
        Evaluate the model on the given dataset.

        The default metric depends on model objective:

            - classifcation : confusion_matrix

            - regression : rmse

        Parameters
        ----------
        dataset : SFrame
            Dataset in the same format used for training. The columns names and
            types of the dataset must be the same as that used in training.

        metric : str, optional
            Name of the evaluation metric.  Possible values are:
            'auto': automatically choose according to objective
            'auc': Area under curve
            'rmse': Rooted mean squared error
            'error': Classification error
            'confusion_matrix': Confusion matrix, and classification accuracy

        threshold : float, optional
            Probability threshold for classifying an example as positive.
            Default is 0.5.

        Returns
        -------
        out : dict
            A dictionary containing the evaluation result.

        Examples
        --------
        >>> m.evaluate(test_data, 'rmse')

        Notes
        -----
        When evaluating for classification metrics (e.g. auc,
        confusion_matrix), the classification threshold is set to 0.5. For more
        flexible classification accuracy, please use functions in the
        :py:mod:`~graphlab.toolkits.evaluation` module.
        """
        _mt._get_metric_tracker().track('toolkit.tree_ensembles.boosted_trees.evaluate')
        if not isinstance(dataset, _SFrame) and not (_HAS_PANDAS and isinstance(dataset, _pandas.DataFrame)):
            raise TypeError("Input 'dataset' must be an SFrame or pandas.DataFrame")

        if not isinstance(metric, str):
            raise TypeError('metric type must be str')
        supported_metrics = ['auto', 'auc', 'rmse', 'error', 'confusion_matrix']
        if metric not in set(supported_metrics):
            raise ToolkitError('Unsupported metric %s. Supported metrics are: %s' % (str(metric), str(supported_metrics)))

        if self.__params__['objective'] == 'regression' and metric not in ['rmse', 'auto']:
                raise ValueError('Cannot evaluate %s on regression model.' % metric)

        features = self.__params__['features']
        target = self.__params__['target']
        testdata = dataset[features]

        if metric in set(['confusion_matrix', 'auc', 'error']):
            if dataset[target].dtype() is not int:
                raise TypeError('dataset target column type must be int for evaluation metric %s. Found type %s.' % (metric, str(dataset[target].dtype())))

        # overwrite the default evaluation metric for classification to confusion matrix
        if metric == 'auto' and self.__params__['objective'] == 'classification':
            metric = 'confusion_matrix'

        if metric == 'confusion_matrix':
            acc = 1 - self.evaluate(dataset, 'error')['error']
            pred = self.predict(testdata, output_type='probability')
            return {'accuracy': acc,
                    'confusion_table': evaluation.confusion_matrix(dataset[target], pred)
                    }

        testlabel = dataset[[target]]
        opts = {'model': self.__proxy__,
                'data': testdata,
                'label': testlabel,
                'metric': metric}
        return _graphlab.toolkits.main.run('xgboost_evaluate', opts)

    def predict(self, dataset, output_type='probability'):
        """
        Predict the target column of the given dataset.

        The target column is provided during
        :func:`~graphlab.boosted_trees.create`. If the target column is in the
        `dataset` it will be ignored.

        Parameters
        ----------
        dataset : SFrame
          A dataset that has the same columns that were used during training.
          If the target column exists in ``dataset`` it will be ignored
          while making predictions.

        output_type : {'probability', 'margin'}, optional.
          Only applicable for models trained with objective=='classification'.
          If output_type is 'probability', then predict will output the class
          probability between [0, 1]. Otherwise, it will output the margin
          score before transforming to probability using the logistic function.

        Returns
        -------
        out : SArray
           Predicted target value for each example (i.e. row) in the dataset.

        Examples
        --------
        >>> m.predict(testdata, output_type='probability')
        """
        _mt._get_metric_tracker().track('toolkit.tree_ensembles.boosted_trees.predict')
        if not isinstance(dataset, _SFrame) and not (_HAS_PANDAS and isinstance(dataset, _pandas.DataFrame)):
            raise TypeError("Input 'dataset' must be an SFrame or pandas.DataFrame")
        valid_output_types = ['probability', 'margin']
        if output_type not in valid_output_types:
            raise ValueError('Unsupported output_type. Valid output_types are %s' % str(valid_output_types))

        features = self.__params__['features']
        testdata = _graphlab.SFrame()
        for f in features:
            if not f in dataset.column_names():
                raise ToolkitError('Cannot find feature column %s in dataset' % f)
            testdata[f] = dataset[f]

        opts = {'model': self.__proxy__,
                'data': dataset}
        if output_type == 'margin':
            opts['is_raw'] = 1
        response = _graphlab.toolkits.main.run("xgboost_predict", opts)
        # Convert predictions to an SArray
        return _SArray(None, _proxy=response['predictions'])

    def __finalize__(self, training_time):
        ''' Internal methods to finalize model state after train finished'''
        evaluation_metric, training_error, validation_error = self.__parse_error_from_history__()
        self._training_params = self.__update_params__(self._training_params,
                {'evaluation_metric': evaluation_metric,
                 'training_error': training_error,
                 'validation_error': validation_error,
                 'training_time': training_time})
        self._model_params = self.__update_params__(self._model_params, {'num_trees': len(self.__proxy__.get('trees'))})
        all_params = self._training_params + self._training_data_params + self._model_params
        self.__params__ = dict(all_params)

    def __parse_error_from_history__(self):
        '''
        Internal utility methods to parse training/validation error from history
        Returns a tuple of (evaluation_metric, training_error, test_error)

        History is a list of string, each element in the format of:
        [ITER]\ttrain-E:V\tvalidation-E:V\n
        where E is the error type, and V is the error value.

        TODO: Get training/validation error from server.
        '''
        training_error = 0
        validation_error = 0
        evaluation_metric = None
        history = self.__proxy__.get('history')
        if len(history) > 0:
            row = history[-1]
            elements = row.split('\t')
            evaluation_metric = elements[1].split(':')[0].split('-')[1]
            training_error = float(elements[1].split(':')[1])
            if len(elements) > 2:
                validation_error = float(elements[2].split(':')[1])

        if evaluation_metric == 'error':
            evaluation_metric = 'classification_error'
        return (evaluation_metric, training_error, validation_error)

    def __update_params__(self, params, new_params):
        '''
        Internal utility function to update the model parameters
        stored in the form of list of key-value pairs.
        '''
        return [(k, new_params[k]) if k in new_params else (k, v) for (k, v) in params]


def _SGraphFromJsonTree(json_str):
    g = json.loads(json_str)
    vertices = [_graphlab.Vertex(x['id'], dict([(str(k), v) for k, v in x.iteritems() if k != 'id'])) for x in g['vertices']]
    edges = [_graphlab.Edge(x['src'], x['dst'], dict([(str(k), v) for k, v in x.iteritems() if k != 'src' and k != 'dst'])) for x in g['edges']]
    return _graphlab.SGraph().add_vertices(vertices).add_edges(edges)


def __mkdata(dataset, target, features, weight_column):
    """
    Typecheck the dataset and return data and label

    Return two sframes: X, Y, where X is the SFrame containing
    valid feature columns, and Y can have 1 or 2 columns containing
    the target column and an optional weight column.

    Parameters
    ----------
    dataset : SFrame
        A data set that has the feature columns and the target column.

    target : str
        The name of the target column

    features : list[str], optional
        If specified, the returned feature SFrame will only contain the
        specified columns

    weight_column : str, optional
        The name of the weight column

    Returns
    -------
    out : SFrame, SFrame, set
        The feature and target SFrames. The set contains the feature column
        set used by the model
    """
    if not isinstance(dataset, _SFrame):
        raise TypeError('dataset must be an SFrame')

    column_types = dataset.column_types()
    column_names = dataset.column_names()
    if not target in column_names:
        raise ToolkitError('Cannot find target %s in dataset' % target)
    if weight_column is not None and weight_column in column_names:
        label_columns = [target, weight_column]
    else:
        label_columns = [target]

    if features is None:
        features = [c for c in column_names if c != target]
    else:
        if target in features:
            raise ToolkitError("Target column '%s' cannot be in feature columns, please remove that." % target)

        all_features = set(column_names)
        all_features.remove(target)
        for f in features:
            if not f in all_features:
                raise ToolkitError('Cannot find feature column %s in dataset' % f)

    feature_set = set(features)
    type_checked_features = []
    for (cname, ctype) in zip(column_names, column_types):
        if cname in feature_set:
            type_checked_features.append(cname)
    return dataset[type_checked_features], dataset[label_columns], feature_set


def create(dataset, target, objective='regression',
           features=None, num_iterations=10,
           params=DEFAULT_TRAINING_PARAM,
           validation_set=None,
           verbose=True):
    """
    Create a :class:`~graphlab.tree_ensembles.boosted_trees.BoostedTreesModel`
    on the given data and objective function.

    By default, the model is created to predict the target column
    as a regression problem. It can also be used to predict the class of
    a binary target variable when "objective" is set to "classification".

    Parameters
    ----------
    dataset : SFrame
        A training dataset containing feature columns and a target column.
        Only numerical typed (int, float) target column is allowed.

    target : str
        The name of the column in ``dataset`` that is the prediction target.
        This column must have a numeric type.

    objective : str, optional
        Specify the learning task and the corresponding learning objective.
        The following objective types are permitted:

        - 'regression'
            For a real valued target, minimizing squared loss.

        - 'classification'
            For binary target 0 or 1 or numerical target between [0, 1],
            minimizing logistic loss.

    features : list[str], optional
        A list of columns names of features used for training the model.
        Defaults to None, using all columns.

    num_iterations : int, optional
        The number of iterations for boosting. It is also the number of trees
        in the model.

    params : dict, optional
        Tree models are easily overfitted. The following parameters are
        important to regularize the model:

        - 'max_depth' : float
            Maximum depth of a tree.
        - 'step_size' : float, [0,1]
            Step size(shrinkage) used in update to prevents overfitting.  It
            shrinks the prediction of each weak learner to make the boosting
            process more conservative.  The smaller, the more conservative the
            algorithm will be. Smaller step_size is usually used together with
            larger num_iterations.
        - 'min_loss_reduction' : float
            Minimum loss reduction required to make a further partition on a
            leaf node of the tree. The larger it is, the more conservative the
            algorithm will be.
        - 'min_child_weight' : float
            This controls number of instances needed at least for each leaf.
            The larger it is, the more conservative the algorithm will be.  Set
            it larger when you want to prevent overfitting.  Formally, this is
            minimum sum of instance weight(hessian) in each leaf.  If the tree
            partition step results in a leaf node with the sum of instance
            weight less than min_child_weight, then the building process will
            give up further partitioning. For a regression task, this simply
            corresponds to minimum number of instances needed to be in each
            node.
        - 'subsample' : float
            Subsample ratio of the training set in each iteration of tree
            construction.  This is called bagging trick and usually can help
            prevent overfitting.  Setting it to 0.5 means that model randomly
            collected half of the data instances to grow each tree.

    validation_set : SFrame, optional
        The validation set that is used to watch the validation result as
        boosting progress.

    verbose : boolean, optional
        If True, print progress information during training.

    Returns
    -------
      out : BoostedTreesModel
          A trained gradient boosted trees model

    References
    ----------
    .. [1] `Wikipedia - Gradient tree boosting
      <http://en.wikipedia.org/wiki/Gradient_boosting#Gradient_tree_boosting>`_
    .. [2] `Trevor Hastie's slides on Boosted Trees and Random Forest
      <http://jessica2.msri.org/attachments/10778/10778-boost.pdf>`_

    See Also
    --------
    BoostedTreesModel

    Examples
    --------
    >>> url = 'http://s3.amazonaws.com/gl-testdata/xgboost/mushroom.csv'
    >>> data = graphlab.SFrame.read_csv(url)
    >>> data['label'] = data['label'] == 'e'
    >>> train, test = data.random_split(0.8)
    >>> train, validate = data.random_split(0.8)
    >>> m = graphlab.boosted_trees.create(train, target='label',
                                          objective='classification')
    >>> m.evaluate(validate)
    >>> m.predict(test)

    Notes
    -----
    Some times for regression tasks, the target variable is rescaled to values
    between 0 and 1. In these cases, using 'objective=classification' typically
    produces better results than using 'objective=regression'.
    """
    _mt._get_metric_tracker().track('toolkit.tree_ensembles.boosted_trees.create')

    if not isinstance(dataset, _SFrame) and not (_HAS_PANDAS and isinstance(dataset, _pandas.DataFrame)):
        raise TypeError("Input 'dataset' must be an SFrame or pandas.DataFrame")

    if not (objective == 'regression' or objective == 'classification'):
        valid_objectives = set(['regression', 'classification'])
        raise ToolkitError('Unknown objective %s. Supported objectives are %s' % (objective, ', '.join(valid_objectives)))

    # overwrite existing parameters
    addparam = params
    params = DEFAULT_TRAINING_PARAM
    params['verbose'] = int(verbose)
    params['num_iterations'] = num_iterations
    params['objective'] = objective

    for k, v in addparam.items():
        params[k] = v

    # HACK: Rename some of the parameters before passing to server
    params_key_translate_map = {'step_size': 'eta',
                                'min_loss_reduction': 'gamma',
                                'num_iterations': 'max_iter'}
    tparam = {}
    for k, v in params.items():
        if k in params_key_translate_map:
            k = params_key_translate_map[k]
        if k == 'objective':
            tparam[k] = 'reg:linear' if v == 'regression' else 'binary:logistic'
        else:
            tparam[k] = str(v)

    # this parameter makes sure that we will run xgboost on all available threads assuming
    # unity_server is compiled with openmp.
    tparam['nthread'] = '0'
    response = _graphlab.toolkits.main.run("xgboost_init", tparam)

    weight_column = None  # not supported
    # Train the model on the given data set and retrieve predictions
    train_x, train_y, feature_set = __mkdata(dataset, target, features, weight_column)

    params['features'] = list(feature_set)
    params['target'] = target
    params['num_training_examples'] = len(train_x)
    params['num_features'] = len(feature_set)
    params['num_validation_examples'] = 0

    # Initialize model
    m = BoostedTreesModel(response['model'], params)

    # Preparing for training. Validation set
    opts = {'model': m.__proxy__,
            'data': train_x,
            'label': train_y}

    if validation_set is not None:
        valid_x, valid_y, feature_set_valid = __mkdata(validation_set, target, features, weight_column)
        if not feature_set == feature_set_valid:
            raise ToolkitError('schema of validation set must be same as training set')
        opts['valid_data'] = valid_x
        opts['valid_label'] = valid_y
        params['num_valiation_examples'] = len(valid_x)

    if dataset[target].dtype() not in [int, float]:
        raise TypeError('dataset target column \"%s\" must be numeric. Found type %s.' % (target, str(dataset[target].dtype())))
    if validation_set is not None and validation_set[target].dtype() not in [int, float]:
        raise TypeError('validation_set target column \"%s\" must be numeric. Found type %s.' % (target, str(validation_set[target].dtype())))
    start = time.time()
    _graphlab.toolkits.main.run("xgboost_train", opts)
    training_time = time.time() - start

    m.__finalize__(training_time)
    return m
