import tables
import os
from os import path
import viset.download
from viset.dataset import SegmentationViset, ImageArchiveViset
from viset.util import isimg

class WeizmannHorsesSingleScale(SegmentationViset, ImageArchiveViset):
    URL = 'http://jamie.shotton.org/work/data/WeizmannSingleScale.zip'
    SHA1 = '2d90eedfedea31ebd97294a9268c06fbd4e52332'
    IMGPATH = os.path.join('horses','images')
    MASKPATH = os.path.join('horses','masks')  
    _dbname = 'weizmann_horses_singlescale'

    def export(self, redo=False):
        # Create empty database
        self.create()

        # Fetch data necessary to initial construction
        viset.download.cache_and_extract(self.URL, self.cachedir, sha1=self.SHA1)              

        # Write images to database
        id_img = 0  
        imdir = path.join(self.cachedir, self.IMGPATH)      
        for filename in os.listdir(imdir):
            if isimg(filename):
                self.add_image(id_img, path.join(imdir, filename), 1, id_img)
                id_img += 1
        self.write()

        # Write labels to database
        id_img = 0
        imdir = path.join(self.cachedir, self.MASKPATH)              
        for filename in os.listdir(imdir):
            if isimg(filename):
                self.dbview.add_segmentation(path.join(imdir, filename), id_img)
                id_img += 1
        self.write()

        # Add one data package for download
        self.add_package(self.URL, self.SHA1)
        
        # Cleanup
        self.close()
        return self.dbfile

class WeizmannHorsesMultiScale(SegmentationViset, ImageArchiveViset):
    URL = 'http://jamie.shotton.org/work/data/WeizmannMultiScale.zip'
    SHA1 = None
    IMGPATH = 'Images'
    MASKPATH = 'Masks'
    _dbname = 'weizmann_horses_multiscale'

    def export(self, redo=False):
        # Create empty database
        self.create()

        # Fetch data necessary to initial construction
        viset.download.cache_and_extract(self.URL, self.cachedir, sha1=self.SHA1)              

        # Write images to database
        id_img = 0  
        imdir = path.join(self.cachedir, self.IMGPATH)      
        for filename in os.listdir(imdir):
            if isimg(filename):
                self.add_image(id_img, path.join(imdir, filename), 1, id_img)
                id_img += 1
        self.write()

        # Write labels to database
        id_mask = 0
        id_img = 0
        imdir = path.join(self.cachedir, self.MASKPATH)              
        filenames = os.listdir(imdir)
        for id in range(0,49):
            self.dbview.add_segmentation(path.join(imdir, filenames[id_mask]), id_img, 'horses')
            id_mask += 1
            id_img += 1
        self.write()

        for id in range(50,99):
			self.dbview.add_segmentation('', id_img, 'background')
			id_img += 1
        self.write()

        for id in range(100,149):
			self.dbview.add_segmentation(path.join(imdir, filenames[id_mask]), id_img, 'horses')
			id_mask += 1
			id_img += 1
        self.write()

        for id in range(150,199):
			self.dbview.add_segmentation('', id_img, 'background')
			id_img += 1
        self.write()

        for id in range(200,147):
			self.dbview.add_segmentation(path.join(imdir, filenames[id_mask]), id_img, 'horses')
			id_mask += 1
			id_img += 1
        self.write()

        for id in range(248,655):
			self.dbview.add_segmentation('', id_img, 'background')
			id_img += 1
        self.write()

        # FIXME: add bounding boxes
		
        # Add one data package for download
        self.add_package(self.URL, self.SHA1)
        
        # Cleanup
        self.close()
        return self.dbfile
    


