"""Classes to create and manage streams and views in vision datasets"""

import tables
import os
from os import path
import shutil
import viset.download
import viset.util
from viset.util import ishdf5, isfile, isarchive, quietprint, timestamp
import numpy as np 
import urllib
import socket
import hashlib
import time
from itertools import imap, islice, count
import urlparse
from viset.show import imshow
import sys
from viset.cache import CachedImage, CachedResult
import viset.partition
import copy

class View(object):
    _node = None
    _db = None    
    _dbrow = None  # fast pytables writes
    _dbrowidx = None  # fast pytables writes    
    _async = False
    
    def __init__(self, node=None, db=None):
        self._node = node
        self._db = db
        
    def viewname(self):        
        """Name of current node which mirrors class name of corresponding view"""
        return self._node.attrs.viewname

    def viewclass(self):        
        """Class name of current node which mirrors class name of corresponding view"""
        return self._node.attrs.viewclass
    
    def isgroup(self):
        return hasattr(self._node, '_v_nchildren')


class GroupView(View):
    def __len__(self):
        """Return the total number of children in current group"""
        return self._node._v_nchildren

    def __getattr__(self, viewname):
        """Return corresponding class object for viewname"""
        for childnode in self._node._f_iter_nodes():
            if childnode._v_attrs.viewname == viewname:
                thismodule = sys.modules[__name__]
                viewclass = getattr(thismodule, childnode._v_attrs.viewclass)
                return viewclass(childnode, self._db)
        raise AttributeError, viewname
        
    def __iter__(self):
        raise NotImplementedError('FIXME: iterate over child nodes in group?')
    
    def __getitem__(self, k):
        raise NotImplementedError('FIXME: get child in group?')        

    def views(self):
        """List of views in group"""
        viewlist = []
        for node in self._node._f_walknodes():
            viewlist.append(str(node._v_pathname).replace('/','.'))
        return viewlist
    
class StreamView(View):
    """Stream view class"""
    _fold = None
    _split = {'train':None, 'test':None, 'validate':None, 'all':None}
    _idx_stream = None
    
    def __init__(self, node=None, db=None, idx_stream=None, split=None, async=False):
        self._node = node
        self._db = db
        self._idx_stream = idx_stream
        self._async = async
        self._split = split
        
    def __len__(self):
        if self._idx_stream is not None:
            return len(self._idx_stream)
        else:
            return self._node.nrows

    def _get_viewclass(self):
        viewmodule = sys.modules['viset.stream']
        viewclass = getattr(viewmodule, self._node._v_attrs.viewclass)        
        return viewclass
    
    def __getattr__(self, name):
        """Attribute factory"""        
        if self._split is not None:
            viewclass = self._get_viewclass()
            if name == 'train':
                return viewclass(self._node, self._db, split=None, idx_stream=self._split['train'], async=self._async)
            elif name == 'test':
                return viewclass(self._node, self._db, split=None, idx_stream=self._split['test'], async=self._async)
            elif name == 'validate':
                return viewclass(self._node, self._db, split=None, idx_stream=self._split['validate'], async=self._async)
            else:  raise AttributeError, name
        else:  raise AttributeError, name                
                
    def __iter__(self):
        f = lambda x: self.read(x, async=self._async) 
        if self._fold is not None:
            return iter(self._fold)
        elif self._idx_stream is not None:
            return imap(f, iter(self._idx_stream))
        else:
            return imap(f, iter(range(len(self)))) # HACK                         
        
    def __getitem__(self, k):
        f = lambda x: self.read(x, async=self._async) 
        if self._fold is not None:
            return self._fold[k]
        elif self._idx_stream is not None:
            return f(self._idx_stream[k])
        else:
            return f(k)

    def __call__(self, *args, **kwargs):
        self._set_split(*args, **kwargs)
        return self        
    
    def _set_split(self, strategy=None, folds=0, step=1, randomize=False, stratify=False, async=False):
        """Internal function that sets a split without returning copied split tuples"""
        self._async = async            
        if strategy == 'kfold':            
            self._split = viset.partition.kfold(len(self), folds, step, randomize=randomize, stratify=stratify)
        elif strategy in ['leave-one-out', 'loo', 'LOO']:
            self._split = viset.partition.leave_one_out(len(self), step, randomize=randomize, stratify=stratify)
        elif strategy in ['train-test']:
            self._split = viset.partition.train_test(len(self), step, randomize=randomize, stratify=stratify)            
        elif strategy in ['all', 'leave_zero_out'] or (step > 1):
            self._split = viset.partition.leave_zero_out(len(self), step, randomize=randomize, stratify=stratify)            
            self._idx_stream = self._split['train']
            self._split = None
        elif strategy is not None:
            IOError('unsupported splitting strategy ' + strategy)
        if self._split is not None:
            viewclass = self._get_viewclass()
            self._fold = [viewclass(self._node, self._db, split=split, idx_stream=None, async=self._async) for split in viset.util.tolist(self._split)]
            if len(self._fold) > 1:
                self._split = None


    def where(self):
        """FIXME: Support for returning iterators over rows matching complex query criteria using pytables 'where' syntax"""
        # add a keyword 'where' to split which is used to create a pytables iterator
        # this iterator is used instead of our iterators
        pass
    
    def split(self, *args, **kwargs):
        """Return a tuple containing split views"""
        self._set_split(*args, **kwargs)
        if self._split is not None:
            if self._split['validate'] is not None:
                return (copy.copy(self._fold[0].train), copy.copy(self._fold[0].test), copy.copy(self._fold[0].validate))
            else:
                return (copy.copy(self._fold[0].train), copy.copy(self._fold[0].test))
        elif self._fold is not None:
            return self._fold
        else:
            return self
            
    def read(self):
        IOError('overloaded by subclass')

    def flush(self):
        self._node.flush()
        
