import tables
import os
from os import path
import shutil
import viset.download
import pylab
import numpy as np 
import urllib
import socket
import hashlib
import time

socket.setdefaulttimeout(10)  # for urlretrieve

class AboutTable(tables.IsDescription):
  version = tables.Int64Col(pos=0)   
  name = tables.StringCol(256, pos=1)   
  task = tables.StringCol(256, pos=1)     
  image = tables.StringCol(256, pos=1)  # hint for binary decode

class PackageTable(tables.IsDescription):
  url = tables.StringCol(512, pos=1)   
  sha1 = tables.StringCol(512, pos=1)     
  path = tables.StringCol(512, pos=1)   

class ImageTable(tables.IsDescription):
  id = tables.Int64Col(pos=0)   
  url = tables.StringCol(512, pos=1)   
  id_anno = tables.Int64Col(pos=2)   
  n_anno = tables.Int32Col(pos=3)   
  
class CategorizationAnnotationTable(tables.IsDescription):
  id = tables.Int64Col(pos=0)
  id_img = tables.Int64Col(pos=1)
  id_category = tables.Int32Col(pos=2)  
  category = tables.StringCol(256, pos=3)      

class DetectionAnnotationTable(tables.IsDescription):
  id = tables.Int64Col(pos=0)   
  id_img = tables.Int64Col(pos=0)
  category = tables.Int32Col()   
  bbox_xmin = tables.UInt16Col()
  bbox_xmax = tables.UInt16Col()
  bbox_ymin = tables.UInt16Col()
  bbox_ymax = tables.UInt16Col()    
  
class Viset(object):
  dbname = None
  dbfile = None
  db = None
  cacheroot = path.join(os.environ['HOME'],'.visym')    
  cachedir = None
  
  def __init__(self, dbfile):
    self.db = tables.open_file(dbfile, mode = "r")
    self.dbfile = dbfile
    self.dbname = self.db.root.dataset.about[0]['name'] 
    self.cachedir = path.join(self.cacheroot, self.dbname)
    if not path.exists(self.cachedir):
      os.makedirs(self.cachedir)
    self.install()
    self.imtype = self.db.root.dataset.about[0]['image']
    
  def open(self):
    self.db = tables.open_file(self.dbfile, mode = "r")    
    return self.db
    
  def close(self):
    self.db.close()

  def delete(self):
    shutil.rmtree(self.cachedir)

  def task(self):
    return self.db.root.dataset.about[0]['task']

  def name(self):
    return self.db.root.dataset.about[0]['name']
  
  def version(self):
    return self.db.root.dataset.about[0]['version']

  def about(self):
    return self.db.root.dataset.about[0]

  def labels(self):
    return np.unique(self.db.root.dataset.annotation.cols.id_category[:])

  def y(self, id_category):
    id = self.db.root.dataset.annotation.cols.id_category[:]
    return 2*(np.int32(id == id_category)) - 1
  
  def image(self, im):
    id = im['id']
    if self.imtype == 'Package':
      # Local path
      im = pylab.imread(self.db.root.dataset.images[id]['url'])
    elif self.imtype == 'url':
      url = self.db.root.dataset.images[id]['url']
      filename = path.join(self.cachedir,hashlib.sha1(url).hexdigest()+os.path.splitext(url)[1])
      if not os.path.isfile(filename):
        print 'Retrieving ' + url
        try:
          urllib.urlretrieve(url, filename)
          im = pylab.imread(filename)
        except:
          print 'Download failed ... skipping\n'
          time.sleep(1)  # for ctrl-c
          im = None
      else:
        im = pylab.imread(filename)

    elif self.imtype == 'MNIST':
      # binary reader for MNIST
      im = None
    elif self.imtype == '80MTinyImages':
      # binary reader for 80 million tiny images
      im = None
    else:
      raise IOError("undefined image type"+str(self.imtype))

    return im
  
    # this returns an numpy structured array
    # im.dtypes.names, dtype.fields
    # http://docs.scipy.org/doc/numpy/user/basics.rec.html

  def read(self, id_start=None, n_skip=1, id_end=None, category=None):
    if id_start is None:
      id_start = 0
    if id_end is None:
      id_end = self.db.root.dataset.images.nrows
    if (category is not None) and (type(category) is str):
      category = [category] # singleton list
    for id in range(id_start,id_end,n_skip):
      anno = self.db.root.dataset.annotation[id]
      if (category is None) or (anno['category'] in category):
        im = self.image(self.db.root.dataset.images[id])
        anno = self.db.root.dataset.annotation[id]
        yield (im,anno)  # tuple
      
  def format(self, dbfile, dbname, task, imtype):
    # HDF5 file structure
    db = tables.open_file(dbfile, mode = "w", title = dbname)
    group = db.create_group("/", 'dataset', 'Dataset')
    tbl_about = db.create_table(group, 'about', AboutTable, expectedrows=1, title='About')    
    tbl_packages = db.create_table(group, 'packages', PackageTable, title='Packages', expectedrows=10)
    tbl_images = db.create_table(group, 'images', ImageTable, title='Images', filters=tables.Filters(complevel=1), expectedrows=1000000)
    if task == 'Categorization':
      tbl_anno = db.create_table(group, 'annotation', CategorizationAnnotationTable, title='Annotation', filters=tables.Filters(complevel=1), expectedrows=1000000)
    elif task == 'Detection':
      tbl_anno = db.create_table(group, 'annotation', DetectionAnnotationTable, title='Annotation', filters=tables.Filters(complevel=1), expectedrows=1000000)
    else:
      pass

    # Metadata
    r = tbl_about.row
    r['version'] = 0
    r['name'] = dbname
    r['task'] = task
    r['image'] = imtype    
    r.append()

    # Write
    tbl_about.flush()
    tbl_packages.flush()
    tbl_images.flush()
    tbl_anno.flush()   
    db.close()

  def install(self):
    if self.db.root.dataset.packages.nrows > 0:
      packages = self.db.root.dataset.packages.iterrows()
      for pkg in packages:
        if not path.exists(path.join(self.cachedir, pkg['path'])):
          viset.download.download_and_extract(pkg['url'], self.cachedir, sha1=pkg['sha1'])      


