#-----------------------------------------------------------------------------
# Copyright (c) 2008  Raymond L. Buvel
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#-----------------------------------------------------------------------------

'''Module for interpolating JPL ephemeris tables.

To use this module, create an instance of the class Ephemeris and call the
appropriate methods to extract the required information.  You may also want to
sub-class Ephemeris if the results are required in a different form.  See the
test program testEphem.py for an example.
'''
__all__ = () # Don't export anything for "from ephemPy import *"

import ephemUtils

#-----------------------------------------------------------------------------
# Interface to Numpy arrays

cdef extern from 'numpy/arrayobject.h':

    ctypedef struct PyArrayObject:
        char *data

    cdef enum NPY_TYPES:
        NPY_DOUBLE

    ctypedef int npy_intp

    void import_array()
    object PyArray_SimpleNew(int ndims, npy_intp* dims, NPY_TYPES type_num)


# The API requires this function to be called before
# using any Numpy facilities in an extension module.
import_array()

#-----------------------------------------------------------------------------
cdef class Ephemeris:
    '''Interpolate JPL ephemeris tables.

    An ephemeris object is created with the following call.

    Ephemeris(ephemerisNumber)
        ephemerisNumber - numeric designation of the ephemeris to open.

    The following attributes are extracted from the header records of the
    selected ephemeris.

    title - a string containing the three lines in the title section.
    eStartTime - starting Julian day of the ephemeris
    eEndTime - ending Julian day of the ephemeris
    eTimeStep - time interval covered by each record
    constants - class instance containing the constants found in the header
        records.  For example, to get the number of kilometers in an
        Astronomical Unit, use constants.AU.
    dataStruct - array containing the structure parameters for a data
        record.  See the JPL documentation for details.
    numRecords - number of data records in the ephemeris file.

    Additional attributes:

    record - array containing the current data record
    rStartTime - starting Julian day of the current record
    rEndTime - ending Julian day of the current record
    refTime - reference time set by the user with the setRefTime method.  The
        time parameter to the other methods is relative to this value.
    arrayBytes - number of bytes in a data record.

    The following attributes are used to select the target.

    Note: the EARTH target is the Earth-Moon barycenter and the MOON is
    relative to the the geocenter.

    MERCURY, VENUS, EARTH, MARS, JUPITER, SATURN, URANUS, NEPTUNE, PLUTO,
    MOON, SUN

    The following are required by the test program but must not be used as
    targets for any of the methods.

    SS_BARY, EM_BARY, NUTATIONS, LIBRATIONS

    The following methods are available.

    position(t, target)
        Interpolate the position vector of the target.

    state(t, target)
        Interpolate the state vector of the target.

    nutations(t)
        Interpolate nutations

    librations(t)
        Interpolate librations

    setRefTime(t)
        Set the reference time to the start of the record containing t.

    getRecord(t)
        Get the record corresponding to the specified time.
    '''

    # The following constants can be used for the target parameter of the
    # methods requiring a target.
    cdef readonly int MERCURY
    cdef readonly int VENUS
    cdef readonly int EARTH        # Earth-Moon Barycenter
    cdef readonly int MARS
    cdef readonly int JUPITER
    cdef readonly int SATURN
    cdef readonly int URANUS
    cdef readonly int NEPTUNE
    cdef readonly int PLUTO
    cdef readonly int MOON         # Relative to geocenter
    cdef readonly int SUN

    # The following are required by the test program but must not be used as
    # targets for any of the methods.
    cdef readonly int SS_BARY
    cdef readonly int EM_BARY
    cdef readonly int NUTATIONS
    cdef readonly int LIBRATIONS

    # Other attributes available to the user.
    cdef readonly object _efile
    cdef readonly int arrayBytes
    cdef readonly object title
    cdef readonly double eStartTime
    cdef readonly double eEndTime
    cdef readonly double eTimeStep
    cdef readonly object dataStruct
    cdef readonly object constants
    cdef readonly int numRecords
    cdef readonly object record
    cdef readonly double rStartTime
    cdef readonly double rEndTime
    cdef readonly double refTime

    # Private variables used to make the calculations more efficient.
    cdef double *_record
    cdef long *_dataStruct
    cdef double _Tc
    cdef double _dt
    cdef double *_A


    def __init__(self, ephemerisNumber):
        # Initialize the constants
        self.MERCURY = 0
        self.VENUS = 1
        self.EARTH = 2
        self.MARS = 3
        self.JUPITER = 4
        self.SATURN = 5
        self.URANUS = 6
        self.NEPTUNE = 7
        self.PLUTO = 8
        self.MOON = 9
        self.SUN = 10
        self.SS_BARY = 11
        self.EM_BARY = 12
        self.NUTATIONS = 13
        self.LIBRATIONS = 14

        # Initialize the reference time so that the time parameter represents
        # the Julian day values in the ephemeris file.
        self.refTime = 0.0

        # Read the header and assign the result to instance attributes.
        hdr = ephemUtils.readHeader(ephemerisNumber)
        self.title = hdr[0]
        self.eStartTime, self.eEndTime, self.eTimeStep = hdr[1]
        self.dataStruct = hdr[2]
        self.constants = hdr[3]
        self.numRecords = hdr[4]
        self._efile = hdr[5]
        self.arrayBytes = hdr[6]

        # Read the first record to set the record times.
        _readRecord(self, 0)

        # Map the Numpy array to a C data structure.
        self._dataStruct = <long*>(<PyArrayObject*>self.dataStruct).data


    def position(self, double t, int target):
        '''Interpolate the position vector of the target.

        Returns an array containing the position measured in kilometers.  For
        all targets except the Moon, the position is relative to the Solar
        System barycenter.  For the Moon, the position is relative to the
        geocenter.

        t - time in Julian days at which the position is desired.
        target - object for which the position is desired [0,...,10].
        '''
        cdef int N, i
        cdef npy_intp dims
        cdef object _pos
        cdef double *pos
        cdef double *A
        cdef double Tc

        if not (0 <= target <= 10):
            raise ValueError('target out of range')
        N = _getParms(self, t, target)
        A = self._A
        Tc = self._Tc
        dims = 3
        _pos = PyArray_SimpleNew(1, &dims, NPY_DOUBLE)
        pos = <double*>(<PyArrayObject*>_pos).data
        for i from 0 <= i < 3:
            pos[i] = chebeval(&A[i*N], N, Tc)
        return _pos


    def state(self, double t, int target):
        '''Interpolate the state vector of the target

        Returns an array containing the state vector of the target.  The
        position is in the first three elements and is measured in kilometers.
        The velocity is in the last three elements and is measured in
        kilometers per Julian day.  For all targets except the Moon, the
        position is relative to the Solar System barycenter.  For the Moon,
        the position is relative to the geocenter.

        t - time in Julian days at which the state vector is desired.
        target - object for which the state vector is desired [0,...,10].
        '''
        cdef int N, i
        cdef npy_intp dims
        cdef object _PV
        cdef double *PV
        cdef double *A
        cdef double Tc,dt2

        if not (0 <= target <= 10):
            raise ValueError('target out of range')
        N = _getParms(self, t, target)
        A = self._A
        Tc = self._Tc
        dt2 = 2.0/self._dt
        dims = 6
        _PV = PyArray_SimpleNew(1, &dims, NPY_DOUBLE)
        PV = <double*>(<PyArrayObject*>_PV).data
        for i from 0 <= i < 3:
            PV[i] = chebeval(&A[i*N], N, Tc)
            PV[i+3] = chebder(&A[i*N], N, Tc)*dt2
        return _PV


    def nutations(self, double t):
        '''Interpolate nutations

        t - time in Julian days at which the nutations are desired.
        '''
        cdef int N, i
        cdef npy_intp dims
        cdef object _NU
        cdef double *NU
        cdef double *A
        cdef double Tc,dt2

        if self._dataStruct[11*3+1] < 2:
            raise TypeError('Ephemeris does not contain nutations')
        N = _getParms(self, t, 11)
        A = self._A
        Tc = self._Tc
        dt2 = 2.0/self._dt
        dims = 4
        _NU = PyArray_SimpleNew(1, &dims, NPY_DOUBLE)
        NU = <double*>(<PyArrayObject*>_NU).data
        for i from 0 <= i < 2:
            NU[i] = chebeval(&A[i*N], N, Tc)
            NU[i+2] = chebder(&A[i*N], N, Tc)*dt2
        return _NU


    def librations(self, double t):
        '''Interpolate librations

        t - time in Julian days at which the Librations are desired.
        '''
        cdef int N, i
        cdef npy_intp dims
        cdef object _LI
        cdef double *LI
        cdef double *A
        cdef double Tc,dt2

        if self._dataStruct[12*3+1] < 2:
            raise TypeError('Ephemeris does not contain librations')
        N = _getParms(self, t, 12)
        A = self._A
        Tc = self._Tc
        dt2 = 2.0/self._dt
        dims = 6
        _LI = PyArray_SimpleNew(1, &dims, NPY_DOUBLE)
        LI = <double*>(<PyArrayObject*>_LI).data
        for i from 0 <= i < 3:
            LI[i] = chebeval(&A[i*N], N, Tc)
            LI[i+3] = chebder(&A[i*N], N, Tc)*dt2
        return _LI


    def setRefTime(self, double t):
        '''Set the reference time to the start of the record containing t.

        If t == 0, the reference time is removed.

        Returns the difference from the value set as the reference time.

        t - time in Julian days
        '''
        if t == 0:
            self.refTime = 0.0
            return 0.0

        if t < self.eStartTime or t > self.eEndTime:
            raise ValueError('Time out of range')

        _readRecord(self, <int>((t-self.eStartTime)/self.eTimeStep))
        self.refTime = self.rStartTime
        return t - self.refTime


    def getRecord(self, double t):
        '''Get the record corresponding to the specified time.

        t - time in Julian days
        '''
        t = t + self.refTime
        if t >= self.rStartTime and t <= self.rEndTime:
            return self.record

        if t < self.eStartTime or t > self.eEndTime:
            raise ValueError('Time out of range')

        _readRecord(self, <int>((t-self.eStartTime)/self.eTimeStep))

        if t < self.rStartTime or t > self.rEndTime:
            raise ValueError('Invalid record')
        return self.record

