"""Classes to create and manage streams and views in vision datasets"""

import tables
import os
from os import path
import shutil
import viset.download
import viset.util
from viset.util import ishdf5, isfile, isarchive, quietprint, timestamp
import numpy as np 
import urllib
import socket
import hashlib
import time
from itertools import imap, islice, count
import urlparse
from viset.show import imshow
import sys
from viset.cache import CachedImage, CachedResult
import viset.partition
import copy

class View(object):
    _node = None
    _db = None    
    _dbrow = None  # fast pytables writes
    _dbrowidx = None  # fast pytables writes    
    
    def __init__(self, node=None, db=None):
        self._node = node
        self._db = db
        
    def viewname(self):        
        """Name of current node which mirrors class name of corresponding view"""
        return self._node.attrs.viewname

    def viewclass(self):        
        """Class name of current node which mirrors class name of corresponding view"""
        return self._node.attrs.viewclass
    
    def isgroup(self):
        return hasattr(self._node, '_v_nchildren')


class GroupView(View):
    def __len__(self):
        """Return the total number of children in current group"""
        return self._node._v_nchildren

    def __getattr__(self, viewname):
        """Return corresponding class object for viewname"""
        for childnode in self._node._f_iter_nodes():
            if childnode._v_attrs.viewname == viewname:
                thismodule = sys.modules[__name__]
                viewclass = getattr(thismodule, childnode._v_attrs.viewclass)
                return viewclass(childnode, self._db)
        raise AttributeError, viewname
        
    def __iter__(self):
        raise NotImplementedError('FIXME: iterate over child nodes in group?')
    
    def __getitem__(self, k):
        raise NotImplementedError('FIXME: get child in group?')        

    def views(self):
        """List of views in group"""
        viewlist = []
        for node in self._node._f_walknodes():
            viewlist.append(str(node._v_pathname).replace('/','.'))
        return viewlist
    
class KFoldStreamView(View):
    """View for splitting a stream into kfold partitions"""
    _split = None
    _id_train = None
    _id_test = None
    _id_validate = None
    _read = None
    
    def __init__(self, node=None, db=None, split=None, readfunc=None):
        self._node = node
        self._db = db
        self._split = split
        self._read = readfunc
        
    def __len__(self):
        return len(self._split)

    def __getattr__(self, name):
        if name == 'train' and self._id_train is not None:
            return SplitStreamView(self._node, self._db, self._id_train, readfunc=self._read)
        elif name == 'test' and self._id_test is not None:
            return SplitStreamView(self._node, self._db, self._id_test, readfunc=self._read)
        elif name == 'validate' and self._id_validate is not None:
            return SplitStreamView(self._node, self._db, self._id_validate, readfunc=self._read)            
        else:  raise AttributeError, name

    def __iter__(self):
        return imap(self._set_fold, iter(range(len(self._split))))                 

    def _set_fold(self, k):
        """Internal function that sets a fold index without changing state"""
        self._id_train = self._split[k][0]
        self._id_test = self._split[k][1]
        self._id_validate = self._split[k][2] 
        return self

    def __getitem__(self, k):
        self._set_fold(k)
        return self
    
class SplitStreamView(View):
    """View for splitting a stream into training, testing and validation partitions"""
    _id_split = None
    _read = None
    
    def __init__(self, node=None, db=None, id_split=None, readfunc=None):
        self._node = node
        self._db = db
        self._id_split = id_split
        self._read = readfunc
        
    def __len__(self):
        return len(self._id_split)

    def __iter__(self):
        return imap(self._read, iter(self._id_split))

    def __getitem__(self, k):
        return self._read(self._id_split[k])

    