class ImageStream(StreamView):
    class Table(tables.IsDescription):
        url = tables.StringCol(512, pos=0)   

    def _create(self,dbobj):
        tbl = dbobj.create_table(dbobj.root, 'image', ImageStream.Table, title='Image', filters=tables.Filters(complevel=6), expectedrows=1000000)
        tbl.attrs.viewname = 'image'
        tbl.attrs.viewclass = 'ImageStream'
        tbl.flush()
        return tbl

    def write(self, url, reader=None, idx=None, subpath=None):
        url_fragment = viset.util.dict2querystring({'reader':reader, 'idx':idx, 'subpath':subpath})
        if len(url_fragment) > 0:
            url = url + '#' + url_fragment
        if self._dbrow is None:
            self._dbrow = self._node.row            
            self._dbrowidx = 0
        else:
            self._dbrowidx += 1            
        self._dbrow['url'] = url
        self._dbrow.append()
        return self._dbrowidx
    
    def read(self, idx, async=False):
        imurl = self._node[idx]['url']	
        im = CachedImage(imurl, cache=self._db._cache, verbose=self._db._verbose)
        if async or self._async:
            return im  
        else:
            return im.get()

    def num_images(self):
        """Total number of images in database"""
        return len(self)
        
        

class CategorizationAnnotationStream(StreamView):
    class Table(tables.IsDescription):
        idx_image = tables.Int64Col(pos=0)
        category = tables.StringCol(64, pos=1)              
        idx_category = tables.Int32Col(pos=2)  

    def _create(self, dbobj):
        try:
            group = dbobj.create_group(dbobj.root, 'annotation', 'Annotation')            
            group._v_attrs.viewname = 'annotation'
            group._v_attrs.viewclass = 'GroupView'
        except:
            # already created
            pass  
        tbl = dbobj.create_table(dbobj.root.annotation, 'categorization', CategorizationAnnotationStream.Table, title='categorization', filters=tables.Filters(complevel=6), expectedrows=1000000) 
        tbl.attrs.viewname = 'categorization'
        tbl.attrs.viewclass = 'CategorizationAnnotationStream' 
        dbobj.flush()

    def read(self, idx, async=False):
        idx_image = self._node[idx]['idx_image']
        imurl = self._db._obj.root.image[idx_image]['url']
        im = CachedImage(imurl, cache=self._db._cache, verbose=self._db._verbose)
        anno = self._node[idx]
        if async or self._async:
            return (im, anno)  # numpy structured array tuple
        else:
            return (im.get(), anno)  # numpy structured array tuple
        
    def write(self, category, idx_category, idx_image=None, imurl=None):
        if imurl is not None and idx_image is None:
            idx_image = self._db.image.write(imurl)
        elif idx_image is None:
            raise IOError('annotation requires either image index or image url')
            
        if self._dbrow is None:
            self._dbrow = self._node.row
            self._dbrowidx = 0
        else:
            self._dbrowidx += 1            
        self._dbrow['idx_image'] = idx_image        
        self._dbrow['idx_category'] = idx_category
        self._dbrow['category'] = category
        self._dbrow.append()
        return self._dbrowidx
        
    def delete(self, *args, **kwargs):
        self._node.remove_rows(*args, **kwargs)
        pass
                    
    def labels(self):
        """Unique category labels in dataset"""        
        return list(set(self._node.cols.category[:]))

    def num_categories(self):
        return len(self.labels())

    def y(self, idx_category):
        """Return a binary label vector y \in {-1,1} for dataset elements of a given category id"""        
        idx = self._node.cols.idx_category[:]
        return 2*(np.int32(idx == idx_category)) - 1

    
