"""Classes to create and manage vision datasets"""

import tables
import os
from os import path
import shutil
import viset.download
import viset.util
from viset.util import ishdf5, isfile, isarchive, quietprint, setfetched, wasfetched
import numpy as np 
import urllib
import socket
import hashlib
import time
import random
from itertools import imap, islice, count
import viset.types
import urlparse
import viset.views 
import datetime

class AboutTable(tables.IsDescription):
    viset = tables.StringCol(256, pos=0)   
    name = tables.StringCol(256, pos=1)   
    version = tables.Int64Col(pos=2)   
    imcodec = tables.StringCol(256, pos=4)  # hint for binary decode

class PackageTable(tables.IsDescription):
    url = tables.StringCol(512, pos=1)   
    sha1 = tables.StringCol(512, pos=1)     

class SplitTable(tables.IsDescription):
    split0 = tables.UInt8Col(pos=0)   
    split1 = tables.UInt8Col(pos=1)   
    split2 = tables.UInt8Col(pos=2)   
    split3 = tables.UInt8Col(pos=3)   
    split4 = tables.UInt8Col(pos=4)   
    split5 = tables.UInt8Col(pos=5)   
    split6 = tables.UInt8Col(pos=6)   
    split7 = tables.UInt8Col(pos=7)   
    split8 = tables.UInt8Col(pos=8)   
    split9 = tables.UInt8Col(pos=9)                     
  
class ImageTable(tables.IsDescription):
    id = tables.Int64Col(pos=0)   
    url = tables.StringCol(512, pos=1)   
    id_anno = tables.Int64Col(pos=2)   
    n_anno = tables.Int32Col(pos=3)   
    
