#! /usr/bin/env python

import ebf
#import pdb
import numpy
import sys
import time
import os

def isfloat(mystr):
    try:
        float(mystr)
        return True
    except ValueError:
        return False

def cformat2type_width(myformat):
    myformat=myformat.strip(' \t\n\r\"\'')
    myformatl=myformat.split('%')
    if len(myformatl[0]) == 0:
        myformatl=myformatl[1:]
    datatypel=[]
    widthl=[]
    for x in myformatl:
        x=x.lstrip(' +-0#')
        width=x.count(' ')
        x=x.strip()
#        temp=x[-1]
#        width=width+int(x[0:-1].split('.')[0])
        temp=x[0]
        width=width+int(x[1:].split('.')[0])
        if temp == 'f':
            datatype='float64'
        elif temp == 'F':
            datatype='float64'
        elif temp == 'g':
            datatype='float64'
        elif temp == 'G':
            datatype='float64'
        elif temp == 'e':
            datatype='float64'
        elif temp == 'E':
            datatype='float64'
        elif temp == 'i':
            datatype='int64'
        elif temp == 'I':
            datatype='int64'
        elif temp == 'd':
            datatype='int64'
        elif temp == 'u':
            datatype='uint64'
        elif temp == 'A':
            datatype='S'
        elif temp == 'S':
            datatype='S'
        elif temp == 's':
            datatype='S'
        else:
            print temp
            datatype=''
            raise RuntimeError('Unknown Data type format')
        datatypel.append(datatype)
        widthl.append(width)
    print datatypel    
    print widthl    
    return [datatypel,widthl]


def ascii_init(filename):
    f=open(filename,'r')
    fields=[]
    datatypes=[]
    widths=[]
    units=[]
    ncols=0

#    print 'reading header'
    # read from header, i.e., commented portion
    for line in f:
        line1=line.strip(' \t\n\r')
        if (line[0] != '#')and(len(line1)>0):
            break
        line=line.strip(' \t\n\r#').split('=')
        quant=line[0].strip()
        if len(line) > 1:
            if quant == 'format':
                [datatypes,widths]=cformat2type_width(line[1])
            else:
                values1=line[1].strip(" []").split(',')
                values=[]
                for value in values1:
                    values.append(value.strip(' \t\n\r\'\"'))
                if quant == 'fields':
                    fields=values
                elif quant == 'datatypes':
                    datatypes=values
                elif quant == 'units':
                    units=values
                elif quant == 'widths':
                    for value in values:
                        widths.append(int(value))


    if len(fields) > 0:
        fields1=[]
        i=0
        for field in fields:
            if len(field) == 0:
                fields1.append('col'+str(i))
                i=i+1
            else:
                fields1.append(field)
        fields=fields1
            

#    print 'reading delimiter'
    # get delimiter
    f.seek(0)
    for line in f:
        dataline=line.strip(' \t\n\r')
        if (line[0] != '#')and(len(dataline)>0):
            if ',' in dataline:
                delimiter=','
            elif '|' in dataline:
                delimiter='|'
            elif ';' in dataline:
                delimiter=';'
            else:
                delimiter=None
            break

#    print 'reading datastart'
    # get datastart and if csv format then also fields
    f.seek(0)
    i=-1
    status=0
    datastart=-1
    for line in f:
        i=i+1
        dataline=line.strip(' \t\n\r')
        if (line[0] != '#')and(datastart == -1)and(len(dataline)>0):
            values=dataline.split(delimiter)
            for value in values:
                value1=value.strip(' \t\n\r\'\"')
                if value1.isdigit() or isfloat(value1):
                    datastart=i
                    break
            if (datastart == -1)and(len(fields)==0):
                for value in values:
                    fields.append(value.strip(' \t\n\r\'\"'))
            else:
                datastart=i
                ncols=len(values)
                break


