"""Classes to create and manage vision datasets"""

import tables
import os
from os import path
import shutil
import viset.download
import viset.util
from viset.util import ishdf5, isfile, isarchive, quietprint
import numpy as np 
import urllib
import socket
import hashlib
import time
import random
from itertools import imap, islice, count
import viset.types
import urlparse


class AboutTable(tables.IsDescription):
    viset = tables.StringCol(256, pos=1)   
    version = tables.Int64Col(pos=0)   
    name = tables.StringCol(256, pos=1)   
    imtype = tables.StringCol(256, pos=2)  # hint for binary decode

class PackageTable(tables.IsDescription):
    url = tables.StringCol(512, pos=1)   
    sha1 = tables.StringCol(512, pos=1)     
    path = tables.StringCol(512, pos=1)   

class SplitTable(tables.IsDescription):
    split0 = tables.UInt8Col(pos=0)   
    split1 = tables.UInt8Col(pos=1)   
    split2 = tables.UInt8Col(pos=2)   
    split3 = tables.UInt8Col(pos=3)   
    split4 = tables.UInt8Col(pos=4)   
    split5 = tables.UInt8Col(pos=5)   
    split6 = tables.UInt8Col(pos=6)   
    split7 = tables.UInt8Col(pos=7)   
    split8 = tables.UInt8Col(pos=8)   
    split9 = tables.UInt8Col(pos=9)                     

  
class ImageTable(tables.IsDescription):
    id = tables.Int64Col(pos=0)   
    url = tables.StringCol(512, pos=1)   
    id_anno = tables.Int64Col(pos=2)   
    n_anno = tables.Int32Col(pos=3)   
  
class CategorizationAnnotationTable(tables.IsDescription):
    id = tables.Int64Col(pos=0)
    id_img = tables.Int64Col(pos=1)
    id_category = tables.Int32Col(pos=2)  
    category = tables.StringCol(256, pos=3)      

class DetectionAnnotationTable(tables.IsDescription):
    id = tables.Int64Col(pos=0)   
    id_img = tables.Int64Col(pos=1)
    id_category = tables.Int32Col(pos=2)  
    category = tables.StringCol(256, pos=3)        
    bbox_xmin = tables.Float32Col(pos=4)
    bbox_xmax = tables.Float32Col(pos=5)
    bbox_ymin = tables.Float32Col(pos=6)
    bbox_ymax = tables.Float32Col(pos=7)    
  