class DetectionAnnotationStream(StreamView):
    class Table(tables.IsDescription):
        category = tables.StringCol(64, pos=0)      
        idx_category = tables.Int32Col(pos=1)  
        idx_image = tables.Int64Col(pos=2)
        bbox_xmin = tables.Float32Col(pos=3)
        bbox_xmax = tables.Float32Col(pos=4)
        bbox_ymin = tables.Float32Col(pos=5)
        bbox_ymax = tables.Float32Col(pos=6)    

    def _create(self, dbobj):
        try:
            group = dbobj.create_group(dbobj.root, 'annotation', 'Annotation')            
            group._v_attrs.viewname = 'annotation'
            group._v_attrs.viewclass = 'GroupView'
        except:
            pass   # already created
        tbl = dbobj.create_table(dbobj.root.annotation, 'detection', DetectionAnnotationStream.Table, title='detection', filters=tables.Filters(complevel=6), expectedrows=1000000)        
        tbl.attrs.viewname = 'detection'
        tbl.attrs.viewclass = 'DetectionAnnotationStream'
        dbobj.flush()
        
    def write(self, category, idx_category, idx_image, xmin, xmax, ymin, ymax):
        if self._dbrow is None:
            self._dbrow = self._node.row
            self._dbrowidx = 0
        else:
            self._dbrowidx += 1                
        self._dbrow['category'] = category
        self._dbrow['idx_category'] = idx_category
        self._dbrow['idx_image'] = idx_image        
        self._dbrow['bbox_xmin'] = xmin
        self._dbrow['bbox_xmax'] = xmax
        self._dbrow['bbox_ymin'] = ymin
        self._dbrow['bbox_ymax'] = ymax                        
        self._dbrow.append()
        return self._dbrowidx

    def read(self, idx, async=False):
        idx_image = self._node[idx]['idx_image']
        impath = self._db._obj.root.image[idx_image]['url']
        im = CachedImage(impath, cache=self._db._cache, verbose=self._db._verbose)
        anno = self._node[idx]
        if async or self._async:
            return (im,anno)  # numpy structured array tuple
        else:
            return (im.get(),anno)  # numpy structured array tuple
        