class StreamView(View):
    """Stream view class"""
    _split = None
    _id = None
    
    def __init__(self, node=None, db=None):
        self._node = node
        self._db = db
        if self._node is not None:
            self._set_split(strategy=None)
            
    def __len__(self):
        if self._id is not None:
            return len(self._id)
        else:
            return self._node.nrows

    def __getattr__(self, name):
        """Return dynamically constructed attributes"""        
        if name == 'train' and self._split[0] is not None:
            return SplitStreamView(self._node, self._db, id_split= self._split[0][0], readfunc=self.read)
        elif name == 'test' and self._split[0] is not None:
            return SplitStreamView(self._node, self._db, id_split= self._split[0][1], readfunc=self.read)
        elif name == 'validate' and self._split[0] is not None:
            return SplitStreamView(self._node, self._db, id_split= self._split[0][2], readfunc=self.read)
        else:  raise AttributeError, name

    def __iter__(self):
        return imap(self.read, iter(self._id))
        
    def __getitem__(self, k):
        return self.read(k)

    def __call__(self, *args, **kwargs):
        return self.split(*args, **kwargs)
    
    def _set_split(self, strategy=None, split=None, folds=0, step=1, randomize=False, stratify=False):
        """Internal function that sets a split without returning copied split tuples"""
        if split is not None:
            self._split = split
        elif strategy == 'kfold':            
            self._split = viset.partition.kfold(len(self), folds, step, randomize=randomize, stratify=stratify)
        elif strategy in ['leave-one-out', 'loo', 'LOO']:
            self._split = viset.partition.leave_one_out(len(self), step, randomize=randomize, stratify=stratify)
        elif strategy in ['train-test']:
            self._split = viset.partition.train_test(len(self), step, randomize=randomize, stratify=stratify)            
        elif strategy in [None, 'all', 'leave_zero_out']:
            split = viset.partition.leave_zero_out(len(self), step, randomize=randomize, stratify=stratify)            
            self._id = split[0][0]
            self._split = None 
        else:
            IOError('unsupported splitting strategy ' + strategy)
        self._split = viset.util.tolist(self._split)
            
    def split(self, *args, **kwargs):
        """Return a tuple containing split views"""
        self._set_split(*args, **kwargs)
        if self._split[0] is None:
            return self
        elif len(self._split) > 1:
            return copy.copy(KFoldStreamView(self._node, self._db, split=self._split, readfunc=self.read))
        else:
            if self._split[0][2] is not None:
                return (copy.copy(SplitStreamView(self._node, self._db, id_split=self._split[0][0], readfunc=self.read)), 
                        copy.copy(SplitStreamView(self._node, self._db, id_split=self._split[0][1], readfunc=self.read)), 
                        copy.copy(SplitStreamView(self._node, self._db, id_split=self._split[0][2], readfunc=self.read)))
            else:
                return (copy.copy(SplitStreamView(self._node, self._db, id_split=self._split[0][0], readfunc=self.read)), 
                        copy.copy(SplitStreamView(self._node, self._db, id_split=self._split[0][1], readfunc=self.read)))
            
    def read(self):
        IOError('overloaded by subclass')


class ImageStream(StreamView):
    class Table(tables.IsDescription):
        url = tables.StringCol(512, pos=0)   

    def _create(self,dbobj):
        tbl = dbobj.create_table(dbobj.root, 'image', ImageStream.Table, title='Image', filters=tables.Filters(complevel=6), expectedrows=1000000)
        tbl.attrs.viewname = 'image'
        tbl.attrs.viewclass = 'ImageStream'
        tbl.flush()
        return tbl

    def write(self, url, reader=None, idx=None, subpath=None):
        url_fragment = viset.util.dict2querystring({'reader':reader, 'idx':idx, 'subpath':subpath})
        if len(url_fragment) > 0:
            url = url + '#' + url_fragment
        if self._dbrow is None:
            self._dbrow = self._node.row            
            self._dbrowidx = 0
        else:
            self._dbrowidx += 1            
        self._dbrow['url'] = url
        self._dbrow.append()
        return self._dbrowidx
    
    def read(self, idx):
        imurl = self._node[idx]['url']	
        im = CachedImage(imurl, cache=self._db._cache, verbose=self._db._verbose)
        if self._db._async:
            return im  
        else:
            return im.get()

    def num_images(self):
        """Total number of images in database"""
        return len(self)
        
        

