"""
Simple minimizer  is a wrapper around scipy.leastsq, allowing a
user to build a fitting model as a function of general purpose
Fit Parameters that can be fixed or floated, bounded, and written
as a simple expression of other Fit Parameters.

The user sets up a model in terms of instance of Parameters, writes a
function-to-be-minimized (residual function) in terms of these Parameters.
"""

from numpy import sqrt
from asteval import Interpreter, NameFinder
from astutils import valid_symbol_name

from scipy.optimize import leastsq as scipy_leastsq
from scipy.optimize import anneal as scipy_anneal
from scipy.optimize import fmin_l_bfgs_b as scipy_lbfgsb

from UserDict import DictMixin

class Parameters(dict, DictMixin):
    """a custom dictionary of Parameters.  All keys must be
    strings, and valid Python symbol names, and all values
    must be Parameters.

    Custom methods:
    ---------------

    add()
    add_many()
    """
    def __init__(self, *args, **kwds):
        self.update(*args, **kwds)

    def __setitem__(self, key, value):
        if key not in self:
            if not valid_symbol_name(key):
                raise KeyError("'%s' is not a valid Parameters name" % key)
        if value is not None and not isinstance(value, Parameter):
            raise ValueError("'%s' is not a Parameter" % value)
        dict.__setitem__(self, key, value)
        value.name = key

    def add(self, name, value=None, vary=True, expr=None,
            min=None, max=None):
        """convenience function for adding a Parameter:
        with   p = Parameters()
        p.add(name, value=XX, ....)

        is equivalent to
        p[name] = Parameter(name=name, value=XX, ....
        """
        self.__setitem__(name, Parameter(value=value, name=name, vary=vary,
                                         expr=expr, min=min, max=max))

    def add_many(self, *parlist):
        """convenience function for adding a list of Parameters:
        Here, you must provide a sequence of tuples, each containing:
            name, value, vary, min, max, expr
        with   p = Parameters()
        p.add_many( (name1, val1, True, None, None, None),
                    (name2, val2, True,  0.0, None, None),
                    (name3, val3, False, None, None, None))

        """
        for name, value, vary, min, max, expr in parlist:
            self.add(name, value=value, vary=vary,
                     min=min, max=max, expr=expr)

class Parameter(object):
    """A Parameter is the basic Parameter going
    into Fit Model.  The Parameter holds many attributes:
    value, vary, max_value, min_value, constraint expression.
    The value and min/max values will be be set to floats.
    """
    def __init__(self, value=None, vary=True, name=None,
                 min=None, max=None, expr=None, **kws):
        self.value = value
        self.min = min
        self.max = max
        self.vary = vary
        self.expr = expr
        self.name = None
        self.stderr = None
        self.correl = None

    def __repr__(self):
        s = []
        if self.name is not None:
            s.append("'%s'" % self.name)
        val = repr(self.value)
        if self.vary and self.stderr is not None:
            val = "value=%s +/- %.3g" % (repr(self.value), self.stderr)
        elif not self.vary:
            val = "value=%s (fixed)" % (repr(self.value))
        s.append(val)
        s.append("bounds=[%s:%s]" % (repr(self.min),repr(self.max)))
        if self.expr is not None:
            s.append("expr='%s'" % (self.expr))

        return "<Parameter %s>" % ', '.join(s)

class MinimizerException(Exception):
    """General Purpose Exception"""
    def __init__(self, msg):
        Exception.__init__(self)
        self.msg = msg

    def __str__(self):
        return "\n%s" % (self.msg)

def check_ast_errors(error):
    if len(error) > 0:
        msg = []
        for err in error:
            msg = '\n'.join(err.get_error())
        raise MinimizernException(msg)


