import tables
import os
from os import path
import viset.download
from viset.dataset import Viset

class Caltech(Viset):
    def __init__(self):
      self.cachedir = path.join(self.cacheroot, self.dbname)      
      if not path.exists(self.cachedir):
        os.makedirs(self.cachedir)
      self.dbfile = path.join(self.cacheroot, self.dbname+'.h5')
      self.imstoragetype = 'Package'
      
    def export(self, redo=False):
      # Fetch image data for construction
      if not path.exists(path.join(self.cachedir,self.SUBDIR)) or redo is True:
        viset.download.download_and_extract(self.URL, self.cachedir, sha1=self.SHA1)              

      # Format dataset
      super(Caltech,self).format(self.dbfile, self.dbname, 'Categorization', self.imstoragetype)
            
      # Write dataset 
      db = tables.open_file(self.dbfile, mode = "a", title = self.dbname)
      id_img = 0
      id_anno = 0      
      id_category = 0
      tbl_images = db.root.images
      tbl_anno = db.root.annotation.categorization
      labeldir = path.join(self.cachedir, self.SUBDIR)      
      for label in os.listdir(labeldir):
        imdir = path.join(labeldir,label)        
        for im in os.listdir(imdir):
          imrow = tbl_images.row
          imrow['id'] = id_img
          imrow['url'] = path.join(self.SUBDIR,imdir,im)
          imrow['n_anno'] = 1
          imrow['id_anno'] = id_anno
          imrow.append()
          annorow = tbl_anno.row
          annorow['id_img'] = id_img
          annorow['id_category'] = id_category
          annorow['category'] = label
          annorow.append()
          id_img += 1
          id_anno += 1
        tbl_images.flush()
        tbl_anno.flush()
        id_category += 1

      # Add one data package for download
      tbl_package = db.root.packages
      r = tbl_package.row
      r['url'] = self.URL
      r['sha1'] = self.SHA1
      r['path'] = self.SUBDIR
      r.append()
      tbl_package.flush()

      # Cleanup
      db.close()
      return self.dbfile
    
class Caltech101(Caltech):
  URL = ('http://www.vision.caltech.edu/Image_Datasets/'
         'Caltech101/101_ObjectCategories.tar.gz')
  SHA1 = 'b8ca4fe15bcd0921dfda882bd6052807e63b4c96'
  SUBDIR = '101_ObjectCategories'
  dbname = 'caltech101'
  
class Caltech256(Caltech):
  URL = ('http://www.vision.caltech.edu/Image_Datasets/'
         'Caltech256/256_ObjectCategories.tar')
  SHA1 = '2195e9a478cf78bd23a1fe51f4dabe1c33744a1c'
  SUBDIR = '256_ObjectCategories'
  dbname = 'caltech256'
    


