from viset.library.caltech101 import Caltech101
from viset.dataset import Viset
from viset.show import imshow
from time import sleep
import os.path


def test_create():
    print '[test_viset.test_create]: creating new viset for caltech101'    
    return Caltech101().export()
    
def test_open(dbfile):
    print '[test_viset.test_open]: open database'  
    db = Viset(dbfile)

    print '[test_viset.test_open]: open database from cache'
    db = Viset(os.path.basename(dbfile))
    
    print '[test_viset.test_open]: trying to open invalid database'  
    try:
        try:
            db = Viset('invalid.path.to.h5')  
            raise IOError('no error opening invalid dataset')
        except:
            print 'Correctly raised exception'        
        try:
            db = Viset(1)  
            raise IOError('no error opening invalid dataset')            
        except:
            print 'Correctly raised exceptions'        
        try:
            db = Viset('http://www.visym.com')
            raise IOError('no error opening invalid dataset')                        
        except:
            print 'Correctly raised exceptions'        
    except:
        raise

    
def test_summary(dbfile):
    print '[test_viset.test_summary]: displaying summary of database'    
    print Viset(dbfile)


def test_iterator(dbfile):
    # why does this fail?  need to explicitly set db=Viset(...) then db.image in iterator
    for im in Viset(dbfile, verbose=True).image:
        print type(im)
    
def test_split(dbfile):
    print '[test_viset.test_split]: dataset splits'  
    db = Viset(dbfile)

    # Show dataset image
    print '[test_viset.test_split]: show single image'      
    imshow(db.image[0])

    # Show all images with a step
    print '[test_viset.test_split]: show all images with step'          
    for im in db.image(step=1000):
        imshow(im)

    # Show all annotations
    print '[test_viset.test_split]: show all annotations with step'              
    for (im, annotation) in db.annotation.categorization(step=1000):
        print 'Image=' + str(annotation['id_img']) + ' Category=' + annotation['category']
        imshow(im)

    # Show a random training/testing split
    print '[test_viset.test_split]: training/testing split'
    for split in db.annotation.categorization(step=1000, strategy='train-test', randomize=False):
        for (im, annotation) in split.train:
            print 'Training image=' + str(annotation['id_img']) + ' Category=' + annotation['category']
            imshow(im)
        for (im, annotation) in split.test:
            print 'Testing image=' + str(annotation['id_img']) + ' Category=' + annotation['category']
    
    # Non iterated splits
    print '[test_viset.test_split]: random training/testing split'    
    (train, test) = db.annotation.categorization.split(step=500, strategy='train-test', randomize=True)
    for (im, annotation) in train:
        print 'Training image=' + str(annotation['id_img']) + ' Category=' + annotation['category']        
        imshow(im)
    for (im, annotation) in test:
        print 'Testing image=' + str(annotation['id_img']) + ' Category=' + annotation['category']
        imshow(im)

    # kfold splits
    for fold in db.annotation.categorization(step=1000, strategy='kfold', folds=3):
        print '[test_viset.test_split]: kfold split - ' + str(len(fold))
        for (im, annotation) in fold.train:
            print 'Training image=' + str(annotation['id_img']) + ' Category=' + annotation['category']
            imshow(im)
        for (im, annotation) in fold.test:
            print 'Testing image=' + str(annotation['id_img']) + ' Category=' + annotation['category']
            imshow(im)

    
def test_labeling(dbfile):
    print '[test_viset.test_split]: dataset labeling'
    db = Viset(dbfile)
    print db.y(db.labels()[0])

    
def test_verbosity(dbfile):
    print '[test_viset.test_verbosity]: verbosity is true'
    db = Viset(dbfile, verbose=True)
    print '[test_viset.test_verbosity]: verbosity is false'
    db = Viset(dbfile, verbose=False)

def main():
    dbfile = test_create()
    test_open(dbfile)
    test_summary(dbfile)
    test_split(dbfile)
    test_verbosity(dbfile)    
    #test_iterator(dbfile)
    
if __name__ == '__main__':
  main()

  
