# -*- coding: utf-8 -*-

###############################################################
# Tools for Reading Inputs/Writing Outputs
#
# Authors:
#  NOVELTIS: Cedric Bacour / Ivan Price
###############################################################
#from __init__ import *


import netCDF4
import numpy as np

import adam




def get_netcdf_data(extent, month_index, cfg):
    ''' function to fetch netcdf4 data from the global netcdf file,
        according to a spatial extent window and the month index (0-11)

    Inputs:
        - extent: dictionary with 'minx', 'maxx', 'miny', 'maxy' keys, values in degrees
        - month: month index from 0->11

    Output:
        - data dict with a number of data elements

        '''


    # -- Main File --
    # open the source NetCDF read-only
    try:
        source_dataset = netCDF4.Dataset(cfg.netcdf_datafile, 'r')
    except:
        raise Exception('Error opening source data file, location given was: %s' % cfg.netcdf_datafile)

    # determine the index extremities corresponding to the spatial extent requested
    if (extent['maxx'] - extent['minx'] > 359) and (extent['maxy'] - extent['miny'] > 179):
        # we have a select all download !
        ind_min_x = 0
        ind_min_y = 0
        ind_max_x = len(source_dataset.variables['longitude'][:])
        ind_max_y = len(source_dataset.variables['latitude'][:])
    else:
        # it's a window download so we need to determine the corresponding indexes in the lat/long variables
        ind_min_x = get_pos_index_from_coord(extent['minx'], source_dataset.variables['longitude'][:])
        ind_min_y = get_pos_index_from_coord(extent['miny'], source_dataset.variables['latitude'][:])
        ind_max_x = get_pos_index_from_coord(extent['maxx'], source_dataset.variables['longitude'][:])
        ind_max_y = get_pos_index_from_coord(extent['maxy'], source_dataset.variables['latitude'][:])


    # double check the min and max are actually the min and max and not inversed
    if ind_min_x > ind_max_x :
        ind_min_x, ind_max_x = ind_max_x, ind_min_x
    if ind_min_y > ind_max_y :
        ind_min_y, ind_max_y = ind_max_y, ind_min_y


    # populate the variables according to the required range
    return_dataset = {}
    return_dataset['latitude']  = source_dataset.variables['latitude'][ind_min_y:ind_max_y]
    return_dataset['longitude'] = source_dataset.variables['longitude'][ind_min_x:ind_max_x]

    return_dataset['month_index'] = month_index

    vars2read = cfg.vars_land + cfg.vars_ocean
    for varname in vars2read:

        # the vars with 3 dimensions (month, latitude, longitude, for vars:  chloro_conc and wind speed)
        if len(source_dataset.variables[varname].shape) == 3:

            data = source_dataset.variables[varname][ind_min_x:ind_max_x, ind_min_y:ind_max_y, month_index]
            idxNAN = np.ma.masked_equal(data, cfg.ocean.missval)
            # # unpack chlorophyll content
            if varname == cfg.vars_ocean[1]:
                data = 10**(data)

            return_dataset[varname] = np.ma.filled(data,0)

        # a var with 4 dimensions (month,band,latitude,longitudefor var: ref_land)
        else:
            data = source_dataset.variables[varname][ind_min_x:ind_max_x, ind_min_y:ind_max_y, :, month_index]
            idxNAN = np.ma.masked_equal(data, cfg.land.missval)
            return_dataset[varname] = np.ma.filled(data,cfg.missval_2b)

    # close off the source NetCDF
    source_dataset.close()


    # -- replace -9999 elements in ref_land by nan
    idxNAN = np.ma.where(return_dataset['ref_land'] == cfg.missval_2b)
    return_dataset['ref_land'][idxNAN] = np.nan

    # -- Error covariance matrix over land --
    return_dataset['ref_land_covar'] = np.empty((ind_max_x-ind_min_x, ind_max_y-ind_min_y, 28))

    
    #for month_index in month_indexes:
    # open the source NetCDF read-only
    error_data_filename = cfg.netcdf_file_covarland_template % (month_index+1)
    try:
        source_dataset = netCDF4.Dataset(error_data_filename, 'r')
    except:
        raise Exception('Error opening error (covariance) data file, location given was: %s' % error_data_filename)

    # read data
    data = source_dataset.variables['ref_land_covar'][ind_min_x:ind_max_x,ind_min_y:ind_max_y,:]
    return_dataset['ref_land_covar'][:, :, :] = data
    source_dataset.close()



    # automatically calculate the land/water/snow masks
    mask_land, mask_ocean, mask_snow, idx_land, idx_ocean, idx_snow = calculate_surface_masks(
        return_dataset['ref_land'],
        return_dataset['chloro_conc'],
        return_dataset['wind_speed'],
        cfg
    )
        

    return_dataset['mask_land']  = mask_land
    return_dataset['mask_ocean'] = mask_ocean
    return_dataset['mask_snow']  = mask_snow
    return_dataset['idx_land']   = idx_land
    return_dataset['idx_ocean']  = idx_ocean
    return_dataset['idx_snow']   = idx_snow


    return return_dataset

# END get_netcdf_data
# ===================================================================



# =============================================================================
# Get index from coordinate
#
# Input
#  - coordinate (value of latitude or longitide)
#  - dimension_variable (latitude or longitide)
# =============================================================================
def get_pos_index_from_coord(coord, dimension_variable):
    ''' returns the index of the given dimension_variable array closest to the coordinate
     used to determine where (in terms of index) the required lat / longs are in a 2d grid

    Inputs:
        - coordinate (value of latitude or longitide) e.g. 100
        - dimension_variable (netCDF4 variable object for the entire dataset)

     '''

    len_dim = len(dimension_variable)

    # dimensions coords are cell middles not edges so we can't use them to defined extents directly
    if (dimension_variable.max() - dimension_variable.min() < 181):
        pixels_per_degree = len_dim / 180.
        # latitude descending, longitude ascending, don't know why
        index = (90 - coord ) * pixels_per_degree
    else:
        pixels_per_degree = len_dim / 360.
        index = (coord + 180 ) * pixels_per_degree

    return int(round(index))



