import tables
import os
from os import path
import viset.download
from viset.dataset import Viset
import numpy as np

class ETHZShapes(Viset):
    def __init__(self):
      self.cachedir = path.join(self.cacheroot, self.dbname)      
      if not path.exists(self.cachedir):
        os.makedirs(self.cachedir)
      self.dbfile = path.join(self.cacheroot, self.dbname+'.h5')
      
    def export(self, redo=False):
      # Fetch image data for construction
      if not path.exists(path.join(self.cachedir,self.SUBDIR)) or redo is True:
        viset.download.download_and_extract(self.URL, self.cachedir, sha1=self.SHA1)              

      # Format dataset
      super(ETHZShapes,self).format(self.dbfile, self.dbname, 'Detection', self.IMSTORAGE)
            
      # Write dataset 
      db = tables.open_file(self.dbfile, mode = "a", title = self.dbname)
      id_img = -1
      id_anno = 0     
      id_category = -1
      tbl_images = db.root.images
      tbl_anno = db.root.annotation.detection
      labeldir = self.LABELS
      for label in labeldir:
        id_category += 1
        imdir = path.join(self.cachedir,self.SUBDIR,label)        
        for filename in os.listdir(imdir):
          if filename.endswith(".jpg") and not filename.startswith('.'):
            # Write image
            id_img += 1
            imrow = tbl_images.row
            imrow['id'] = id_img
            imrow['url'] = path.join(imdir,filename)
            imrow['id_anno'] = id_anno
            
            # Detections
            gtfile = path.join(self.cachedir,self.SUBDIR,label,os.path.splitext(path.basename(filename))[0] + '_' + label.lower()+'.groundtruth')
            if not os.path.isfile(gtfile):
              gtfile = path.join(self.cachedir,self.SUBDIR,label,os.path.splitext(path.basename(filename))[0] + '_' + label.lower()+'s.groundtruth') # plural
            if not os.path.isfile(gtfile):
              continue  # skip me
            for line in open(gtfile,'r'):
              if line.strip() == '':
                continue
              (xmin,ymin,xmax,ymax) = line.strip().split()
              annorow = tbl_anno.row
              annorow['bbox_xmin'] = np.float32(xmin)
              annorow['bbox_xmax'] = np.float32(xmax)
              annorow['bbox_ymin'] = np.float32(ymin)
              annorow['bbox_ymax'] = np.float32(ymax)
              annorow['id_category'] = id_category
              annorow['id_img'] = id_img
              annorow['category'] = label
              annorow.append()
              id_anno += 1              
              
            imrow['n_anno'] = id_anno - imrow['id_anno']
            imrow.append()
            tbl_images.flush()
            tbl_anno.flush()

      # Add one data package for download
      tbl_package = db.root.packages
      r = tbl_package.row
      r['url'] = self.URL
      r['sha1'] = self.SHA1
      r['path'] = self.SUBDIR
      r.append()
      tbl_package.flush()

      # Cleanup
      db.close()
      return self.dbfile
    
class ETHZShapeClasses(ETHZShapes):
  URL = 'http://www.vision.ee.ethz.ch/datasets/downloads/ethz_shape_classes_v12.tgz'
  SHA1 = 'ae9b8fad2d170e098e5126ea9181d0843505a84b'
  SUBDIR = 'ETHZShapeClasses-V1.2'
  IMSTORAGE = 'Package'
  LABELS = ['Applelogos','Bottles','Giraffes','Mugs','Swans']
  dbname = 'ethzshapes'
  
class ETHZExtendedShapeClasses(ETHZShapes):
  URL = 'http://www.vision.ee.ethz.ch/datasets/downloads/extended_ethz_shapes.tgz'
  SHA1 = None
  SUBDIR = 'extended_ethz_shapes'
  IMSTORAGE = 'Package'  
  dbname = 'ethzshapes_extended'
  LABELS = ['apple','bottle','giraffe','hat','mug','starfish','swan']