#    print 'reading datatype'
    # get datatype
    f.seek(0)
    i=-1
    for line in f:
        i=i+1
        dataline=line.strip(' \t\n\r')
        if (line[0] != '#')and(i >= datastart)and(len(dataline)>0):
            values=dataline.split(delimiter)
            if len(datatypes) == 0:
                for value in values:
                    datatypes.append('')
            status=0
            j=0
            for value in values:
                value1=value.strip(' \t\n\r\'\"')                
                if len(datatypes[j]) == 0:
                    if len(value1) == 0:
                        status=1
                    else:    
                        if value1.isdigit():
                            datatypes[j]='int64'
                        elif isfloat(value1):
                            datatypes[j]='float64'
                        else:
                            datatypes[j]='S'
                j=j+1
            if status == 0:
                break

    f.close()

    # initialize to default values  and check length with ncols
    if len(datatypes) != ncols:
        print len(datatypes),ncols,'delimiter=',delimiter
        raise RuntimeError("len(datatypes) != ncols")

    if len(fields) == 0:
        for i in range(0,ncols):
            fields.append('col'+str(i))            
    else:
        if len(fields) != ncols:
            print len(fields),ncols,'delimiter=',delimiter
            raise RuntimeError("len(fields) != ncols")
        
    if len(units) == 0:
        for i in range(0,ncols):
            units.append('')            
    else:
        if len(units) != ncols:
            print len(units),ncols,'delimiter=',delimiter
            raise RuntimeError("len(units) != ncols")

    if len(widths) == 0:
        for i in range(0,ncols):
            widths.append(0)
    else:
        if len(widths) != ncols:
            print len(widths),ncols,'delimiter=',delimiter
            raise RuntimeError("len(widths) != ncols")

    return [fields,datatypes,widths,units,delimiter,datastart]

def read_from_schema(schemafile):
    f=open(schemafile,'r')
    mystr=f.read()
    f.close()
    mystr=mystr.replace('\n','')
    print mystr
    items=mystr.partition('(')[2].partition(')')[0].split(',')
    print items
    fields=[]
    datatypes=[]
    widths=[]
    units=[]
    for item in items:
        item=item.strip()
        if len(item) > 0:
            word=item.split()
            fields.append(word[0].strip(' \'\"'))
            datatypes.append(word[1].strip('\'\"'))
            if len(word) > 2:
                widths.append(int(word[2].strip('\'\"')))
            else:
                widths.append(0)            
            if len(word) > 3:
                units.append(word[3].strip('\'\"'))
            else:
                units.append('')
            if len(word) > 4:
                raise RuntimeError('Max 4 space separated word allowed. Enclose units in quotes if needed')

    return [fields,datatypes,widths,units]


def data2ebf(data,fields,datatypes,units,outfile,structon=0):    
    verbose=0
    dt=[]
    for j,field in enumerate(fields):
        if j == 0:
            nsize=len(data[field])
        else:
            if len(data[field]) != nsize:
                print field,len(data[field]),nsize
                print 'size of fields are not same'
                raise ValueError('')
        dt.append((field,datatypes[j]))
    
    data1=numpy.zeros(nsize,dtype=dt)
    if structon == 1:
        data1=numpy.zeros(nsize,dtype=dt)
    else:
        data1={}
        
        
    for j,field in enumerate(fields):                
        if verbose == 1:
            print "{0:20s} {1:10s} {2:d}".format(field,datatypes[j],len(data[field]))
        missing=''
        if datatypes[j][0] != 'S':
            if datatypes[j][0] == 'i':
                missing=numpy.iinfo(datatypes[j]).max
            elif datatypes[j][0] == 'u':
                missing=numpy.iinfo(datatypes[j]).max
            else:
                missing='inf'
        temp=numpy.array(data[field],dtype='S')
        ind=numpy.where((temp == 'NULL') | (temp=='\N'))[0]
        if len(ind) > 0:
            if verbose == 1:
                print 'missing elements found using default missing value=',missing, len(data[field])
                print ind,len(ind)
            for i in ind:
                data[field][i]=missing
            
        try:        
            data1[field]=numpy.array(data[field],dtype=datatypes[j])
        except ValueError:
#            print 'ValueError',field,datatypes[j]
            try:        
                data1[field]=numpy.array(data[field],dtype='float64')
            except ValueError:
                for x in data[field]:
                    try:
                        temp=numpy.array(x,dtype=datatypes[j])
                    except ValueError:
                        print field,datatypes[j],x,' Is it NULL?',(x == 'NULL')
                        raise ValueError('')
        except TypeError:
            print datatypes[j],data1[field]
            raise TypeError(' datattype='+datatypes[j])

    
    dataunitdict={}    
    for j,field in enumerate(fields):
        dataunitdict[field]=units[j]
    if type(data1) == dict:        
        dataunit1=[]    
        for field in data.keys():
            dataunit1.append(dataunitdict[field])
    else:        
        dataunit1=units
        
    if structon == 1:
        ebf.write(outfile,'/data',data1,'w',dataunit=dataunit1)
    else:
        ebf.write(outfile,'/',data1,'w',dataunit=dataunit1)



def ascii2ebf(filename,schema,structon=0,maxsize=1000000,initfile=None):
    t1=time.time()    
    outfile=''
    [fields,datatypes,widths,units,delimiter,datastart]=ascii_init(filename)        
    if schema != '':
        [fields,datatypes,widths,units]=read_from_schema(schema)
    if initfile != None:
        [fields,datatypes,widths,units,delimiter1,datastart1]=ascii_init(initfile)
        
    print 'fields=',fields
    print 'datatypes=',datatypes
    print 'widths=',widths
    print 'units=',units
    print 'delimiter=',delimiter
    print 'datastart=',datastart
    
    if outfile == '':
        outfile=filename
        if outfile.endswith('.txt'):
            outfile=outfile.rsplit('.txt',1)[0]
        elif outfile.endswith('.dat'):
            outfile=outfile.rsplit('.dat',1)[0]
        elif outfile.endswith('.ascii'):
            outfile=outfile.rsplit('.ascii',1)[0]
        elif outfile.endswith('.csv'):
            outfile=outfile.rsplit('.csv',1)[0]
        outfile=outfile+'.ebf'

    
    data={}
    for field in fields:
        data[field]=[]

    f=open(filename,'r')
    filelist1=[]
    if sum(widths) == 0:
        i=0
        di=0
        jmax=len(fields)
        for line in f:
            if (line[0] != '#')and(i >= datastart)and(len(line.strip(' \t\r\n'))>0):
                j=0
                for word in line.strip().strip('\r').split(delimiter):
                    word=word.strip()
                    if j == jmax:
                        print line
                        print j,jmax
                        print 'line has too many fields'
                        break
#                        raise RuntimeError('line has too many fields')
                    if len(word) == 0:
                        word='NULL'                            
                    data[fields[j]].append(word) 
                    j=j+1
                di=di+1
                if (di%maxsize) == 0:
                    outfile1=outfile+'.part'+str(len(filelist1))
                    data2ebf(data,fields,datatypes,units,outfile1,structon)
                    filelist1.append(outfile1)
                    for field in fields:
                        data[field]=[]  
                                                          
            i=i+1
    else:
        start=[]
        end=[]
        temp=0
        for x in widths:
            start.append(temp)
            temp=temp+x
            end.append(temp)
            
        i=0
        di=0
        for line in f:
            if (line[0] != '#')and(i >= datastart)and(len(line.strip(' \t\r\n'))>0):
                for j in range(0,len(widths)):
                    data[fields[j]].append(line[start[j]:end[j]])                     
                di=di+1
                if (di%maxsize) == 0:
                    outfile1=outfile+'.part'+str(len(filelist1))
                    data2ebf(data,fields,datatypes,units,outfile1,structon)
                    filelist1.append(outfile1)
                    for field in fields:
                        data[field]=[]
            i=i+1

    f.close()
        

    if len(filelist1) > 0:
        if len(data[fields[0]]) > 0:
            outfile1=outfile+'.part'+str(len(filelist1))
            data2ebf(data,fields,datatypes,units,outfile1,structon)
            filelist1.append(outfile1)
        print 'structon',structon
        ebf.join(filelist1,'/',outfile,'/','w')        
        for temp in filelist1:
            os.remove(temp)
    else:
        data2ebf(data,fields,datatypes,units,outfile,structon)
        
    
    print 'Time taken=',time.time()-t1,' s'
    return outfile
    
#    t1=time.time()
#    data=asciitable.read(filename)
#    print 'Time taken=',time.time()-t1,' s'

#    return [data1,units]