class CategorizationAnnotationStream(StreamView):
    class Table(tables.IsDescription):
        idx_image = tables.Int64Col(pos=0)
        category = tables.StringCol(64, pos=1)              
        idx_category = tables.Int32Col(pos=2)  

    def _create(self, dbobj):
        try:
            group = dbobj.create_group(dbobj.root, 'annotation', 'Annotation')            
            group._v_attrs.viewname = 'annotation'
            group._v_attrs.viewclass = 'GroupView'
        except:
            # already created
            pass  
        tbl = dbobj.create_table(dbobj.root.annotation, 'categorization', CategorizationAnnotationStream.Table, title='categorization', filters=tables.Filters(complevel=6), expectedrows=1000000) 
        tbl.attrs.viewname = 'categorization'
        tbl.attrs.viewclass = 'CategorizationAnnotationStream' 
        dbobj.flush()

    def read(self, idx):
        idx_image = self._node[idx]['idx_image']
        imurl = self._db._obj.root.image[idx_image]['url']
        im = CachedImage(imurl, cache=self._db._cache, verbose=self._db._verbose)
        anno = self._node[idx]
        if self._db._async:
            return (im, anno)  # numpy structured array tuple
        else:
            return (im.get(), anno)  # numpy structured array tuple
        
    def write(self, category, idx_category, idx_image=None, imurl=None):
        if imurl is not None and idx_image is None:
            idx_image = self._db.image.write(imurl)
        elif idx_image is None:
            raise IOError('annotation requires either image index or image url')
            
        if self._dbrow is None:
            self._dbrow = self._node.row
            self._dbrowidx = 0
        else:
            self._dbrowidx += 1            
        self._dbrow['idx_image'] = idx_image        
        self._dbrow['idx_category'] = idx_category
        self._dbrow['category'] = category
        self._dbrow.append()
        return self._dbrowidx
        
    def delete(self, *args, **kwargs):
        self._node.remove_rows(*args, **kwargs)
        pass
                    
    def labels(self):
        """Unique category ids in dataset"""        
        return np.unique(self._node.cols.idx_category[:])

    def num_categories(self):
        return len(self.labels())

    def y(self, idx_category):
        """Return a binary label vector y \in {-1,1} for dataset elements of a given category id"""        
        idx = self._node.cols.idx_category[:]
        return 2*(np.int32(idx == idx_category)) - 1

    
class DetectionAnnotationStream(StreamView):
    class Table(tables.IsDescription):
        category = tables.StringCol(64, pos=0)      
        idx_category = tables.Int32Col(pos=1)  
        idx_image = tables.Int64Col(pos=2)
        bbox_xmin = tables.Float32Col(pos=3)
        bbox_xmax = tables.Float32Col(pos=4)
        bbox_ymin = tables.Float32Col(pos=5)
        bbox_ymax = tables.Float32Col(pos=6)    

    def _create(self, dbobj):
        try:
            group = dbobj.create_group(dbobj.root, 'annotation', 'Annotation')            
            group._v_attrs.viewname = 'annotation'
            group._v_attrs.viewclass = 'GroupView'
        except:
            pass   # already created
        tbl = dbobj.create_table(dbobj.root.annotation, 'detection', DetectionAnnotationStream.Table, title='detection', filters=tables.Filters(complevel=6), expectedrows=1000000)        
        tbl.attrs.viewname = 'detection'
        tbl.attrs.viewclass = 'DetectionAnnotationStream'
        dbobj.flush()
        
    def write(self, category, idx_category, idx_image, xmin, xmax, ymin, ymax):
        if self._dbrow is None:
            self._dbrow = self._node.row
            self._dbrowidx = 0
        else:
            self._dbrowidx += 1                
        self._dbrow['category'] = category
        self._dbrow['idx_category'] = idx_category
        self._dbrow['idx_image'] = idx_image        
        self._dbrow['bbox_xmin'] = xmin
        self._dbrow['bbox_xmax'] = xmax
        self._dbrow['bbox_ymin'] = ymin
        self._dbrow['bbox_ymax'] = ymax                        
        self._dbrow.append()
        return self._dbrowidx

    def read(self, idx):
        idx_image = self._node[idx]['idx_image']
        impath = self._db._obj.root.image[idx_image]['url']
        im = CachedImage(impath, cache=self._db._cache, verbose=self._db._verbose)
        anno = self._node[idx]
        if self._db._async:
            return (im,anno)  # numpy structured array tuple
        else:
            return (im.get(),anno)  # numpy structured array tuple
        

