# Copyright (c) 2003-2014 by Mike Jarvis
#
# TreeCorr is free software: redistribution and use in source and binary forms,
# with or without modification, are permitted provided that the following
# conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions, and the disclaimer given in the accompanying LICENSE
#    file.
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions, and the disclaimer given in the documentation
#    and/or other materials provided with the distribution.


import numpy
import treecorr
import os
from numpy import pi

from test_helper import get_aardvark

def test_ascii():

    nobj = 5000
    numpy.random.seed(8675309)
    x = numpy.random.random_sample(nobj)
    y = numpy.random.random_sample(nobj)
    ra = numpy.random.random_sample(nobj)
    dec = numpy.random.random_sample(nobj)
    w = numpy.random.random_sample(nobj)
    g1 = numpy.random.random_sample(nobj)
    g2 = numpy.random.random_sample(nobj)
    k = numpy.random.random_sample(nobj)

    flags = numpy.zeros(nobj).astype(int)
    for flag in [ 1, 2, 4, 8, 16 ]:
        sub = numpy.random.random_sample(nobj) < 0.1
        flags[sub] = numpy.bitwise_or(flags[sub], flag)

    file_name = os.path.join('data','test.dat')
    with open(file_name, 'w') as fid:
        # These are intentionally in a different order from the order we parse them.
        fid.write('# ra,dec,x,y,k,g1,g2,w,flag\n')
        for i in range(nobj):
            fid.write((('%.8f '*8)+'%d\n')%(ra[i],dec[i],x[i],y[i],k[i],g1[i],g2[i],w[i],flags[i]))

    # Check basic input
    config = {
        'x_col' : 3,
        'y_col' : 4,
        'x_units' : 'rad',
        'y_units' : 'rad',
        'w_col' : 8,
        'g1_col' : 6,
        'g2_col' : 7,
        'k_col' : 5,
    }
    cat1 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat1.x, x)
    numpy.testing.assert_almost_equal(cat1.y, y)
    numpy.testing.assert_almost_equal(cat1.w, w)
    numpy.testing.assert_almost_equal(cat1.g1, g1)
    numpy.testing.assert_almost_equal(cat1.g2, g2)
    numpy.testing.assert_almost_equal(cat1.k, k)

    # Check flags
    config['flag_col'] = 9
    cat2 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat2.w[flags==0], w[flags==0])
    numpy.testing.assert_almost_equal(cat2.w[flags!=0], 0.)

    # Check ok_flag
    config['ok_flag'] = 4
    cat3 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat3.w[numpy.logical_or(flags==0, flags==4)], 
                                      w[numpy.logical_or(flags==0, flags==4)])
    numpy.testing.assert_almost_equal(cat3.w[numpy.logical_and(flags!=0, flags!=4)], 0.)

    # Check ignore_flag
    del config['ok_flag']
    config['ignore_flag'] = 16
    cat4 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat4.w[flags < 16], w[flags < 16])
    numpy.testing.assert_almost_equal(cat4.w[flags >= 16], 0.)

    # Check different units for x,y
    config['x_units'] = 'arcsec'
    config['y_units'] = 'arcsec'
    cat5 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat5.x, x * (pi/180./3600.))
    numpy.testing.assert_almost_equal(cat5.y, y * (pi/180./3600.))

    config['x_units'] = 'arcmin'
    config['y_units'] = 'arcmin'
    cat5 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat5.x, x * (pi/180./60.))
    numpy.testing.assert_almost_equal(cat5.y, y * (pi/180./60.))

    config['x_units'] = 'deg'
    config['y_units'] = 'deg'
    cat5 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat5.x, x * (pi/180.))
    numpy.testing.assert_almost_equal(cat5.y, y * (pi/180.))

    del config['x_units']  # Default is radians
    del config['y_units']
    cat5 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat5.x, x)
    numpy.testing.assert_almost_equal(cat5.y, y)

    # Check ra,dec
    del config['x_col']
    del config['y_col']
    config['ra_col'] = 1
    config['dec_col'] = 2
    config['ra_units'] = 'rad'
    config['dec_units'] = 'rad'
    cat6 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat6.ra, ra)
    numpy.testing.assert_almost_equal(cat6.dec, dec)

    config['ra_units'] = 'deg'
    config['dec_units'] = 'deg'
    cat6 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat6.ra, ra * (pi/180.))
    numpy.testing.assert_almost_equal(cat6.dec, dec * (pi/180.))

    config['ra_units'] = 'hour'
    config['dec_units'] = 'deg'
    cat6 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat6.ra, ra * (pi/12.))
    numpy.testing.assert_almost_equal(cat6.dec, dec * (pi/180.))

    # Check using a different delimiter, comment marker
    csv_file_name = os.path.join('data','test.csv')
    with open(csv_file_name, 'w') as fid:
        # These are intentionally in a different order from the order we parse them.
        fid.write('% This file uses commas for its delimiter')
        fid.write('% And more than one header line.')
        fid.write('% Plus some extra comment lines every so often.')
        fid.write('% And we use a weird comment marker to boot.')
        fid.write('% ra,dec,x,y,k,g1,g2,w,flag\n')
        for i in range(nobj):
            fid.write((('%.8f,'*8)+'%d\n')%(ra[i],dec[i],x[i],y[i],k[i],g1[i],g2[i],w[i],flags[i]))
            if i%100 == 0:
                fid.write('%%%% Line %d\n'%i)
    config['delimiter'] = ','
    config['comment_marker'] = '%'
    cat7 = treecorr.Catalog(csv_file_name, config)
    numpy.testing.assert_almost_equal(cat7.ra, ra * (pi/12.))
    numpy.testing.assert_almost_equal(cat7.dec, dec * (pi/180.))
    numpy.testing.assert_almost_equal(cat7.g1, g1)
    numpy.testing.assert_almost_equal(cat7.g2, g2)
    numpy.testing.assert_almost_equal(cat7.w[flags < 16], w[flags < 16])
    numpy.testing.assert_almost_equal(cat7.w[flags >= 16], 0.)

 
