#!/home/scox/dev/grayson/venv/bin/python

''' Grayson unit tests. '''
import glob
import json
import logging
import unittest
import xmlrunner
import shutil
import os
import re
import sys
import traceback

from grayson.packager import GraysonPackager
from grayson.compiler.compiler import GraysonCompiler
from grayson.compiler.exception import CycleException
from grayson.executor import Executor
from grayson.common.util import GraysonUtil

from Pegasus.DAX3 import ADAG
from Pegasus.DAX3 import DAX
from Pegasus.DAX3 import Executable
from Pegasus.DAX3 import File
from Pegasus.DAX3 import Job
from Pegasus.DAX3 import Link
from Pegasus.DAX3 import PFN
from Pegasus.DAX3 import Profile
from Pegasus.DAX3 import Transformation
import Pegasus.DAX3

logger = logging.getLogger (__name__)
broken = False

generateBaseline = False

''' Basic test case '''
class GraysonTestCase (unittest.TestCase):

    def setUp (self):
        self.outputDir = "target"
        self.invocations = json.loads (GraysonUtil.readFile ("test/test.conf"))
        self.invocations ["lambda"]["define"]["data"] = os.path.realpath (self.invocations ["lambda"]["define"]["data"])
        self.baselinePath = os.path.join ("test", "baseline")

    def getOutputFile (self, fileName):
        return os.path.join (self.outputDir, fileName)

    def getSamplePath (self, sample):
        if type (sample) == list:
            sample = os.path.join (sample)
        return os.path.join ("samples", sample, "model")

    def getSampleModelPath (self, sample, modelPath):
        newModelPath = []
        for model in modelPath:
            if model == "samples":
                newModelPath.append (model)
            else:
                newModelPath.append (self.getSamplePath (model))
        return newModelPath