def calculate_surface_masks(ref_land, chloro_conc, wind_speed, cfg):
    ''' function returning the masks and therefore indixes of the data for
    land, ocean and snow '''


    #lmbd  = cfg.lmbd
    #nlmbd = len(lmbd)

    #ref_land       = self.data['ref_land']
    #chloro_conc    = self.data['chloro_conc']
    #wind_speed     = self.data['wind_speed']


    # flatten to 2d and 1d arrays
    ref_land    = ref_land.reshape(ref_land.shape[0]*ref_land.shape[1], ref_land.shape[2])
    chloro_conc = chloro_conc.reshape(chloro_conc.shape[0]*chloro_conc.shape[1])
    wind_speed  = wind_speed.reshape(wind_speed.shape[0]*wind_speed.shape[1])


    # - land mask
    # A pixel is declared as Land if the corresponding ref_land != NaN
    # A pixel is declared as Ocean if wind_speed or chloro != 0 AND ref_land == NaN

    # mask out (the mask will be True) all pixels that are of the 'missing' value
    #mask_land_missing =  np.ma.masked_equal(ref_land[:,0], cfg.missval_2b)
    mask_land_missing =  np.ma.masked_invalid(ref_land[:,0])

    # some pixels in the land surface reflectance database are a value = 0
    # for each waveband => they can not be treated as land pixel nor ocean pixel => they
    # should be NaN
    mask_land_zero =  np.ma.masked_equal(np.sum(ref_land,axis=0), 0)


    # combine the two masks, so where the mask is True there is no (/invalid?) land
    # here a true will always win, so if either mask was true then the pixel will be
    # masked, and therefore designated NOT LAND
    mask_land = np.ma.mask_or(mask_land_missing.mask, mask_land_zero.mask)


    # check if no pixels were masked.. this is an unfortunate result
    # of the numpy masked arrays working differently when all vs some of the elements
    # are masked.
    if type(mask_land) == type(np.False_):
        idx_land = np.arange(len(ref_land), dtype=np.int)
    else:
        idx_land = np.ma.nonzero(mask_land==False)[0]


    # get the indices of the data matrix corresponding to where the land mask=False
    # (i.e. not masked, i.e. land)
    #idx_land = np.ma.nonzero(mask_land==False)


    # sadly, if the mask was null, the nonzero function does not simply return all
    # the indicies, so we need to detect this.
    #if len(idx_land) > 0:
    #    idx_land = idx_land[0]
    #else:
    #    idx_land = np.arange(len(ref_land), dtype=np.int)



    ###############
    # now we build the ocean mask, which is not simply the inverse of the land mask.
    mask_ocean_chl = np.ma.masked_equal(chloro_conc, 0)
    mask_ocean_ws  = np.ma.masked_equal(wind_speed, 0)


    # combine the wind and chlorophyl masks to generate the first ocean mask..
    # True here indicates that chlorophyl or wind = 0 and so the pixel is considered
    # NOT ocean. again, True here wins when there is a True and a False
    mask_ocean = np.ma.mask_or(mask_ocean_chl.mask, mask_ocean_ws.mask)

    # combine the ocean and land masks..as the ocean mask is only False (not masked) when the land is True (masked).
    # the ~ indicates false
    mask_ocean = np.ma.mask_or(mask_ocean, ~mask_land, shrink=False)

    idx_ocean = np.ma.nonzero(mask_ocean==False)
    if len(idx_ocean)> 0:
        idx_ocean = idx_ocean[0]
    else:
        idx_ocean = np.array([], dtype=np.int)

    ###############
    # detection of snow covered pixels
    # start with declaring no snow pixels
    idx_snow = np.array([], dtype=np.int)

    mask_snow = np.True_
    if len(idx_land) > 0:
        #TODO: need commentry here about what this is doing !!!
        ##CHANGE_CB mask_snow_1 = np.ma.masked_greater(ref_land[idx_land,cfg.land.chanR], cfg.land.cond_snow_1)
        ##CHANGE_CB mask_snow_2 = np.ma.masked_less(ref_land[idx_land,cfg.land.chanPIR2]-ref_land[idx_land,cfg.land.chanR], cfg.land.cond_snow_2)
        mask_snow_1 = np.ma.masked_greater(ref_land[idx_land,cfg.land.chanB], cfg.land.cond_snow_1) # CHANGE_CB
        mask_snow_2 = np.ma.masked_less(ref_land[idx_land,cfg.land.chanPIR2]-ref_land[idx_land,cfg.land.chanB], cfg.land.cond_snow_2) # CHANGE_CB


        mask_snow = mask_snow_1.mask * mask_snow_2.mask
        idx_snow = np.ma.nonzero(mask_snow)

        if len(idx_snow)> 0:

            # this happens when idx_snow = (array([], dtype=int32),)
            if len(idx_snow[0]) ==0:
                idx_snow = np.array([], dtype=np.int)
            else:
                idx_buf = idx_snow[:][0]
                idx_snow = idx_land[idx_snow[0]]
                # if there were snow pixels we need to recalculate the land pixels
                # as pixels cannot be snow and land at the same time
                idx_land = np.delete(idx_land, idx_buf)
        else:
            idx_snow = np.array([], dtype=np.int)



    return mask_land, mask_ocean, mask_snow, idx_land, idx_ocean, idx_snow