def test_fits():
    get_aardvark()

    file_name = os.path.join('data','Aardvark.fit')
    config = treecorr.read_config('Aardvark.params')

    # Just test a few random particular values
    cat1 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_equal(len(cat1.ra), 390935)
    numpy.testing.assert_equal(cat1.nobj, 390935)
    numpy.testing.assert_almost_equal(cat1.ra[0], 56.4195 * (pi/180.))
    numpy.testing.assert_almost_equal(cat1.ra[390934], 78.4782 * (pi/180.))
    numpy.testing.assert_almost_equal(cat1.dec[290333], 83.1579 * (pi/180.))
    numpy.testing.assert_almost_equal(cat1.g1[46392], 0.0005066675)
    numpy.testing.assert_almost_equal(cat1.g2[46392], -0.0001006742)
    numpy.testing.assert_almost_equal(cat1.k[46392], -0.0008628797)

    # The catalog doesn't have x, y, or w, but test that functionality as well.
    del config['ra_col']
    del config['dec_col']
    config['x_col'] = 'RA'
    config['y_col'] = 'DEC'
    config['w_col'] = 'MU'
    config['flag_col'] = 'INDEX'
    config['ignore_flag'] = 64
    cat2 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_almost_equal(cat2.x[390934], 78.4782, decimal=4)
    numpy.testing.assert_almost_equal(cat2.y[290333], 83.1579, decimal=4)
    numpy.testing.assert_almost_equal(cat2.w[46392], 0.)        # index = 1200379
    numpy.testing.assert_almost_equal(cat2.w[46393], 0.9995946) # index = 1200386

    # Test using a limited set of rows
    config['first_row'] = 101
    config['last_row'] = 50000
    cat3 = treecorr.Catalog(file_name, config)
    numpy.testing.assert_equal(len(cat3.x), 49900)
    numpy.testing.assert_equal(cat3.nobj, sum(cat3.w != 0))
    numpy.testing.assert_almost_equal(cat3.g1[46292], 0.0005066675)
    numpy.testing.assert_almost_equal(cat3.g2[46292], -0.0001006742)
    numpy.testing.assert_almost_equal(cat3.k[46292], -0.0008628797)