class SegmentationAnnotationStream(StreamView):
    class Table(tables.IsDescription):
        category = tables.StringCol(64, pos=0)      
        idx_category = tables.Int32Col(pos=1)  
        idx_image = tables.Int64Col(pos=2)
        bbox_xmin = tables.Float32Col(pos=3)
        bbox_xmax = tables.Float32Col(pos=4)
        bbox_ymin = tables.Float32Col(pos=5)
        bbox_ymax = tables.Float32Col(pos=6)    
        maskurl = tables.StringCol(512, pos=7)   

    def _create(self, dbobj):
        try:
            group = dbobj.create_group(dbobj.root, 'annotation', 'Annotation')            
            group._v_attrs.viewname = 'annotation'
            group._v_attrs.viewclass = 'GroupView'
        except:
            # already created
            pass  
        tbl = dbobj.create_table(dbobj.root.annotation, 'segmentation', SegmentationAnnotationStream.Table, title='segmentation', filters=tables.Filters(complevel=6), expectedrows=1000000)        
        tbl.attrs.viewname = 'segmentation'
        tbl.attrs.viewclass = 'SegmentationAnnotationStream'  
        dbobj.flush()
        
    def write(self, category, idx_category, idx_image, maskurl, subpath=None):
        url_fragment = viset.util.dict2querystring({'reader':None, 'idx':None, 'subpath':subpath})
        if len(url_fragment) > 0:
            maskurl = maskurl + '#' + url_fragment
        if self._dbrow is None:
            self._dbrow = self._node.row
            self._dbrowidx = 0
        else:
            self._dbrowidx += 1
        self._dbrow['maskurl'] = maskurl
        self._dbrow['idx_image'] = idx_image        
        self._dbrow['category'] = category  
        self._dbrow['idx_category'] = idx_category          
        self._dbrow.append()
        return self._dbrowidx
        
    def read(self, idx, async=False):
        idx_image = self._node[idx]['idx_image']
        imurl = self._db._obj.root.image[idx_image]['url']	
        im = CachedImage(imurl, cache=self._db._cache, verbose=self._db._verbose)
        maskurl = self._node[idx]['maskurl']
        if maskurl == '':
			mask = None
        else:
			mask = CachedImage(maskurl, cache=self._db._cache, verbose=self._db._verbose)
        anno = self._node[idx]
        if async or self._async:
            anno = {'id_category':anno['idx_category'], 'category':anno['category'], 'mask':mask}            
            return (im,anno)  # numpy structured array tuple
        else:
            if mask is not None:
                mask = mask.get()
            anno = {'id_category':anno['idx_category'], 'category':anno['category'], 'mask':mask}                        
            return (im.get(),anno)  # numpy structured array tuple


class CategorizationOutputGroup(GroupView):
    def _create(self, dbobj):
        try:
            group = dbobj.create_group(dbobj.root, 'output', 'Output')            
            group._v_attrs.viewname = 'output'
            group._v_attrs.viewclass = 'GroupView'
        except:
            pass  # already created

        group = dbobj.create_group(dbobj.root.output, 'categorization', 'Categorization')            
        group._v_attrs.viewname = 'categorization'
        group._v_attrs.viewclass = 'CategorizationOutputGroup'
        dbobj.flush()

    def new(self, streamname=None):
        """generate a new stream in the categorization output group"""
        return CategorizationOutputStream(None, self._db)._create(self._db._obj, viewname=streamname)

    
