import tables
import os
from os import path
import viset.download
from viset.dataset import Viset

class ImageNet(Viset):
  def __init__(self):
    self.cachedir = path.join(self.cacheroot, self.dbname)      
    if not path.exists(self.cachedir):
      os.makedirs(self.cachedir)
    self.dbfile = path.join(self.cacheroot, self.dbname+'.h5')
    self.imstoragetype = 'url'
      
  def export(self, redo=False):            
    # Fetch textfile for construction
    if not os.path.isfile(path.join(self.cachedir,self.TXTFILE)) or redo is True:
      viset.download.download_and_extract(self.URL, self.cachedir, sha1=self.SHA1)              

    # Format dataset
    super(ImageNet,self).format(self.dbfile, self.dbname, 'Categorization', self.imstoragetype)
      
    # Open text file
    db = tables.open_file(self.dbfile, mode = "a", title = self.dbname)    
    id_img = 0
    id_anno = 0
    tbl_images = db.root.images
    tbl_anno = db.root.annotation.categorization
    txtfile = path.join(self.cachedir,self.TXTFILE)
    for line in open(txtfile,'r'):
      try:
        name, url = line.rstrip().split('\t')
        id_wordnet, suffix = name.rstrip().split('_')      

        annorow = tbl_anno.row
        annorow['id_img'] = id_img
        annorow['id_category'] = id_anno
        annorow['category'] = id_wordnet
        annorow.append()

        imrow = tbl_images.row
        imrow['id'] = id_img
        imrow['url'] = url
        imrow['n_anno'] = 1
        imrow['id_anno'] = id_anno
        imrow.append()
        id_img += 1
        id_anno += 1

      except:
        print 'Warning: Ignoring malformed line ' + line
        pass      

    # Cleanup
    tbl_anno.flush()      
    tbl_images.flush()
    db.close()
    return self.dbfile


class ImageNetFall2011(ImageNet):
  URL = 'http://www.image-net.org/archive/imagenet_fall11_urls.tgz'
  SHA1 = 'f5fd118232b871727fe333778be81df6c6fec372'
  dbname = 'imagenet_fall2011'
  TXTFILE = 'fall11_urls.txt'
  
