"""Classes to create and manage views in vision datasets"""

import tables
import os
from os import path
import shutil
import viset.download
import viset.util
from viset.util import ishdf5, isfile, isarchive, quietprint
import numpy as np 
import urllib
import socket
import hashlib
import time
from itertools import imap, islice, count
import urlparse
from viset.show import imshow
import sys
from viset.cache import CachedImage, CachedResult
import viset.partition
import copy

class View(object):
    _node = None
    _db = None    
    _dbrow = None  # fast pytables writes
    
    def __init__(self, node=None, db=None):
        self._node = node
        self._db = db
        
    def viewname(self):        
        """Name of current node which mirrors class name of corresponding view"""
        return self._node.attrs.viewname

    def viewclass(self):        
        """Class name of current node which mirrors class name of corresponding view"""
        return self._node.attrs.viewclass
    
    def isgroup(self):
        return hasattr(self._node, '_v_nchildren')


class GroupView(View):
    def __len__(self):
        """Return the total number of children in current group"""
        return self._node._v_nchildren

    def __getattr__(self, viewname):
        """Return corresponding class object for viewname"""
        for childnode in self._node._f_iter_nodes():
            if childnode._v_attrs.viewname == viewname:
                thismodule = sys.modules[__name__]
                viewclass = getattr(thismodule, childnode._v_attrs.viewclass)
                return viewclass(childnode, self._db)
        raise AttributeError, viewname
        
    def __iter__(self):
        raise NotImplementedError('FIXME: iterate over child nodes in group?')
    
    def __getitem__(self, k):
        raise NotImplementedError('FIXME: get child in group?')        

    
class LeafView(View):
    _split = None
    _id_train = None
    _id_test = None
    _id_validate = None
    _state = 'leaf'
    
    def __init__(self, node=None, db=None):
        self._node = node
        self._db = db
        if self._node is not None:
            self._set_split(strategy=None)
            self._set_fold(0)
        self._state = 'leaf'
            
    def __len__(self):
        """Return the total number of rows in the current table"""
        if self._state == 'kfold':
            return len(self._split)
        else:
            return self._node.nrows

    def __getattr__(self, name):
        """Return dynamically constructed attributes"""
        if self._state == 'leaf' and name in ['kfold','fold']:            
            self._state = 'kfold'
        elif self._state == 'leaf' and name in ['train', 'test', 'validate']:
            self._state = name
        elif self._state == 'kfold' and name in ['train', 'test', 'validate']:
            self._state = name            
        elif self._state in ['train', 'test', 'validate'] and name in ['train', 'test', 'validate']:
            self._state = name  # allows infinite .test.train for kfold iterators
        else:  raise AttributeError, name
        return self

    def __iter__(self):
        if self._state == 'kfold':
            return imap(self._set_fold, iter(range(len(self._split))))
        elif self._state == 'train':
            return imap(self.read, iter(self._id_train))
        elif self._state == 'test':
            return imap(self.read, iter(self._id_test))
        elif self._state == 'validate':
            return imap(self.read, iter(self._id_validate))
        elif self._state == 'leaf':
            return imap(self.read, iter(self._id_train))
        else:
            raise
        
    def __getitem__(self, k):
        if self._state == 'kfold':
            raise IndexError('unsupported indexing of kfolds - use iterator instead')
        if type(k) == slice:
            kstart = 0 if k.start is None else k.start
            kstop = len(self) if k.stop is None else k.stop
            kstep = 1 if k.step is None else k.step
            kslice = slice(kstart, kstop, kstep)
        else:
            kslice = slice(k,k+1,1)
        readlist = []
        for k in iter(range(kslice.start, kslice.stop, kslice.step)):
            if self._state == 'train':
                readlist = readlist + [self.read(self._id_train[k])]
            elif self._state == 'test':
                readlist = readlist + [self.read(self._id_test[k])]
            elif self._state == 'validate':
                readlist = readlist + [self.read(self._id_test[k])]
            elif self._state == 'leaf':
                readlist = readlist + [self.read(self._id_train[k])]
            else:
                raise
        if len(readlist) == 1:
            return readlist[0]
        else:
            return readlist

    def __call__(self, *args, **kwargs):
        self._set_split(*args, **kwargs)
        return self
    
    def _set_split(self, strategy=None, split=None, folds=0, step=1, randomize=False, stratify=False):
        """Internal function that sets a split without changing state"""
        if split is not None:
            self._split = split
            self._set_fold(0)
            self._state = 'kfold'
        elif strategy == 'kfold':            
            self._split = viset.partition.kfold(len(self), folds, step, randomize=randomize, stratify=stratify)
            self._set_fold(0)
            self._state = 'kfold'
        elif strategy in ['leave-one-out', 'loo', 'LOO']:
            self._split = viset.partition.leave_one_out(len(self), step, randomize=randomize, stratify=stratify)
            self._set_fold(0)
            self._state = 'kfold'                    
        elif strategy in ['train-test']:
            self._split = viset.partition.train_test(len(self), step, randomize=randomize, stratify=stratify)            
            self._set_fold(0)
            self._state = 'kfold'
        elif strategy in [None, 'all', 'leave_zero_out']:
            self._split = viset.partition.leave_zero_out(len(self), step, randomize=randomize, stratify=stratify)            
            self._set_fold(0)            
            self._state = 'leaf'
        else:
            IOError('unsupported splitting strategy ' + strategy)

    def _set_fold(self, k):
        """Internal function that sets a fold index without changing state"""
        self._id_train = self._split[k][0]
        self._id_test = self._split[k][1]
        self._id_validate = self._split[k][2] 
        self._state = 'kfold'
        return self
    
    def split(self, *args, **kwargs):
        """Return a tuple containing splits"""
        self._set_split(*args, **kwargs)
        if self._state == 'kfold' and len(self._split) > 1:
            return copy.copy(self.kfold)
        else:
            if self._id_validate is not None:
                return (copy.copy(self.train), copy.copy(self.test), copy.copy(self.validate))
            else:
                return (copy.copy(self.train), copy.copy(self.test))
                
    def read(self):
        IOError('overloaded by subclass')