#-----------------------------------------------------------------------------
import numpy
cdef object _Num_fromstring
_Num_fromstring = numpy.fromstring
cdef object _Num_Float64
_Num_Float64 = numpy.float64

cdef void _readRecord(Ephemeris self, int num):
    cdef int arrayBytes
    cdef object efile

    arrayBytes = self.arrayBytes
    efile = self._efile
    efile.seek((num+2)*arrayBytes)
    self.record = _Num_fromstring(efile.read(arrayBytes), _Num_Float64)

    # Map the Numpy array to a C data structure.
    self._record = <double*>(<PyArrayObject*>self.record).data

    self.rStartTime = self._record[0]
    self.rEndTime = self._record[1]

#-----------------------------------------------------------------------------
cdef int _getParms(Ephemeris self, double t, int target):
    cdef double timeStep, dt, t0
    cdef int C,N,G,i

    self.getRecord(t)
    timeStep = self.eTimeStep
    t0 = self.rStartTime - self.refTime

    # Get structure parameters for the specified target
    C = self._dataStruct[3*target]
    N = self._dataStruct[3*target+1]
    G = self._dataStruct[3*target+2]

    if G == 1:
        dt = timeStep
        self._Tc = 2.0*(t-t0)/dt - 1.0
    else:
        dt = timeStep/G  # Time step per granule
        i = <int>((t-t0)/dt)
        if i == G: i = i-1 # This can happen if the time is the endpoint

        if target == 11:
            # Nutations only have two entries
            C = C + i*2*N
        else:
            C = C + i*3*N
        self._Tc = 2.0*((t-t0)-i*dt)/dt - 1.0

    self._dt = dt
    self._A = &self._record[C]
    return N

#-----------------------------------------------------------------------------
cdef double chebeval(double *coef, int N, double x):
    cdef double x2, d, dd
    cdef int k

    # Use Clenshaw's recurrence to evaluate the polynomial.  See Numerical
    # Recipes for a discussion.
    x2 = 2.0*x
    d = dd = 0.0
    k = N-1
    while k >= 1:
        dd, d = d, x2*d - dd + coef[k]
        k = k-1
    return x*d - dd + coef[0]

#-----------------------------------------------------------------------------
cdef double chebder(double *coef, int N, double x):
    cdef double x2, d, dd
    cdef int k

    # Use Clenshaw's recurrence to evaluate the polynomial.  See Numerical
    # Recipes for a discussion.
    x2 = 2.0*x
    d = dd = 0.0
    k = N-1
    while k >= 2:
        dd, d = d, x2*d - dd + k*coef[k]
        k = k-1
    return x2*d - dd + coef[1]

