from graphlab.connect import _get_metric_tracker
from inspect import getmodule as _get_module
from itertools import product as _product
from os.path import dirname as _dirname, join as _path_join
from random import sample as _random_sample
from time import time as _time

import graphlab as _gl
from graphlab import SFrame as _SFrame
from graphlab import SArray as _SArray
from graphlab.deps import pandas, HAS_PANDAS
from graphlab.deploy import _Pipeline as _Pipeline, Task as _Task


def _append_inputs(task):
    '''
    Takes a task and appends all of its inputs into one SFrame.
    '''
    result = _SFrame()

    for cur_path in task.params.itervalues():
        cur = _SFrame(data=cur_path)
        result = result.append(cur)

    task.outputs['save_path'] = result


def _get_all_parameters_combinations(parameters):
    '''
    Takes a dictionary where the keys are parameter names. The value of a key is a list of all
    possible values parameter.

    Returns a list of all possible parameter combinations. Each parameter set is a dictionary.

    For example an input of {'foo':[1,2], 'bar':['a','b']} will produce
    [{'foo':1, 'bar':'a'}, {'foo':1, 'bar':'b'}, {'foo':2, 'bar':'a'}, {'foo':2, 'bar':'b'}]
    '''

    # Get all possible combinations
    parameter_names = parameters.keys()
    arg_list = []
    for i in parameter_names:
        arg_list.append(parameters[i])
    param_iter = _product(*arg_list)

    # Construct the output
    result = []
    for param_tuple in param_iter:
        param_dict = {}
        for i in range(len(param_tuple)):
            cur_arg_name = parameter_names[i]
            cur_arg_value = param_tuple[i]
            param_dict[cur_arg_name] = cur_arg_value
        result.append(param_dict)

    return result


def _get_all_model_fields(model):
    '''
    Saves all fields that a model exposes.
    '''
    result = {}
    for field_name in model.list_fields():
        result[field_name] = model[field_name]
    return result


def _get_default_parameter_range(model_factory, factory_params):
    '''
    Lookup the default hyper parameters for this model. First, try the module of the factory
    method. If that doesn't work, create a model and look in that module.
    '''
    if model_factory is _gl.recommender.create:
        if(('method' not in factory_params or factory_params['method'] == 'matrix_factorization')
           and 'target_column' in factory_params):
            model_factory = _gl.recommender.matrix_factorization.create
    module = _get_module(model_factory)
    if(not hasattr(module, 'DEFAULT_HYPER_PARAMETER_RANGE')):
        raise TypeError("%s does not support hyper parameter tuning" % model_factory)
    return module.DEFAULT_HYPER_PARAMETER_RANGE


def _flatten_to_single_row_sframe(dic):
    '''
    Takes a dictionary where the key are strings and the values are strings, dictionaries or pandas
    data frame. Flattens any dictionaries and returns a one row SFrame where the column names are
    keys of the dictionary.
    '''
    if(dic is None):
        return _SFrame()

    data = _SFrame()
    for key, value in dic.items():
        try:
            if isinstance(value, dict):
                for inner_key, inner_value in value.items():
                    data[inner_key] = [inner_value]
            elif HAS_PANDAS and isinstance(value, pandas.DataFrame):
                inner_value = []
                # Jay: iterrows returns tuple pair of (index, Series), take the second element, and convert it to dict
                for i in value.iterrows():
                    inner_value.append(dict(i[1]))
                data[key] = [inner_value]
            elif isinstance (value, _SFrame):
                #Jay: a quick fix to the test, however, this is probably a bad idea, since the value sframe can be large
                keys = value.column_names()
                inner_value = []
                for i in value:
                    inner_value.append(dict(zip(keys, i)))
                data[key] = [inner_value]
            elif isinstance (value, _SArray):
                data[key] = list(value)
            else:
                #Jay: a lot more type checking is needed here
                data[key] = [value]
        except:
            data[key] = "Unable to store field"
    return data


def _partition_list(input_list, n):
    '''
    Breaks the input_list in n partition. Each partition, except the last one, is of size
    floor(len(input_list) / n). The last partition contains the remainder of the elements.
    '''
    assert(n > 0)
    result = []
    non_last_partition_len =  len(input_list) / n
    for i in range(n-1):
        cur_partition = input_list[ i * non_last_partition_len : (i + 1) * non_last_partition_len]
        result.append(cur_partition)
    result.append(input_list[ (n - 1) * non_last_partition_len : ])
    return result


