from viset.library.caltech101 import Caltech101
from viset.dataset import Viset
from viset.show import imshow
from time import sleep
import os.path


def test_create():
    print '[test_viset.test_create]: creating new viset for caltech101'    
    return Caltech101().export()
    
def test_open(dbfile):
    print '[test_viset.test_open]: open database'  
    db = Viset(dbfile)

    print '[test_viset.test_open]: open database from cache'
    db = Viset(os.path.basename(dbfile))
    
    print '[test_viset.test_open]: trying to open invalid database'  
    try:
        try:
            db = Viset('invalid.path.to.h5')  
            raise IOError('no error opening invalid dataset')
        except:
            print 'Correctly raised exception'        
        try:
            db = Viset(1)  
            raise IOError('no error opening invalid dataset')            
        except:
            print 'Correctly raised exceptions'        
        try:
            db = Viset('http://www.visym.com')
            raise IOError('no error opening invalid dataset')                        
        except:
            print 'Correctly raised exceptions'        
    except:
        raise

    
def test_summary(dbfile):
    print '[test_viset.test_summary]: displaying summary of database'    
    print Viset(dbfile)


def test_iterator(dbfile):
    # why does this fail?  need to explicitly set db=Viset(...) then db.image in iterator
    for im in Viset(dbfile, verbose=True).image:
        print type(im)
    
def test_split(dbfile):
    print '[test_viset.test_split]: dataset splits'  
    db = Viset(dbfile, strategy='kfold', kfold=3, datastep=1000)
    for k in range(0,2):
        for (im,anno) in db[k]['test']:
            print anno
    for (im,anno) in Viset(dbfile, datastep=1000)[0]:
        print anno
    db = Viset(dbfile, strategy='train_test_shuffle', datastep=1000)
    for (im,anno) in db['train']:
        print anno
    for (im,anno) in db['test']:
        print anno
    for (im,anno) in Viset(dbfile).view().split(strategy='train_test_shuffle', datastep=1000)['train']:
        print anno
    for (im,anno) in Viset(dbfile).view().split(strategy='train_test_shuffle', datastep=1000)['test']:
        print im
        print anno
    for (im,anno) in Viset(dbfile).view(async=False).split(strategy='train_test_shuffle', datastep=1000)['test']:
        print anno
        imshow(im)

def test_labeling(dbfile):
    print '[test_viset.test_split]: dataset labeling'
    db = Viset(dbfile)
    print db.y(db.labels()[0])

    
def test_verbosity(dbfile):
    print '[test_viset.test_verbosity]: verbosity is true'
    db = Viset(dbfile, verbose=True)
    print '[test_viset.test_verbosity]: verbosity is false'
    db = Viset(dbfile, verbose=False)

def main():
    dbfile = test_create()
    test_open(dbfile)
    test_summary(dbfile)
    #    test_split(dbfile)
    test_verbosity(dbfile)    
    #test_iterator(dbfile)
    
if __name__ == '__main__':
  main()

  
