#!/usr/bin/env python
# encoding: utf-8
"""
Combine several datasets by interleaving

the product of the values in shape must be equal to the 
number of files given as arguments

ie.
$    slinterlace shape=2,2 file00.rsf file01.rsf file10.rsf file11.rsf > out.rsf   
"""

import signal_lab as slab
import sys
import numpy as np

allsame = lambda iterable: all( [ x==iterable[0] for x in iterable] )

def copy_slfile( env, tagin , tagout ):
    
    infile = slab.File( tagin ,env=env )
    outfile = slab.File( tagout ,env=env, input=infile )
    outfile.finalize()
    
    bin1 = infile.binary_file
    bin2 = outfile.binary_file
    
    buff_size = infile.shape[0]
    buff = np.zeros( buff_size, dtype=infile.dtype, order=infile.order )
    
    for _ in range( infile.leftsize(1) ):
        bin1.readinto( buff )
        bin2.write( buff )
    
    infile.close( )
    outfile.close( )

def interlace(env):
    
    shape = list(env['shape'])
    
    tags = env.args
    
    files_to_expect = np.prod( shape )
    
    if files_to_expect != len(tags):
        raise slab.error( 'Expected %i files to interlace for a shape of %s (got %i files)' %(files_to_expect,shape,len(tags)) )
    
    if files_to_expect == 1:
        copy_slfile( env, tags[0], '$stdout' )
        return
    
    
    slfiles = [slab.File(tag,env=env) for tag in tags]
    
    
    ndims  = np.array([input.ndim for input in slfiles])
    
    if not allsame( ndims ):
        for input in slfiles:
            print >> sys.stderr, input.header, "ndim=%i"%input.ndim  
        raise slab.error( 'Expected all input files have the same dimention' )
    
    ndim = ndims[0]
    if len(shape) > ndim:
        raise slab.error( 'lenght of shape= is greater than the number of dimentions in each file ... what to do?' )
    
    while ndim > len(shape):
        shape.append(1)
    
    shapes = np.array([input.shape for input in slfiles])
    
    shapes = shapes.reshape( list(shape)+[ndim] )
    
    t = [shapes.ndim-1] + range( shapes.ndim - 1 )
    
    bshapes = shapes.transpose( t )
    
    outshape = [ ]
    for i in range( ndim ):
        dimsizes = list(bshapes[ i ].sum( axis=i ).flat)
        if not allsame( dimsizes ):
            raise slab.error( 'Interlace in deimention %i does not sum to be the same for all windows' %(i+1) )
        outshape.append( dimsizes[0] )
    
    #np.ndindex( )
    
    
    
    output = slab.File( '$stdout', env=env, input=slfiles[0] )
    output.shape = outshape
    output.step = list(np.divide(slfiles[0].step, shape))
    
    output.finalize( )
    
    n1 = outshape[0]
    out_array = np.zeros( n1, dtype=output.dtype, order=output.order )
    
    
    n1_size = bshapes[0]
    max_in_n1 = n1_size.max( )
    
    if shape[0] > 1: #multiple n1 files
        in_arrays = [ ]
        for i in range( shape[0] ):
            in_arrays.append(np.zeros( max_in_n1, dtype=slfiles[i].dtype, order=slfiles[i].order ) )
    else: # use output as input
        pass
    
    
    outshape2 = list( outshape[1:] )
    outshape2.reverse( )
    slarray = np.array( slfiles, dtype=np.object ).reshape( shape )
    
    n1_step = shape[0]
    
    print >> sys.stderr, "Finished calculations, writing data"
    for index in np.ndindex( *outshape2 ):
        index = list(index)
        index.reverse( )
        choose_files = tuple(np.mod( index, shape[1:] ) )
        
        fancy_indexing = tuple([range(slarray.shape[0])]+list(choose_files))
        n1_files = list(slarray[fancy_indexing].flat)
        
        if shape[0] > 1: #multiple n1 files
            out_array[:] = 0
            for i,(file,array) in enumerate(zip(n1_files,in_arrays)):
                in_n1 = file.shape[0]
                file.binary_file.readinto( array[:in_n1].data)
                out_array[i::n1_step] = array[:in_n1]
        else:
            in_n1 = n1_files[0].shape[0]
            n1_files[0].binary_file.readinto( out_array.data )
            #value = sum([v*(10**i) for i,v in enumerate(choose_files)])
            #out_array[:] = value
            
        #print >> sys.stderr, choose_files,"choose_file", [f.header for f in list(slarray[:,choose_files].flat)]
        
        output.binary_file.write( out_array.data )
        

if __name__ == '__main__':
    
    
    shape_par = slab.Parameter( 'shape', (lambda args: tuple(eval(args)),'tuple'), help='shape of files to interlace' )
    user_arguments =[shape_par]
     
    env = slab.Environment( sys.argv,  
                            
                            #define the help for this program and its options
                            help=__doc__,
                            use_stdin=False,inputs=['file00.rsf','...', 'fileXY.rsf'],
                            user_arguments=user_arguments )
    
    interlace(env)

