"""Classes to create and manage views in vision datasets"""

import tables
import os
from os import path
import shutil
import viset.download
import viset.util
from viset.util import ishdf5, isfile, isarchive, quietprint
import numpy as np 
import urllib
import socket
import hashlib
import time
import random
from itertools import imap, islice, count
import viset.types
import urlparse


def bind(viewname, db):
    if viewname == 'Categorization':
        dbview = viset.views.CategorizationView(db)
    elif viewname == 'Detection':
        dbview = viset.views.DetectionView(db)
    elif viewname == 'Segmentation':        
        dbview = viset.views.SegmentationView(db)
    else:
        raise IOError('Unknown database view \'' + str(viewname) + '\'')
    return dbview


class AnnotationTable(tables.IsDescription):
    id_img = tables.Int64Col(pos=0)


class VisetView(object):
    db = None   # viset object
    dbtbl = None   # annotation table object for view
    viewiter = None
    async = None
    
    _viewname = None  # overloaded
    _dbrow = None  # fast pytables writes
    
    def __init__(self, db, async=True):
        self.db = db
        self.async = async
        pass

    def set(self, async=True):
        self.async = async
        return self
    
    def open(self):
        """Open table name to table object in annotation group"""
        dbtbl = None
        for node in self.db.dbobj.root.annotation:
            if node.attrs.name == self._viewname:
                dbtbl = node
                break
        if dbtbl is None:
            raise IOError('Invalid dataset view "' + str(self._viewname) + '"')
        self.dbtbl = dbtbl
        return self
        
    def num_annotations(self):
        """Number of annotations in current view"""    
        return self.dbtbl.nrows

    def name(self):        
        """Name of current view"""
        return self.dbtbl.attrs.name
    
    def split(self, strategy=None, kfold=None, datastep=1):
        """Dataset cross validation split"""

        readfunc = self.read
        if strategy is None:
            itrain = imap(readfunc, islice(count(), 0, self.num_annotations(), datastep))
            viewiter = [itrain]
        elif strategy is 'randomize':
            id_split = range(self.num_annotations())
            id_split = id_split[::datastep]
            random.shuffle(id_split)
            itrain = imap(readfunc, id_split)
            viewiter = [itrain]
        elif strategy == 'train_test_shuffle':
            id_split = range(self.num_annotations())
            id_split = id_split[::datastep]
            random.shuffle(id_split)
            itrain = imap(readfunc, (id_split[:len(id_split)/2]))
            itest = imap(readfunc, (id_split[len(id_split)/2:]))      
            viewiter = {'train':itrain, 'test':itest, 'validate':None}
        elif strategy == 'kfold':            
            id_split = range(self.num_annotations())  
            id_split = id_split[::datastep]
            random.shuffle(id_split)
            foldsize = len(id_split)/kfold
            viewiter = [None]*kfold  # empty list
            for k in range(0,kfold-1):
                u = k*foldsize
                v = u + foldsize        
                id_test = id_split[u:v]
                id_train = id_split
                id_train[u:v] = []
                itrain = imap(readfunc, id_train)
                itest = imap(readfunc, id_test)
                viewiter[k] =  {'train':itrain, 'test':itest, 'validate':None}
        else:
            raise IOError('TODO: retrieve stored splits in database' )
        self.viewiter = viewiter
        return self.viewiter
    
class CategorizationView(VisetView):
    _viewname = 'Categorization'
    
    class Table(AnnotationTable, tables.IsDescription):
        id_category = tables.Int32Col(pos=1)  
        category = tables.StringCol(64, pos=2)      

    def create(self,group):
        tbl = self.db.dbobj.create_table(group, self._viewname.lower(), CategorizationView.Table, title=self._viewname, filters=tables.Filters(complevel=6), expectedrows=1000000)        
        tbl.attrs.name = self._viewname
        tbl.flush()
        return tbl

    def num_categories(self):
        return len(self.labels())
    
    def labels(self):
        """Unique category ids in dataset"""        
        return np.unique(self.dbtbl.cols.id_category[:])

    def read(self, id_anno):
        id_img = self.dbtbl[id_anno]['id_img']
        imurl = self.db.dbobj.root.images[id_img]['url']
        im = viset.types.Image(imurl, cache=self.db.dbcache, verbose=self.db.verbose)
        anno = self.dbtbl[id_anno]
        if self.async:
            return (im,anno)  # numpy structured array tuple
        else:
            return (im.get(),anno)  # numpy structured array tuple
        
    def add_categorization(self, id_category, category, id_img):
        if self._dbrow is None:
            self._dbrow = self.db.dbobj.root.annotation.categorization.row
        self._dbrow['id_img'] = id_img        
        self._dbrow['id_category'] = id_category
        self._dbrow['category'] = category
        self._dbrow.append()

    def y(self, id_category):
        """Return a binary label vector y \in {-1,1} for dataset elements of a given category id"""        
        id = self.dbtask.cols.id_category[:]
        return 2*(np.int32(id == id_category)) - 1

    
