#!/usr/bin/env python
from __future__ import print_function, unicode_literals
import pllpy
import argparse
import json
import sys


def parse_args():
    """
    Use argparse to collect command line flags
    :return: argparse namespace
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('-a', '--alignment', type=str)
    parser.add_argument('-t', '--tree', type=str)
    parser.add_argument('-p', '--partitions', type=str)
    parser.add_argument('-m', '--model_data', type=str)
    parser.add_argument('-o', '--output', type=str)
    parser.add_argument('-f', '--freqs', nargs='+', type=float)
    parser.add_argument('-r', '--rates', nargs='+', type=float)
    parser.add_argument('-g', '--alpha', type=float)
    parser.add_argument('--fix_alpha', dest='optalpha', action='store_false')
    parser.add_argument('--fix_freqs', dest='optfreqs', action='store_false')
    parser.add_argument('--fix_rates', dest='optrates', action='store_false')
    parser.add_argument('--fix_topol', dest='opttopol', action='store_false')
    parser.add_argument('--emp_freqs', action='store_true')
    parser.add_argument('--equal_freqs', action='store_true')
    parser.add_argument('--threads', type=int, default=1)
    parser.add_argument('--rns', type=int, default=int("0xDEADBEEF", 16))
    parser.add_argument('--epsilon', type=float, default=0.01)
    parser.add_argument('--brlen_opts', type=int, default=32)
    return parser.parse_args()


def read_model_json(json_file):
    """
    Reads a json format file of model parameters

    :param json_file: file path of the json file to read
    :return: dict constructed from json file
    """

    with open(json_file) as input:
        model = json.load(input)
        model['partitions'] = {int(k): v for (k, v) in model['partitions'].iteritems()}
        return model


def write_model_json(instance, json_file):
    """
    Summarises parameter values from PLL instance and writes their values
    to disk in a json format file

    :param instance: PLL instance being summarised
    :param json_file: Either a filepath or a file-like stream (e.g. sys.stdout)
    :return: void
    """
    model = {}
    model['tree'] = instance.get_tree()
    model['likelihood'] = instance.get_likelihood()
    model['partitions'] = {}
    for i in range(instance.get_number_of_partitions()):
        data = {}
        data['alpha'] = instance.get_alpha(i)
        data['frequencies'] = instance.get_frequencies_vector(i)
        if instance.is_dna(i):
            data['rates'] = instance.get_rates_vector(i)
        data['name'] = instance.get_partition_name(i)
        model['partitions'][i] = data
    if hasattr(json_file, 'write'):
        json.dump(model, json_file, indent=4)
    else:
        with open(json_file, 'w') as output:
            json.dump(model, output, indent=4)


def set_partition_model_parameters(instance, partition, alpha, freqs, rates, empirical_freqs, equal_freqs):
    """
    Sets parameter values for a specific partition.

    :param instance: PLL instance being modified
    :param partition: Number of the partition having its parameters set
    :param alpha: Alpha parameter of the 4-category discrete gamma rates distribution
    :param freqs: Equilibrium frequencies of states (4 for DNA, 20 for aa)
    :param rates: Relative substitution rate parameters - values of the upper triangle of 4x4 matrix,
                  so 6 numbers in all. The sixth value must be 1.0. Assume matrix is in "acgt" order.
                  Only applies to DNA data; protein models all use empirical rates.
    :param empirical_freqs: Use empirical estimates for state frequencies. Overwrites 'freqs'.
    :param equal_freqs: Set all state frequencies to 1/num_states
    :return: void
    """
    if empirical_freqs:
        freqs = instance.get_empirical_frequencies()[partition]
    elif equal_freqs:
        if instance.is_dna(partition):
            freqs = [0.25] * 4
        else:
            freqs = [0.05] * 20
    if alpha is not None:
        instance.set_alpha(alpha, partition, True)
    if freqs is not None:
        instance.set_frequencies(freqs, partition, True)
    if rates is not None:
        instance.set_rates(rates, partition, True)


def set_model_parameters(instance, alpha, freqs, rates, empirical_freqs, equal_freqs):
    """
    Sets model parameters for all partitions. If different partitions need different
    values, use a json model file.
    :param instance: PLL instance being modified
    :param alpha: Alpha parameter of the 4-category discrete gamma rates distribution
    :param freqs: Equilibrium frequencies of states (4 for DNA, 20 for aa)
    :param rates: Relative substitution rate parameters - values of the upper triangle of 4x4 matrix,
                  so 6 numbers in all. The sixth value must be 1.0. Assume matrix is in "acgt" order.
                  Only applies to DNA data; protein models all use empirical rates.
    :param empirical_freqs: Use empirical estimates for state frequencies. Overwrites 'freqs'.
    :param equal_freqs: Set all state frequencies to 1/num_states
    :return: void
    """
    for i in range(instance.get_number_of_partitions()):
        set_partition_model_parameters(instance, i, alpha, freqs, rates, empirical_freqs, equal_freqs)


def set_parameter_optimisation(instance, optalpha, optfreqs, optrates):
    """
    Sets whether PLL will optimise parameters, or use their fixed values.

    :param instance: PLL instance being modified
    :param optalpha: (bool) Optimise alpha?
    :param optfreqs: (bool) Optimise equilibrium frequencies?
    :param optrates: (bool) Optimise substitution rates?
    :return: void
    """
    for i in range(instance.get_number_of_partitions()):
        instance.set_optimisable_alpha(i, optalpha)
        instance.set_optimisable_frequencies(i, optfreqs)
        if instance.is_dna(i):
            instance.set_optimisable_rates(i, optrates)


def construct_from_json(json_file, args):
    """
    Builds a new PLL instance from a json model file, and applies parameter settings.
    :param json_file: File path to json model file
    :return: PLL instance.
    """

    model = read_model_json(json_file)
    if args.tree:
        if args.tree in ['random', 'parsimony']:
            if args.tree == 'parsimony':
                pll = pllpy.pll(args.alignment, args.partitions, True, args.threads, args.rns)
            else:
                pll = pllpy.pll(args.alignment, args.partitions, False, args.threads, args.rns)
        else:
            pll = pllpy.pll(args.alignment, args.partitions, args.tree, args.threads, args.rns)

    else:
        tree = model.get('tree')
        pll = pllpy.pll(args.alignment, args.partitions, str(tree).rstrip(), args.threads, args.rns)
    p_info = model['partitions']
    for i in range(pll.get_number_of_partitions()):
        alpha = p_info[i].get('alpha')
        freqs = p_info[i].get('frequencies')
        rates = p_info[i].get('rates')
        set_partition_model_parameters(pll, i, alpha, freqs, rates, False, False)
    return pll


def construct_from_args(args):
    """
    Builds a new PLL instance using the parameters passed on the command line. Suitable for
    simple models, such as those involving a single partition or locus.
    :param args: Argparse namespace containing the settings passed on the command line
    :return: PLL instance
    """
    if args.tree in ['random', 'parsimony']:
        if args.tree == 'parsimony':
            pll = pllpy.pll(args.alignment, args.partitions, True, args.threads, args.rns)
        else:
            pll = pllpy.pll(args.alignment, args.partitions, False, args.threads, args.rns)
    else:
        pll = pllpy.pll(args.alignment, args.partitions, args.tree, args.threads, args.rns)

    set_model_parameters(pll, args.alpha, args.freqs, args.rates, args.emp_freqs, args.equal_freqs)
    if (args.emp_freqs or args.equal_freqs):
        args.optfreqs = False

    set_parameter_optimisation(pll, args.optalpha, args.optfreqs, args.optrates)
    return pll


def main():
    """
    Runs the show
    :return: 0 for success, 1 for failure
    """
    try:
        args = parse_args()

        if args.model_data:
            pll = construct_from_json(args.model_data, args)

        else:
            pll = construct_from_args(args)

        pll.set_epsilon(args.epsilon)
        set_parameter_optimisation(pll, args.optalpha, args.optfreqs, args.optrates)
        pll.get_likelihood()
        if args.opttopol:
            pll.optimise(args.optalpha and args.optfreqs and args.optrates)
        else:
            pll.optimise_model()
            pll.optimise_branch_lengths(args.brlen_opts)

        if args.output:
            write_model_json(pll, args.output)
        else:
            write_model_json(pll, sys.stdout)

        return 0

    except:
        raise
        return 1


if __name__ == '__main__':
    sys.exit(main())
