import tables
import os
from os import path
import viset.download
from viset.dataset import DetectionViset, ImageArchiveViset
import numpy as np

class ETHZShapes(DetectionViset, ImageArchiveViset):      
    def export(self, redo=False):
        # Create empty database
        self.create()
        
        # Fetch data necessary for initial construction
        viset.download.cache_and_extract(self.URL, self.cachedir, sha1=self.SHA1)              
            
        # Write dataset
        id_img = 0
        id_anno = 0     
        id_category = 0
        labeldir = self.LABELS
        for label in labeldir:
            imdir = path.join(self.cachedir,self.PATH,label)        
            for filename in os.listdir(imdir):
                if filename.endswith(".jpg") and not filename.startswith('.'):
                    # Write detections
                    n_detections = 0
                    id_detectionstart = id_anno
                    gtfile = path.join(self.cachedir,self.PATH,label,os.path.splitext(path.basename(filename))[0] + '_' + label.lower()+'.groundtruth')
                    if not os.path.isfile(gtfile):
                        gtfile = path.join(self.cachedir,self.PATH,label,os.path.splitext(path.basename(filename))[0] + '_' + label.lower()+'s.groundtruth') # plural hack
                    for line in open(gtfile,'r'):
                        if line.strip() == '':
                            continue
                        (xmin,ymin,xmax,ymax) = line.strip().split()
                        self.dbview.add_detection(id_category, label, id_img, np.float32(xmin), np.float32(xmax), np.float32(ymin), np.float32(ymax))
                        id_anno += 1
                        n_detections += 1              
                    
                    # Write image
                    self.add_image(id_img, path.join(imdir,filename), n_detections, id_detectionstart)
                    id_img += 1
            id_category += 1
                    
        # Add one data package for download
        self.add_package(self.URL, self.SHA1)

        # Cleanup
        self.close()
        return self.dbfile
    
class ETHZShapeClasses(ETHZShapes):
  URL = 'http://www.vision.ee.ethz.ch/datasets/downloads/ethz_shape_classes_v12.tgz'
  SHA1 = 'ae9b8fad2d170e098e5126ea9181d0843505a84b'
  PATH = 'ETHZShapeClasses-V1.2'
  LABELS = ['Applelogos','Bottles','Giraffes','Mugs','Swans']
  _dbname = 'ethzshapes'
  
class ETHZExtendedShapeClasses(ETHZShapes):
  URL = 'http://www.vision.ee.ethz.ch/datasets/downloads/extended_ethz_shapes.tgz'
  SHA1 = None
  PATH = 'extended_ethz_shapes'
  LABELS = ['apple','bottle','giraffe','hat','mug','starfish','swan']
  _dbname = 'ethzshapes_extended'