def fits2ebf(infile,outfile,kw_attr=0,kw_name=0):
    import pyfits
    if outfile == '':
        outfile=infile
        if outfile.endswith('.fits'):
            outfile=outfile.rsplit('.fits',1)[0]
        elif outfile.endswith('.fit'):
            outfile=outfile.rsplit('.fit',1)[0]
        elif outfile.endswith('.fts'):
            outfile=outfile.rsplit('.fts',1)[0]
        outfile=outfile+'.ebf'
    mode='w'
    hdulist = pyfits.open(infile)
    i=0
    for hdu in hdulist:
        name='du'+str(i)
        print i, hdu.header[0]
        if hdu.data != None:
            if (kw_name==1)and(hdu.name != ''):
                name=hdu.name
            if (hdu.header[0] == True) or (hdu.header[0] == 'IMAGE'):    
                ebf.write(outfile,'/'+name,hdu.data,mode)
            elif (hdu.header[0] == 'BINTABLE') or (hdu.header[0] == 'TABLE'):
                ebf.write(outfile,'/'+name+'/',hdu.data,mode)
            else:
                raise RuntimeError("Unknown FITS extension, must be IMAGE,TABLE or BINTABLE")
            if mode == 'w':
                mode='a'
            
            mycard=[]
            for card in hdu.header.ascardlist():
                mystr=repr(card)
                if len(mystr) > 0:
                    mycard.append(mystr)
            if len(mycard) > 0:
                ebf.write(outfile,'/'+name+'_fitsheader',numpy.array(mycard),mode)
            if kw_attr == 1:
                if hdu.header.has_key('COMMENT'):    
                    ebf.write(outfile,'/'+name+'_attributes/COMMENT',numpy.array(hdu.header.get_comment()),mode)
                if hdu.header.has_key('HISTORY'):    
                    ebf.write(outfile,'/'+name+'_attributes/HISTORY',numpy.array(hdu.header.get_history()),mode)
                items=hdu.header.items()
                for item in items:
                    if (item[0] != 'COMMENT')&(item[0] != 'HISTORY'):
                        if len(item[0]) > 0:
                            dataname='/'+name+'_attributes/'+item[0]
                            if (ebf.containsKey(outfile, dataname) == 1):                                     
                                ebf.write(outfile,dataname+str(1),numpy.array(item[1]),mode)
                            else:
                                ebf.write(outfile,dataname,numpy.array(item[1]),mode)
        i=i+1

    
def _usage():
    print "NAME:"
    print '\t >>EBF<<  (Efficient and Easy to use Binary File Format)'
    print "\t ebfconvert 0.0.2 - converts ascii and fits files to  EBF format"
    print "\t Can also handle very large files."
    print "\t Copyright (c) 2013 Sanjib Sharma "
    print "USAGE:"
    print "\t ebfconvert [OPTIONS] filename"
    print '\t ebfconvert [OPTIONS] "prefix*suffix"'
    print "\t\t wild card must be in double quotes"
    print "DESCRIPTION:"
    print "\t Outfile name is constructed from filename with suffix replaced by .ebf \n"    
    print "\t Files with suffix .fits .fit or .fts are assumed to be fits rest are" 
    print "\t  treated as ascii \n"
    print "\t Empty lines and lines beginning with # are ignored.\n"
    print "\t Following delimiters are allowed (tab,space, comma, semicolon)."
    print "\t First non commented line determines the delimiter. \n"
    print "\t Tab and spaces can be mixed but not others.\n"
    print "\t The program tries to guess datatype  from data and chooses from"
    print "\t float64, int64 and string\n"
    print "\t Default column names are of form 'col'+str(i).\n"
    print "\t If items in first non-commented lines are all strings then they treated" 
    print "\t as field names.\n"
    print "\t For explicit control over formatting before the start of data"
    print "\t following lines can be added \n"
    print "\t #fields   =[name_1 , name_2, name_3, ...name_n]"
    print "\t #datatypes=[float32, int64, S      , ...float64]"
    print "\t #units    =[m/s    ,      , kg     , ...m*kg^{-2}]"
    print
    print "\t for fixed format data one can specify widths of fields as"
    print "\t #widths   =[5      , 4    , 10     , ...12]"
    print
    print "\t instead of datatypes and widths one can also provide"
    print '\t #format   ="%-8.2f%+3.2e %10i %#3u %04d" '
    print "\t This is c printf format %[flags][width][.precision]type"
    print "\t Note spaces between two % changes width of fields"
    print "\t Allowed formats codes (e, E, f, F, g, G, i, u, d, s)"
    print
    print "\t format and width if present should not have blank members."
    print "\t fields, datatypes, units can have blank members."
    print "\t Allowed datatypes are (int8, int16, int32, in64, uint8, uint16, "
    print "\t uint32, uint64, float32, float64, S)"
    print    
    print "OPTIONS:"
    print "\t --join=join_file"
    print "\t\t When using wildcard this option joins multiple files into one."
    print "\t\t All files to be joined should have same column names."
    print 
    print "\t --schema=schema_file"
    print "\t\t This is an alternate way to explicity specify formatting"
    print "\t\t An example schema file is given below.Each field is separated"    
    print "\t\t by comma. A field contains space separated entities"
    print "\t\t name datatype width unit. The width is for fixed width"
    print "\t\t format only, for other formats set it to 0."    
    print 
    print "\t\t (ra       float32 0 degree,"
    print "\t\t dec      float32         ,"
    print "\t\t velocity float32 0       )"
    print 
    print "\t -struct"
    print "\t\t will write data as numpy structure"
    print "\t -names"
    print "\t\t For fits file this will name data items using header.name."
    print "\t\t Default is to name as du+str(i)"
    print "\t -attributes"
    print "\t\t For fits file this will parse the header and write keyword"
    print "\t\t value pairs as dataname+_attributes/. Default is to write"
    print "\t\t dataname_fitsheader"

