#!/usr/bin/env python
"""Machine learning tool"""

import sys
import logging
import argparse
import math
try:
    import cPickle as pickle
except ImportError:
    import pickle
from itertools import izip
from functools import partial
from operator import itemgetter

from mltool.decisiontree import train_decision_tree, impurity_split
from mltool.forest import train_random_forest
from mltool.utils import read_input_file
from mltool.evaluate import evaluate_model


log = logging.getLogger(__name__)


def print_gain(dataset, gain):
    print 'Feature gains:'
    for findex, fgain in sorted(gain.iteritems(),
                                key=itemgetter(1), reverse=True):
        print '  %s\t%r' % (dataset.feature_names[findex], fgain)
    

def dt_train_command(args):
    log.info('Reading training set...')
    with args.dataset as fin:
        dataset = read_input_file(fin)

    log.info('Items: %s, Features: %s',
             len(dataset.labels),
             len(dataset.samples))

    log.info('Training...')
    split_func = partial(impurity_split, ff=args.feature_fraction)
    model, gain = train_decision_tree(dataset, args.max_depth, split_func, args.seed)

    if args.output:
        with args.output as fout:
            pickle.dump(model, fout, -1)

    print_gain(dataset, gain)


def rf_callback_with_eval(num_trees, datasets, ndcg_at=10):
    import numpy as np
    from mltool.predict import predict_all
    from mltool.evaluate import evaluate_preds

    preds_set = [np.zeros(len(validation_set.labels))
                 for validation_set in datasets]

    i = 1
    while True:
        try:
            tree = yield
        except StopIteration:
            break

        stats = []
        for preds, validation_set in izip(preds_set, datasets):
            preds += list(predict_all(tree, validation_set))
            rmse, ndcg = evaluate_preds(preds/i, validation_set, ndcg_at)
            stats.extend((rmse, ndcg))

        stats = '\t'.join('%f' % x for x in stats)
        print '%d\t%s' % (i, stats)
        sys.stdout.flush()
        log.info('Trees %d/%d generated.', i, num_trees)
        i += 1


def rf_train_command(args):
    log.info('Reading training set...')
    with args.dataset as fin:
        training_set = read_input_file(fin)

    validation_sets = [training_set]
    for validation_file in args.validation:
        log.info('Reading validation set %s...', validation_file.name)
        with validation_file as fin:
            validation_sets.append(read_input_file(fin))

    gen_callback = rf_callback_with_eval(args.trees,
                                         validation_sets)
    gen_callback.next()

    log.info('Training forest...')
    model, gain = train_random_forest(training_set, args.trees, args.max_depth,
                                      args.feature_fraction, args.seed,
                                      processors=args.processors,
                                      callback=gen_callback.send)

    if args.output:
        with args.output as fout:
            pickle.dump(model, fout, -1)

    print_gain(training_set, gain)


def eval_command(args):
    with args.model as fin:
        model = pickle.load(fin)

    log.info('Reading dataset...')
    rmse = 0.0
    with args.dataset as fin:
        dataset = read_input_file(fin)

    rmse, ndcg, preds = evaluate_model(model, dataset, return_preds=True)
    print 'RMSE: %s' % rmse
    print 'NDCG: %s' % ndcg

    if args.output:
        log.info('Writing predicted labels...')
        with args.output as fout:
            for pred in preds:
                print >>fout, pred


def conv_command(args):
    from mltool.utils import read_input_file_svmrank, write_dataset
    
    with args.input as fin:
        dataset = read_input_file_svmrank(fin)

    with args.output as fout:
        write_dataset(fout, dataset)
        

def main():
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()

    parser_train = subparsers.add_parser('dt-train',
                                         help='Train a decision tree')
    parser_train.add_argument('dataset', type=argparse.FileType('r'),
                              help='Dataset for training')
    parser_train.add_argument('-o', '--output', default=None,
                              type=argparse.FileType('wb'),
                              help='Output model')
    parser_train.add_argument('-f', '--feature-fraction', default=1.0,
                              type=float,
                              help='Fraction of features to use'
                              ' (default: %(default)s)')
    parser_train.add_argument('-s', '--seed', default=1, type=int,
                              help='Seed for random number generator'
                              ' (default: %(default)s)')
    parser_train.add_argument('-d', '--max-depth', type=int, default=1000,
                              help='Max depth (default: %(default)s)')
    parser_train.set_defaults(func=dt_train_command)

    parser_train = subparsers.add_parser('rf-train',
                                         help='Train a random forest model')
    parser_train.add_argument('dataset', type=argparse.FileType('r'),
                              help='Dataset for training')
    parser_train.add_argument('validation', nargs='*',
                              type=argparse.FileType('r'),
                              help='Dataset used for validation')
    parser_train.add_argument('-o', '--output', default=None,
                              type=argparse.FileType('wb'),
                              help='Output model')
    parser_train.add_argument('-f', '--feature-fraction', default=1.0,
                              type=float,
                              help='Fraction of features to use'
                              ' (default: %(default)s)')
    parser_train.add_argument('-s', '--seed', default=1, type=int,
                              help='Seed for random number generator'
                              ' (default: %(default)s)')
    parser_train.add_argument('-d', '--max-depth', type=int, default=1000,
                              help='Max depth (default: %(default)s)')
    parser_train.add_argument('-t', '--trees', type=int, default=1,
                              help='Number of trees to generate'
                              ' (default: %(default)s)')
    parser_train.add_argument('-p', '--processors', type=int, default=None,
                              help='Number of processors to use (default: all)')
    parser_train.set_defaults(func=rf_train_command)

    parser_pred = subparsers.add_parser('eval', help='Evaluate a model')
    parser_pred.add_argument('model', help='Machine learning model to use',
                             type=argparse.FileType('r'))
    parser_pred.add_argument('dataset', type=argparse.FileType('r'),
                             default=sys.stdin, nargs='?',
                             help='Input dataset (default: stdin)')
    parser_pred.add_argument('-o', '--output', default=None,
                             type=argparse.FileType('w'),
                             help='Output file where to write the predictions')
    parser_pred.set_defaults(func=eval_command)

    parser_conv = subparsers.add_parser('conv',
                                        help='Convert file from svm-light '
                                        'format to mltool format')
    parser_conv.add_argument('input', help='Input file', nargs='?',
                             type=argparse.FileType('r'),
                             default=sys.stdin)
    parser_conv.add_argument('output', help='Output file', nargs='?',
                             type=argparse.FileType('w'),
                             default=sys.stdout)
    parser_conv.set_defaults(func=conv_command)

    args = parser.parse_args()

    args.func(args)


if __name__ == '__main__':
    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    sys.exit(main())
