#!/usr/local/bin/python2.7
# encoding: utf-8
'''
hmf (script) provides command-line access to much of the functionality of hmf.

This script basically takes any input arguments and runs all combinations of 
them through :func:`hmf.tools.get_hmf`, writing the results to a filename 
which is specified. It only writes out the attributes requested, making it 
quite optimal. 

.. note :: at this time, this script is in alpha. It works for the obvious arguments!
'''

import sys
import os
import traceback

import hmf
import numpy as np
from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter

__all__ = []
__version__ = hmf.hmf.version
__date__ = "2014 - 01 - 23"
__updated__ = "2014 - 01 - 23"

DEBUG = 0
TESTRUN = 0
PROFILE = 0

class CLIError(Exception):
    '''Generic exception to raise and log different fatal errors.'''
    def __init__(self, msg):
        super(CLIError).__init__(type(self))
        self.msg = "E: %s" % msg
    def __str__(self):
        return self.msg
    def __unicode__(self):
        return self.msg

def main(argv=None):
    '''Generate halo mass functions and write them to file (BETA).'''

    if argv is None:
        argv = sys.argv
    else:
        sys.argv.extend(argv)

    program_name = os.path.basename(sys.argv[0])
    program_version = "v%s" % __version__
    program_build_date = str(__updated__)
    program_version_message = '%%(prog)s %s (%s)' % (program_version, program_build_date)
    program_shortdesc = __import__('__main__').__doc__.split("\n")[1]
    program_license = '''%s

  Created by user_name on %s.
  Copyright 2014 organization_name. All rights reserved.

  Licensed under the Apache License 2.0
  http://www.apache.org/licenses/LICENSE-2.0

  Distributed on an "AS IS" basis without warranties
  or conditions of any kind, either express or implied.

USAGE
''' % (program_shortdesc, str(__date__))

    try:
        h = hmf.MassFunction()
        m_attrs = ["M", "dndlog10m", "lnsigma", "n_eff", "sigma",
                   "dndm", "ngtm", "fsigma", "mgtm", "nltm", "dndlnm",
                   "how_big", "mltm", "_sigma_0", "_dlnsdlnm"]
        k_attrs = ["power", "delta_k", "lnk", "transfer", "nonlinear_power",
                   "_lnP_0", "_lnP_cdm_0", "_lnT_cdm", "_unnormalised_lnP",
                   "_unnormalised_lnT"]
        # Setup argument parser
        parser = ArgumentParser(description=program_license, formatter_class=RawDescriptionHelpFormatter)
        parser.add_argument("-v", "--verbose", dest="verbose", action="count", help="set verbosity level [default: %(default)s]")
        parser.add_argument('-V', '--version', action='version', version=program_version_message)

        # HMF specific arguments
        config = parser.add_argument_group("Config", "Variables of Configuration")
        config.add_argument("filename", help="filename to write to")
        config.add_argument("--get", nargs="*", default=["M", "dndm"],
                            choices=m_attrs + k_attrs)

        hmfargs = parser.add_argument_group("HMF", "HMF-specific arguments")
        hmfargs.add_argument("--M", nargs=3, type=float,
                            help="the mass range and intervals, min max step [default: %s %s %s]" %
                            (np.log10(h.M[0]), np.log10(h.M[-1]), np.log10(h.M[1]) - np.log10(h.M[0])))
        hmfargs.add_argument("--mf-fit", nargs="*", choices=hmf.Fits.mf_fits + ["all"],
                            help="fitting function(s) to use. 'all' uses all of them [default: %s]" % h.mf_fit)
        hmfargs.add_argument("--delta-h", nargs="*", type=float,
                            help="overdensity of halo w.r.t delta_wrt [default %s]" % h.delta_wrt)
        hmfargs.add_argument("--delta-wrt", choices=["mean", "crit"],
                            help="what delta_h is with respect to [default: %s]" % h.delta_h)
        hmfargs.add_argument("--user-fit", help="a custom fitting function defined as a string in terms of x for sigma [default: %s]" % "'" + h.user_fit + "'")
        hmfargs.add_argument("--no-cut-fit", action="store_true", help="whether to cut the fitting function at tested boundaries")
        hmfargs.add_argument("--z2", nargs="*", type=float, help="upper redshift for volume weighting")
        hmfargs.add_argument("--nz", nargs="*", type=float, help="number of redshift bins for volume weighting")
        hmfargs.add_argument("--delta-c", nargs="*", type=float, help="critical overdensity for collapse [default: %s]" % h.delta_c)

        # # Transfer-specific arguments
        transferargs = parser.add_argument_group("Transfer", "Transfer-specific arguments")
        transferargs.add_argument("--z", nargs="*", type=float, help="redshift of analysis [default: %s]" % h.transfer.z)
        transferargs.add_argument("--lnk", nargs=3, type=float, help="the wavenumber range and intervals, min max step [default: %s %s %s]" %
                                  (h.transfer.lnk[0], h.transfer.lnk[-1], h.transfer.lnk[1] - h.transfer.lnk[0]))
        transferargs.add_argument("--maxk", nargs="*", type=float, default=2e4, help="maximum wavenumber of analysis [default: %s]" % np.exp(h.transfer.lnk[-1]))
        transferargs.add_argument("--numk", nargs="*", type=int, default=250, help="number of wavenumbers in analysis [default: %s]" % len(h.transfer.lnk))
        transferargs.add_argument("--wdm-mass", nargs="*", type=float, help="warm dark matter mass (0 is CDM)")
        transferargs.add_argument("--transfer-fit", nargs="*", choices=hmf.transfer.Transfer.fits + ['all'],
                                  help="which fit for the transfer function to use ('all' uses all of them) [default: %s]" % h.transfer.transfer_fit)

        cambargs = parser.add_argument_group("CAMB", "CAMB-specific arguments")
        cambargs.add_argument("--Scalar-initial-condition", nargs="*", type=int, choices=[1, 2, 3, 4, 5],
                              help="[CAMB] initial scalar perturbation mode [default: %s]" % h.transfer._camb_options["Scalar_initial_condition"])
        cambargs.add_argument("--lAccuracyBoost", nargs="*", type=float,
                            help="[CAMB] optional accuracy boost [default: %s]" % h.transfer._camb_options["lAccuracyBoost"])
        cambargs.add_argument("--AccuracyBoost", nargs="*", type=float,
                            help="[CAMB] optional accuracy boost [default: %s]" % h.transfer._camb_options["AccuracyBoost"])
        cambargs.add_argument("--w-perturb", action="store_true", help="[CAMB] whether w should be perturbed or not")
        cambargs.add_argument("--transfer--k-per-logint", nargs="*", type=float,
                            help="[CAMB] number of estimated wavenumbers per interval [default: %s]" % h.transfer._camb_options["transfer__k_per_logint"])
        cambargs.add_argument("--transfer--kmax", nargs='*', type=float,
                            help="[CAMB] maximum wavenumber to estimate [default: %s]" % h.transfer._camb_options["transfer__kmax"])
        cambargs.add_argument("--ThreadNum", type=int,
                              help="number of threads to use (0 is automatic detection) [default: %s]" % h.transfer._camb_options["ThreadNum"])

        # # Cosmo-specific arguments
        cosmoargs = parser.add_argument_group("Cosmology", "Cosmology arguments")
        cosmoargs.add_argument("--default", nargs="*",
                            choices=['planck1_base'], help="base cosmology to use [default: %s]" % h.transfer.cosmo.default)
        cosmoargs.add_argument("--force-flat", action="store_true",
                            help="force cosmology to be flat (changes omega_lambda) [default: %s]" % h.transfer.cosmo.force_flat)
        cosmoargs.add_argument("--sigma-8", nargs="*", type=float, help="mass variance in top-hat spheres with r=8")
        cosmoargs.add_argument("--n", nargs="*", type=float, help="spectral index")
        cosmoargs.add_argument("--w", nargs="*", type=float, help="dark energy equation of state")
        cosmoargs.add_argument("--cs2-lam", nargs="*", type=float, help="constant comoving sound speed of dark energy")

        h_group = cosmoargs.add_mutually_exclusive_group()
        h_group.add_argument("--h", nargs="*", type=float, help="The hubble parameter")
        h_group.add_argument("--H0", nargs="*", type=float, help="The hubble constant")

        omegab_group = cosmoargs.add_mutually_exclusive_group()
        omegab_group.add_argument("--omegab", nargs="*", type=float, help="baryon density")
        omegab_group.add_argument("--omegab-h2", nargs="*", type=float, help="baryon density by h^2")

        omegac_group = cosmoargs.add_mutually_exclusive_group()
        omegac_group.add_argument("--omegac", nargs="*", type=float, help="cdm density")
        omegac_group.add_argument("--omegac-h2", nargs="*", type=float, help="cdm density by h^2")
        omegac_group.add_argument("--omegam", nargs="*", type=float, help="total matter density")

        cosmoargs.add_argument("--omegav", type=float, nargs="*", help="the dark energy density")

        # Process arguments
        args = parser.parse_args()

        # # Process the arguments
        kwargs = {}
        for arg in ["omegab", "omegab_h2", "omegac", "omegac_h2", "omegam", "h", "H0",
                    "sigma_8", "n", "w", "cs2_lam", "omegav", "ThreadNum", "transfer__kmax",
                    "transfer__k_per_logint", "AccuracyBoost", "lAccuracyBoost",
                    "Scalar_initial_condition", "z", "z2", "nz", "delta_c", "user_fit", "delta_h",
                    "delta_wrt"]:
            if getattr(args, arg) is not None:
                kwargs[arg] = getattr(args, arg)

        if args.M is not None:
            kwargs["M"] = np.arange(args.M[0], args.M[1], args.M[2])

        if args.mf_fit is not None:
            if "all" in args.mf_fit:
                kwargs['mf_fit'] = hmf.Fits.mf_fits
                kwargs["mf_fit"].remove("user_model")
            else:
                kwargs['mf_fit'] = args.mf_fit

        if args.user_fit is not None:
            if "user_model" not in kwargs['mf_fit']:
                kwargs['mf_fit'].append("user_model")

        if args.no_cut_fit:
            kwargs['cut_fit'] = not args.no_cut_fit

        if args.w_perturb:
            kwargs["w_perturb"] = args.w_perturb

        if args.lnk is not None:
            kwargs["lnk"] = np.arange(args.lnk[0], args.lnk[1], args.lnk[2])

        if args.transfer_fit is not None:
            if 'all' in args.transfer_fit:
                kwargs["transfer_fit"] = hmf.Transfer.fits
            else:
                kwargs["transfer_fit"] = args.transfer_fit


        m_att = [a for a in args.get if a in m_attrs]
        k_att = [a for a in args.get if a in k_attrs]
        # # run the hmf
        for res, label in hmf.tools.get_hmf(args.get, **kwargs):
            if m_att:
                marray = np.empty((len(h.M), len(m_att)))
                for i, attr in enumerate(m_att):
                    marray[:, i] = getattr(res, attr)
                np.savetxt(args.filename + "_MDATA_" + label, marray, header="\t".join(m_att))
            if k_att:
                karray = np.empty((len(h.transfer.lnk), len(k_att)))
                for i, attr in enumerate(k_att):
                    karray[:, i] = getattr(res, attr)
                np.savetxt(args.filename + "_KDATA_" + label, karray, header="\t".join(k_att))


        return 0
    except KeyboardInterrupt:
        ### handle keyboard interrupt ###
        return 0
    except Exception, e:
        if DEBUG or TESTRUN:
            raise(e)
        traceback.print_exc()
        indent = len(program_name) * " "
        sys.stderr.write(program_name + ": " + repr(e) + "\n")
        sys.stderr.write(indent + "  for help use --help\n")
        return 2

if __name__ == "__main__":
    if DEBUG:
        sys.argv.append("-h")
        sys.argv.append("-v")
    if TESTRUN:
        import doctest
        doctest.testmod()
    if PROFILE:
        import cProfile
        import pstats
        profile_filename = 'scripts.hmfrun_profile.txt'
        cProfile.run('main()', profile_filename)
        statsfile = open("profile_stats.txt", "wb")
        p = pstats.Stats(profile_filename, stream=statsfile)
        stats = p.strip_dirs().sort_stats('cumulative')
        stats.print_stats()
        statsfile.close()
        sys.exit(0)
    sys.exit(main())
