import atexit
import codecs
import errno
import glob
import os
import time
import sys
from inspect import getsource
from itertools import chain
from md5 import md5
try:
    import cPickle as pickle
except ImportError:
    import pickle

script_path = os.path.dirname(os.path.realpath(sys.argv[0]))
checksum_file = os.path.join(script_path, '.musket')
os.chdir(script_path)
OLD_CHECKSUMS = {}

@atexit.register
def save_meta():
    all_path = set(Fragment._all_checksum.values())
    for checksum, path in OLD_CHECKSUMS.items():
        if path is not None and path not in all_path:
            if not os.path.exists(path):
                continue
            print 'delete %s' % path
            os.remove(path)

    with open(checksum_file, 'w') as fh:
        pickle.dump(Fragment._all_checksum, fh)

def load_checksums():
    global OLD_CHECKSUMS
    if os.path.exists(checksum_file):
        with open(checksum_file) as fh:
            OLD_CHECKSUMS= pickle.load(fh)

load_checksums()


class Fragment(object):

    _all_checksum = {}

    def __init__(self, encoding=None):
        self._content = None
        self.encoding = encoding
        cs = self.get_cheksum()
        if isinstance(cs, tuple):
            self.checksum = cs[0]
            frag_path = cs[1]
        else:
            self.checksum = cs
            frag_path = None

        self._all_checksum[self.checksum] = frag_path
        self.expired = self.is_expired()

    @property
    def content(self):
        if self._content is None:
            self._content = self.get_content()
        return self._content

    def is_expired(self):
        return self.checksum not in OLD_CHECKSUMS

class File(Fragment):

    _all_files = {}

    def __init__(self, path, encoding='utf-8'):
        if path in self._all_files:
            self.__dict__ = self._all_files[path].__dict__
            # Current object is already __init__'ed
            assert self.encoding == encoding, "Encoding mismatch for %s" % path
            return

        self.path = path
        File._all_files[self.path] = self
        try:
            self.timestamp = os.path.getmtime(self.path)
        except os.error:
            self.timestamp = None
        super(File, self).__init__(encoding=encoding)

    def get_cheksum(self):
        return md5(self.path + str(self.timestamp)).hexdigest(), self.path

    def get_content(self):
        if not self._content:
            self._content = self.read()
        return self._content

    def read(self):
        print 'read', self.path
        with codecs.open(self.path, mode='r', encoding=self.encoding) as fh:
            return fh.read()


class FileList(Fragment):

    def __init__(self, paths):
        if isinstance(paths, basestring):
            paths = glob.glob(paths)
        self.files = map(File, paths)

        super(FileList, self).__init__()

    def get_cheksum(self):
        m = md5(''.join(f.checksum for f in self.files))
        return m.hexdigest()

    def get_content(self):
        if not self._content:
            self._content = [f.content for f in self.files]
        return self._content


class FileDict(Fragment):

    def __init__(self, paths):
        if isinstance(paths, basestring):
            paths = dict((p,p) for p in glob.glob(paths))
        self.files = dict((k, File(v)) for k, v in paths.iteritems())
        super(FileDict, self).__init__()

    def get_cheksum(self):
        m = md5(''.join(f.checksum for f in self.files.itervalues()))
        return m.hexdigest()

    def get_content(self):
        if not self._content:
            self._content = dict((k, v.content) \
                    for k,v in self.files.iteritems())
        return self._content


class Hunk(Fragment):

    def __init__(self, compiler, args, kwargs, meta=None):
        self.compiler = compiler
        self.args = args
        self.kwargs = kwargs
        self.meta = meta or {}
        super(Hunk, self).__init__()

    def get_cheksum(self):
        m = md5(getsource(self.compiler.fn))
        for arg in chain(self.args, self.kwargs.iteritems()):
            if isinstance(arg, Fragment):
                m.update(arg.checksum)
        return m.hexdigest()

    def get_content(self):
        if self._content is  None:
            args = []
            kwargs = {}

            for pos, arg in enumerate(self.args):
                if isinstance(arg, Fragment):
                    args.append(arg.content)
                else:
                    args.append(arg)

                if isinstance(arg, Hunk):
                    self.meta.update(arg.meta)

            for key, arg in self.kwargs.iteritems():
                if isinstance(arg, Fragment):
                    self.kwargs[key] = arg.content
                else:
                    kwargs[key] = arg

                if isinstance(arg, Hunk):
                    self.meta.update(arg.meta)

            res = self.compiler.fn(*args, **kwargs)
            if isinstance(res, (tuple, list)):
                self._content = res[0]
                self.meta.update(res[1])

            else:
                self._content = res

        return self._content

    def save(self, dest, encoding='utf-8'):
        sink = Sink(self, dest, encoding=encoding)
        if sink.expired:
            sink.save()


class Sink(Fragment):

    def __init__(self, hunk, dest, encoding):
        self.hunk = hunk
        self.dest = dest
        super(Sink, self).__init__(encoding=encoding)

    def get_cheksum(self):
        m = md5(self.hunk.checksum)
        if callable(self.dest):
            m.update(getsource(self.dest))
            path = None
        else:
            m.update(self.dest)
            path = self.dest

        csum = m.hexdigest()
        # Inject last known value of destination path
        path or OLD_CHECKSUMS.get(csum)
        return csum, path

    def save(self):
        data = self.hunk.content

        if callable(self.dest):
            dest = self.dest(self.hunk.meta)
        else:
            dest = self.dest

        Fragment._all_checksum[self.checksum] = dest

        save_dir = os.path.dirname(dest)
        if save_dir:
            try:
                os.makedirs(save_dir)
            except OSError as exc:
                if exc.errno != errno.EEXIST:
                    raise

        print ' -> write %s' % dest
        with codecs.open(dest, mode='w', encoding=self.encoding) as fh:
            fh.write(data)

class Compiler(object):

    def __init__(self, *args, **kwargs):
        self.args = list(args)
        self.kwargs = kwargs
        self.fn = None

    def __call__(self, *args, **kwargs):
        if self.fn is None:
            self.fn = args[0]
            return self

        meta = {}
        for arg in chain(args, kwargs.iteritems()):
            if isinstance(arg, File) and hasattr(arg, 'path'):
                meta['path']= arg.path

        args = self.args + list(args)
        kwargs.update(self.kwargs)

        return Hunk(self, args, kwargs, meta)
