###############################################################
# Tools used to compute BRDF
#
# Authors:
#  NOVELTIS: Cedric Bacour / Ivan Price
###############################################################


import numpy as np
import process_reflectance




# ===================================================================
# compute: Compute the BRDF in the prescribed observation geometry
#
def main(reflectance, data, cfg, sza, lmbd, vza, phi,
         sza_max_limit=None, vza_max_limit=None, beyond_za_limit=None,
         do_spectral_averaging=None, do_compute_error=False):
    ''' Compute the BRDF in the prescribed observation geometry.

         Inputs
          - reflectance array [nlon, nlat, nlmbd]
          - wind speed [nlon, nlat]
          - wind speed [nlon, nlat]
          - cfg parameters
          - sun zenith angle (optional):       scalar
          - view zenith angle (optional):      scalar or vector
          - relative azimuth angle (optional): scalar or vector

         Outputs
          - reflectance array which size depends on the number of view directions
            - [nlon, nlat, nlmbd] if vza is a scalar
            - [nlon, nlat, nlmbd, nang] if vza is a vector

    '''

    if sza_max_limit is None  : sza_max_limit = cfg.sza_max_limit
    if vza_max_limit is None  : vza_max_limit = cfg.vza_max_limit
    if beyond_za_limit is None: beyond_za_limit = cfg.beyond_za_limit

    npix_X = reflectance.shape[0]
    npix_Y = reflectance.shape[1]

    # allow for when there is only one spectral domain
    nlmbd  =  1
    if len(reflectance.shape) > 2:
        nlmbd  =  reflectance.shape[2]


    ref  = np.reshape(reflectance, (npix_X * npix_Y, nlmbd))
    ndvi = np.reshape(data['NDVI'], (npix_X * npix_Y,))
    ws   = np.reshape(data['wind_speed'],(npix_X * npix_Y,))


    ref_ans = np.zeros((npix_X * npix_Y, nlmbd, len(vza), len(phi)), np.float32) + cfg.missval_2b

    # > MODIF_CB
    ref_err_ans = np.empty(ref_ans.shape)
    ref_err_ans[:] = np.NaN


    # -
    # - Handle sza values out of the physically acceptable range of definition
    # -
    if abs(sza) > sza_max_limit:
        ans = [[],[]]
        if beyond_za_limit == '0':
            ans[0] = ref_ans * 0
            ans[1] = ref_err_ans
            return ans
        if beyond_za_limit == 'nan':
            ans[0] = ref_ans*np.nan
            ans[1] = ref_err_ans
            return ans
        if beyond_za_limit == 'constant':
            sza = np.clip(sza, -sza_max_limit, sza_max_limit)

    if np.maximum.reduce(abs(vza)) > vza_max_limit:
        if beyond_za_limit == 'constant':
            vza = np.clip(vza, -vza_max_limit, vza_max_limit)


    # -
    # - Land processing
    # -
    if len(data['idx_land']) > 0:
        idx_land = data['idx_land']

        ipix = 0
        for idx in idx_land:
            ans = function_brdf_land(ref[idx,:], ndvi[idx], cfg, sza=sza, vza=vza, phi=phi)

            if do_compute_error:
                ans_err =  function_brdf_land_error(ref[idx,:],
                                                    cfg,
                                                    sza=sza,
                                                    vza=vza,
                                                    phi=phi,
                                                    sza_std=cfg.sza_std,
                                                    vza_std=cfg.vza_std,
                                                    phi_std=cfg.phi_std)



            ref_ans[idx,:,:,:] = np.reshape(ans,(1, nlmbd, vza.shape[0], phi.shape[0]))
            if do_compute_error:
                ref_err_ans[idx,:,:,:] = np.reshape(ans_err,(1, nlmbd, vza.shape[0], phi.shape[0]))

            ipix = ipix +1


    # -
    # - Snow processing
    # -

    if len(data['idx_snow']) > 0:

        for idx in data['idx_snow']:
            ans = function_brdf_snow( cfg,
                                      sza=sza,
                                      vza=vza,
                                      phi=phi,
                                      lmbd=lmbd,
                                      do_spectral_averaging=do_spectral_averaging,
                                      data_snow = data['snow'][idx])


            ref_ans[idx,:,:,:] = np.resize(ans[0],(1,nlmbd, vza.shape[0], phi.shape[0]))

    # -
    # - Ocean processing
    # -
    if len(data['idx_ocean']) > 0:
        # Foam component
        ans_foam = function_ocean_foam(cfg, ws[data['idx_ocean']], lmbd, do_spectral_averaging = do_spectral_averaging)

        # add contribution of foam
        ref_foam = np.zeros((len(data['idx_ocean']), nlmbd, len(vza), len(phi)))

        for ivza in range(len(vza)):
            for iphi in range(len(phi)):
                ref_foam[:,:,ivza,iphi] = ans_foam

        for iteration, idx in enumerate(data['idx_ocean']):
            ans = function_brdf_ocean(  ref[idx,:],
                                        ws[idx],
                                        cfg,
                                        sza=sza,
                                        vza=vza,
                                        phi=phi,
                                        lmbd=lmbd,
                                        do_spectral_averaging=do_spectral_averaging,
                                        multi_dir=True)

            # dimension is pixel_no, wavelength, vza, phi
            # we need to perform a resize here as for numpy an array  of shape (17, 100)
            # is not the same as (1, 17, 100)
            ref_ans[idx,:,:,:] = np.resize(ans, (1, nlmbd, len(vza), len(phi))) + ref_foam[iteration,:,:,:]


    ref_ans = np.reshape(ref_ans, (npix_X,npix_Y, nlmbd, len(vza), len(phi)))
    ref_err_ans = np.reshape(ref_err_ans, (npix_X,npix_Y, nlmbd, len(vza), len(phi)))


    # - Eliminate out of range calculations with respect to view zenith angle
    if beyond_za_limit in ['0','nan']:
        mask_vza = np.ma.masked_greater(abs(vza), vza_max_limit)
        idxBAD = np.ma.nonzero(mask_vza.mask == True)
        if beyond_za_limit == '0':
            ref_ans[:,:,idxBAD,:] = 0
            ref_err_ans[:,:,idxBAD,:] = 0
        if beyond_za_limit == 'nan':
            ref_ans[:,:,idxBAD,:] = np.nan # IVAN / or cfg.missval ???
            ref_err_ans[:,:,idxBAD,:] = np.nan


    ans = [[],[]]
    ans[0] = ref_ans
    ans[1] = ref_err_ans
    return ans