class GraysonCompilerTests (GraysonTestCase):

    def setUp (self):
        super(GraysonCompilerTests, self).setUp ()

    def areEqual (self, A, B, skip=[]):
        equal = True
        for key in A:
            if not key in skip:
                value = A [key]
                other = B [key]
                equal = self.areEqual (value, other) if isinstance (value, dict) else value == other
                if not equal:
                    break
        return equal

    def output (self, text):
        print text
        logger.info (text)
    
    def do_compile_model (self, test_name):
        test_case = self.invocations [test_name]
        models = [ test_case ["model"] ]
        modelPath = self.getSampleModelPath (test_name, test_case ["modelPath"])
        inputProperties={}
        if "define" in test_case:
            inputProperties = test_case ["define"]
            
        banner = """
        ========================================================================
        == P A C K A G E                                                        
        ==             test: (%s)
        ==            model: (%s)
        ==             path: (%s)
        == input-properties: (%s)
        ========================================================================""" % (
            test_name, models, modelPath, inputProperties)

        self.output (banner)
        outputPath = os.path.join (self.outputDir, test_name)
        GraysonCompiler.compile (models = models,
                                 modelPath = modelPath,
                                 outputdir = outputPath,
                                 logLevel = "debug",
                                 modelProperties = inputProperties,
                                 output=models[0].replace (".graphml", ".dax"),
                                 packaging = True)
        banner = """
        ========================================================================
        == C O M P I L E                                                        
        ==             test: (%s)
        ==            model: (%s)
        ==             path: (%s)
        == input-properties: (%s)
        =========================================================================""" % (
            test_name, models, modelPath, inputProperties)

        self.output (banner)
        modelName = models[0].replace (".graphml", ".grayson")
        modelName = os.path.join (outputPath, os.path.basename (modelName))
        models[0] = modelName
        GraysonCompiler.compile (models = models,
                                 outputdir = outputPath,
                                 output=models[0].replace (".grayson", ".dax"),
                                 toLogFile="%s/%s" % (outputPath, "log.txt") )

        return test_case

    def do_baseline_model (self, test_name):
        outputPath = os.path.join (self.outputDir, test_name)
        test_case = self.do_compile_model (test_name)
        if "verify" in test_case:
            verifications = test_case ["verify"]
            for filename in verifications:
                output_workflow = os.path.join (outputPath, filename)
                baselinePath = os.path.join (self.baselinePath, test_name)
                if not os.path.exists (baselinePath):
                    os.makedirs (baselinePath)
                baselinePath = os.path.join (self.baselinePath, test_name, filename)
                output = self.generate_digest (output_workflow)
                output = output.replace (os.getcwd (), ".*")                
                output, subs = re.subn (" value:\<(.*)\>", " value:<...>", output)
                GraysonUtil.writeFile (baselinePath, output)

    def do_test_model (self, test_name):
        if generateBaseline:
            self.do_baseline_model (test_name)
            return

        for test_name in self.invocations:
            test = self.invocations [test_name]
            verify = test ["verify"]
            for name in verify:
                baseline = os.path.join (self.baselinePath, test_name, name)
                verify [name] = GraysonUtil.readFile (baseline)

        outputPath = os.path.join (self.outputDir, test_name)
        test_case = self.do_compile_model (test_name)
        if "verify" in test_case:
            verifications = test_case ["verify"]
            for filename in verifications:
                output_workflow = os.path.join (outputPath, filename)
                baseline = verifications [filename]
                self.verify_compiler_output (output_workflow, baseline.rstrip ())
                
    def bufferlog (self, buffer, text):
        #self.output (text)
        buffer.append (text)

    def generate_digest (self, output_workflow):
        buffer = []
        dag = Pegasus.DAX3.parse (output_workflow)
        self.bufferlog (buffer, "dax name: <%s>" % dag.name)
        self.bufferlog (buffer, "files:")

        files = sorted (dag.files, key=lambda f: f.name)
        for file in files:
            self.bufferlog (buffer, "   file: <%s>" % str(file.name))
            for pfn in file.pfns:
                self.bufferlog (buffer, "      pfn: <%s>" % str(pfn.url))
                for profile in pfn.profiles:
                    self.bufferlog (buffer, "         profile: namespace:<%s> key:<%s> value:<%s>" % (profile.namespace, profile.key, profile.value))

        self.bufferlog (buffer, "jobs:")
        njobs = []
        for job in dag.jobs:
            if not isinstance (job, Job):
                job = dag.getJob (job)
            njobs.append (job)           
        jobs = sorted (njobs, key=lambda job: job.file.name if isinstance (job, DAX) else job.name)
        for job in jobs:
            if isinstance (job, DAX):
                self.bufferlog (buffer, "   dax-job: <%s>" % job.file.name)
            else:
                self.bufferlog (buffer, "   job: <%s>, id: <%s> namespace: <%s> version: <%s>" % ( job.name, job.id, job.namespace, job.version))
            used = sorted (job.used, key=lambda u: u.name)
            for file in used:
                self.bufferlog (buffer, "      usedfile: <%s>" % file.name)
            profiles = sorted (job.profiles, key=lambda p: "%s-%s-%s" % (p.namespace, p.key, p.value))
            for profile in profiles:
                self.bufferlog (buffer, "         profile: namespace:<%s> key:<%s> value:<%s>" % (profile.namespace, profile.key, profile.value))

        self.bufferlog (buffer, "transformations:")
        transformations = sorted (dag.transformations, key=lambda trans: trans.name)
        for trans in transformations:
            self.bufferlog (buffer, "   transformation: <%s> namespace: <%s>, version: <%s>" % (trans.name, trans.namespace, trans.version))
            for use in trans.used:
                self.bufferlog (buffer, "       uses: <%s>" % use.name)

        self.bufferlog (buffer, "executables:")
        executables = sorted (dag.executables, key=lambda ex: ex.name)
        for exe in executables:
            self.bufferlog (buffer, "   executable: <%s> namespace: <%s>, version: <%s>, arch: <%s>, os: <%s>" % ( exe.name, exe.namespace, exe.version, exe.arch, exe.os))
            for pfn in exe.pfns:
                self.bufferlog (buffer, "      pfn: <%s>" % str(pfn.url))
                for profile in pfn.profiles:
                    self.bufferlog (buffer, "         profile: namespace:<%s> key:<%s> value:<%s>" % (profile.namespace, profile.key, profile.value))

        self.bufferlog (buffer, "dependencies:")
        dependencies = sorted(dag.dependencies, key=lambda dep: dep.child)
        for dep in dependencies:
            child = dag.getJob (dep.child)
            self.bufferlog (buffer, "   dependency: <%s>" % child.id) 
            job = dag.getJob (dep.parent)
            name = job.file.name if isinstance (job, DAX) else job.name
            self.bufferlog (buffer, "      parent: parent=><%s> edge_label=><%s>" % ( name, "None"))
        actual = '\n'.join (buffer).rstrip ()
        return actual

    def verify_compiler_output (self, output_workflow, baseline=None):
        buffer = []
        actual = self.generate_digest (output_workflow)
        try:
            if re.search (baseline, actual):
                matched = "== ** (%s) : output matched baseline expectations" % output_workflow
                self.output (matched)
            else:
                broken = True
                raise ValueError ('Expected Baseline Output: \n[%s] \nReceived Actual Output: \n[%s]' % (baseline, actual))

        except Exception as e:
            error = """
=========================================================================================================
== E_R_R_O_R
=========================================================================================================
== GOT:
%s
== EXPECTED:
%s
=========================================================================================================
== E_R_R_O_R
=========================================================================================================
""" % (actual, baseline)

            traceback.print_exc ()

            self.output (error)
            exit (1)

            raise e

    def tearDown (self):
        pass