class CategorizationOutputStream(StreamView):
    class Table(tables.IsDescription):
        predlabel = tables.StringCol(64, pos=0)              
        score = tables.Float32Col(pos=1)           
        elapsed = tables.Float32Col(pos=2)   
        imurl = tables.StringCol(512, pos=3)   
        truelabel = tables.StringCol(64, pos=4)              
        wn_predlabel = tables.StringCol(64, pos=5)           
        wn_truelabel = tables.StringCol(64, pos=6)                                                         

    def _create(self, dbobj, viewname=None):
        if viewname is None:
            viewname = '_' + timestamp()
        self._node = dbobj.create_table(dbobj.root.output.categorization, viewname, CategorizationOutputStream.Table, title=viewname, filters=tables.Filters(complevel=6), expectedrows=1000000) 
        self._node.attrs.viewname = viewname
        self._node.attrs.viewclass = 'CategorizationOutputStream' 
        dbobj.flush()
        # FIXME: rearchitect iterators from streamview to use where syntax, and use lazy evaluation to construct until iterator is requested so _set_split is not dependent on len
        # len will be zero when we _create, so we cannot initialize _id for splits, need lazy evaluation to construct splits from options when we need them
        return self

    def write(self, predlabel, score, elapsed, imurl, imsubpath=None, imreader=None, imreaderindex=None, truelabel=None, wn_predlabel=None, wn_truelabel=None):
        if self._dbrow is None:
            self._dbrow = self._node.row
            self._dbrowidx = 0
        else:
            self._dbrowidx += 1   
        self._dbrow['predlabel'] = predlabel
        self._dbrow['score'] = score
        self._dbrow['elapsed'] = elapsed
        self._dbrow['imurl'] = viset.util.join_fragment(imurl, viset.util.dict2querystring({'subpath':imsubpath, 'reader':imreader, 'index':imreaderindex}))
        self._dbrow['truelabel'] = truelabel
        self._dbrow['wn_predlabel'] = wn_predlabel
        self._dbrow['wn_truelabel'] = wn_truelabel                
        self._dbrow.append()
        return self._dbrowidx

    
    def read(self, idx, async=False):
        im = CachedImage(self._node[idx]['imurl'], cache=self._db._cache, verbose=self._db._verbose)
        pred = {'label':self._node[idx]['predlabel'], 'score':self._node[idx]['score'], 'elapsed':self._node[idx]['elapsed']}
        anno = {'label':self._node[idx]['truelabel']}
        if async or self._async:
            return (im,pred,anno)  # numpy structured array tuple
        else:
            return (im.get(),pred,anno)  # numpy structured array tuple


    def num_categories(self):
        return len(self.predlabels())
        
    def predlabels(self):
        """Unique category labels in dataset"""        
        return list(set(self._node.cols.predlabel[:]))

    def truelabels(self):
        """Unique category labels in dataset"""        
        return list(set(self._node.cols.truelabel[:]))
    

    def y_pred(self, label, binary=False):
        """Return a binary label vector y \in {-1,1} for dataset elements of a given category id"""        
        predlabels = self._node.cols.predlabel[:]
        if binary:
            return np.int32(predlabels == label)  # {0,1}
        else:
            return 2*(np.int32(predlabels == label)) - 1  # {-1,1}

    def y_true(self, label, binary=False):
        """Return a binary label vector y \in {-1,1} for dataset elements of a given category id"""        
        labels = self._node.cols.truelabel[:]
        if binary:
            return np.int32(labels == label)  # {0,1}
        else:
            return 2*(np.int32(labels == label)) - 1  # {-1,1}
        
    def Y_pred(self):
        """Return a integer label vector Y \in {0,n} for predictions of a given category label index"""        
        idx_predlabels = {k:v for (v,k) in enumerate(self.predlabels())}
        labels = np.array(self._node.cols.predlabel[:])
        Y = np.zeros(len(labels))
        for (k, label) in enumerate(labels):
            Y[k] = idx_predlabels[label]
        return Y

    def Y_true(self):
        """Return a integer label vector Y \in {0,n} for truth of a given category label index"""        
        idx_truelabels = {k:v for (v,k) in enumerate(self.truelabels())}
        labels = np.array(self._node.cols.truelabel[:])
        Y = np.zeros(len(labels))
        for (k, label) in enumerate(labels):
            Y[k] = idx_truelabels[label]
        return Y