def test_direct():

    nobj = 5000
    numpy.random.seed(8675309)
    x = numpy.random.random_sample(nobj)
    y = numpy.random.random_sample(nobj)
    ra = numpy.random.random_sample(nobj)
    dec = numpy.random.random_sample(nobj)
    w = numpy.random.random_sample(nobj)
    g1 = numpy.random.random_sample(nobj)
    g2 = numpy.random.random_sample(nobj)
    k = numpy.random.random_sample(nobj)

    cat1 = treecorr.Catalog(x=x, y=y, w=w, g1=g1, g2=g2, k=k)
    numpy.testing.assert_almost_equal(cat1.x, x)
    numpy.testing.assert_almost_equal(cat1.y, y)
    numpy.testing.assert_almost_equal(cat1.w, w)
    numpy.testing.assert_almost_equal(cat1.g1, g1)
    numpy.testing.assert_almost_equal(cat1.g2, g2)
    numpy.testing.assert_almost_equal(cat1.k, k)

    cat2 = treecorr.Catalog(ra=ra, dec=dec, w=w, g1=g1, g2=g2, k=k,
                            ra_units='hours', dec_units='degrees')
    numpy.testing.assert_almost_equal(cat2.ra, ra * treecorr.hours)
    numpy.testing.assert_almost_equal(cat2.dec, dec * treecorr.degrees)
    numpy.testing.assert_almost_equal(cat2.w, w)
    numpy.testing.assert_almost_equal(cat2.g1, g1)
    numpy.testing.assert_almost_equal(cat2.g2, g2)
    numpy.testing.assert_almost_equal(cat2.k, k)

def test_contiguous():
    # This unit test comes from Melanie Simet who discovered a bug in earlier
    # versions of the code that the Catalog didn't correctly handle input arrays
    # that were not contiguous in memory.  We want to make sure this kind of
    # input works correctly.  It also checks that the input dtype doesn't have
    # to be float

    source_data = numpy.array([
            (0.0380569697547, 0.0142782758818, 0.330845443464, -0.111049332655),
            (-0.0261291090735, 0.0863787933931, 0.122954685209, 0.40260430406),
            (0.125086697534, 0.0283621046495, -0.208159531309, 0.142491564101),
            (0.0457709426026, -0.0299249486373, -0.0406555089425, 0.24515956887),
            (-0.00338578248926, 0.0460291122935, 0.363057738173, -0.524536297555)],
            dtype=[('ra', None), ('dec', numpy.float64), ('g1', numpy.float32),
                   ('g2', numpy.float128)])

    config = {'min_sep': 0.05, 'max_sep': 0.2, 'sep_units': 'degrees', 'nbins': 5 }

    cat1 = treecorr.Catalog(ra=[0], dec=[0], ra_units='deg', dec_units='deg') # dumb lens
    cat2 = treecorr.Catalog(ra=source_data['ra'], ra_units='deg',
                            dec=source_data['dec'], dec_units='deg',
                            g1=source_data['g1'],
                            g2=source_data['g2'])
    cat2_float = treecorr.Catalog(ra=source_data['ra'].astype(float), ra_units='deg',
                                  dec=source_data['dec'].astype(float), dec_units='deg',
                                  g1=source_data['g1'].astype(float), 
                                  g2=source_data['g2'].astype(float))

    print "dtypes of original arrays: ", [source_data[key].dtype for key in ['ra','dec','g1','g2']]
    print "dtypes of cat2 arrays: ", [getattr(cat2,key).dtype for key in ['ra','dec','g1','g2']]
    print "is original g2 array contiguous?", source_data['g2'].flags['C_CONTIGUOUS']
    print "is cat2.g2 array contiguous?", cat2.g2.flags['C_CONTIGUOUS']
    assert not source_data['g2'].flags['C_CONTIGUOUS']
    assert cat2.g2.flags['C_CONTIGUOUS']

    ng = treecorr.NGCorrelation(config)
    ng.process(cat1,cat2)
    ng_float = treecorr.NGCorrelation(config)
    ng_float.process(cat1,cat2_float)

    numpy.testing.assert_equal(ng.xi, ng_float.xi)

if __name__ == '__main__':
    test_ascii()
    test_fits()
    test_direct()
    test_contiguous()
