#!/usr/bin/env python
from __future__ import print_function, unicode_literals
import pllpy
import argparse
import json
import sys


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-a', '--alignment', type=str)
    parser.add_argument('-t', '--tree', type=str)
    parser.add_argument('-p', '--partitions', type=str)
    parser.add_argument('-m', '--model_data', type=str)
    parser.add_argument('-o', '--output', type=str)
    parser.add_argument('-f', '--freqs', nargs='+', type=float)
    parser.add_argument('-r', '--rates', nargs='+', type=float)
    parser.add_argument('-g', '--alpha', type=float)
    parser.add_argument('--fix_alpha', dest='optalpha', action='store_false')
    parser.add_argument('--fix_freqs', dest='optfreqs', action='store_false')
    parser.add_argument('--fix_rates', dest='optrates', action='store_false')
    parser.add_argument('--fix_topol', dest='opttopol', action='store_false')
    parser.add_argument('--emp_freqs', action='store_true')
    parser.add_argument('--equal_freqs', action='store_true')
    parser.add_argument('--threads', type=int, default=1)
    parser.add_argument('--rns', type=int, default=int("0xDEADBEEF", 16))
    parser.add_argument('--epsilon', type=float, default=0.01)
    return parser.parse_args()


def read_model_json(json_file):
    model = json.loads(json_file)
    model['partitions'] = {int(k): v for (k, v) in model['partitions'].iteritems()}
    return model


def write_model_json(instance, json_file):
    model = {}
    model['likelihood'] = instance.get_likelihood()
    model['tree'] = instance.get_tree()
    partitions = {}
    for i in range(instance.get_number_of_partitions()):
        curr = {}
        curr['alpha'] = instance.get_alpha(i)
        curr['frequencies'] = instance.get_frequencies_vector(i)
        if instance.is_dna(i):
            curr['rates'] = instance.get_rates_vector(i)
        curr['model'] = instance.get_model_name(i)
        curr['name'] = instance.get_partition_name(i)
        partitions[i] = curr
    model['partitions'] = partitions
    return json.dumps(model, indent=4)


def set_model_parameters_from_json(instance, json_file):
    model = read_model_json(json_file)
    pinfo = model['partitions']
    assert max(pinfo) == instance.get_number_of_partitions() - 1
    for i in range(instance.get_number_of_partitions()):
        curr = pinfo[i]
        alpha = curr.get('alpha')
        freqs = curr.get('frequencies')
        rates = curr.get('rates')
        set_partition_model_parameters(instance, i, alpha, freqs, rates)


def set_partition_model_parameters(instance, partition, alpha, freqs, rates, empirical_freqs, equal_freqs):
    if empirical_freqs:
        freqs = instance.get_empirical_frequencies()[partition]
    elif equal_freqs:
        if instance.is_dna(partition):
            freqs = [0.25] * 4
        else:
            freqs = [0.05] * 20
    if alpha is not None:
        instance.set_alpha(alpha, partition, True)
    if freqs is not None:
        instance.set_frequencies(freqs, partition, True)
    if rates is not None:
        instance.set_rates(rates, partition, True)


def set_model_parameters(instance, alpha, freqs, rates, empirical_freqs, equal_freqs):
    for i in range(instance.get_number_of_partitions()):
        set_partition_model_parameters(instance, i, alpha, freqs, rates, empirical_freqs, equal_freqs)


def set_parameter_optimisation(instance, optalpha, optfreqs, optrates):
    for i in range(instance.get_number_of_partitions()):
        instance.set_optimisable_alpha(i, optalpha)
        instance.set_optimisable_frequencies(i, optfreqs)
        if instance.is_dna(i):
            instance.set_optimisable_rates(i, optrates)


def main():
    args = parse_args()
    if args.tree in ['random', 'parsimony']:
        if args.tree == 'parsimony':
            pll = pllpy.pll(args.alignment, args.partitions, True, args.threads, args.rns)
        else:
            pll = pllpy.pll(args.alignment, args.partitions, False, args.threads, args.rns)
    else:
        pll = pllpy.pll(args.alignment, args.partitions, args.tree, args.threads, args.rns)

    set_model_parameters(pll, args.alpha, args.freqs, args.rates, args.emp_freqs, args.equal_freqs)
    if (args.emp_freqs or args.equal_freqs):
        args.optfreqs = False

    set_parameter_optimisation(pll, args.optalpha, args.optfreqs, args.optrates)

    if args.opttopol:
        pll.optimise(args.optalpha and args.optfreqs and args.optrates)
    pll.optimise_model()
    pll.optimise_branch_lengths(32)
    js = write_model_json(pll, '')
    print(js)

    # for i in range(pll.get_number_of_partitions()):
    # print("Partition {} = {}:".format(i,
    #           pll.get_partition_name(i)))
    #     print("Frequencies: ", pll.get_frequencies_vector(i))
    #     print("Alpha: ", pll.get_alpha(i))
    #
    #     if pll.is_dna(i):
    #         print("Rates: ", pll.get_rates_vector(i))
    return 0


if __name__ == '__main__':
    sys.exit(main())