class DetectionView(VisetView):
    _viewname = 'Detection'
    
    class Table(AnnotationTable, tables.IsDescription):
        id_category = tables.Int32Col(pos=1)  
        category = tables.StringCol(64, pos=2)      
        bbox_xmin = tables.Float32Col(pos=3)
        bbox_xmax = tables.Float32Col(pos=4)
        bbox_ymin = tables.Float32Col(pos=5)
        bbox_ymax = tables.Float32Col(pos=6)    
        
    def create(self,group):
        tbl = self.db.dbobj.create_table(group, self._viewname.lower(), DetectionView.Table, title=self._viewname, filters=tables.Filters(complevel=6), expectedrows=1000000)        
        tbl.attrs.name = self._viewname
        tbl.flush()
        return tbl

    
    def add_detection(self, id_category, category, id_img, xmin, xmax, ymin, ymax):
        if self._dbrow is None:
            self._dbrow = self.db.dbobj.root.annotation.detection.row
        self._dbrow['id_category'] = id_category
        self._dbrow['category'] = category
        self._dbrow['id_img'] = id_img        
        self._dbrow['bbox_xmin'] = xmin
        self._dbrow['bbox_xmax'] = xmax
        self._dbrow['bbox_ymin'] = ymin
        self._dbrow['bbox_ymax'] = ymax                        
        self._dbrow.append()

    def read(self, id_anno):
        id_img = self.dbtbl[id_anno]['id_img']
        impath = self.db.dbobj.root.images[id_img]['url']
        im = viset.types.Image(impath, cache=self.db.dbcache, verbose=self.db.verbose)
        anno = self.dbtbl[id_anno]
        if self.async:
            return (im,anno)  # numpy structured array tuple
        else:
            return (im.get(),anno)  # numpy structured array tuple
        

class SegmentationView(VisetView):
    _viewname = 'Segmentation'

    class Table(AnnotationTable, tables.IsDescription):
        id_category = tables.Int32Col(pos=1)  
        category = tables.StringCol(64, pos=2)      
        bbox_xmin = tables.Float32Col(pos=3)
        bbox_xmax = tables.Float32Col(pos=4)
        bbox_ymin = tables.Float32Col(pos=5)
        bbox_ymax = tables.Float32Col(pos=6)    
        mask = tables.StringCol(512, pos=7)   

    def create(self,group):
        tbl = self.db.dbobj.create_table(group, self._viewname.lower(), SegmentationView.Table, title=self._viewname, filters=tables.Filters(complevel=6), expectedrows=1000000)        
        tbl.attrs.name = self._viewname
        tbl.flush()
        return tbl
        
    def add_segmentation(self, relpath, id_img, category=None, url=None):
        qs = {}
        if relpath is not None:
            qs['item'] = relpath
        url_fragment = urlparse.unquote(urllib.urlencode(qs))
        if len(url_fragment) > 0:
            url = url + '#' + url_fragment
        if self._dbrow is None:
            self._dbrow = self.db.dbobj.root.annotation.segmentation.row
        self._dbrow['mask'] = url
        self._dbrow['id_img'] = id_img        
        self._dbrow['category'] = category  
        self._dbrow.append()
    
    def read(self, id_anno):
        id_img = self.dbtbl[id_anno]['id_img']
        imurl = self.db.dbobj.root.images[id_img]['url']	
        im = viset.types.Image(imurl, cache=self.db.dbcache, verbose=self.db.verbose)
        maskurl = self.dbtbl[id_anno]['mask']
        if maskurl == '':
			mask = None
        else:
			mask = viset.types.Image(maskurl, cache=self.db.dbcache, verbose=self.db.verbose)
        anno = self.dbtbl[id_anno]
        if self.async:
            anno = {'id_category':anno['id_category'], 'category':anno['category'], 'mask':mask}            
            return (im,anno)  # numpy structured array tuple
        else:
            if mask is not None:
                mask = mask.get()
            anno = {'id_category':anno['id_category'], 'category':anno['category'], 'mask':mask}                        
            return (im.get(),anno)  # numpy structured array tuple