class GraysonAlphaTest (GraysonCompilerTests):
    def setUp (self):
        super (GraysonAlphaTest, self).setUp ()
    def test_alpha (self):
        self.do_test_model ("alpha")
    def baseline_alpha (self):
        self.do_baseline_model ("alpha")

class GraysonNucleosomeTest (GraysonCompilerTests):
    def setUp (self):
        super (GraysonNucleosomeTest, self).setUp ()
    def test_nucleosome (self):
        self.do_test_model ("nucleosome")
    def baseline_alpha (self):
        self.do_baseline_model ("nucleosome")

class GraysonBlackdiamondTest (GraysonCompilerTests):
    def setUp (self):
        super (GraysonBlackdiamondTest, self).setUp ()
    def test_blackdiamond (self):
        self.do_test_model ("blackdiamond")
    def baseline_alpha (self):
        self.do_baseline_model ("blackdiamond")

class GraysonLambdaTest (GraysonCompilerTests):
    def setUp (self):
        super (GraysonLambdaTest, self).setUp ()
    def test_lambda (self):
        self.do_test_model ("lambda")
    def baseline_alpha (self):
        self.do_baseline_model ("lambda")

class GraysonScanTest (GraysonCompilerTests):
    def setUp (self):
        super (GraysonScanTest, self).setUp ()
    def test_scan (self):
        test_name = "scan"
        self.do_test_model (test_name)
        config = os.path.join (self.outputDir, test_name, "map.grayconf")
        generatedObj = json.loads (GraysonUtil.readFile (config))
        expectedObj = {
            "outputName": "scan-flow", 
            "modelPath": "target/scan:target/scan:", 
            "operator": {
                "index": "chunk", 
                "version": "1.0", 
                "variable": "file", 
                "input": "fasta-chunks.tar.gz", 
                "flow": "scan-flow.graphml", 
                "namespace": "scan-flow", 
                "type": "dynamicMap", 
                "method": "tar"
                },
            "graysonHome": "/home/scox/dev/grayson", 
            "appHome": "target/scan", 
            "outputDir": "/home/scox/dev/grayson/target/scan", 
            "contextModels": "scan-test-core-context.graphml:scan-executable-context.graphml", 
            "sites": "local"
            }
        self.compareConfs (expectedObj, generatedObj)

        #self.assertTrue (self.areEqual (expectedObj, generatedObj, [ 'graysonHome' ]))
        logger.info ("Verified generated map operator configuration: \n%s", json.dumps (generatedObj, indent=4))

    def compareConfs (self, L, R):
        fields = {
            "top" : [
                "outputName",
                "modelPath",
                #"graysonHome", machine dependent
                "appHome",
                #"outputDir", machine dependent
                "sites" ],
            "operator" : [
                "index",
                "version",
                "variable",
                "input",
                "flow",
                "namespace",
                "type",
                "method" ]
        }
        self.compareObj (L, R, fields ['top'])
        self.compareObj (L, R, fields ['operator'])

    def compareObj (self, L, R, fields):
        for field in fields:
            logger.info ("comparing map operator field: %s", field)
            self.areEqual (L, R, [ field ])

