import tables
import os
from os import path
import shutil
import viset.download
import viset.util
import pylab
import numpy as np 
import urllib
import socket
import hashlib
import time
import random
from itertools import imap, islice, count



class AboutTable(tables.IsDescription):
  viset = tables.StringCol(256, pos=1)   
  version = tables.Int64Col(pos=0)   
  name = tables.StringCol(256, pos=1)   
  imstoragetype = tables.StringCol(256, pos=2)  # hint for binary decode

class PackageTable(tables.IsDescription):
  url = tables.StringCol(512, pos=1)   
  sha1 = tables.StringCol(512, pos=1)     
  path = tables.StringCol(512, pos=1)   

class SplitTable(tables.IsDescription):
  split0 = tables.UInt8Col(pos=0)   
  split1 = tables.UInt8Col(pos=1)   
  split2 = tables.UInt8Col(pos=2)   
  split3 = tables.UInt8Col(pos=3)   
  split4 = tables.UInt8Col(pos=4)   
  split5 = tables.UInt8Col(pos=5)   
  split6 = tables.UInt8Col(pos=6)   
  split7 = tables.UInt8Col(pos=7)   
  split8 = tables.UInt8Col(pos=8)   
  split9 = tables.UInt8Col(pos=9)                     

  
class ImageTable(tables.IsDescription):
  id = tables.Int64Col(pos=0)   
  url = tables.StringCol(512, pos=1)   
  id_anno = tables.Int64Col(pos=2)   
  n_anno = tables.Int32Col(pos=3)   
  
class CategorizationAnnotationTable(tables.IsDescription):
  id = tables.Int64Col(pos=0)
  id_img = tables.Int64Col(pos=1)
  id_category = tables.Int32Col(pos=2)  
  category = tables.StringCol(256, pos=3)      

class DetectionAnnotationTable(tables.IsDescription):
  id = tables.Int64Col(pos=0)   
  id_img = tables.Int64Col(pos=1)
  id_category = tables.Int32Col(pos=2)  
  category = tables.StringCol(256, pos=3)        
  bbox_xmin = tables.Float32Col(pos=4)
  bbox_xmax = tables.Float32Col(pos=5)
  bbox_ymin = tables.Float32Col(pos=6)
  bbox_ymax = tables.Float32Col(pos=7)    
  