class Viset(object):
    cacheroot = viset.util.cacheroot()
    cachedir = None
    dburl = None   # remote file
    dbfile = None  # local file in cache
    verbose = None
    dbobj = None  # database object
    dbview = None  # active database view from task
    _impkg = None   # overloaded for create
    _viewname = None  # overloaded for create
    _imcodec = 'universal'  # overloaded for create
    _version = 0 # overloaded for create
    _dbname = None # overloaded for create
    _imrow = None    # fast pytable writes (FIXME)
    _pkgrow = None   # fast pytable writes (FIXME)
    
    
    def __init__(self, dbfile=dbfile, task=None, strategy=None, kfold=0, datastep=1, async=True, verbose=False):
        if dbfile is not None:
            self.dbobj = self.open(dbfile)
            self.dbview = self.view(task, async=async)
            self.dbview.split(strategy=strategy, kfold=kfold, datastep=datastep)
            self.verbose = verbose            
            self.fetch()
        else:
            # empty initialization for subclass create
            pass

    def __call__(self):        
        raise IOError('FIXME: this can be used to change the split instead of calling .split()?')
    
    def __len__(self):
        """Return the total number of annotations in the current dataset"""
        return self.dbview.num_annotations()
    
    def __getitem__(self, k):
        if (len(self.dbview.viewiter) == 0) and (k == 0):
            return self.dbview.viewiter
        elif (len(self.dbview.viewiter) > 0) and (k >= 0):
            return self.dbview.viewiter[k]
        else:
            raise ValueError('invalid index')
      
    def __del__(self):
        if self.dbobj is not None:
            self.close()
        
    def __str__(self):
        return self.summary()

    def write(self):
        self.dbobj.flush()
            
    def summary(self):
        print 'Dataset name: ' + str(self.name())    
        print 'Database file: ' + self.dbfile
        print 'Database version: ' + str(self.version())
        print 'Cache root directory: ' + self.cacheroot
        print 'Cache directory: ' + self.cachedir
        print 'Image codec: ' + '\'' + self.imcodec() + '\''
        print 'Database views: ' + str(self.list_views())
        print 'Number of images: ' + str(self.num_images())
        print 'Current view: ' + '\'' + self.dbview.name() + '\''
        print 'Number of annotations: ' + str(self.dbview.num_annotations())
        print 'HDF5 database file structure: ' 
        print '  ' + "\n  ".join(str(self.dbobj).split("\n"))
        return str()

    def view(self, viewname=None, async=True):        
        """Bind viewname to a database view object"""
        if viewname is None:
            viewname = self.list_views()[0]
        self.dbview = viset.views.bind(viewname, self).set(async=async).open()
        return self.dbview
        
    def add_package(self, URL, SHA1):
        if self._pkgrow is None:
            self._pkgrow = self.dbobj.root.packages.row            
        self._pkgrow['url'] = URL
        self._pkgrow['sha1'] = SHA1
        self._pkgrow.append()

    def add_image(self, id_img, url, n_anno, id_anno):
        if self._imrow is None:
            self._imrow = self.dbobj.root.images.row            
        self._imrow['id'] = id_img
        self._imrow['url'] = url
        self._imrow['n_anno'] = n_anno
        self._imrow['id_anno'] = id_anno
        self._imrow.append()
        
        
    def open(self, dbfile):
        # Database location resolution
        if not type(dbfile) is str:
            raise IOError('database filename must be a string')
        if not ishdf5(dbfile):
            raise IOError('invalid database file extension - must be HDF5 (.h5)')                
        if isfile(dbfile):
            dburl = None
        elif isfile(path.join(self.cacheroot, path.basename(dbfile))):
            dbfile = path.join(self.cacheroot, path.basename(dbfile))
            dburl = None
        elif isurl(dbfile):
            dburl = dbfile
            dbfile = path.join(self.cachedir, path.basename(dburl))
            viset.download.download(dburl, dbfile, verbose=self.verbose)
        else:
            dbfile = None
            dburl = None

        # Open in read mode
        try:
            quietprint('Opening viset "' + dbfile + '"', self.verbose)
            dbobj = tables.open_file(dbfile, mode = "r")
            if not viset.util.isurl(dbobj.root.about[0]['viset']):
                raise IOError()        
        except:
            print 'Invalid database file "' + str(dbfile) + '"'
            raise      

        # Initialize cache
        dbname = dbobj.root.about[0]['name']
        cachedir = path.join(self.cacheroot, dbname) 
        if not path.exists(cachedir):
            os.makedirs(cachedir)

        # Update class properties
        self.dbobj = dbobj
        self.dbfile = dbfile
        self.dburl = dburl
        self.cachedir = cachedir
        return self.dbobj
    
    def create(self):
        # Initialize cache
        cachedir = path.join(self.cacheroot, self._dbname)      
        if not path.exists(cachedir):
            os.makedirs(cachedir)
        dbfile = path.join(self.cacheroot, self._dbname+'.h5')
        dburl = None

        # Format new dataset
        self.dbobj = self.format(dbfile, self._dbname, self._viewname, self._impkg, self._version, self._imcodec)

        # Create default view
        self.dbview = self.view()
        
        # Update class properties
        self.dbfile = dbfile
        self.dburl = dburl
        self.cachedir = cachedir
        
    def close(self):
        self.dbobj.close()

    def delete(self):
        quietprint('Deleting all cached data in "' + self.cachedir + '" for dataset "' + self.dbfile + '"', self.verbose)
        shutil.rmtree(self.cachedir)

    def name(self):
        return self.dbobj.root.about[0]['name']
  
    def version(self):
        return self.dbobj.root.about[0]['version']

    def imcodec(self):
        return self.dbobj.root.about[0]['imcodec']
  
    def about(self):
        return self.dbobj.root.about[0]

    def list_views(self):
        """List of views in database"""    
        viewlist = []
        for node in self.dbobj.root.annotation:
            viewlist.append(node.attrs.name)
        return viewlist

    def num_views(self):
        """Total number of views in database"""
        return len(self.views())

    def num_images(self):
        """Total number of images in database"""
        return self.dbobj.root.images.nrows
  
    def urlquery(self, url):
        """query database by using querystring format 'http://domain.com/mydb.h5?table=x&row=y&field=z'"""
        if not isurl(url):
            ValueError('Invalid URL \'' + url + '\'')
        p = urlparse.urlsplit(url)
        p = urlparse.parse_qs(p[3])
        if 'table' in p:
            table = p['table'][0]
        else:
            table = None
        if 'row' in p:
            row = int(p['row'][0])
        else:
            row = None
        if 'field' in p:
            field = int(p['field'][0])
        else: 
            field = None
        return self.query(table, row, field)
        
    def query(self, table, row, field=None):
        """database query by table[row][field]"""
        if field is None:
            return self.dbobj.getNode(table)[row] 
        else:
            return self.dbobj.getNode(table)[row][field] 
   
    
    def format(self, dbfile, dbname, dbviewname=[], dbimpkg='url', dbversion=0, dbimcodec='universal'):
        # HDF5 file structure
        self.dbobj = tables.open_file(dbfile, mode = "w", title = dbname)
        tbl_about = self.dbobj.create_table(self.dbobj.root, 'about', AboutTable, expectedrows=1, title='About')    
        tbl_packagse = self.dbobj.create_table(self.dbobj.root, 'packages', PackageTable, title='Packages', expectedrows=10)
        tbl_splits = self.dbobj.create_table(self.dbobj.root, 'splits', SplitTable, title='Splits', expectedrows=1000000)
        tbl_images = self.dbobj.create_table(self.dbobj.root, 'images', ImageTable, title='Images', filters=tables.Filters(complevel=6), expectedrows=1000000)
    
        # Annotation tables
        ingroup = self.dbobj.create_group(self.dbobj.root, 'annotation', 'Annotation')    
        for viewname in viset.util.tolist(dbviewname):
            dbview = viset.views.bind(viewname, self);
            tbl_anno = dbview.create(ingroup)
            tbl_anno.flush()
      
        # Metadata
        r = tbl_about.row
        r['viset'] = 'http://www.visym.com'
        r['version'] = dbversion
        r['name'] = dbname
        r['imcodec'] = dbimcodec
        r.append()

        # Return newly created database for writing
        self.dbobj.flush()
        return self.dbobj

    def fetch(self):
        """Fetch packages stored in package table"""
        if self.dbobj.root.packages.nrows > 0:
            if not self.iscached():
                packages = self.dbobj.root.packages.iterrows()
                for pkg in packages:
                    if isfile(path.join(self.cachedir, path.basename(pkg['url']))):
                        # already cached (from export)
                        continue
                    elif isarchive(pkg['url']):
                        viset.download.extract_and_cleanup(pkg['url'], self.cachedir, sha1=pkg['sha1'], verbose=self.verbose)      
                        os.remove(path.join(self.cachedir, path.basename(pkg['url'])))
                    else:
                        viset.download.download(pkg['url'], path.join(self.cachedir, path.basename(pkg['url'])), sha1=pkg['sha1'], verbose=self.verbose)                              
                setfetched(self.cachedir)

    def iscached(self):
        wasfetched(self.cachedir)

        
class CategorizationViset(Viset):
    _viewname = 'Categorization'

class ImageViset(Viset):
    _viewname = 'Image'
    
class DetectionViset(Viset):
    _viewname = 'Detection'

class SegmentationViset(Viset):
    _viewname = 'Segmentation'
    
class ImageArchiveViset(Viset):
    _impkg = 'package'

class UrlViset(Viset):
    _impkg = 'url'    

class CustomImageViset(Viset):
    _imcodec = 'custom'


