###############################################################
# Tools used to compute the reflectance spectra for
# the data (tiles or pixels)
#
# Authors:
#  NOVELTIS: Cedric Bacour / Ivan Price
###############################################################



import numpy as np
import process_brdf







# ===================================================================
# Calculate the reflectance over the 300-4000nm spectral domain
#
# (note: this is done for each month)
#
def main(job, do_compute_error = True):
    '''
    Calculate the reflectance spectrum over the whole spectral domain (300-4000 nm)
    for each pixel depending on the type of surface (land / snow / ocean):
    - Land: spectral extrapolation based on the normalised reflectances in the 7 MODIS bands
    (+ associated uncertainty)
    - Snow: fitting of a snow model to the normalised reflectances in the 7 MODIS bands
    - Ocean: reflectance model depending on the chlorophyll content

     Inputs: job class containing in particular the job.data dictionnary
     * containing the variables contained in the tile file for the month considered
     
         - ref_land:               array [ nlon, nlat, nlmbd_MODIS = 7]
         - ref_land_covar:         array [ nlon, nlat, 28]
         - chloro_conc:            array [ nlon, nlat]
     
     * containing the indices of each land, snow, ocean, pixel, previously determined using job.calculate_surface_masks()
     
         - idx_land
         - idx_snow
         - idx_ocean

     Outputs:
     * List of 2 elements:
         
         - reflectance:             array [ nlon, nlat, nlmbd = 3701]
         - err_reflectance_land:    array [ nlon, nlat, nlmbd = 3701]

    '''

    # -- Initialisation
    data = job.data
    ref_land       = data['ref_land']
    ref_land_covar = data['ref_land_covar']
    chloro_conc    = data['chloro_conc']
    original_shape = ref_land.shape


    # flatten to 2d and 1d arrays
    ref_land = ref_land.reshape(ref_land.shape[0]*ref_land.shape[1], ref_land.shape[2])
    ref_land_covar = ref_land_covar.reshape(ref_land_covar.shape[0]*ref_land_covar.shape[1], ref_land_covar.shape[2])
    chloro_conc = chloro_conc.flatten()


    # - initialise the variable to be returned
    reflectance = np.zeros((ref_land.shape[0], len(job.cfg.lmbd)), np.float32) + job.cfg.missval_2b
    err_reflectance_land = np.empty((ref_land.shape[0], len(job.cfg.lmbd)), np.float32)
    err_reflectance_land[:] = np.NaN

    # -
    # - Land
    # -
    # for each pixel in the land pixels collection
    if len(data['idx_land']) > 0:

        for idx in data['idx_land']:
            ans = reflectance_spectrum_land( ref_land[idx,:],
                                             job.cfg)
            reflectance[idx,:] = ans

            # Now compute error on reflectance
            if do_compute_error:
                covar = np.zeros([7,7],np.float64)
                icnt = 0
                for i in range(7):
                    for j in range(7-i):
                        covar[i,j+i] = ref_land_covar[idx,icnt]
                        covar[j+i,i] = ref_land_covar[idx,icnt]
                        icnt = icnt+1

                err_reflectance_land[idx,:] = reflectance_spectrum_land(covar,
                                                                        job.cfg,
                                                                        error = 1
                                                                       )
    # -
    # - Snow
    # -
    # the snow model is fitted to the MODIS reflectance observations
    if len(data['idx_snow']) > 0:

        # initialiaze the structure for saving informations on snow pixels
        job.data['snow'] = {}

        # in this first pass, we estimate the various components necessary
        # for processing snow pixels
        for idx in data['idx_snow']:
            job.data['snow'][idx] = {}
            ans = process_brdf.function_brdf_snow( job.cfg,
                                                   sza  = job.cfg.sza_std,
                                                   vza  = job.cfg.vza_std,
                                                   phi  = job.cfg.phi_std,
                                                   lmbd = job.cfg.lmbd,
                                                   fit2obs = True,
                                                   ref_obs = ref_land[idx,:])
            reflectance[idx,:] = ans[0].ravel()

            job.data['snow'][idx]['fveg'] = ans[1]['fveg']
            job.data['snow'][idx]['snow_grain_size'] = ans[1]['snow_grain_size']
            job.data['snow'][idx]['rveg'] = ans[1]['rveg']
            job.data['snow'][idx]['rtot'] = ans[1]['rtot']

    # -
    # - Ocean processing
    # -
    if len(data['idx_ocean']) > 0:

        ref_ocean = reflectance_spectrum_ocean(chloro_conc, job.cfg.lmbd, job.cfg)
        for idx in data['idx_ocean']:
            reflectance[idx,:] = ref_ocean[idx,:]


    # -
    # - Reshape the reflectance matrix back to the original (grid selection) dimensions
    # -
    reflectance = np.reshape(reflectance, (original_shape[0], original_shape[1], len(job.cfg.lmbd)))
    err_reflectance_land = np.reshape(err_reflectance_land, (original_shape[0], original_shape[1], len(job.cfg.lmbd)))

    ans = [[],[]]
    ans[0] = reflectance
    ans[1] = err_reflectance_land
    return ans