# END compute
# ===================================================================





# ===================================================================
# Compute BRDF over Land on a pixel basis from
#
def function_brdf_land(ref, ndvi, cfg, sza=None, vza=None, phi=None):
    '''Compute BRDF over Land on a pixel basis from


         Inputs
          - spectral reflectance   scalar or array [nlmbd]
          - NDVI
          - sun zenith angle (optional):       scalar
          - view zenith angle (optional):      scalar or vector
          - relative azimuth angle (optional): scalar or vector

         Outputs
          - brdf array of size (depending on the cases)
            - 1
            - nlmbd
            - nang

    '''
    if sza is None: sza = cfg.sza_std
    if vza is None: vza = cfg.vza_std
    if phi is None: phi = cfg.phi_std

    # syncronise shapes, this is neccessary because the shape of vza and phi need to always be the same
    #vza_meshed, phi_meshed = vza, phi
    #if vza.shape != phi.shape:
    phi_meshed, vza_meshed = np.meshgrid(phi, vza)
    vza_meshed = vza_meshed.ravel()               
    phi_meshed = phi_meshed.ravel()               


    # BRDF fonctions
    f1 = brdf_land_f1(sza,vza_meshed,phi_meshed)
    f2 = brdf_land_f2(sza,vza_meshed,phi_meshed)


    # Scaling of the R and V coefficients as a function of NDVI
    #ndvi = (ref_MODIS[Cland.chanPIR]-ref_MODIS[Cland.chanR]) / (ref_MODIS[Cland.chanPIR]+ref_MODIS[Cland.chanR])
    Rr   = cfg.land.slope_R[0]*ndvi + cfg.land.intercept_R[0]
    Rnir = cfg.land.slope_R[1]*ndvi + cfg.land.intercept_R[1]
    Vr   = cfg.land.slope_V[0]*ndvi + cfg.land.intercept_V[0]
    Vnir = cfg.land.slope_V[1]*ndvi + cfg.land.intercept_V[1]

    R = cfg.land.alpha_RV * Rnir + (1-cfg.land.alpha_RV)*Rr
    V = cfg.land.alpha_RV * Vnir + (1-cfg.land.alpha_RV)*Vr

    # is it possible to provide commentry as to when each condition
    # would be the case ? and the same for the snow and ocean ?
    if f1.shape == () or ref.shape == (): # is this to determine if it is a float and not
                                          # an array ? when would that happen ? why not
                                          # use == type(float) ?
        brdf = ref*(1+R*f1+V*f2)/(1+R*cfg.f1_std+V*cfg.f2_std)
    else:
        brdf = np.zeros( (ref.shape[0], len(vza), len(phi) ),np.float32)
        if ref.shape[0] > f1.shape[0]:
            for i in range(f1.shape[0]):
                calc = ref[:]*(1+R*f1[i]+V*f2[i])/((1+R*cfg.f1_std+V*cfg.f2_std))
                brdf[:,i] = calc.reshape(brdf[:,i].shape)
        else:
            # for each waveband (?)
            for i in range(ref.shape[0]):
                calc = ref[i]*(1+R*f1[:]+V*f2[:])/(1+R*cfg.f1_std+V*cfg.f2_std)
                brdf[i,:] = calc.reshape(brdf[i,:].shape)


    return brdf