class GraysonNAMDTest (GraysonCompilerTests):
    def setUp (self):
        super (GraysonNAMDTest, self).setUp ()
    def test_namd (self):
        self.do_test_model ("namd")
    def baseline_alpha (self):
        self.do_baseline_model ("namd")


class GraysonCycleDetectionTests (GraysonTestCase):
    def setUp (self):
        super (GraysonCycleDetectionTests, self).setUp ()
    def test_cycle_detection (self):        
        banner = """
====================================
== C Y C L E   D E T E C T I O N  ==
===================================="""
        logger.info (banner)
        print banner
        testName = "detect-cycle"
        model = "%s.graphml" % testName
        outputPath = os.path.join (self.outputDir, testName)
        with self.assertRaises (CycleException):
            GraysonCompiler.compile (models    = [ model ],
                                     modelPath = [ self.getSamplePath ("test") ],
                                     outputdir = outputPath,
                                     output    = model.replace (".graphml", ".dax")) 
    def baseline_alpha (self):
        self.do_baseline_model ("detect-cycle")

                           
class GraysonUITests (GraysonTestCase):

    def setUp (self):
        super(GraysonUITests, self).setUp ()

    def test_ui (self):
        print "+======================================== %s" % len (sys.argv)
        if len (sys.argv) > 1 and sys.argv[1] == '--ui':
            context = {
                'GRAYSON_HOME' : os.environ ['GRAYSON_HOME'],
                'PHANTOMJS_HOME' : os.environ ['PHANTOMJS_HOME']
                }

            def processor (line):
                print " ||- %s" % line.rstrip ()

            executor = Executor (context)
            try:
                command = "$PHANTOMJS_HOME/bin/phantomjs $GRAYSON_HOME/web/graysonapp/static/js/test/phantomjs-qunit-headless.js"
                if len (sys.argv) == 5:
                    command = "%s %s %s %s" % (
                        command,
                        sys.argv [2],
                        sys.argv [3],
                        sys.argv [4])
                executor.execute (command   = command,
                                  pipe      = True,
                                  processor = processor);
            except ValueError as e:
                traceback.print_exc ()

if __name__ == '__main__':
    
    if len (sys.argv) == 2 and sys.argv [1] == "--baseline":
        generateBaseline = True

    suite = []
    testClasses = [
        GraysonCycleDetectionTests,
        GraysonAlphaTest,
        GraysonBlackdiamondTest,
        GraysonNucleosomeTest,
        GraysonLambdaTest,
        GraysonNAMDTest,
        GraysonScanTest,
        GraysonUITests 
        ]
    loader = unittest.TestLoader ()
    logger.info ("Initializing test classes.")
    for testClass in testClasses:
        suite.append (loader.loadTestsFromTestCase (testClass))

    for test in suite:
        unittest.TextTestRunner (verbosity=2).run (test)    
        for t in test:
            try:
                output = None
                try:
                    name = ("%s" % t).split (' ')[0]
                    output = open ("TEST-%s.xml" % name, "w")
                    runner = xmlrunner.XMLTestRunner (output).run (t)
                    if broken:
                        break
                finally:
                    if output:
                        output.close ()
            except IOError as e:
                logger.error (e)