# END main
# ===================================================================


# ===================================================================
# Compute NDVI from the MODIS observations
#
def calculate_ndvi(ref, cfg, lmbd=None, domains=None):
    '''
    Compute the NDVI over land surfaces from the MODIS normalised reflectances
    at 858nm (PIR) and 645nm (R)
    '''
    if lmbd is None:
        lmbd = cfg.lmbd

    ref_MODIS = spectral_selection(ref, lmbd, domains=domains)

    num = ref_MODIS[:,:,cfg.land.chanPIR]-ref_MODIS[:,:,cfg.land.chanR]
    denom = ref_MODIS[:,:,cfg.land.chanPIR]+ref_MODIS[:,:,cfg.land.chanR]

    ans = num / denom

    return ans
# END calculate_ndvi
# ===================================================================

# ===================================================================
# Get the spectral reflectance in the selected instrument wavebands
#
def get_instrument(instrument_wavebands, ref_in, ref_err_in, lmbd):
    '''
    Get the spectral reflectance in the selected instrument wavebands

    Inputs:
    - instrument_waveband:                list [nwl]
    - ref_in:                             array [nlon, nlat, nlmd=3701] or [nlon, nlat, nlmbd=3701, nvza, nphi]
    - ref_err_in:                         array [nlon, nlat, nlmd=3701] or [nlon, nlat, nlmbd=3701, nvza, nphi]
    - lmbd:                               [nlmbd=3701]

    Outputs:
    - ref:                                array [nlon, nlat, nwl] or [nlon, nlat, nwl, nvza, nphi]
    - ref_err:                            array [nlon, nlat, nwl] or [nlon, nlat, nwl, nvza, nphi]
    '''

    # sampling in the narrow bands of the instrument
    if len(instrument_wavebands) >1:
        #instrument_bands = [item for sublist in instrument_wavebands for item in sublist] # flatten list
        ref = spectral_selection(ref_in, lmbd, instrument_wavebands)
        ref_err = spectral_selection(ref_err_in, lmbd, instrument_wavebands)

    # averaging over the instrument broad band
    else:
        ref = spectral_selection(ref_in, lmbd, instrument_wavebands[0])
        ref_err = spectral_selection(ref_err_in, lmbd, instrument_wavebands[0])

    # return
    return ref, ref_err

# END
# ===================================================================



# ===================================================================
# Sub-sampling of the reflectance spectrum over 300-4000 nm for few
# spectral bands, or averaging over broad spectral bands
#
def spectral_selection(ref, lmbd, domains=None):
    ''' Reflectance subsampling in narrow bands OR averaging by spectral domains

     Inputs:
     - reflectance:         array  [nlon, nlat, nlmbd=3701]
     - wavelength:          array  [nlmbd=3701]
     - domains: either:
         
         * a list of list of spectral domains defined by [lmbd_min, lmbd_max] => averaging over spectral domain
         * or a list of wavebands => get the reflectance for these wavebands

     Outputs:
     - reflectance:         array [nlon, nlat, ndomains]
    '''


    ndomains = len(domains)
    case = None
    try:
        len(domains[0])
        case = 'averaging'
    except:
        case = 'sampling'


    # -- Averaging over spectral domains
    if case == 'averaging':

        # allocation
        if len(ref.shape) == 1:
            ans = np.zeros(ndomains, np.float32)
        else:
            ans = np.zeros((ref.shape[0],ref.shape[1],ndomains), np.float32)

        # computation per domain
        for i in range(ndomains):
            mask = np.ma.masked_outside(lmbd, domains[i][0], domains[i][1])
            indices = np.ma.nonzero(mask.mask == False)[0]
            indices = indices.tolist()

            # One pixel
            if len(ref.shape) == 1:
                ans[i] = np.average( np.take(ref, indices) )
            else:
                ans[:,:,i] = np.average( np.take(ref, indices, axis = 2), axis = 2)

    # -- Sub-sampling for the required narrow wavebands
    if case == 'sampling':
        indices = []
        for wl in domains:
            ind = np.ma.nonzero(lmbd == wl)[0]
            indices.extend(ind.tolist())
        if len(ref.shape) == 1:
            ans =  np.take(ref, indices)
        else:
            ans =  np.take(ref, indices, axis = 2)

    # -- Return
    return ans
# END spectral selection
# ===================================================================



