import tables
import os
from os import path
import viset.download
from viset.dataset import CategorizationViset, CustomImageViset, ImageArchiveViset
import numpy as np
from viset.util import isfile
import gzip
import struct
from array import array
import numpy
import urlparse

def labels(gzfile):
    with gzip.open(gzfile, 'rb') as file:
        magic, size = struct.unpack(">II", file.read(8))
        if magic != 2049:
            raise ValueError('Magic number mismatch, expected 2049,'
                            'got %d' % magic)
        labels = array("B", file.read())
    return labels


def imread(url):
    """Read MNIST encoded images, adapted from: https://github.com/sorki/python-mnist/blob/master/mnist/loader.py"""
    gzfile = str(urlparse.urldefrag(url)[0])
    index = int(urlparse.urldefrag(url)[1])
	
    with gzip.open(gzfile, 'rb') as file:
        magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
        if magic != 2051:
            raise ValueError('Magic number mismatch, expected 2051, got %d' % magic)
        file.seek(index*rows*cols + 16)
        image = numpy.asarray(array("B", file.read(rows*cols)).tolist())
        return numpy.reshape(image, (rows,cols))
		

class MNIST(CategorizationViset, ImageArchiveViset, CustomImageViset):
    TRAIN_IMG_URL = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
    TRAIN_IMG_SHA1 =  '6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d'
    TRAIN_LBL_URL = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
    TRAIN_LBL_SHA1 = '2a80914081dc54586dbdf242f9805a6b8d2a15fc'
    TEST_IMG_URL = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
    TEST_IMG_SHA1 = 'c3a25af1f52dad7f726cce8cacb138654b760d48'
    TEST_LBL_URL = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
    TEST_LBL_SHA1 = '763e7fa3757d93b0cdec073cef058b2004252c17'
    _dbname = 'mnist'
    _imcodec = 'mnist'  # for baked in read (HACK?)
    
    def export(self, redo=False):
        # Create empty database
        self.create()
        
        # Fetch data necessary to initial construction
        train_img_file = path.join(self.cachedir, path.basename(self.TRAIN_IMG_URL))
        train_lbl_file = path.join(self.cachedir, path.basename(self.TRAIN_LBL_URL))
        test_img_file = path.join(self.cachedir, path.basename(self.TEST_IMG_URL))
        test_lbl_file = path.join(self.cachedir, path.basename(self.TEST_LBL_URL))
        if not isfile(train_img_file) or redo is True:
            viset.download.download(self.TRAIN_IMG_URL,  train_img_file, sha1=self.TRAIN_IMG_SHA1)              
        if not isfile(train_lbl_file) or redo is True:
            viset.download.download(self.TRAIN_LBL_URL,  train_lbl_file, sha1=self.TRAIN_LBL_SHA1)              
        if not isfile(test_img_file) or redo is True:
            viset.download.download(self.TEST_IMG_URL,  test_img_file, sha1=self.TEST_IMG_SHA1)              
        if not isfile(test_lbl_file) or redo is True:
            viset.download.download(self.TEST_LBL_URL,  test_lbl_file, sha1=self.TEST_LBL_SHA1)              

        # Write dataset 
        self.write()
        y = labels(train_lbl_file).tolist()
        for k in range(60000):
            self.dbview.add_categorization(int(y[k]), str(y[k]), k)
        self.write()
        for k in range(60000):
            self.add_image(k, path.basename(self.TRAIN_IMG_URL) + '#' + str(k), 1, k)
        self.write()
        
        y = labels(test_lbl_file).tolist()
        for k in range(10000):
            self.dbview.add_categorization(int(y[k]), str(y[k]), k+60000)
        self.write()
        for k in range(10000):
            self.add_image(k+60000, path.basename(self.TEST_IMG_URL) + '#' + str(k), 1, k+60000)			
        self.write()
        
        # Add train image package for download
        self.add_package(self.TRAIN_IMG_URL, self.TRAIN_IMG_SHA1)
        self.add_package(self.TEST_IMG_URL, self.TEST_IMG_SHA1)        

        # Add splits
        # self.add_split(...)
      
        # Cleanup
        self.close()
        return self.dbfile
    
