import tables
import os
from os import path
import viset.download
from viset.dataset import CategorizationViset, ImageArchiveViset

class Caltech(CategorizationViset, ImageArchiveViset):
    def export(self, redo=False):
        # Create new database
        (self.dbobj, dbname, dbfile, dburl, cachedir) = self.create(self.dbname, self.dbtype, self.dbimtype, self.dbversion)

        # Fetch data necessary to initial construction
        if not path.exists(path.join(cachedir,self.PATH)) or redo is True:
            viset.download.download_and_extract(self.URL, cachedir, sha1=self.SHA1)              

        # Write images to database
        id_img = 0  
        labeldir = path.join(cachedir, self.PATH)      
        for label in os.listdir(labeldir):
            imdir = path.join(labeldir,label)        
            for im in os.listdir(imdir):
                self.add_image(id_img, path.join(self.PATH,imdir,im), 1, id_img)
                id_img += 1
        self.write()

        # Write labels to database
        id_anno = 0
        id_category = 0
        for label in os.listdir(labeldir):
            imdir = path.join(labeldir,label)        
            for im in os.listdir(imdir):
                self.add_categorization(id_category, label, id_anno)
                id_anno += 1
            id_category += 1
        self.write()
        
        # Add one data package for download
        self.add_package(self.URL, self.SHA1, self.PATH)
        
        # Cleanup
        self.close()
        return dbfile
    
class Caltech101(Caltech):
  URL = ('http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz')
  SHA1 = 'b8ca4fe15bcd0921dfda882bd6052807e63b4c96'
  PATH = '101_ObjectCategories'
  dbname = 'caltech101'
  
class Caltech256(Caltech):
  URL = ('http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar')
  SHA1 = '2195e9a478cf78bd23a1fe51f4dabe1c33744a1c'
  PATH = '256_ObjectCategories'
  dbname = 'caltech256'
    


