from MDAnalysis.analysis.align import *
import numpy as np
from scipy import spatial,cluster
from cofasu import aatb, rmsd

class Optimizer():
    '''
    Coco optimizer class. Initialised with a set of trajectory data and, 
    optionally, an ideal reference structure, it then provides a method to 
    optimise approximate structures generated by the Coco method (or any other).
    '''
    def __init__(self, cofasu, ideal=None, tol=0.02, log=None):
        '''
        Initialization of the optimizer. A collection of snapshots (in
        the form of a cofasu) are analysed to identify rigid subgroups,
        on the basis of having close-to-invariant interatomic distances.
        The 'ideal' structure provides a source of optimal subgroup
        geometries that will be used in the optimize procedure itself.
        If no 'ideal' structure is provided, one will be chosen from the
        cofasu provided, namely that which has the smallest mean squared 
        error in its interatomic distances from the mean values calculated
        over the whole dataset. The 'tol' parameter sets the threshhold for
        the complete linkage clustering that is used to identify the rigid
        groups. Be aware that if tol is set too tight then rigid groups of
        less than three atoms may be generated, which will (currently) break
        the least-squares fitting procedure that is part of the optimization
        process. 'Log' provides a hook to a logger instance.
        '''
        if log is not None:
            log.info("calculating distances...")
#
# step 1: generate the distance matrix <d>, and the matrix of variances var(d)
#
        d = spatial.distance.pdist(cofasu.coords(0), 'euclidean')
        s = d
        s2 = d*d

        for i in range(1,cofasu.numframes()):
            d = spatial.distance.pdist(cofasu.coords(i),'euclidean')
            s += d
            s2 += d*d

        if log is not None:
            log.info("making k matrix...")
        n = cofasu.numframes()
        self.davg = s/n
        s2 = s2/n
        k = s2-(self.davg*self.davg)
#
# the clustering process. This has two stages. In the first the variance matrix
# is used in a complete-linkage clustering process. In the second stage the
# non-overlapping clusters are expanded to include all singly-linked neigbours.
#
        if log is not None:
            log.info("clustering...")
        z = cluster.hierarchy.complete(k)
        rb_id = cluster.hierarchy.fcluster(z,tol,criterion='distance')

        if log is not None:
            log.info("creating rigid groups...")
        numrb = rb_id.max()
        rbsize = np.zeros(numrb)
        for i in rb_id:
            rbsize[i-1] += 1

        rbmat = np.zeros((numrb,cofasu.natoms))
        sqk = spatial.distance.squareform(k)
#
# This is the srigid groups expansion process:
#
        self.rblist = []
        for i in range(numrb):
            rbatomlist = []
            for j in range(cofasu.natoms):
                if rb_id[j] == i+1:
                    rbatomlist.append(j)
            for j in rbatomlist:
                for l in range(cofasu.natoms):
                    if sqk[j,l] < tol:
                        rbmat[i,l] += 1

            for j in range(cofasu.natoms):
                if rbmat[i,j] >= rbsize[i]:
                    rbmat[i,j] = 1
                else:
                    rbmat[i,j] = 0

            rbatomlist = []
            for j in range(cofasu.natoms):
                if rbmat[i,j] == 1:
                    rbatomlist.append(j)
            self.rblist.append(rbatomlist)

        rbsize = rbmat.sum(axis=1)
        self.dups = rbmat.sum(axis=0)

        if log is not None:
            log.info("number of groups: {}".format(len(self.rblist)))
#
# Here we convert the variance matrix into a force constant matrix. We
# take care to replace zero values with something small (tol*0.1) before
# inverting it. We also set to zero force constants associated with
# interatomic distances that had a varaince greater than tol.
#
        for i in range(len(k)):
            k[i] = max(k[i],tol*0.1)

        self.k = 1.0/k
        invtol = 1.0/tol
        for i in range(len(self.k)):
            if self.k[i] < invtol:
                self.k[i]=0.0
#
# Here we deal with the ideal structure - selecting one if none was given.
#
        if ideal is not None:
            self.ideal = ideal
            d = spatial.distance.pdist(ideal, 'euclidean')
            e = d-self.davg
#            emin = (e*e*self.k).sum()
            emin = (e*e).mean()
        else:
            d = spatial.distance.pdist(cofasu.coords(0), 'euclidean')
            e = d-self.davg
#            emin = (e*e*self.k).sum()
            emin = (e*e).mean()
            imin = 0
            for i in range(1,cofasu.numframes()):
                d = spatial.distance.pdist(cofasu.coords(i), 'euclidean')
                e = d-self.davg
#                e = (e*e*self.k).sum()
                e = (e*e).mean()
                if e < emin:
                    imin = i
                    emin = e
            self.ideal = cofasu.coords(i)

        if log is not None:
            log.info("ideal structure has energy: {}".format(emin))
        return

    def optimize(self, coords, dtol=0.1, maxcycles=100):
        '''
        Optimize a crude structure, provided as a (natoms,3) numpy array.
        dtol and maxcycles provide criterea to terminate the iterative 
        optimisation. Dtol is the rmsd betwen structures from succesive
        iterations, maxcycles is the maximum number of iterations.
        ''' 
#
# The method works as follows: Using the coordinatates from the 'ideal'
# structure, each rigid subgroup is sequentially least-squares fitted to
# the target structure. The new coordinates are accumulated. At the end, most
# atoms will have been fitted a number of times, as they will have featured
# in a number of the (often overlapping) rigid fragments. Therefore mean
# coordinates are calculated.
# The rmsd of this new structure is compared to the original, and depending
# on the stopping criterea the process may be iterated or terminated.
#
        opt = coords.copy()
        testopt = np.zeros_like(opt)

        for j in range(maxcycles):
            testopt[:,:] = 0.0
            for i in range(len(self.rblist)):
                testopt[self.rblist[i],:] += aatb(self.ideal.take(self.rblist[i],axis=0),opt.take(self.rblist[i],axis=0))

            testopt[:,0] = testopt[:,0]/self.dups
            testopt[:,1] = testopt[:,1]/self.dups
            testopt[:,2] = testopt[:,2]/self.dups

            r = rmsd(opt,testopt)
            opt = testopt.copy()
            if r < dtol:
                break

        return opt

    def energy(self, coords):
        '''
        calculate the energy of a given structure
        '''
        d = spatial.distance.pdist(coords, 'euclidean')
        e = d-self.davg
#        e = (e*e*self.k).sum()
        e = (e*e).mean()
        return e