class SegmentationAnnotationStream(StreamView):
    class Table(tables.IsDescription):
        category = tables.StringCol(64, pos=0)      
        idx_category = tables.Int32Col(pos=1)  
        idx_image = tables.Int64Col(pos=2)
        bbox_xmin = tables.Float32Col(pos=3)
        bbox_xmax = tables.Float32Col(pos=4)
        bbox_ymin = tables.Float32Col(pos=5)
        bbox_ymax = tables.Float32Col(pos=6)    
        maskurl = tables.StringCol(512, pos=7)   

    def _create(self, dbobj):
        try:
            group = dbobj.create_group(dbobj.root, 'annotation', 'Annotation')            
            group._v_attrs.viewname = 'annotation'
            group._v_attrs.viewclass = 'GroupView'
        except:
            # already created
            pass  
        tbl = dbobj.create_table(dbobj.root.annotation, 'segmentation', SegmentationAnnotationStream.Table, title='segmentation', filters=tables.Filters(complevel=6), expectedrows=1000000)        
        tbl.attrs.viewname = 'segmentation'
        tbl.attrs.viewclass = 'SegmentationAnnotationStream'  
        dbobj.flush()
        
    def write(self, category, idx_category, idx_image, maskurl, subpath=None):
        url_fragment = viset.util.dict2querystring({'reader':None, 'idx':None, 'subpath':subpath})
        if len(url_fragment) > 0:
            maskurl = maskurl + '#' + url_fragment
        if self._dbrow is None:
            self._dbrow = self._node.row
            self._dbrowidx = 0
        else:
            self._dbrowidx += 1
        self._dbrow['maskurl'] = maskurl
        self._dbrow['idx_image'] = idx_image        
        self._dbrow['category'] = category  
        self._dbrow['idx_category'] = idx_category          
        self._dbrow.append()
        return self._dbrowidx
        
    def read(self, idx):
        idx_image = self._node[idx]['idx_image']
        imurl = self._db._obj.root.image[idx_image]['url']	
        im = CachedImage(imurl, cache=self._db._cache, verbose=self._db._verbose)
        maskurl = self._node[idx]['maskurl']
        if maskurl == '':
			mask = None
        else:
			mask = CachedImage(maskurl, cache=self._db._cache, verbose=self._db._verbose)
        anno = self._node[idx]
        if self._db._async:
            anno = {'id_category':anno['idx_category'], 'category':anno['category'], 'mask':mask}            
            return (im,anno)  # numpy structured array tuple
        else:
            if mask is not None:
                mask = mask.get()
            anno = {'id_category':anno['idx_category'], 'category':anno['category'], 'mask':mask}                        
            return (im.get(),anno)  # numpy structured array tuple


class CategorizationOutputGroup(GroupView):
    def _create(self, dbobj):
        try:
            group = dbobj.create_group(dbobj.root, 'output', 'Output')            
            group._v_attrs.viewname = 'output'
            group._v_attrs.viewclass = 'GroupView'
        except:
            pass  # already created

        group = dbobj.create_group(dbobj.root.output, 'categorization', 'Categorization')            
        group._v_attrs.viewname = 'categorization'
        group._v_attrs.viewclass = 'CategorizationOutputGroup'
        dbobj.flush()

    def new(self, streamname=None):
        """generate a new stream in the categorization output group"""
        return CategorizationOutputStream(None, self._db)._create(self._db._obj, viewname=streamname)

    
class CategorizationOutputStream(StreamView):
    class Table(tables.IsDescription):
        category = tables.StringCol(64, pos=0)              
        idx_category = tables.Int32Col(pos=1)  
        score = tables.Float32Col(pos=2)           
        elapsed = tables.Float32Col(pos=3)   
        idx_image = tables.Int64Col(pos=4)
        imurl = tables.StringCol(512, pos=5)   

    def _create(self, dbobj, viewname=None):
        if viewname is None:
            viewname = '_' + timestamp()
        self._node = dbobj.create_table(dbobj.root.output.categorization, viewname, CategorizationOutputStream.Table, title=viewname, filters=tables.Filters(complevel=6), expectedrows=1000000) 
        self._node.attrs.viewname = viewname
        self._node.attrs.viewclass = 'CategorizationOutputStream' 
        dbobj.flush()
        return self

    def write(self, category, idx_category, score, elapsed, idx_image, imurl, imsubpath=None, imreader=None, imreaderindex=None):
        if self._dbrow is None:
            self._dbrow = self._node.row
            self._dbrowidx = 0
        else:
            self._dbrowidx += 1   
        self._dbrow['category'] = category
        self._dbrow['idx_category'] = idx_category
        self._dbrow['score'] = score
        self._dbrow['elapsed'] = elapsed        
        self._dbrow['imurl'] = viset.util.join_fragment(imurl, viset.util.dict2querystring({'subpath':imsubpath, 'reader':imreader, 'index':imreaderindex}))
        if idx_image is None:
            self._dbrow['idx_image'] = self._dbrowidx
        else:
            self._dbrow['idx_image'] = idx_image        
        self._dbrow.append()
        return self._dbrowidx

    def read(self, idx):
        im = CachedImage(self._node[idx]['imurl'], cache=self._db._cache, verbose=self._db._verbose)
        anno = self._node[idx]
        if self._db._async:
            return (im,anno)  # numpy structured array tuple
        else:
            return (im.get(),anno)  # numpy structured array tuple
    