#fits2ebf('/home/sharma/Desktop/sdss_catalogue.fits','')
#ebf1.info('/home/sharma/Desktop/sdss_catalogue.ebf')
#fits2ebf('check.fits','')
#ebf1.info('check.ebf')
import unittest
class _ebfconvert_test(unittest.TestCase):

    def test0(self):
    #    file1='/work1/sharma/Projects/RAVE/data/rave1/RAVE_Distances_Burnett.csv'
    #    file2='/work1/sharma/Projects/RAVE/data/rave1/RAVE_Fieldpositions_VDR3.csv'
    #    file3='/work1/sharma/Projects/RAVE/data/rave1/RAVE_raveinternal_vdr3_150512.csv'
    #    file4='/work1/sharma/Projects/RAVE/data/rave_inputcatalog.txt'
    #    file5='/work1/sharma/Projects/kepler/data/test.txt'
    #    ascii2ebf(file5,'')
#        ascii2ebf('asciidata/test0'+'.txt','')
#        ebf.info('asciidata/test0'+'.ebf')
        x=numpy.array([0,1,2])
        y=numpy.array([3,4,5])
        z=numpy.array([6,7,8])
        filelist=['test_space','test_tab','test_spacetab']
        for file1 in filelist:        
            ascii2ebf('asciidata/'+file1+'.txt','')
            data=ebf.read('asciidata/'+file1+'.ebf','/')
            print file1+'.txt'
            self.assertTrue(numpy.all(data["x"]==x))
            self.assertTrue(numpy.all(data["y"]==y))
            self.assertTrue(numpy.all(data["z"]==z))
            
        
    def test1(self):
        x=numpy.linspace(0,99,100)
        y=numpy.linspace(100,199,100)
        z=numpy.linspace(200,299,100)
        f=open('asciidata/test_split.txt','w')
        f.write('# fields=[x, y, z] \n')
        for i in range(0,x.size):
            f.write(str(x[i])+','+str(y[i])+','+str(z[i])+'\n')
        f.close()
        ascii2ebf('asciidata/test_split.txt','',0,10)
        data=ebf.read('asciidata/test_split.ebf')
        self.assertTrue(numpy.all(data["x"]==x))
        self.assertTrue(numpy.all(data["y"]==y))
        print numpy.sum(data["z"]-z)
        self.assertTrue(numpy.all(data["z"]==z))

    def test3(self):
        x=numpy.linspace(0,99,100)
        y=numpy.linspace(100,199,100)
        z=numpy.linspace(200,299,100)
        f=open('asciidata/test_split.txt','w')
        f.write('# fields=[x, y, z] \n')
        units1=['m','km','s']
        f.write('# units=[m, km, s] \n')
        for i in range(0,x.size):
            f.write(str(x[i])+','+str(y[i])+','+str(z[i])+'\n')
        f.close()
        ascii2ebf('asciidata/test_split.txt','',1,10)
        data=ebf.read('asciidata/test_split.ebf','/data')
        ebf.info('asciidata/test_split.ebf')
        units2=ebf.unit('asciidata/test_split.ebf','/data')
        self.assertTrue(numpy.all(data["x"]==x))
        self.assertTrue(numpy.all(data["y"]==y))
        print numpy.sum(data["z"]-z)
        self.assertTrue(numpy.all(data["z"]==z))
        print units2
        self.assertTrue(numpy.all(units1==units2))
        
    def test2(self):
        ascii2ebf('asciidata/test_format.txt','',0,10)
        ascii2ebf('asciidata/test_width.txt','',0,10)
        data1=ebf.read('asciidata/test_format.ebf')
        data2=ebf.read('asciidata/test_width.ebf')
        x1=numpy.ones(100)
        x2=numpy.ones(100)*2
        x3=numpy.ones(100)*3
        x4=numpy.ones(100)*4
        x5=numpy.ones(100)*5
        self.assertTrue(numpy.all(data1["x1"]==x1))
        self.assertTrue(numpy.all(data1["x2"]==x2))
        self.assertTrue(numpy.all(data1["x3"]==x3))
        self.assertTrue(numpy.all(data1["x4"]==x4))
        self.assertTrue(numpy.all(data1["x5"]==x5))
        self.assertTrue(numpy.all(data2["x1"]==x1))
        self.assertTrue(numpy.all(data2["x2"]==x2))
        self.assertTrue(numpy.all(data2["x3"]==x3))
        self.assertTrue(numpy.all(data2["x4"]==x4))
        self.assertTrue(numpy.all(data2["x5"]==x5))
        
        