class ImageView(LeafView):
    class Table(tables.IsDescription):
        id = tables.Int64Col(pos=0)   
        url = tables.StringCol(512, pos=1)   
        id_anno = tables.Int64Col(pos=2)   
        n_anno = tables.Int32Col(pos=3)   

    def create(self,dbobj):
        tbl = dbobj.create_table(dbobj.root, 'image', ImageView.Table, title='Image', filters=tables.Filters(complevel=6), expectedrows=1000000)
        tbl.attrs.viewname = 'image'
        tbl.attrs.viewclass = 'ImageView'
        tbl.flush()
        return tbl

    def write(self, id_img, relpath, n_anno, id_anno, url=None, reader=None, id=None):
        qs = {}
        if reader is not None:
            qs['reader'] = reader
        if id is not None:
            qs['id'] = id
        if relpath is not None:
            qs['item'] = relpath
        url_fragment = urlparse.unquote(urllib.urlencode(qs))
        if len(url_fragment) > 0:
            url = url + '#' + url_fragment
        if self._dbrow is None:
            self._dbrow = self._node.row            
        self._dbrow['id'] = id_img
        self._dbrow['url'] = url
        self._dbrow['n_anno'] = n_anno
        self._dbrow['id_anno'] = id_anno
        self._dbrow.append()
        return (id_img, url, n_anno, id_anno)

    # for updating values use row.update()
    
    def read(self, id_img):
        imurl = self._node[id_img]['url']	
        im = CachedImage(imurl, cache=self._db._cache, verbose=self._db._verbose)
        if self._db._async:
            return im  
        else:
            return im.get()

    def num_images(self):
        """Total number of images in database"""
        return len(self)
        
        

class AnnotationView(GroupView):
    pass


class TaskView(GroupView):
    pass


    
class CategorizationAnnotationView(LeafView):
    class Table(tables.IsDescription):
        id_img = tables.Int64Col(pos=0)
        id_category = tables.Int32Col(pos=1)  
        category = tables.StringCol(64, pos=2)      

    def create(self, dbobj):
        try:
            group = dbobj.create_group(dbobj.root, 'annotation', 'Annotation')            
            group._v_attrs.viewname = 'annotation'
            group._v_attrs.viewclass = 'GroupView'
        except:
            # already created
            pass  
        tbl = dbobj.create_table(dbobj.root.annotation, 'categorization', CategorizationAnnotationView.Table, title='categorization', filters=tables.Filters(complevel=6), expectedrows=1000000)        
        tbl.attrs.viewname = 'categorization'
        tbl.attrs.viewclass = 'CategorizationAnnotationView'  # FIXME: why not self.__name__?
        dbobj.flush()

    def read(self, id_anno):
        id_img = self._node[id_anno]['id_img']
        imurl = self._db._obj.root.image[id_img]['url']
        im = CachedImage(imurl, cache=self._db._cache, verbose=self._db._verbose)
        anno = self._node[id_anno]
        if self._db._async:
            return (im,anno)  # numpy structured array tuple
        else:
            return (im.get(),anno)  # numpy structured array tuple
        
    def write(self, id_category, category, id_img):
        if self._dbrow is None:
            self._dbrow = self._node.row
        self._dbrow['id_img'] = id_img        
        self._dbrow['id_category'] = id_category
        self._dbrow['category'] = category
        self._dbrow.append()

    def delete(self, *args, **kwargs):
        self._node.remove_rows(*args, **kwargs)
        pass
                    
    def labels(self):
        """Unique category ids in dataset"""        
        return np.unique(self._node.cols.id_category[:])

    def num_categories(self):
        return len(self.labels())

    def y(self, id_category):
        """Return a binary label vector y \in {-1,1} for dataset elements of a given category id"""        
        id = self._node.cols.id_category[:]
        return 2*(np.int32(id == id_category)) - 1

    