# ===================================================================
# Compute reflectance spectrum over Land on a pixel basis from
# normalized land surface reflectance in the 7 MODIS wavebands
#
def reflectance_spectrum_land(data_in, cfg, error = None):
    '''
    Compute reflectance over Land on a pixel basis from the
    normalised reflectances in the 7 MODIS wavebands
    OR
    Compute reflectance uncertainty over Land on a pixel basis from the
    variance covariance matrix among the MODIS spectral bands

     Inputs:
      - Reflectance computation: data_in :           array [nlmbd = 7]
      - Error reflectance computation: data_in :     array [nlmbd = 7, nlmbd = 7]
      - cfg class containing all information on the EOF used for the spectral extrapolation

     Outputs:
      - reflectance spectrum from 300 to 4000 nm:    array [nlmbd = 3701]
    '''


    # Compute the reflectance spectrum
    if error == None:

        ref_out = data_in - cfg.land.ACP_ObsMean

        # HERE IS ALL THE TIME FOR CALCULATION
        ans = np.dot(cfg.land.ACP_mat, ref_out) + cfg.land.ACP_SpecMean

        # reflectance values must be > 0.005
        mask_low = np.ma.masked_less(ans, 0.005)
        ans = np.array(np.ma.filled(mask_low, 0.005))


    # Compute the error on the reflectance spectrum (return standard deviation)
    else:

        ans = np.zeros(cfg.lmbd.shape,np.float64)
        for i in range((cfg.lmbd.shape)[0]):
            buf = np.dot(cfg.land.ACP_mat[i,:], np.array(data_in))
            ans[i] = np.sqrt(np.dot(buf,np.transpose(cfg.land.ACP_mat[i,:])))

    # return value
    return ans
# END reflectance_spectrum_land
# ===================================================================



# ===================================================================
# Compute reflectance over all ocean pixels between lmin and lmax
#
def reflectance_spectrum_ocean(chl, lmbd, cfg):
    '''
    Compute reflectance over Ocean on a pixel basis

     Inputs
      - chlorophyll content:       array [nlon*nlat]
      - spectral bands:            array [nlmbd]

     Outputs
      - reflectance:               array [nlon*nlat, nlmbd]
    '''

    lmin = lmbd[0]
    lmax = lmbd[-1]

    ref_ocean = np.zeros([len(chl),len(lmbd)])

    # Mask depending on the values of the chlorophyll content
    ind_in = np.ma.masked_inside(chl, cfg.ocean.chl[0], cfg.ocean.chl[-1])
    ind_in = np.ma.where(ind_in.mask == True)

    ind_inf = np.ma.masked_less(chl, cfg.ocean.chl[0])
    ind_inf = np.ma.where(ind_inf.mask == True)

    ind_sup = np.ma.masked_greater(chl, cfg.ocean.chl[-1])
    ind_sup = np.ma.where(ind_sup.mask == True)

    # indices of lmin and lmax in the domain 300-4000 nm
    lmbd_tmp = np.arange(4000-300+1)+300
    idx800 = 500
    imin = np.ma.masked_less_equal(lmbd_tmp,lmin)
    imin = list(np.ma.where(imin.mask == True))
    imax = np.ma.masked_greater_equal(lmbd_tmp,lmax)
    imax = list(np.ma.where(imax.mask == True))
    if len(imin) > 0: imin=[imin[0][-1]]
    if len(imax) > 0: imax=[imax[0][0]+1]

    # Computation for pixels which chl is within tabulated chlorphyll content min and max
    if len(ind_in) > 0:
        ind_in = ind_in[0]
        for ind in ind_in:
            diff = cfg.ocean.chl - chl[ind]
            buf = np.ma.masked_less(diff,0)

            if buf.mask.all():
                print 'ERROR buf all masked in proc_reflectance !'
            else:
                ind_min = (np.ma.where(buf == np.ma.minimum.reduce(buf)))[0]-1
                x = np.log(cfg.ocean.chl[ind_min+1]/chl[ind])/np.log(cfg.ocean.chl[ind_min+1]/cfg.ocean.chl[ind_min])
                ref_interp = (x*cfg.ocean.ref_chl_std[ind_min,:] + (1-x)*cfg.ocean.ref_chl_std[ind_min+1,:]).ravel()
                ref_ocean[ind,0:idx800+1] = ref_interp

    # Computation for pixels which chl is lower than tabulated chlorphyll content min
    if len(ind_inf) > 0:
        ind_inf = ind_inf[0]
        ref_ocean[ind_inf,0:idx800+1] = cfg.ocean.ref_chl_std[0,:]

    # Computation for pixels which chl is greater than tabulated chlorphyll content max
    if len(ind_sup) > 0:
        ind_sup = ind_sup[0]
        ref_ocean[ind_sup,0:idx800+1] = cfg.ocean.ref_chl_std[-1,:]


    return ref_ocean

# END reflectance_spectrum_ocean
# ===================================================================