class Viset(object):
    cacheroot = viset.util.cacheroot()
    cachedir = None
    dbname = None
    dbversion = 0
    dburl = None   # remote file
    dbfile = None  # local file in cache
    dbtype = None    
    dbimtype = None        
    dbobj = None  # database object
    dbtask = None 
    dbtaskname = None
    dbiter = None     
    dbasyncread = True # asynchronous dataset read
    dbverbose = False
    dbannorow = None  # fast writes
    dbimrow = None    # fast writes
    dbpkgrow = None    # fast writes    
    dbimread = None
    
    def __init__(self, dbfile=dbfile, task=None, strategy=None, kfold=0, datastep=1, asyncread=True, verbose=False):
        if dbfile is not None:
            (self.dbobj, self.dbname, self.dbfile, self.dburl, self.cachedir) = self.open(dbfile)
            (self.dbtask, self.dbtaskname) = self.set_task(task)
            self.dbiter = self.split(strategy=strategy, kfold=kfold, datastep=datastep, asyncread=asyncread)
            self.dbasyncread = asyncread
            self.dbverbose = verbose            
            self.fetch()
        else:
            # empty initialization for subclass calling create
            pass
        
    def __len__(self):
        """Return the total number of annotations in the current dataset"""
        return self.dbtask.nrows
    
    def __getitem__(self, k):
        return self.dbiter[k]
      
    def __del__(self):
        if self.dbobj is not None:
            self.close()
        
    def __str__(self):
        return self.summary()

    def write(self):
        self.dbobj.flush()
            
    def summary(self):
        print 'Database file: ' + self.dbfile
        print 'Database version: ' + str(self.version())
        print 'Cache root directory: ' + self.cacheroot
        print 'Cache directory: ' + self.cachedir
        print 'Image storage type: ' + '\'' + self.dbimtype() + '\''
        print 'Database annotations: ' + str(self.annotypes())
        print 'Dataset name: ' + str(self.name())    
        print 'Dataset annotation: ' + '\'' + self.annotype() + '\''
        print 'Number of images: ' + str(self.num_images())
        print 'Number of annotations: ' + str(self.num_annotations())
        print 'HDF5 database file structure: ' 
        print '  ' + "\n  ".join(str(self.dbobj).split("\n"))
        return str()

    def add_package(self, URL, SHA1, PATH):
        if self.dbpkgrow is None:
            self.dbpkgrow = self.dbobj.root.packages.row            
        self.dbpkgrow['url'] = URL
        self.dbpkgrow['sha1'] = SHA1
        self.dbpkgrow['path'] = PATH
        self.dbpkgrow.append()

    def add_image(self, id_img, url, n_anno, id_anno):
        if self.dbimrow is None:
            self.dbimrow = self.dbobj.root.images.row            
        self.dbimrow['id'] = id_img
        self.dbimrow['url'] = url
        self.dbimrow['n_anno'] = n_anno
        self.dbimrow['id_anno'] = id_anno
        self.dbimrow.append()
        
    def add_categorization(self, id_category, category, id_img):
        if self.dbannorow is None:
            self.dbannorow = self.dbobj.root.annotation.categorization.row
        self.dbannorow['id_category'] = id_category
        self.dbannorow['category'] = category
        self.dbannorow['id_img'] = id_img        
        self.dbannorow.append()
        
    def add_detection(self, id_category, category, id_img, xmin, xmax, ymin, ymax):
        if self.dbannorow is None:
            self.dbannorow = self.dbobj.root.annotation.detection.row
        self.dbannorow['id_category'] = id_category
        self.dbannorow['category'] = category
        self.dbannorow['id_img'] = id_img        
        self.dbannorow['bbox_xmin'] = xmin
        self.dbannorow['bbox_xmax'] = xmax
        self.dbannorow['bbox_ymin'] = ymin
        self.dbannorow['bbox_ymax'] = ymax                        
        self.dbannorow.append()
        
    def split(self, strategy = None, kfold = None, datastep = 1, asyncread=True):
        """Dataset cross validation split"""

        if asyncread is False:
            readfunc = self.read
        else:
            readfunc = self.readasync
            
        if strategy is None:
            itrain = imap(readfunc,islice(count(), 0, self.num_annotations(), datastep))
            dbiter = [itrain]
        elif strategy is 'randomize':
            id_split = range(self.num_annotations())
            id_split = id_split[::datastep]
            random.shuffle(id_split)
            itrain = imap(readfunc,id_split)
            dbiter = [itrain]
        elif strategy == 'train_test_shuffle':
            id_split = range(self.num_annotations())
            id_split = id_split[::datastep]
            random.shuffle(id_split)
            itrain = imap(readfunc,(id_split[:len(id_split)/2]))
            itest = imap(readfunc,(id_split[len(id_split)/2:]))      
            dbiter = {'train':itrain,'test':itest,'validate':None}
        elif strategy == 'kfold':            
            id_split = range(self.num_annotations())  
            id_split = id_split[::datastep]
            random.shuffle(id_split)
            foldsize = len(id_split)/kfold
            dbiter = [None]*kfold  # empty list
            for k in range(0,kfold-1):
                u = k*foldsize
                v = u + foldsize        
                id_test = id_split[u:v]
                id_train = id_split
                id_train[u:v] = []
                itrain = imap(readfunc, id_train)
                itest = imap(readfunc, id_test)
                dbiter[k] =  {'train':itrain,'test':itest,'validate':None}
        else:
            raise IOError('TODO: retrieve stored splits in database' )
        return dbiter

    
    def open(self, dbfile):
        # Database location resolution
        if not type(dbfile) is str:
            raise IOError('database filename must be a string')
        if not ishdf5(dbfile):
            raise IOError('invalid database file extension - must be HDF5 (.h5)')                
        if isfile(dbfile):
            dburl = None
        elif isfile(path.join(self.cacheroot, path.basename(dbfile))):
            dbfile = path.join(self.cacheroot, path.basename(dbfile))
            dburl = None
        elif isurl(dbfile):
            dburl = dbfile
            dbfile = path.join(self.cachedir, path.basename(dburl))
            viset.download.download(dburl, dbfile, verbose=self.dbverbose)
        else:
            dbfile = None
            dburl = None

        # Open in read mode
        try:
            quietprint('Opening viset "' + dbfile + '"', self.dbverbose)
            dbobj = tables.open_file(dbfile, mode = "r")
            if not viset.util.isurl(dbobj.root.about[0]['viset']):
                raise IOError()        
        except:
            print 'Invalid database file "' + str(dbfile) + '"'
            raise      

        # Initialize cache
        dbname = dbobj.root.about[0]['name']
        cachedir = path.join(self.cacheroot, dbname) 
        if not path.exists(cachedir):
            os.makedirs(cachedir)

        # Return tuple
        return (dbobj, dbname, dbfile, dburl, cachedir)

    
    def create(self, dbname, dbtype, dbimtype, dbversion):
        # Initialize cache
        cachedir = path.join(self.cacheroot, dbname)      
        if not path.exists(cachedir):
            os.makedirs(cachedir)
        dbfile = path.join(self.cacheroot, dbname+'.h5')
        dburl = None

        # Format new dataset
        dbobj = self.format(dbfile, dbname, dbtype, dbimtype, dbversion)
        dbobj.flush()
        
        # Return tuple
        return (dbobj, dbname, dbfile, dburl, cachedir)
        

    def close(self):
        self.dbobj.close()

    def set_task(self, task):
        for node in self.dbobj.root.annotation:
            if task is None:
                return (node, node.attrs.annotype) # first task
            elif node.attrs.annotype == task:
                return (node, task)   # annotation of matching type
        raise IOError('Invalid dataset task "' + str(task) + '"')
        
    def delete(self):
        quietprint('Deleting all cached data in "' + self.cachedir + '" for dataset "' + self.dbfile + '"', self.dbverbose)
        shutil.rmtree(self.cachedir)

    def name(self):
        return self.dbobj.root.about[0]['name']
  
    def version(self):
        return self.dbobj.root.about[0]['version']

    def dbimtype(self):
        return self.dbobj.root.about[0]['imtype']
  
    def about(self):
        return self.dbobj.root.about[0]

    def labels(self):
        """Unique category ids in dataset"""        
        return np.unique(self.dbtask.cols.id_category[:])

    def num_categories(self):
        return len(self.labels())
  
    def num_annotations(self):
        """Number of annotations in dataset"""    
        return self.dbtask.nrows

    def annotype(self):
        """Annotation type of dataset"""
        return self.dbtask.attrs.annotype
  
    def num_annotypes(self):
        """Total number of annotation types in database"""
        return len(self.annotypes())

    def annotypes(self):
        """List of annotation types in database"""    
        annolist = []
        for node in self.dbobj.root.annotation:
            annolist.append(node.attrs.annotype)
        return annolist

    def num_images(self):
        """Total number of images in database"""
        return self.dbobj.root.images.nrows
  
    def y(self, id_category):
        """Return a binary label vector y \in {-1,1} for dataset elements of a given category id"""        
        id = self.dbtask.cols.id_category[:]
        return 2*(np.int32(id == id_category)) - 1

  
    def image(self, im):
        id = im['id']
        if self.dbimtype() == 'package':
            path = self.dbobj.root.images[id]['url']
            im = viset.types.Image(uri=path, cacheroot=self.cacheroot, verbose=self.dbverbose)
        elif self.dbimtype() == 'url':
            url = self.dbobj.root.images[id]['url']
            im = viset.types.Image(uri=url, cacheroot=self.cacheroot, verbose=self.dbverbose)
        elif self.dbimtype() == 'custom_mnist':
            im = viset.types.CustomImage(cachedir=self.cachedir, index=id, verbose=self.dbverbose, imreadfunc=viset.mnist.imread)
        else:
            raise IOError("undefined database image type "+str(self.dbimtype()))
        return im

    def query(self, table = None, row = 0, field = None):
        # if uri is empty, look at uri from database open
        # otherwise, return the queried node
        if table is None:
            uri = self.dburi
            p = urlparse.urlsplit(uri)
            p = urlparse.parse_qs(p[3])
            if 'table' in p:
               table = p['table'][0]
            if 'row' in p:
                row = int(p['row'][0])
            if 'field' in p:
                field = int(p['field'][0])
        if table is None:
            return None
        elif field is None:
            return self.dbobj.getNode(table)[row] 
        else:
            return self.dbobj.getNode(table)[row][field] 
   
    def read(self, id_anno):
        im = self.image(self.dbobj.root.images[self.dbtask[id_anno]['id_img']]).get()
        anno = self.dbtask[id_anno]
        return (im,anno)  # numpy structured array tuple

    def readasync(self, id_anno):
        im = self.image(self.dbobj.root.images[self.dbtask[id_anno]['id_img']]) # no .get()
        anno = self.dbtask[id_anno]
        return (im,anno)  # numpy structured array tuple        
    
    def format(self, dbfile, dbname, dbtype=[], dbimtype='url', dbversion=0):
        # HDF5 file structure
        db = tables.open_file(dbfile, mode = "w", title = dbname)
        #group = db.create_group("/", 'dataset', 'Dataset')
        tbl_about = db.create_table(db.root, 'about', AboutTable, expectedrows=1, title='About')    
        tbl_packagse = db.create_table(db.root, 'packages', PackageTable, title='Packages', expectedrows=10)
        tbl_splits = db.create_table(db.root, 'splits', SplitTable, title='Splits', expectedrows=1000000)
        tbl_images = db.create_table(db.root, 'images', ImageTable, title='Images', filters=tables.Filters(complevel=1), expectedrows=1000000)
    
        # Annotation tables
        group = db.create_group(db.root, 'annotation', 'Annotation')    
        for annotype in viset.util.tolist(dbtype):
            if annotype == 'Categorization':
                tbl_anno = db.create_table(group, 'categorization', CategorizationAnnotationTable, title='Categorization', filters=tables.Filters(complevel=1), expectedrows=1000000)
            elif annotype == 'Detection':
                tbl_anno = db.create_table(group, 'detection', DetectionAnnotationTable, title='Detection', filters=tables.Filters(complevel=1), expectedrows=1000000)
            else:
                raise IOError('Invalid annotation type ' + str(annotype))
            tbl_anno.attrs.annotype = annotype
            tbl_anno.flush()
      
        # Metadata
        r = tbl_about.row
        r['viset'] = 'http://www.visym.com'
        r['version'] = dbversion
        r['name'] = dbname
        r['imtype'] = dbimtype
        r.append()

        # Write
        db.flush()
       
        # Return newly created database object
        return db

    def fetch(self):
        if self.dbobj.root.packages.nrows > 0:
            packages = self.dbobj.root.packages.iterrows()
            for pkg in packages:
                if not path.exists(path.join(self.cachedir, pkg['path'])):
                    if isarchive(pkg['url']):
                        viset.download.download_and_extract(pkg['url'], self.cachedir, sha1=pkg['sha1'], verbose=self.dbverbose)      
                    else:
                        viset.download.download(pkg['url'], path.join(self.cachedir, pkg['path']), sha1=pkg['sha1'], verbose=self.dbverbose)                              



class BlockingViset(Viset):
    dbasyncread = False
                        
class CategorizationViset(Viset):
    dbtype = 'Categorization'

class DetectionViset(Viset):
    dbtype = 'Detection'
    
class ImageArchiveViset(Viset):
    dbimtype = 'package'

class UrlViset(Viset):
    dbimtype = 'url'    

class CustomImageViset(Viset):
    dbimtype = 'custom'
    readfunc = None