# END function_brdf_land
# ===================================================================


# ===================================================================
# Compute tne error on BRDF over Land on a pixel basis from
#
def function_brdf_land_error(ref, cfg, sza=None, vza=None, phi=None,
                 sza_std=None, vza_std=None, phi_std=None):
    '''Compute the error on BRDF over Land on a pixel basis from


         Inputs
          - spectral reflectance   scalar or array [nlmbd]
          - sun zenith angle (optional):       scalar
          - view zenith angle (optional):      scalar or vector
          - relative azimuth angle (optional): scalar or vector

         Outputs
          - brdf array of size (depending on the cases)
            - 1
            - nlmbd
            - nang

    '''
    if sza is None: sza = cfg.sza_std
    if vza is None: vza = cfg.vza_std
    if phi is None: phi = cfg.phi_std
    if sza_std is None: sza_std = cfg.sza_std
    if vza_std is None: vza_std = cfg.vza_std
    if phi_std is None: phi_std = cfg.phi_std

    # syncronise shapes, this is neccessary because the shape of vza and phi need
    # to always be the same
    #vza_meshed, phi_meshed = vza, phi
    #if vza.shape != phi.shape:
    phi_meshed, vza_meshed = np.meshgrid(phi, vza)
    vza_meshed = vza_meshed.ravel()
    phi_meshed = phi_meshed.ravel()

    # BRDF fonctions
    f1 = brdf_land_f1(sza,vza_meshed,phi_meshed)
    f2 = brdf_land_f2(sza,vza_meshed,phi_meshed)

    err_brdf = np.zeros([ref.shape[0], len(vza), len(phi)],np.float32)

    # computation for each waveband
    for i in range(ref.shape[0]):
        var_brdf = ref[i]**2 * (cfg.land.err_R**2*(f1-cfg.f1_std)**2 + cfg.land.err_V**2*(f2-cfg.f2_std)**2)
        calc = np.sqrt(var_brdf)
        err_brdf[i,:] = calc.reshape(err_brdf[i,:].shape)


    # when is this used ?
    if not f1.shape: err_brdf=np.reshape(err_brdf,ref.shape[0])

    # return value
    return err_brdf
# END function_brdf_land
# ===================================================================