class Viset(object):
  cacheroot = os.environ.get('VISYM_CACHE')
  if cacheroot is None:
    cacheroot = path.join(os.environ['HOME'],'.visym')    
  cachedir = None
  dbname = None
  dbfile = None
  db = None  # database object
  dataset = None  # dataset object
  evalset = None
  
  def __init__(self, dbfile, version=0, task=None, split=None, kfold=0, datastep=1):
    # Database path resolution
    try:
      if not type(dbfile) is str:
        raise IOError()
      elif os.path.isfile(dbfile):
        pass
      elif viset.util.isurl(dbfile):
        url = dbfile
        filename = path.join(self.cachedir, path.basename(url))
        viset.download.download(url, filename)
        dbfile = filename
      elif os.path.isfile(path.join(self.cacheroot,dbfile)):
        dbfile = path.join(self.cacheroot,dbfile)
      else:
        raise IOError()
    except:
      print 'Invalid database file "' + str(dbfile) + '"'
      raise 

    # Database import
    try:
      self.db = tables.open_file(dbfile, mode = "r")
      if not viset.util.isurl(self.db.root.about[0]['viset']):
        raise IOError()        
      if self.db.root.about[0]['version'] != version:
        raise IOError()
    except:
      print 'Invalid database file "' + str(dbfile) + '"'
      raise       
    self.dbfile = dbfile
    self.dbname = self.name()
    self.cachedir = path.join(self.cacheroot, self.dbname)
    if not path.exists(self.cachedir):
      os.makedirs(self.cachedir)
    self.install()

    # Dataset initialization
    if task is None:
      for node in self.db.root.annotation:      
        self.dataset = node     # first annotation 
        break
    else:
      for node in self.db.root.annotation:
        if node.attrs.annotype == task:
          self.dataset = node   # annotation of matching type
          break
    if self.dataset is None:
      print 'Invalid dataset task "' + str(task) + '"'
      raise IOError()

    # Dataset split  
    if split is None:
      itrain = imap(self.read,islice(count(), 0, self.num_annotations(), datastep))
      self.evalset = {'dataset':itrain}
    elif split == '2fold':
      id_split = range(self.num_annotations())
      random.shuffle(id_split)
      itrain = imap(self.read,(id_split[:len(id_split)/2]))
      itest = imap(self.read,(id_split[len(id_split)/2:]))      
      self.evalset = {'train':itrain,'test':itest,'validate':None}
    elif split == 'kfold':            
      id_split = range(self.num_annotations())  
      random.shuffle(id_split)
      foldsize = len(id_split)/kfold
      self.evalset = [None]*kfold  # empty list
      for k in range(0,kfold-1):
        u = k*foldsize
        v = u + foldsize        
        id_test = id_split[u:v]
        id_train = id_split
        id_train[u:v] = []
        itrain = imap(self.read, id_train)
        itest = imap(self.read, id_test)
        self.evalset[k] =  {'train':itrain,'test':itest,'validate':None}
    else:
      # Retrieve split stored in SplitTable in database?
      print 'TODO: retrieve stored splits in database' 
      raise IOError()

    
  def __len__(self):
    """Return the total number of annotations in the current dataset"""
    return self.dataset.nrows
    
  def __getitem__(self,k):
    """Return an evaluation set"""    
    return self.evalset[k]
      
  def __del__(self):
    if self.db is not None:
      self.close()
        
  def __str__(self):
    return self.summary()

  def summary(self):
    print 'Database file: ' + self.dbfile
    print 'Database version: ' + str(self.version())
    print 'Cache root directory: ' + self.cacheroot
    print 'Cache directory: ' + self.cachedir
    print 'Image storage type: ' + '\'' + self.imstoragetype() + '\''
    print 'Database annotations: ' + str(self.annotypes())
    print 'Dataset name: ' + str(self.name())    
    print 'Dataset annotation: ' + '\'' + self.annotype() + '\''
    print 'Number of images: ' + str(self.num_images())
    print 'Number of annotations: ' + str(self.num_annotations())
    print 'HDF5 database file structure: ' 
    print '  ' + "\n  ".join(str(self.db).split("\n"))
    return str()
    
  def open(self):
    self.db = tables.open_file(self.dbfile, mode = "r")    
    return self.db

  def close(self):
    self.db.close()

  def delete(self):
    print + 'Deleting all cached data in "' + self.cachedir + '" for dataset "' + self.dbfile + '"'
    shutil.rmtree(self.cachedir)

  def name(self):
    return self.db.root.about[0]['name']
  
  def version(self):
    return self.db.root.about[0]['version']

  def imstoragetype(self):
    return self.db.root.about[0]['imstoragetype']
  
  def about(self):
    return self.db.root.about[0]

  def labels(self):
    """Unique category ids in dataset"""        
    return np.unique(self.dataset.cols.id_category[:])

  def num_categories(self):
    return len(self.labels())
  
  def num_annotations(self):
    """Number of annotations in dataset"""    
    return self.dataset.nrows

  def annotype(self):
    """Annotation type of dataset"""
    return self.dataset.attrs.annotype
  
  def num_annotypes(self):
    """Total number of annotation types in database"""
    return len(self.annotypes())

  def annotypes(self):
    """List of annotation types in database"""    
    annolist = []
    for node in self.db.root.annotation:
      annolist.append(node.attrs.annotype)
    return annolist

  def num_images(self):
    """Total number of images in database"""
    return self.db.root.images.nrows
  
  def y(self, id_category):
    """Return a binary label vector y \in {-1,1} for dataset elements of a given category id"""        
    id = self.dataset.cols.id_category[:]
    return 2*(np.int32(id == id_category)) - 1

  
  def image(self, im):
    id = im['id']
    if self.imstoragetype() == 'Package':
      # Local path
      im = pylab.imread(self.db.root.images[id]['url'])
    elif self.imstoragetype() == 'url':
      url = self.db.root.images[id]['url']
      filename = path.join(self.cachedir,hashlib.sha1(url).hexdigest()+os.path.splitext(url)[1])
      if not os.path.isfile(filename):
        print 'Retrieving ' + url
        try:
          socket.setdefaulttimeout(10)  # do not set globally
          urllib.urlretrieve(url, filename)
          im = pylab.imread(filename)
        except:
          print 'Download failed ... skipping\n'
          time.sleep(1)  # for ctrl-c
          im = None
      else:
        im = pylab.imread(filename)

    elif self.imstoragetype() == 'MNIST':
      # binary reader for MNIST
      im = None
    elif self.imstoragetype() == '80MTinyImages':
      # binary reader for 80 million tiny images
      im = None
    else:
      raise IOError("undefined image type"+str(self.imstoragetype()))

    return im
  
    # this returns an numpy structured array
    # im.dtypes.names, dtype.fields
    # http://docs.scipy.org/doc/numpy/user/basics.rec.html

    
  def read(self, id_anno):
    im = self.image(self.db.root.images[self.dataset[id_anno]['id_img']])
    anno = self.dataset[id_anno]
    return (im,anno)  # tuple
      
  def format(self, dbfile, dbname, annotationlist=[], imstoragetype='url'):
    # HDF5 file structure
    db = tables.open_file(dbfile, mode = "w", title = dbname)
    #group = db.create_group("/", 'dataset', 'Dataset')
    tbl_about = db.create_table(db.root, 'about', AboutTable, expectedrows=1, title='About')    
    tbl_packages = db.create_table(db.root, 'packages', PackageTable, title='Packages', expectedrows=10)
    tbl_splits = db.create_table(db.root, 'splits', SplitTable, title='Splits', expectedrows=1000000)
    tbl_images = db.create_table(db.root, 'images', ImageTable, title='Images', filters=tables.Filters(complevel=1), expectedrows=1000000)
    
    # Annotation tables
    group = db.create_group(db.root, 'annotation', 'Annotation')    
    for annotype in viset.util.tolist(annotationlist):
      if annotype == 'Categorization':
        tbl_anno = db.create_table(group, 'categorization', CategorizationAnnotationTable, title='Categorization', filters=tables.Filters(complevel=1), expectedrows=1000000)
      elif annotype == 'Detection':
        tbl_anno = db.create_table(group, 'detection', DetectionAnnotationTable, title='Detection', filters=tables.Filters(complevel=1), expectedrows=1000000)
      else:
        raise IOError('Invalid annotation type ' + str(annotype))
      tbl_anno.attrs.annotype = annotype
      tbl_anno.flush()
      
    # Metadata
    r = tbl_about.row
    r['viset'] = 'http://www.visym.com'
    r['version'] = 0
    r['name'] = dbname
    r['imstoragetype'] = imstoragetype
    r.append()

    # Write
    tbl_about.flush()
    tbl_packages.flush()
    tbl_images.flush()
    tbl_splits.flush()       
    db.close()

  def install(self):
    if self.db.root.packages.nrows > 0:
      packages = self.db.root.packages.iterrows()
      for pkg in packages:
        if not path.exists(path.join(self.cachedir, pkg['path'])):
          viset.download.download_and_extract(pkg['url'], self.cachedir, sha1=pkg['sha1'])      


