###############################################################
# Tools used to perfor data analysis for the data (tiles or pixels)
#
# Authors:
#  NOVELTIS: Cedric Bacour / Ivan Price
###############################################################

#from __init__ import *

import numpy as np

import adam





def get_stats(ref, data):
    ''' Returns statistics for the array passed in ( typically the output of process_brdf.compute() )

      - Computes the mean, std, min, max
      - Separate for land/ocean pixels

    Inputs:
      - reflectance array [nlon, nlat, nlmbd] or [nlon, nlat, nlmbd, nang]

     Outputs; dictionnary stats containing
      - 'land' and 'ocean' elements
      - each element contain the following keys
          - 'mean'
          - 'min'
          - 'max'
          - 'std'
    '''

    # determine if shape is [nlon, nlat, nlmbd] or [nlon, nlat, nlmbd, nang]
    if len(ref.shape) == 5:
        ref = ref.reshape((ref.shape[0] * ref.shape[1], ref.shape[2], ref.shape[3], ref.shape[4] ))
    else:
        ref = ref.ravel()

    # -
    # - Separate land from ocean in the analysis
    # -
    stats = {}

    if len(data['idx_land']) != 0:
        stats['land'] = calculate_stats(ref[data['idx_land']])
        stats['land']['npts'] = len(data['idx_land'])
    else:
        stats['land'] = None

    if len(data['idx_ocean']) != 0:
        stats['ocean'] = calculate_stats(ref[data['idx_ocean']])
        stats['ocean']['npts'] = len(data['idx_ocean'])
    else:
        stats['ocean'] = None

    if len(data['idx_snow']) != 0:
        stats['snow'] = calculate_stats(ref[data['idx_snow']])
        stats['snow']['npts'] = len(data['idx_snow'])
    else:
        stats['snow'] = None

    return stats






def calculate_stats(data):
    ''' Computation of mean, min, max, std for stats analysis, for each wavebands and each viewing angle.

         NOTE: The statistics are performed on the 1st axis

         Inputs :
          - data array [ndim0, ndim1] or  [ndim0, ndim1, ndim2]

         Outputs: dictionnary containing the following elements
          - 'mean' [ndim1] or [ndim1, ndim2]
          - 'min'  [ndim1] or [ndim1, ndim2]
          - 'max'  [ndim1] or [ndim1, ndim2]
          - 'std'  [ndim1] or [ndim1, ndim2]
    '''


    sz = data.shape
    ans = {}

    # - Mean
    ans['mean'] = np.mean(data, axis = 0).ravel()
    # - Std
    ans['std'] = np.std(data, axis = 0).ravel()
    # - Min & Max
    if len(sz) == 1:
        ans['min'] = np.minimum.reduce(data.ravel())
        ans['max'] = np.maximum.reduce(data.ravel())
    elif len(sz) > 1:
        ans['min'] = np.amin(data, axis = 0).ravel()
        ans['max'] = np.amax(data, axis = 0).ravel()


    return ans