if __name__  ==  '__main__':
    from glob import glob
#    unittest.main()
    kw_attr=0    
    kw_name=0    
    kw_struct=0    
    schema=''
    joinfile=''
    ecode=0
    if len(sys.argv) == 1:
#        test()
        _usage()
    elif len(sys.argv) >= 2:
        for option in sys.argv[1:]:
            if option[0:2] == '--':
                temp=option.split('=')
                if  temp[0] == "--join":
                    joinfile=temp[1]
                elif temp[0] == "--schema":
                    schema=temp[1]
                else:
                    print 'unrecognized option',option                    
                    ecode=1
            elif option[0] == '-':
                if option == "-attributes":
                    kw_attr=1
                elif option == "-names":
                    kw_name=1
                elif option == "-struct":
                    kw_struct=1
                else:
                    print 'unrecognized option',option                    
                    ecode=1
            else:
                infile=option
                
        if ecode != 0:
            _usage()
            sys.exit()
        
        print 'infile:',infile
        print 'schema:',schema
        print 'joinfile:',joinfile
        print 'kw_attr:',kw_attr
        print 'kw_names:',kw_name
        print 'kw_struct:',kw_struct
        
        outfile=''    
        if infile.endswith('.fits') or infile.endswith('.fit') or infile.endswith('.fts'):
            if '*' in infile:
                filelist=glob(infile)
                filelist1=[]
                for file1 in filelist:
                    print file1, outfile
                    fits2ebf(file1,outfile,kw_attr=kw_attr,kw_name=kw_name)
            else:
                fits2ebf(infile,outfile)
        else:
            print 'infile:',infile
            print 'schema:',schema
            print 'joinfile:',joinfile
            if '*' in infile:
                filelist=glob(infile)
                filelist.sort()    
                filelist1=[]
                print filelist    
                for file1 in filelist:
                    outfile=ascii2ebf(file1,schema,kw_struct,initfile=filelist[0])
                    filelist1.append(outfile)
                if len(joinfile) > 0:
                    filelist1.sort()
                    print 'joining to ',joinfile
                    print filelist1                    
                    ebf.join(filelist1,'/',joinfile,'/','w')
                    for file1 in filelist1:
                        os.remove(file1)                
            else:
                ascii2ebf(infile,schema,kw_struct)


