# This file is part of Neuroinfo Toolkit.
#
# Neuroinfo Toolkit is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Neuroinfo Toolkit is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Neuroinfo Toolkit.  If not, see <http://www.gnu.org/licenses/>.

import re
import os
import sets
import json
import shlex
import difflib
import mimetypes
import subprocess
from lxml import etree
from neuro.base import Object
from StringIO import StringIO

class XDiff(Object):
    '''
    Extensible diff class
    '''
    
    def __init__(self):
        '''
        Constructor
        '''
        self._output = {}
        self._preproc = []
        
        ## --- default map
        self._map = { "text/plain"          : self.text_diff_enhanced,
                      "application/nifti1"  : XDiff.nifti1_diff,
                      "*"                   : XDiff.binary_diff }
        
        ## --- add MIME types
        mimetypes.add_type("application/nifti1", ".nii", strict=True)
        
    def add_map(self, type, fun):
        '''
        Add a type/function mapping
        
        :param type:     
        :type type: str
        :param fun: function
        '''
        self._map[type] = fun
        
    def add_preproc_rule(self, file_pattern, replace):
        '''
        Text file preprocessing rule
        
        :param file_pattern:
        :type file_pattern: str 
        :param replace:
        :type replace: tuple
        '''
        self._preproc.append({ "fname_pattern"    : file_pattern,
                               "search_pattern"   : replace[0],
                               "replace_content"  : replace[1] })
        
    def run(self, dir_a, dir_b, suppress_diffs=False, skip_patterns=[]):
        '''
        Run diff on two directories
        
        :param dir_a:
        :type dir_a: str
        :param dir_b:
        :type dir_b: str
        :param suppress_diffs:
        :type suppress_diffs: bool
        :param skip_patterns:
        :type skip_patterns: list
        '''
        a_set = sets.Set(XDiff.listdirs(dir_a, skip_patterns=skip_patterns))
        b_set = sets.Set(XDiff.listdirs(dir_b, skip_patterns=skip_patterns))
        
        self._output["unique_a"] = list(a_set.difference(b_set))
        self._output["unique_b"] = list(b_set.difference(a_set))
        
        self._output["difference"] = []
        self._output["summary"] = { "different_files": [] }
        
        ## --- iterate over intersection
        for file in a_set.intersection(b_set):
            file_in_a = os.path.join(dir_a, file)
            file_in_b = os.path.join(dir_b, file)
            
            summary = self._output["summary"]
            
            ## --- initial file diff dictionary
            file_diff = { "file_a"  : os.path.realpath(file_in_a),
                          "file_b"  : os.path.realpath(file_in_b),
                          "differ"  : False,
                          "diff_blob"    : [] }
            
            ## --- get MIME type and apply the associated diff function
            file_diff["mime_type"] = mimetypes.guess_type(file_in_a)[0]
            fun = self._get_diff_fun(file_diff["mime_type"])
            file_diff["comparator"] = fun.__name__
            if(fun):
                diff = fun(file_in_a, file_in_b)
                if(diff[0] == True):
                    file_diff["differ"] = True
                    summary["different_files"].append(file_in_a)
                if(not suppress_diffs):
                    file_diff["diff_blob"] = diff[1]
                
            self._output["difference"].append(file_diff)
            
    def _get_diff_fun(self, type):
        '''
        Get diffing function for type

        :param type:
        :type type: str
        :rtype: function
        '''
        if type not in self._map:
            print "[WARN]: Using default function for type '" + str(type) + "'"
            return self._map["*"]
        else:
            return self._map[type]
        
    def text_diff_enhanced(self, file_a, file_b):
        '''
        Text diff with search/replace
        
        :param file_a:
        :type file_a: str
        :param file_b:
        :type file_b: str
        :returns: (bool, list)
        :rtype: tuple
        '''
        content_a = open(file_a, "rb").read()
        content_b = open(file_b, "rb").read()
        
        for item in self._preproc:
            fname_pattern = item["fname_pattern"]
            search_pattern = item["search_pattern"]
            replace_content = item["replace_content"]
            
            if(re.search(fname_pattern, file_a)):
                content_a = re.sub(search_pattern, replace_content, content_a)
                content_b = re.sub(search_pattern, replace_content, content_b)
                
        return XDiff.text_diff(StringIO(content_a), StringIO(content_b))
    
    @staticmethod
    def text_diff(file_a, file_b):
        '''
        Text diff

        :param file_a:
        :type file_a: str
        :param file_b:
        :type file_b: str
        :returns: (bool, list)
        :rtype: tuple
        '''
        if isinstance(file_a, basestring):
            file_a = open(file_a, "rb")
        if isinstance(file_b, basestring):
            file_b = open(file_b, "rb")
            
        content_a = file_a.read()
        content_b = file_b.read()
        
        diff = ""
        d = difflib.unified_diff(content_a.splitlines(True), content_b.splitlines(True))
        for line in d:
            diff += line
        if len(diff) > 0:
            return (True, diff)
        else:
            return (False, diff)
        
    @staticmethod
    def listdirs(dir, skip_patterns=[]):
        '''
        List all files in a directory, recursively
        
        :param dir:
        :type dir: str
        :param skip_patterns:
        :type skip_patterns: list
        :rtype: list
        '''
        file_list = []
        for root, dirs, files in os.walk(dir):
            for name in files:
                skip = False
                full_file = os.path.join(root, name)
                for pattern in skip_patterns:
                    if re.match(pattern, full_file):
                        skip = True
                        break
                if(not skip):
                    file_list.append(os.path.relpath(full_file, dir))
                        
        return file_list
    
    @staticmethod
    def no_diff(file_a, file_b):
        '''
        Empty diff, returns False all of the time and no diff blob
        
        :param file_a:
        :type file_a: str
        :param file_b:
        :type file_b: str
        :returns: (bool, str)
        :rtype: tuple
        '''
        return (False, "")

    @staticmethod
    def binary_diff(file_a, file_b):
        '''
        Binary file diff
        
        :param file_a:
        :type file_a: str
        :param file_b:
        :type file_b: str
        :returns: (bool, str)
        :rtype: tuple
        '''
        content_a = open(file_a, "rb").read()
        content_b = open(file_b, "rb").read()
        
        diff = ""
        d = difflib.unified_diff(content_a, content_b)
        
        if d:
            return (True, "")
        else:
            return (False, "")
        
    @staticmethod
    def nifti1_diff(file_a, file_b):
        '''
        NIFTI-1 diff
        
        :param file_a:
        :type file_a: str
        :param file_b:
        :type file_b: str
        :returns: (bool, str)
        :rtype: tuple
        '''
        comm = "niftidiff " + file_a + " " + file_b
        output = XDiff.run_command(comm)
        if output["return_code"] > 0:
            return (True, output["stdout"])
        else:
            return (False, output["stdout"])
        
    @staticmethod
    def run_command(command):
        '''
        Execute a shell command
        
        :param command:
        :type command: str
        :returns: {stdout: str, stderr: str, return_code: int}
        :rtype: dict
        '''
        output = {}
        command = shlex.split(command)
        p = subprocess.Popen(command, stdout=subprocess.PIPE,
                 stderr=subprocess.PIPE)
        (out, err) = p.communicate()
        output["stdout"] = out
        output["stderr"] = err
        output["return_code"] = p.returncode
        return output

    def to_json(self, indent=0):
        '''
        Return diff output as JSON
            
        :param indent: Indentation level
        :type indent: int
        :rtype: str
        '''
        return json.dumps(self._output, indent=indent)

    def to_xml(self, as_string=True):
        '''
        Return object output as XML

        :param as_string:
        :type as_string: bool
        :rtype: str, lxml.Element
        '''
        def _dict_to_xml(x, node):
            if isinstance(x, dict):
                for key, value in x.items():
                    sub_node = etree.SubElement(node, key)
                    _dict_to_xml(value, sub_node)
            elif isinstance(x, list):
                for item in x:
                    sub_node = etree.SubElement(node, "item")
                    _dict_to_xml(item, sub_node)
            else:
                node.text = str(x)
                        
        node = etree.Element("output")
        _dict_to_xml(self._output, node)

        if(as_string):
                return etree.tostring(node, pretty_print=True)
        else:
                return node

    def to_dict(self):
        '''
        Return object output as dictionary
        
        :rtype: dict
        '''
        return self._output