# ===================================================================
# Compute BRDF over Ocean on a pixel basis from the Cox and Munk model
#
def function_brdf_ocean(ref, wind_speed, cfg, sza=None, vza=None, phi=None, lmbd = None,
                        multi_dir = False, do_spectral_averaging = None):
    '''
    Compute BRDF over Oean on a pixel basis from the Cox and Munk model
    See Breon FM (1993), An analytical model for the Cloud-Free Atmosphere/Ocean
    system reflectance, Remote Sensing of Environment, Remote Sensing of Environment,
    43:179-192
    The specular reflectance is computed thanks to Fresnel equation, using the
    real parti of the water refractive index. The latter comes from http://refractiveindex.info/?group=LIQUIDS&material=Water /
    G. M. Hale and M. R. Querry. Optical Constants of Water in the 200-nm to 200-microm Wavelength Region, Appl. Opt. 12, 555-563 (1973) doi:10.1364/AO.12.000555

    Inputs
        - spectral reflectance
        - wind_speed [npixels]
        - sun zenith angle       :  scalar
        - view zenith angle      : scalar or vector
        - relative azimuth angle : scalar or vector
        - wavelength             : scalar or vector
        - wavelength for nwater  : vector
        - water refraction index : vector
        - multi_dir              : logical


    Outputs
        - reflectance [nlmbd, nvza]


    '''

    if sza is None: sza = cfg.sza_std
    if vza is None: vza = cfg.vza_std
    if phi is None: phi = cfg.phi_std

    # syncronise shapes, this is neccessary because the shape of vza and phi need to always be the same
    #vza_meshed, phi_meshed = vza, phi
    #if vza.shape != phi.shape:
    phi_meshed, vza_meshed = np.meshgrid(phi, vza)
    vza_meshed = vza_meshed.ravel()
    phi_meshed = phi_meshed.ravel()

    dr = np.pi/180.

    ws = wind_speed

    # - Specular reflectance
    nlmbd = len(lmbd)
    nwater = cfg.ocean.nwater

    # spectral subsampling
    if nlmbd != len(cfg.lmbd):
        idxWave = []
        for wl in lmbd:
            ind = np.ma.nonzero(cfg.ocean.lmbd_nwater == wl)[0]
            idxWave.extend(ind.tolist())
        nwater = cfg.ocean.nwater[idxWave]



    # incidence angle = phase angle / 2
    mus = np.cos(sza*dr)
    muv = np.cos(vza_meshed*dr)

    xi = np.arccos(mus*muv+np.sin(sza*dr)*np.sin(vza_meshed*dr)*np.cos(phi_meshed*dr))/2.
    xi = xi/dr
    # specular reflectance
    ref_sp = ocean_specular_reflectance(xi,nwater)

    # only if a broad spectral domain is considered
    if do_spectral_averaging == True:
        ans = np.zeros((1,xi.shape[0]), np.float)
        for i in range(xi.shape[0]):
            ans[0,i] = np.average(ref_sp[:,i])
        ref_sp = ans


    # - BRDF
    m = 1/(np.cos(sza*dr)+np.cos(vza_meshed*dr))
    Zx = (np.sin(vza_meshed*dr)*np.cos(phi_meshed*dr)+np.sin(sza*dr)) *m
    Zy = (np.sin(vza_meshed*dr)*np.sin(phi_meshed*dr)) *m

    tan_beta = np.sqrt(Zx**2+Zy**2)
    beta = np.arctan(tan_beta)
    sigma2 = 0.003+5.12*1e-3*ws
    Pws = np.exp(-tan_beta**2/sigma2)/(np.pi*sigma2)

    #if multi_dir == True:
    brdf = np.zeros( ( len(ref.flatten()), len(vza), len(phi) ), np.float)
    ######### when would this happen vs the other? ###############################
    if  ref_sp.shape[0] < vza.shape[0]:
        for i in range(ref_sp.shape[0]):
            calc = ref[i] + np.pi*Pws*ref_sp[i,:] / (4*np.cos(beta)**4 * np.cos(vza_meshed*dr) * np.cos(sza*dr) )
            # this is reshaped for the cases where calc.shape=(47,1) and brdf[i,:].shape=(47,)
            brdf[i,:,:] = calc.reshape(brdf[i,:,:].shape)
    else:
        for i in range(vza.shape[0]):
            calc = ref + np.pi*Pws[i]*ref_sp[:,i] / (4*np.cos(beta[i])**4 * np.cos(vza_meshed[i]*dr) * np.cos(sza*dr) )
            # this is reshaped for the cases where calc.shape=(3701,) and brdf[i,:].shape=(3701, 1)
            brdf[:,:,i] = calc.reshape(brdf[:,:,i].shape)


    return brdf

# END function_brdf_ocean
# ===================================================================


