from viset.library.dummy import Dummy
from viset.dataset import Viset
from viset.show import imshow
from time import sleep
import os.path
from nose.tools import assert_equal
import numpy
from numpy.testing import assert_array_equal

dbfile = Dummy().export()

def test_create():
    print '[test_viset.test_create]: creating new viset for caltech101'    
    dbfile = Dummy().export()
    assert_equal(os.path.basename(dbfile), 'dummy.h5')
    
def test_open():
    print '[test_viset.test_open]: open database'  
    db = Viset(dbfile)

    print '[test_viset.test_open]: open database from cache'
    db = Viset(os.path.basename(dbfile))
    
    print '[test_viset.test_open]: trying to open invalid database'  
    try:
        try:
            db = Viset('invalid.path.to.h5')  
            raise IOError('no error opening invalid dataset')
        except:
            print 'Correctly raised exception'        
        try:
            db = Viset(1)  
            raise IOError('no error opening invalid dataset')            
        except:
            print 'Correctly raised exceptions'        
        try:
            db = Viset('http://www.visym.com')
            raise IOError('no error opening invalid dataset')                        
        except:
            print 'Correctly raised exceptions'        
    except:
        raise

    
def test_summary():
    print '[test_viset.test_summary]: displaying summary of database'    
    print Viset(dbfile)


def test_iterator():
    # why does this fail?  need to explicitly set db=Viset(...) then db.image in iterator
    #for im in Viset(dbfile, verbose=True).image:
    #    print type(im)

    db = Viset(dbfile)
    for im in db.image:
        pass
    k = 0
    for (im,anno) in db.annotation.categorization:    
        assert_equal(anno['id_img'], k)
        k += 1
    
def test_split():
    print '[test_viset.test_split]: dataset splits'  
    db = Viset(dbfile)
    
    # Show dataset image
    print '[test_viset.test_split]: show single image'      
    imshow(db.image[0])

    # Show all images with a step
    print '[test_viset.test_split]: show all images with step'          
    for im in db.image(step=1):
        imshow(im)

    # Show all annotations
    print '[test_viset.test_split]: show all annotations with step'              
    for (im, annotation) in db.annotation.categorization(step=1):
        print 'Image=' + str(annotation['id_img']) + ', Category=' + annotation['category']
        imshow(im)

    # Show a random training/testing split
    print '[test_viset.test_split]: training/testing split'
    for split in db.annotation.categorization(step=1, strategy='train-test', randomize=False):
        for (im, annotation) in split.train:
            print 'Training image=' + str(annotation['id_img']) + ', Category=' + annotation['category']
            imshow(im)
        for (im, annotation) in split.test:
            print 'Testing image=' + str(annotation['id_img']) + ', Category=' + annotation['category']
    
    # Non iterated splits
    print '[test_viset.test_split]: random training/testing split'    
    (train, test) = db.annotation.categorization.split(step=1, strategy='train-test', randomize=True)
    for (im, annotation) in train:
        print 'Training image=' + str(annotation['id_img']) + ', Category=' + annotation['category']        
        imshow(im)
    for (im, annotation) in test:
        print 'Testing image=' + str(annotation['id_img']) + ', Category=' + annotation['category']
        imshow(im)

    # kfold splits
    for fold in db.annotation.categorization(step=1, strategy='kfold', folds=2):
        print '[test_viset.test_split]: kfold split - ' + str(len(fold))
        for (im, annotation) in fold.train:
            print 'Training image=' + str(annotation['id_img']) + ', Category=' + annotation['category']
            imshow(im)
        for (im, annotation) in fold.test:
            print 'Testing image=' + str(annotation['id_img']) + ', Category=' + annotation['category']
            imshow(im)

    
def test_labeling():
    print '[test_viset.test_split]: dataset labeling'
    db = Viset(dbfile)
    assert_array_equal(db.annotation.categorization.y(0), numpy.array([1,-1,-1,-1,-1]))

    
def test_verbosity():
    print '[test_viset.test_verbosity]: verbosity is true'
    db = Viset(dbfile, verbose=True)
    print '[test_viset.test_verbosity]: verbosity is false'
    db = Viset(dbfile, verbose=False)

if __name__ == '__main__':
    print 'sh> nosetests /path/to/file.py -s'

  