def _model_parameter_search(task):
    '''
    This is the actual top level function that will be run (possibly remotely) to do the actual work
    of creating and evaluating models with different parameters.
    '''
    train_set = _SFrame(task.params['train_set'])
    test_set = None
    if 'test_set' in task.params and task.params['test_set'] is not None:
        test_set = _SFrame(task.params['test_set'])

    params = task.params
    hyper_params = params['hyper_params']
    max_num_models = params['max_num_models']
    model_factory = params['model_factory']
    search_space = params['search_space']

    result_accumulator = _SFrame()
    for cur_params in search_space:

        # Create the model
        cur_params.update(params['standard_model_params'])
        cur_model = model_factory(train_set, **cur_params)

        model_info = _get_all_model_fields(cur_model)

        # Save test info
        if(test_set is not None):
            test_evaluation_info = cur_model.evaluate(test_set)
            model_info.update(test_evaluation_info)

        # Write results
        model_info = _flatten_to_single_row_sframe(model_info)
        result_accumulator = result_accumulator.append(model_info)

    task.outputs['save_path'] = result_accumulator

def model_parameter_search(environment, model_factory, train_set, save_path, test_set=None,
                           standard_model_params={}, hyper_params=None, max_num_models='all', name=None):
    '''
    Search for optimal model parameters. Automatically creates models using different parameters.
    Optionally, evaluates these models using a test set.

    model_parameter_search is supported for: :py:class:`~graphlab.linear_regression.LinearRegressionModel`, :py:class:`~graphlab.logistic_regression.LogisticRegressionModel`, :class:`~graphlab.recommender.MatrixFactorizationModel` and :py:class:`~graphlab.svm.SVMModel`, :py:class:`~graphlab.kmeans.KmeansModel`

    Parameters
    ----------
    environment : Environment
        Used to run the job.

    model_factory : function
        The function used to create the models.

    train_set : str
        Path to an SFrame containing the train set.

    save_path : str
        Path to save the result. Results will be saved as an SFrame.

    test_set : str
        Path to an SFrame containing the test set. This SFrame should be in the same format as the train set.

    standard_model_params : dict
        A set of arguments that should be used to create each model.

    hyper_params : dict
        The keys in the dictionary should be strings that correspond to the names of parameters accepted by model_factory.
        The values in the dictionary should be lists. This is the list of values that will be tried for the corresponding parameter.

        If this dictionary is not specified, a set of default parameter combinations, based on the model type, will be tried.

    max_num_models : int
        The max number of models to test.

    name : str (optional)
        Name for the Job created. If not specified then will be 'Model-Parameter-Search-(timestamp)'

    Returns
    -------
    out : Job object
        The job is ran using the environment parameter.

    Examples
    --------
    If you want to do a parameter search for a recommender model, given a saved SFrame with columns
    ``user_id``, ``item_id`` and ``rating``:

    >>> job = model_parameter_search(self.env, graphlab.recommender.create, train_file_path,
                                      result_file_path, test_set = SFrame_test_file_path,
                                      standard_model_params = {'target_column': 'rating'}
                                      )
    '''
    _get_metric_tracker().track('jobs.model_parameter_search')

    if name is None:
        name =  "Model-Parameter-Search-%s" % _time()

    # Determine search space
    if(hyper_params is None):
        hyper_params = _get_default_parameter_range(model_factory, standard_model_params)
    search_space = _get_all_parameters_combinations(hyper_params)
    if(max_num_models != 'all' and max_num_models < len(search_space)):
        search_space = _random_sample(search_space, max_num_models)

    # Partition search space
    degrees_of_parallelism = environment.get_max_degree_of_parallelism()
    search_partitions = _partition_list(search_space, degrees_of_parallelism)

    # Create a job for each partition
    tasks = [[]]
    intermediate_results_dir = _dirname(save_path)
    intermediate_results = []
    for i in range(degrees_of_parallelism):
        cur_worker = _Task('%s-Task-%4d' % (name, i))

        cur_worker.set_code(_model_parameter_search)
        cur_worker.set_params({'train_set':train_set, 'test_set':test_set})

        params = {
            'model_factory': model_factory,
            'standard_model_params': standard_model_params,
            'hyper_params': hyper_params,
            'max_num_models': max_num_models,
            'search_space': search_partitions[i]
            }
        cur_worker.set_params(params)

        work_result_path = _path_join(intermediate_results_dir, str(i))
        intermediate_results.append(work_result_path)
        cur_worker.set_outputs({'save_path':work_result_path})

        tasks[0].append(cur_worker)

    if degrees_of_parallelism > 1:
        # Create combiner job
        combiner = _Task('Combiner')
        params = {}
        for i in range(degrees_of_parallelism):
            params['intermediate_results - %d' % i] = intermediate_results[i]
        combiner.set_params(params)
        combiner.set_code(_append_inputs)
        combiner.set_outputs({'save_path':save_path})
        tasks.append([combiner])
    else:
        # No need for a combiner
        tasks[0][0].set_outputs({'save_path':save_path})

    pipeline = _Pipeline(name + "-Pipeline")
    pipeline.set_tasks(tasks)
    return _gl.deploy.job.create([pipeline], name=name, environment=environment)