# > MODIF_CB
# ===================================================================
# Compute reflectance over Ocean due to foam
#
def function_ocean_foam(cfg, ws, lmbd, do_spectral_averaging=False):
    '''
    Compute the foam component of ocean reflectance based on
        - Koepke P. (1984), Effective reflectance of oceanic whitecaps, Applied Optics, 23:11, 1816-1824.
          to determine the contribution of the foam reflectance as a function of wind speed
        - Kokhanovsky A.A. (2004), Spectral reflectance of whitecaps, Journal of Geophysical Research, 109, C05021,
          doi:10.1029/2003JC002177
          to determine its spectral dependancy


    The foam reflectance is assumed isotropic

    Inputs
        - wind_speed [npixels]

    Outputs
        - reflectance [npixels]

    '''

    # determine the spectral dimension
    foam_specmod = cfg.ocean.foam_specmod[:]
    nlmbd = len(lmbd)

    # spectral subsampling
    if len(lmbd) != len(cfg.lmbd):
        idxWave = []
        for wl in lmbd:
            ind = np.ma.nonzero(cfg.ocean.lmbd_nwater == wl)[0]
            idxWave.extend(ind.tolist())
        foam_specmod = cfg.ocean.foam_specmod[idxWave]

    # spectral averaging
    if do_spectral_averaging == True:
        foam_specmod = np.average(foam_specmod)
        nlmbd = 1


    # Contribution (from Koepke)
    contrib_foam = np.zeros(len(ws))
    for ipix in range(ws.shape[0]):
        if ws[ipix] > cfg.ocean.foam_ws_std[-1]:
            contrib_foam[ipix] = cfg.ocean.foam_ref_std[-1]
        else:
            idx = np.fix(ws[ipix]).tolist()
            x = ws[ipix]-cfg.ocean.foam_ws_std[idx]
            contrib_foam[ipix] = x*(cfg.ocean.foam_ref_std[idx+1]-cfg.ocean.foam_ref_std[idx])+cfg.ocean.foam_ref_std[idx]

    # Apply the spectral model
    ref_foam = np.zeros((len(ws), nlmbd),np.float32)
    for ipix in range(len(ws.ravel())):
        ref_foam[ipix,:] = contrib_foam[ipix]* foam_specmod/cfg.ocean.foam_normfac

    # return
    return ref_foam

# END function_ocean_foam
# ===================================================================
# < MODIF_CB

