#!/usr/bin/env python

"""compare.py: Compare data generated by MOOSE and NEURON.

Last modified: Wed May 28, 2014  02:48PM

"""
    
__author__           = "Dilawar Singh"
__copyright__        = "Copyright 2013, NCBS Bangalore"
__credits__          = ["NCBS Bangalore", "Bhalla Lab"]
__license__          = "GPL"
__version__          = "1.0.0"
__maintainer__       = "Dilawar Singh"
__email__            = "dilawars@ncbs.res.in"
__status__           = "Development"

import os
import sys
from collections import defaultdict
import pylab
import numpy as np

EPSILON = 1e-10

def findMaxima(y, x, filters=[], **kwargs):
    """Find the location of peaks in data 
    
    If type of peak is cap then 
    
    """
    maximas = []
    index = []
    for i, a in enumerate(y[1:-1]):
        if a > y[i] and a > y[i+2]:
            # Check if point satisfies addtional condition.
            insert = True
            for f in filters:
                if not f(a):
                    insert = False
                    break
            if insert:
                maximas.append(a)
                index.append(x[i+1])
    return index, maximas

def findMinima(y, x, filters=[], **kwargs):
    """Find all minimas on the curve 
    """
    minimas = []
    index = []
    for i, a in enumerate(y[1:-1]):
        if a > y[i] and a > y[i+2]:
            # Check if point satisfies addtional condition.
            insert = True
            for f in filters:
                if not f(a):
                    insert = False
                    break
            if insert:
                minimas.append(a)
                index.append(i+1)
    return index, minimas

def compareData(x1, y1, x2, y2):
    """Given two plots (x1, y1) and (x2, y2), Do some statistics on them 
    """

    # First compare that there x-axis are same. else report warning.
    x1 = np.array(x1)
    x2 = np.array(x2)
    y1 = np.array(y1)
    y2 = np.array(y2)

    assert(len(x1) == len(x2)), "X axis must have equal no of entries"
    for i, x in enumerate(x1):
        msg = "Value mismatch in x-axis: {}-{} = {}".format(x, x2[i], x-x2[i])
        assert np.absolute(x - x2[i]) < EPSILON, msg 

    # Good, now do a simple root-mean square test on both y-axis.
    pylab.figure()
    maximasY1 = findMaxima(y1, x1, filters=[(lambda x : x > 20) ])
    maximasY2 = findMaxima(y2, x2, filters=[(lambda x : x > 20) ])
    pylab.plot(maximasY1[0], maximasY1[1], '^')
    pylab.plot(maximasY2[0], maximasY2[1], 'o')
    pylab.show()
                

def compare(mooseData, nrnData, outputFile = None):
    """Compare two data-vectors """
    mooseX, mooseY = mooseData
    nrnX, nrnY = nrnData
    mooseX = [ x * 1e3 for x in mooseX ]
    for v in mooseY:
        mooseY[v] = [ 1e3 * y for y in mooseY[v]]
    for i, v in enumerate( mooseY ):
        peaksY1 = compareData(mooseX, mooseY.values()[i], nrnX, nrnY.values()[i])
        print peaksY1
"""
        pylab.figure()
        pylab.plot(mooseX, mooseY.values()[i])
        pylab.plot(nrnX, nrnY.values()[i])
        if outputFile is None:
            pylab.show()
        else:
            outFile = "{}{}.png".format(outputFile, i)
            print("[INFO] Dumping plot to {}".format(outFile))
            pylab.savefig(outFile)
"""

def txtToData(txt):
    """Convert text to data"""
    vecX = []
    vecY = defaultdict(list)
    for line in txt.split("\n"):
        line = line.strip()
        values = line.split()
        if not values:
            continue
        vecX.append(float(values[0].strip()))
        for i, v in enumerate(values[1:]):
            v = v.strip()
            vecY[i].append(float(v))
    return vecX, vecY

def main():
    print("[INFO] Second file will be scaled by a factor of 1e3")
    mooseFile = sys.argv[1]
    nrnFile = sys.argv[2]
    outputFile = None
    if len(sys.argv) > 3:
        outputFile = sys.argv[3]

    with open(mooseFile, "r") as f:
        mooseData = txtToData(f.read())
    with open(nrnFile, "r") as f:
        nrnData = txtToData(f.read())
    compare(mooseData, nrnData, outputFile)

if __name__ == '__main__':
    main()