class Minimizer(object):
    """general minimizer"""
    err_nonparam = "params must be a minimizer.Parameters() instance"
    err_maxfev   = """Too many function calls (max set to  %%i)!  Use:
    minimize(func, params, ...., maxfev=NNN)
or set  leastsq_kws['maxfev']  to increase this maximum."""

    def __init__(self, userfcn, params, fcn_args=None, fcn_kws=None, **kws):
        self.userfcn = userfcn
        self.params = params

        self.userargs = fcn_args
        if self.userargs is None:
            self.userargs = []

        self.userkws = fcn_kws
        if self.userkws is None:
            self.userkws = {}
        self.kws = kws
        self.var_map = []
        self.asteval = Interpreter()
        self.namefinder = NameFinder()

    def __update_paramval(self, name):
        """
        update parameter value, including setting bounds.
        For a constrained parameter (one with an expr defined),
        this first updates (recursively) all parameters on which
        the parameter depends (using the 'deps' field).
       """
        # Has this param already been updated?
        # if this is called as an expression dependency,
        # it may have been!
        if self.updated[name]:
            return

        par = self.params[name]
        val = par.value
        if par.expr is not None:
            for dep in par.deps:
                self.__update_paramval(dep)
            val = self.asteval.interp(par.ast)
            check_ast_errors(self.asteval.error)
        # apply min/max
        if par.min is not None:
            val = max(val, par.min)
        if par.max is not None:
            val = min(val, par.max)

        self.asteval.symtable[name] = par.value = float(val)
        self.updated[name] = True

    def __residual(self, fvars):
        """
        residual function used for least-squares fit.
        With the new, candidate values of fvars (the fitting variables),
        this evaluates all parameters, including setting bounds and
        evaluating constraints, and then passes those to the
        user-supplied function to calculate the residual.
        """
        # set parameter values
        for varname, val in zip(self.var_map, fvars):
            self.params[varname].value = val

        self.updated = dict([(name, False) for name in self.params])
        for name in self.params:
            self.__update_paramval(name)

        sout = []
        for i in self.params.values():
            sout.append('%s=%.5f'  % (i.name, (i.value)))
        # print '%s: %s' % ('ITER : ', ','.join(sout))
        return self.userfcn(self.params, *self.userargs, **self.userkws)

    def prepare_fit(self):
        """prepare parameters for fit"""

        # determine which parameters are actually variables
        # and which are defined expressions.
        if not isinstance(self.params, Parameters):
            raise MinimizerException(self.err_nonparam)

        self.var_map = []
        self.vars = []
        self.vmin = []
        self.vmax = []
        for name, par in self.params.items():
            if par.expr is not None:
                par.ast = self.asteval.compile(par.expr)
                check_ast_errors(self.asteval.error)
                par.vary = False
                par.deps = []
                self.namefinder.names = []
                self.namefinder.generic_visit(par.ast)
                for symname in self.namefinder.names:
                    if (symname in self.params and
                        symname not in par.deps):
                        par.deps.append(symname)

            elif par.vary:
                self.var_map.append(name)
                self.vars.append(par.value)
                self.vmin.append(par.min)
                self.vmax.append(par.max)

            self.asteval.symtable[name] = par.value
            par.init_value = par.value
            if par.name is None:
                par.name = name

        self.nvarys = len(self.vars)

        # now evaluate make sure initial values
        # are used to set values of the defined expressions.
        # this also acts as a check of expression syntax.
        self.updated = dict([(name, False) for name in self.params])
        for name in self.params:
            self.__update_paramval(name)

    def anneal(self, schedule='cauchy', **kws):
        """
        use simulated annealing
        """
        print("Simulated Annealing...")
        sched = 'fast'
        if schedule in ('cauchy', 'boltzmann'):
            sched = schedule

        self.prepare_fit()
        sakws = dict(full_output=1, schedule=sched,
                     maxiter = 2000 * (self.nvarys + 1),
                     upper = self.vmax, lower=self.vmin)

        sakws.update(self.kws)
        sakws.update(kws)

        def penalty(params):
            r =self.__residual(params)
            return (r*r).sum()

        saout = scipy_anneal(penalty, self.vars, **sakws)
        self.sa_out = saout
        return

    def lbfgsb(self, **kws):
        """
        use l-bfgs-b minimization
        """
        self.prepare_fit()
        lb_kws = dict(factr=1000.0, approx_grad=True, m=20,
                      maxfun = 2000 * (self.nvarys + 1),
                      bounds = zip(self.vmin, self.vmax))
        lb_kws.update(self.kws)
        lb_kws.update(kws)
        def penalty(params):
            r =self.__residual(params)
            return (r*r).sum()
        
        xout, fout, info = scipy_lbfgsb(penalty, self.vars, **lb_kws)

        for k, v in info.items():
            print k, v
            
        self.nfev =  info['funcalls']
        self.message = info['task']

        
    def leastsq(self, **kws):
        """
        use Levenberg-Marquardt minimization to perform fit.
        This assumes that ModelParameters have been stored into
        and a function to mi

        This wraps scipy.optimize.leastsq, and keyward arguments are passed
        directly as options to scipy.optimize.leastsq

        When possible, this calculates the estimated uncertainties and
        variable correlations from the covariance matrix.

        writes outputs to many internal attributes, and
        returns True if fit was successful, False if not.
        """

        self.prepare_fit()
        lskws = dict(full_output=1, xtol=1.e-7, ftol=1.e-7,
                     maxfev= 1000 * (self.nvarys + 1))
        lskws.update(self.kws)
        lskws.update(kws)

        lsout = scipy_leastsq(self.__residual, self.vars, **lskws)
        vbest, cov, infodict, errmsg, ier = lsout

        self.residual = resid = infodict['fvec']

        self.ier = ier
        self.lmdif_message = errmsg
        self.message = 'Fit succeeded.'
        self.success = ier in [1, 2, 3, 4]

        if ier == 0:
            self.message = 'Invalid Input Parameters.'
        elif ier == 5:
            self.message = self.err_maxfev % lskws['maxfev']
        else:
            self.message = 'Tolerance seems to be too small.'

        self.nfev =  infodict['nfev']
        self.ndata = len(resid)

        sum_sqr = (resid**2).sum()
        self.chisqr = sum_sqr
        self.nfree = (self.ndata - self.nvarys)
        self.redchi = sum_sqr / self.nfree

        for par in self.params.values():
            par.stderr = 0
            par.correl = None
            if hasattr(par, 'ast'):
                delattr(par, 'ast')

        if cov is None:
            self.errorbars = False
            self.message = '%s. Could not estimate error-bars'
        else:
            self.errorbars = True
            cov = cov * sum_sqr / self.nfree
            for ivar, varname in enumerate(self.var_map):
                par = self.params[varname]
                par.stderr = sqrt(cov[ivar, ivar])
                par.correl = {}
                for jvar, varn2 in enumerate(self.var_map):
                    if jvar != ivar:
                        par.correl[varn2] = (cov[ivar, jvar]/
                                        (par.stderr * sqrt(cov[jvar, jvar])))

        return self.success

def minimize(fcn, params, engine='leastsq', args=None, kws=None, **fit_kws):
    fitter = Minimizer(fcn, params, fcn_args=args, fcn_kws=kws, **fit_kws)
    if engine == 'anneal':
        fitter.anneal()
    elif engine == 'lbfgsb':
        fitter.lbfgsb()
    else:
        fitter.leastsq()
    return fitter
