import atexit
import codecs
import errno
import glob
import os
import time
import sys
from itertools import chain
try:
    import cPickle as pickle
except ImportError:
    import pickle

script_path = os.path.dirname(os.path.realpath(sys.argv[0]))
meta_file = os.path.join(script_path, '.musket')
META = {}

@atexit.register
def save_meta():
    to_remove = set()

    # Check if some sources have disappeared
    for src in META.keys():
        if src in File._all_files:
            continue
        to_remove |= META[src]['sinks']
        del META[src]
        continue

    # Update META for new/updated sources
    for path, f in File._all_files.iteritems():
        if path in META:
            sinks = META[path]['sinks']
        else:
            sinks = set()

        if f.expired:
            # Remove untouched sinks
            to_remove |= set(s for s in sinks if s not in f.sinks)
            sinks = f.sinks
        else:
            sinks |= f.sinks

        META[path] = {
            'timestamp': f.timestamp,
            'sinks': sinks,
            }

    # Keep sinks that may be owned by other source file
    for path, f in File._all_files.iteritems():
        if not f.expired:
            continue
        to_remove -= f.sinks

    remove(to_remove)

    with open(meta_file, 'w') as fh:
        pickle.dump(META, fh)

def load_meta():
    if os.path.exists(meta_file):
        with open(meta_file) as fh:
            META.update(pickle.load(fh))

load_meta()


def remove(paths):
    for p in paths:
        if not os.path.exists(p):
            continue
        print "delete %s" % p
        os.remove(p)


class Fragment(object):

    def __init__(self, encoding=None):
        self._content = None
        self.encoding = encoding
        self.expired = self.is_expired()

    @property
    def content(self):
        if self._content is None:
            self._content = self.get_content()
        return self._content

class File(Fragment):

    _all_files = {}

    def __new__(cls, path, encoding='utf-8'):
        if path in cls._all_files:
            return cls._all_files[path]
        return super(File, cls).__new__(cls, path, encoding=encoding)

    def __init__(self, path, encoding='utf-8'):
        if path in self._all_files:
            # Current object is already __init__'ed
            assert self.encoding == encoding, "Encoding mismatch for %s" % path
            return
        self.path = path
        self.sinks = set()
        File._all_files[self.path] = self
        try:
            self.timestamp = os.path.getmtime(self.path)
        except os.error:
            self.timestamp = None
        super(File, self).__init__(encoding=encoding)

    def is_expired(self):
        if not self.path in META:
            return True
        return self.timestamp > META[self.path]['timestamp']

    def get_content(self):
        if not self._content:
            self._content = self.read()
        return self._content

    def read(self):
        print 'read', self.path
        with codecs.open(self.path, mode='r', encoding=self.encoding) as fh:
            return fh.read()

    def add_sink(self, dest):
        self.sinks.add(dest)

class FileList(Fragment):

    def __init__(self, paths):
        if isinstance(paths, basestring):
            paths = glob.glob(paths)
        self.files = map(File, paths)

        super(FileList, self).__init__()

    def is_expired(self):
        return any(f.expired for f in self.files)

    def get_content(self):
        if not self._content:
            self._content = [f.content for f in self.files]
        return self._content

    def add_sink(self, dest):
        for f in self.files:
            f.add_sink(dest)


class FileDict(Fragment):

    def __init__(self, paths):
        if isinstance(paths, basestring):
            paths = dict((p,p) for p in glob.glob(paths))
        self.files = dict((k, File(v)) for k, v in paths.iteritems())
        super(FileDict, self).__init__()

    def is_expired(self):
        expired=any(f.expired for f in self.files.itervalues())

    def get_content(self):
        if not self._content:
            self._content = dict((k, v.content) \
                    for k,v in self.files.iteritems())
        return self._content

    def add_sink(self, dest):
        for f in self.files.itervalues():
            f.add_sink(dest)


class Hunk(Fragment):

    def __init__(self, compiler, args, kwargs, meta=None):
        self.compiler = compiler
        self.args = args
        self.kwargs = kwargs
        self.meta = meta or {}
        super(Hunk, self).__init__()

    def is_expired(self):
        for arg in chain(self.args, self.kwargs.iteritems()):
            if isinstance(arg, Fragment) and arg.expired:
                return True
        return False

    def get_content(self):
        if self._content is  None:
            args = []
            kwargs = {}

            for pos, arg in enumerate(self.args):
                if isinstance(arg, Fragment):
                    args.append(arg.content)
                else:
                    args.append(arg)

                if isinstance(arg, Hunk):
                    self.meta.update(arg.meta)

            for key, arg in self.kwargs.iteritems():
                if isinstance(arg, Fragment):
                    self.kwargs[key] = arg.content
                else:
                    kwargs[key] = arg

                if isinstance(arg, Hunk):
                    self.meta.update(arg.meta)

            res = self.compiler.fn(*args, **kwargs)
            if isinstance(res, (tuple, list)):
                self._content = res[0]
                self.meta.update(res[1])

            else:
                self._content = res

        return self._content

    def save(self, dest, encoding='utf-8'):
        if not self.expired:
            return

        data = self.content

        if callable(dest):
            dest = dest(self.meta)

        save_dir = os.path.dirname(dest)
        if save_dir:
            try:
                os.makedirs(save_dir)
            except OSError as exc:
                if exc.errno != errno.EEXIST:
                    raise

        self.add_sink(dest)

        print ' -> write %s' % dest
        with codecs.open(dest, mode='w', encoding=encoding) as fh:
            fh.write(data)

    def add_sink(self, dest):
        for f in chain(self.args, self.kwargs.itervalues()):
            if isinstance(f, Fragment):
                f.add_sink(dest)


class Compiler(object):

    def __init__(self, *args, **kwargs):
        self.args = list(args)
        self.kwargs = kwargs
        self.fn = None

    def __call__(self, *args, **kwargs):
        if self.fn is None:
            self.fn = args[0]
            return self

        meta = {}
        for arg in chain(args, kwargs.iteritems()):
            if isinstance(arg, File) and hasattr(arg, 'path'):
                meta['path']= arg.path

        args = self.args + list(args)
        kwargs.update(self.kwargs)

        return Hunk(self, args, kwargs, meta)