# ===================================================================
# Compute BRDF over Snow on a pixel basis
#
def function_brdf_snow( cfg, sza=None, vza=None, phi=None, lmbd=None,
                        do_spectral_averaging=False,
                        fit2obs = False,  ref_obs = None, data_snow = None):
    '''
    Model of A.Kokhanovsky & E.P.Zege (2004, Scattering of snow, Applied
    Optics, 43(7):1589-1602)

    Inputs
        - sun zenith angle (optional):       scalar
        - view zenith angle (optional):      scalar or vector
        - relative azimuth angle (optional): scalar or vector
        - wavelength             : scalar or vector
        - wavelength for kice  : vector
        - imaginary part of ice refraction index : vector


    Outputs
        - reflectance [nlmbd, nvza]

    '''

    if sza is None: sza = cfg.sza_std
    if vza is None: vza = cfg.vza_std
    if phi is None: phi = cfg.phi_std
    if lmbd is None: lmbd = cfg.lmbd

    if type(sza) != type(np.zeros(1)): sza=np.array([sza])
    if type(vza) != type(np.zeros(1)): vza=np.array([vza])
    if type(phi) != type(np.zeros(1)): phi=np.array([phi])


    # - Synchronise shapes, this is neccessary because the shape of vza and phi need to always be the same
    #vza_meshed, phi_meshed = vza, phi
    #sza_meshed = sza
    #if vza.shape != phi.shape:
    phi_meshed, vza_meshed = np.meshgrid(phi, vza)
    vza_meshed = vza_meshed.ravel()
    phi_meshed = phi_meshed.ravel()
    sza_meshed = np.ones(len(vza_meshed))*sza

    dr = np.pi/180


    # -
    # - if fit2obs == True, then the model is "fitted" to MODIS observations
    # -
        #  Principles:
        #    1. Determine the fraction of non-snow cover
        #    2. Determine the reflectance associated to non-snow (assuming a linear composition of the 2 components)
    # if fveg > 0 (we can not neglet non snow elements)
        #    3. Apply spectral interpolation procedure to non-snow reflectance: rveg
        #    4. Recompute the pseudo-observation of snow reflectance as a function of MODIS observations and previous rveg
        #    5. Fit the snow model at 1240 nm by estimating the snow grain size
        #    6. Repeat 3
        #    7. Repeat 4
        #    8. Determine the total reflectance
        # if fveg < 0.2
        #    9. Fit the model to the observation at 1240 nm by estimating the snow grain size

    if fit2obs == True:

        data_snow = {} # structure containing the snow information for each pixel

        # - Model: computed over all wavebands
        RefSnow  = snow_model(cfg, sza_meshed, vza_meshed, phi_meshed, cfg.lmbd, cfg.land.kice)
        rsnow = RefSnow.ravel()

        # 1
        fveg = np.clip((1-ref_obs[2])/0.95,0,1)
        if fveg > 0: # mixed pixel

            # iterative process of for estimating snow grain size and reflectance of non snow elements
            for istep in range(2):
                # 2 - 6
                rveg_obs = (ref_obs - (1-fveg)*rsnow[cfg.indices_modis_bands].ravel())/fveg
                # 3 - 7
                rveg = process_reflectance.reflectance_spectrum_land(rveg_obs, cfg.lmbd, cfg)

                if istep ==0:
                    # 4
                    rsnow_obs_new = (ref_obs-fveg*rveg[cfg.indices_modis_bands])/(1-fveg)
                    snow_grain_size = snow_model(cfg, sza_meshed, vza_meshed, phi_meshed, cfg.lmbd, cfg.land.kice,
                                                 invert_snow_grain_size = True,
                                                 invert_ref_obs = rsnow_obs_new[4],
                                                 invert_ilmbd = cfg.indices_modis_bands[4])
                    # 5
                    rsnow = snow_model(cfg, sza_meshed, vza_meshed, phi_meshed, cfg.lmbd, cfg.land.kice, snow_grain_size = snow_grain_size).ravel()

            # 8
            ans = fveg*rveg+(1-fveg)*rsnow
            ans = ans.reshape((len(cfg.lmbd),len(vza),len(phi)))

         # 9
        if fveg <0.2:      # only snow
            snow_grain_size = snow_model(cfg, sza_meshed, vza_meshed, phi_meshed, cfg.lmbd, cfg.land.kice,
                                         invert_snow_grain_size = True,
                                         invert_ref_obs = ref_obs[4],
                                         invert_ilmbd = cfg.indices_modis_bands[4])
            rsnow  = snow_model(cfg, sza_meshed, vza_meshed, phi_meshed, cfg.lmbd, cfg.land.kice, snow_grain_size = snow_grain_size).ravel()
            rsnow = rsnow.reshape((len(cfg.lmbd),len(vza),len(phi)))
            ans = rsnow.ravel()
            rveg = rsnow[:]*0

        data_snow['fveg'] = fveg
        data_snow['snow_grain_size'] = snow_grain_size
        data_snow['rveg'] = rveg
        data_snow['rtot'] = ans.ravel()


    # -
    # - Compute the BRDF
    # -
        # if fveg > 0.2: mixed pixel considered as isotropic
        # only snow otherwise
    else:
        if data_snow['fveg'] < 0.2: # only snow
            RefSnow  = snow_model(cfg, sza_meshed, vza_meshed, phi_meshed, cfg.lmbd, cfg.land.kice, snow_grain_size = data_snow['snow_grain_size'] )*(1-data_snow['fveg'])
            RefSnow = RefSnow.reshape((len(cfg.lmbd),len(vza),len(phi)))
            ans = RefSnow

            # - Spectral subsampling
            if len(lmbd) != len(cfg.lmbd):
                idxWave = []
                for wl in lmbd:
                    ind = np.ma.nonzero(cfg.land.lmbd_kice == wl)[0]
                    idxWave.extend(ind.tolist())
                ans = RefSnow[idxWave,:,:]

        else: # mixed pixel: isotropic
            RefSnow = np.zeros((len(lmbd),len(vza),len(phi)))
            for iwl in range(len(lmbd)): RefSnow[iwl,:,:] = data_snow['rtot'][iwl]
            ans = RefSnow


        # - Spectral averaging
        # if do_spectral_averaging:
        #     ans  = np.zeros((1,len(vza_meshed)), np.float)
        #     RefSnow = RefSnow.reshape(nlmbd,len(vza_meshed))
        #     for i in range(len(vza)):  ans[:,i] = np.average(RefSnow[:,i])
        #     ans = ans.reshape(1,len(vza),len(phi))

    # -
    # - return value
    # -
    return  [ans, data_snow]

# END function_brdf_snow
# ===================================================================