class DetectionAnnotationView(LeafView):
    _viewname = 'Detection'
    
    class Table(tables.IsDescription):
        id_img = tables.Int64Col(pos=0)
        id_category = tables.Int32Col(pos=1)  
        category = tables.StringCol(64, pos=2)      
        bbox_xmin = tables.Float32Col(pos=3)
        bbox_xmax = tables.Float32Col(pos=4)
        bbox_ymin = tables.Float32Col(pos=5)
        bbox_ymax = tables.Float32Col(pos=6)    

    def create(self, dbobj):
        try:
            group = dbobj.create_group(dbobj.root, 'annotation', 'Annotation')            
            group._v_attrs.viewname = 'annotation'
            group._v_attrs.viewclass = 'GroupView'
        except:
            # already created
            pass  
        tbl = dbobj.create_table(dbobj.root.annotation, 'detection', DetectionAnnotationView.Table, title='detection', filters=tables.Filters(complevel=6), expectedrows=1000000)        
        tbl.attrs.viewname = 'detection'
        tbl.attrs.viewclass = 'DetectionAnnotationView'  # FIXME: why not self.__name__?
        dbobj.flush()
        
    def write(self, id_category, category, id_img, xmin, xmax, ymin, ymax):
        if self._dbrow is None:
            self._dbrow = self._node.row
        self._dbrow['id_category'] = id_category
        self._dbrow['category'] = category
        self._dbrow['id_img'] = id_img        
        self._dbrow['bbox_xmin'] = xmin
        self._dbrow['bbox_xmax'] = xmax
        self._dbrow['bbox_ymin'] = ymin
        self._dbrow['bbox_ymax'] = ymax                        
        self._dbrow.append()

    def read(self, id_anno):
        id_img = self._node[id_anno]['id_img']
        impath = self._db._obj.root.image[id_img]['url']
        im = CachedImage(impath, cache=self._db._cache, verbose=self._db._verbose)
        anno = self._node[id_anno]
        if self._db._async:
            return (im,anno)  # numpy structured array tuple
        else:
            return (im.get(),anno)  # numpy structured array tuple
        

class SegmentationAnnotationView(LeafView):
    class Table(tables.IsDescription):
        id_img = tables.Int64Col(pos=0)
        id_category = tables.Int32Col(pos=1)  
        category = tables.StringCol(64, pos=2)      
        bbox_xmin = tables.Float32Col(pos=3)
        bbox_xmax = tables.Float32Col(pos=4)
        bbox_ymin = tables.Float32Col(pos=5)
        bbox_ymax = tables.Float32Col(pos=6)    
        mask = tables.StringCol(512, pos=7)   

    def create(self, dbobj):
        try:
            group = dbobj.create_group(dbobj.root, 'annotation', 'Annotation')            
            group._v_attrs.viewname = 'annotation'
            group._v_attrs.viewclass = 'GroupView'
        except:
            # already created
            pass  
        tbl = dbobj.create_table(dbobj.root.annotation, 'segmentation', SegmentationAnnotationView.Table, title='segmentation', filters=tables.Filters(complevel=6), expectedrows=1000000)        
        tbl.attrs.viewname = 'segmentation'
        tbl.attrs.viewclass = 'SegmentationAnnotationView'  # FIXME: why not self.__name__?
        dbobj.flush()
        
    def write(self, relpath, id_img, category=None, url=None):
        qs = {}
        if relpath is not None:
            qs['item'] = relpath
        url_fragment = urlparse.unquote(urllib.urlencode(qs))
        if len(url_fragment) > 0:
            url = url + '#' + url_fragment
        if self._dbrow is None:
            self._dbrow = self._node.row
        self._dbrow['mask'] = url
        self._dbrow['id_img'] = id_img        
        self._dbrow['category'] = category  
        self._dbrow.append()
    
    def read(self, id_anno):
        id_img = self._node[id_anno]['id_img']
        imurl = self._db._obj.root.image[id_img]['url']	
        im = CachedImage(imurl, cache=self._db._cache, verbose=self._db._verbose)
        maskurl = self._node[id_anno]['mask']
        if maskurl == '':
			mask = None
        else:
			mask = CachedImage(maskurl, cache=self._db._cache, verbose=self._db._verbose)
        anno = self._node[id_anno]
        if self._db._async:
            anno = {'id_category':anno['id_category'], 'category':anno['category'], 'mask':mask}            
            return (im,anno)  # numpy structured array tuple
        else:
            if mask is not None:
                mask = mask.get()
            anno = {'id_category':anno['id_category'], 'category':anno['category'], 'mask':mask}                        
            return (im.get(),anno)  # numpy structured array tuple

