import tables
import os
from os import path
import viset.download
from viset.dataset import SegmentationViset
from viset.util import isimg

class WeizmannHorsesSingleScale(SegmentationViset):
    URL = 'http://jamie.shotton.org/work/data/WeizmannSingleScale.zip'
    SHA1 = '2d90eedfedea31ebd97294a9268c06fbd4e52332'
    IMGPATH = os.path.join('horses','images')
    MASKPATH = os.path.join('horses','masks')  
    _dbname = 'weizmann_horses_singlescale'
    verbose = True
    
    def export(self):
        # Create empty database
        self.create()

        # Fetch data necessary to initial construction
        pkgdir = self.dbcache.get(self.URL, sha1=self.SHA1)        

        # Write images to database
        id_img = 0  
        imdir = path.join(pkgdir, self.IMGPATH)      
        for filename in os.listdir(imdir):
            if isimg(filename):
                self.add_image(id_img, path.join(self.IMGPATH, filename), 1, id_img, url=self.URL)
                id_img += 1
        self.write()

        # Write labels to database
        id_img = 0
        imdir = path.join(pkgdir, self.MASKPATH)              
        for filename in os.listdir(imdir):
            if isimg(filename):
                self.dbview.add_segmentation(path.join(self.MASKPATH, filename), id_img, url=self.URL)
                id_img += 1
        self.write()

        # Cleanup
        self.close()
        return self.dbfile

class WeizmannHorsesMultiScale(SegmentationViset):
    URL = 'http://jamie.shotton.org/work/data/WeizmannMultiScale.zip'
    SHA1 = None
    IMGPATH = 'Images'
    MASKPATH = 'Masks'
    _dbname = 'weizmann_horses_multiscale'

    def export(self, redo=False):
        # Create empty database
        self.create()

        # Fetch data necessary to initial construction
        pkgdir = self.dbcache.get(self.URL, sha1=self.SHA1)        

        # Write images to database
        id_img = 0  
        imdir = path.join(pkgdir, self.IMGPATH)      
        for filename in os.listdir(imdir):
            if isimg(filename):
                print self.add_image(id_img, path.join(self.IMGPATH, filename), 1, id_img, url=self.URL)
                id_img += 1
        self.write()

        # Write labels to database
        id_mask = 0
        id_img = 0
        imdir = path.join(pkgdir, self.MASKPATH)              
        filenames = os.listdir(imdir)
        for id in range(0,49):
            self.dbview.add_segmentation(path.join(self.IMGPATH, filenames[id_mask]), id_img, 'horses', url=self.URL)
            id_mask += 1
            id_img += 1
        self.write()

        for id in range(50,99):
			self.dbview.add_segmentation(None, id_img, 'background')
			id_img += 1
        self.write()

        for id in range(100,149):
			self.dbview.add_segmentation(path.join(self.IMGPATH, filenames[id_mask]), id_img, 'horses', url=self.URL)
			id_mask += 1
			id_img += 1
        self.write()

        for id in range(150,199):
			self.dbview.add_segmentation(None, id_img, 'background')
			id_img += 1
        self.write()

        for id in range(200,147):
			self.dbview.add_segmentation(path.join(self.IMGPATH, filenames[id_mask]), id_img, 'horses', url=self.URL)
			id_mask += 1
			id_img += 1
        self.write()

        for id in range(248,655):
			self.dbview.add_segmentation(None, id_img, 'background')
			id_img += 1
        self.write()

        # FIXME: add bounding boxes
		
        # Cleanup
        self.close()
        return self.dbfile
    