# ===================================================================
# Compute reflectance (spectro-directinal variations) over Snow
#
def snow_model( cfg, sza, vza, phi, lmbd, kice,
                snow_grain_size = None,
                invert_snow_grain_size = False, # needed to estimate snow grain size
                invert_ref_obs         = None,          # needed to estimate snow grain size
                invert_ilmbd           = None           # needed to estimate snow grain size
                ):

    '''Model of A.Kokhanovsky & E.P.Zege (2004, Scattering of snow, Applied
    Optics, 43(7):1589-1602)

    Inputs
    - sun zenith angle:       scalar
    - view zenith angle:      scalar or vector
    - relative azimuth angle: scalar or vector
    - wavelength             : scalar or vector
    - wavelength for kice  : vector
    - imaginary part of ice refraction index : vector


    Outputs
        - reflectance [nlmbd, nvza]

    '''

    if snow_grain_size == None: snow_grain_size = cfg.land.snow_grain_size

    nlmbd = len(lmbd)
    dr = np.pi/180

    # scattering angle
    ksi = np.arccos(-np.cos(sza*dr)*np.cos(vza*dr)-np.sin(sza*dr)*np.sin(vza*dr)*np.cos(phi*dr)) * 180./np.pi
    # phase function
    PF = (11.1*np.exp(-0.087*ksi)+1.1*np.exp(-0.014*ksi))
    # value of R at zero absorption
    A = 1.247
    B = 1.186
    C = 5.157
    Ref0 = (A+B*(np.cos(sza*dr)+np.cos(vza*dr))+C*np.cos(sza*dr)*np.cos(vza*dr)+PF)/(4*(np.cos(sza*dr)+np.cos(vza*dr)))
    Ks = 3./7.*(1+2*np.cos(sza*dr))
    Kv = 3./7.*(1+2*np.cos(vza*dr))

    alpha =  np.sqrt(4.*np.pi* kice *(13*snow_grain_size)/(lmbd*1e-9))


    # - Inverse mode: estimate snow grain size from the measurement at a given wavelength
    if invert_snow_grain_size == True:
        alpha = -Ref0*np.log(invert_ref_obs/Ref0)/(Ks*Kv)
        snow_grain_size = alpha*alpha*lmbd[invert_ilmbd]*1e-9/(4.*np.pi*kice[invert_ilmbd]*13)
        return snow_grain_size

    # - Normal mode: compute reflectance

    #  if multi_dir:
    RefS  = np.zeros((nlmbd,len(vza)), np.float)

    if nlmbd < len(vza):
        for i in range(nlmbd):
            RefS[i,:] = Ref0*np.exp(-alpha[i]*Ks*Kv/Ref0)
    else:
        for i in range(len(vza)):
            RefS[:,i] = Ref0[i]*np.exp(-alpha*Ks[i]*Kv[i]/Ref0[i])


    ans = RefS    #.reshape((nlmbd,len(vza),len(phi)))


    # -
    # - return value
    # -
    return  ans

# END brdf_snow
# ===================================================================



# ===================================================================
#  F1 function of the Li-Sparse Reciprocal BRDF model
#
#  Inputs:
#   - sun zenith angle (deg) (scalar or vector)
#   - view zenith angle (deg) (scalar or vector)
#   - relative azimuth angle (deg) (scalar or vector)
#
# Outputs:
#   - f1 (scalar or vector)
# ===================================================================
def brdf_land_f1(sza,vza,phi):

    dr = np.pi/180

    delta = np.sqrt(np.tan(sza*dr)**2 + np.tan(vza*dr)**2 - 2.0* np.tan(sza*dr)*np.tan(vza*dr)*np.cos(phi*dr))
    sec_sza = 1/np.cos(sza*dr)
    sec_vza = 1/np.cos(vza*dr)
    sec = sec_sza + sec_vza

    cos_t = 2*np.sqrt(delta**2 + (np.tan(sza*dr)*np.tan(vza*dr)*np.sin(phi*dr))**2)/sec
    cos_t = np.clip(cos_t,0,1)

    sin_t = np.sqrt(1 - cos_t**2)
    t = np.arccos(cos_t)
    cos_ksi = np.cos(sza*dr)*np.cos(vza*dr) + np.sin(sza*dr)*np.sin(vza*dr)*np.cos(phi*dr)

    temp = (t - sin_t*cos_t)*sec/np.pi - sec
    ans = temp + (1 + cos_ksi)*sec_vza*sec_sza/2.
    return np.array(ans)

# END brdf_f1
# ===================================================================


# ===================================================================
# F2 function of the Roujean + HS model
#
# Inputs :
#  - sun zenith angle (deg) (scalar or vector)
#  - view zenith angle (deg) (scalar or vector)
#  - relative azimuth angle (deg) (scalar or vector)
#
# Outputs:
#   - f2 (scalar or vector)
# ===================================================================
def brdf_land_f2(sza,vza,phi):

    dr = np.pi/180

    cos_ksi = np.cos(sza*dr)*np.cos(vza*dr)+np.sin(sza*dr)*np.sin(vza*dr)*np.cos(phi*dr)
    sin_ksi = np.sqrt(1.-cos_ksi**2)
    ksi = np.arccos(cos_ksi)

    ans = (2./(3.*np.pi) *(((np.pi - 2.0*ksi)*cos_ksi + 2.0*sin_ksi) /(np.cos(sza*dr) + np.cos(vza*dr))) )* \
          (1 + 1/(1 + ksi/(1.5*np.pi/180))) - 1./3.
    return np.array(ans)

# END brdf_f2
# ===================================================================



# ===================================================================
# Determine vza depending on sza for display
#
def define_vza_hotspot(sza, vza):
    '''Determine vza depending on sza for display.

     Increase resolution around hot spot

     Inputs
      - sza  scalar
      - vza  array of size [nang]

     Outputs
      - vza  array of size [nang+20]
    '''


    # increase vza resolution around hot spot
    vza_hotspot = np.arange(-10,10,1)+sza
    ans = np.sort(np.concatenate((vza,vza_hotspot)))

    return ans
# END define_vza_hotspot
# ===================================================================



# ===================================================================
# Define vza,phi domain for displaying polar plots
#
def define_angles_polar_plot( vza, step_polar ):
    ''' Define vza,phi domain for displaying polar plots.

     Inputs
      - vza         scalar
      - step_polar  scalar

     Outputs
      - vza         array [step_polar,step_polar]
      - phi         array [step_polar,step_polar]
    '''

    vza_max = np.amax(vza)
    vza_val = np.linspace(0, vza_max, step_polar)
    phi_val = np.linspace(0, 360, step_polar)
    vza_val,phi_val = np.meshgrid(vza_val,phi_val)
    #phi_meshed, vza_meshed = np.meshgrid(phi, vza) # CHANGE_CB

    return vza_val.ravel(), phi_val.ravel()
# END define_vza
# ===================================================================



# ===================================================================
# Computation of the specular reflectance over ocean / required
# for the determination of the BRDF as a function of the wind speed

def ocean_specular_reflectance( ti, n):
    '''
    Computation of the specular reflectance over ocean / required
    for the determination of the BRDF as a function of the wind speed

    Specular reflectance is determined thanks to Fresnel equations,
    accounting only for the real part of the water refractive index
    and with the view zenith angle = solar zenith angle

    Inputs
        - incidence angle :       scalar or vector
        - water refractive index:      scalar or vector

    Outputs
        - reflectance [len(n)]

    '''
    dr = np.pi/180.

    cti = np.cos(ti*dr)

    if type(n) != type(np.array(0)): n = [n]
    if type(ti) != type(np.array(0)): ti = [ti]
    ti = np.array(ti)

    nlmbd = len(n)
    nang = len(ti)
    tv = np.zeros((nlmbd,nang))
    rho_sp = np.zeros((nlmbd, nang))

    if nlmbd< nang:
        for i in range(nlmbd):
            tv = np.arcsin(np.sin(ti*dr)/n[i])
            ctv = np.cos(tv)
            Rs = (cti-n[i]*ctv)/(cti+n[i]*ctv)
            Rp = (ctv-n[i]*cti)/(ctv+n[i]*cti)
            Rs = Rs*Rs
            Rp = Rp*Rp
            rho_sp[i,:] = (Rs+Rp)*0.5
    else:
        for i in range(nang):
            tv = np.arcsin(np.sin(ti[i]*dr)/n)
            cti = np.cos(ti[i]*dr)
            ctv = np.cos(tv)
            Rs = (cti-n*cti)/(cti+n*cti)
            Rp = (ctv-n*cti)/(ctv+n*cti)
            Rs = Rs*Rs
            Rp = Rp*Rp
            rho_sp[:,i] = (Rs+Rp)*0.5

    return rho_sp
# END ocean_specular_reflectance
# ===================================================================
