#!/usr/bin/env python

"""
Performs Solar Rotational Tomography from the command line.
"""

def usage():
    print(__usage__)

__usage__ = """Usage: srt [options] path [output]

Options:

  -h --help          Show this help message and exit.
  --config           Config file name (default: srt_default.cfg).
                     Command line options overloads config file options.

  Data parameters:

  -b --bin           Bin factor of images.
  -s --time_step     Time step between two images of the same kind.
  -w --time_window   Time window in which to look for data.
  -i --instrument    Instrument name(s).
  -o --observatory   Observatory name(s).
  -d --detector      Detector name(s).

  Object parameters:

  --naxis            Object shape in pixels.
  --crpix            Position of the reference pixel in fraction of pixels.
  --cdelt            Size of a pixel in physical coordinates.
  --crval            Position of the reference pixel in physical coordinates.

  Masking parameters:

  --obj_rmin
  --obj_rmax
  --data_rmin
  --data_rmax
  -n --negative      Mask negative data values.

  Optimization options:

  --hyperparameters  Hyperparameters of the smoothness prior.
  --maxiter          Maximum iteration number.
  --tol              Tolerance.

  Other options

  --output           Output filename (default: srt.fts).

"""

options = "hb:s:w:i:o:d:n"

long_options = ["help", "bin=", "time_step=", "time_window=",
                "instrument=", "observatory=", "detector=", 
                "naxis=", "crpix=", "cdelt=", "crval=",
                "obj_rmin=", "obj_rmax=", "data_rmin=", "data_rmax=",
                "negative",
                "hyperparameters=", "maxiter=", "tol="]

def main():
    """Handle config file, options and perform computations accordingly."""
    import os, getopt, sys, ConfigParser
    import lo, siddon
    import fitsarray as fa
    import secchi, models

    # parse command line arguments
    try:
        opts, args = getopt.getopt(sys.argv[1:], options, long_options)
    except getopt.GetoptError, err:
        # print help information and exit:
        print str(err) # will print something like "option -a not recognized"
        usage()
        sys.exit(2)
    # defaults
    config_file = "srt_default.cfg"
    data_params = dict()
    mask_opt = dict()
    mask_negative = False
    output = "srt.fts"
    # parse config file
    for o, a in opts:
        if o == "--config":
            config_file = a
    config = ConfigParser.RawConfigParser()
    config.read(config_file)
    instrume = parse_tuple(config.get("data", "instrument"))
    obsrvtry = parse_tuple(config.get("data", "observatory"))
    detector = parse_tuple(config.get("data", "detector"))
    data_params["bin_factor"] = config.getint("data", "bin")
    data_params["time_step"] = config.getfloat("data", "time_step")
    data_params["time_window"] = parse_tuple(config.get("data", "time_window"))
    naxis = parse_tuple_int(config.get("object", "naxis"))
    crpix = parse_tuple_float(config.get("object", "crpix"))
    cdelt = parse_tuple_float(config.get("object", "cdelt"))
    crval = parse_tuple_float(config.get("object", "crval"))
    mask_opt["obj_rmin"] = config.getfloat("masking", "obj_rmin")
    mask_opt["obj_rmax"] = config.getfloat("masking", "obj_rmax")
    mask_opt["data_rmin"] = config.getfloat("masking", "data_rmin")
    mask_opt["data_rmax"] = config.getfloat("masking", "data_rmax")
    mask_opt["mask_negative"] = config.getboolean("masking", "negative")
    hypers = parse_tuple_float(config.get("optimization", "hyperparameters"))
    maxiter = config.getint("optimization", "maxiter")
    tol = config.getfloat("optimization", "tol")
    # parse arguments
    if len(args) == 0:
        usage()
        sys.exit()
    path = args[0]
    if len(args) > 1:
        output = args[1]
    # parse options
    for o, a in opts:
        if o in ("-h", "--help"):
            usage()
            sys.exit()
        # data parameters
        elif o in ("-b", "--bin"):
            data_params["bin_factor"] = int(a)
        elif o in ("-s", "--time_step"):
            data_params["time_step"] = float(a)
        elif o in ("-w", "--time_window"):
            data_params["time_window"] = parse_tuple(a)
        elif o in ("-i", "--instrument"):
            instrume = parse_tuple(a)
        elif o in ("-o", "--observatory"):
            obsrvtry = parse_tuple(a)
        elif o in ("-d", "--detector"):
            detector = parse_tuple(a)
        # object parameters
        elif o == "--naxis":
            naxis = parse_tuple_int(a)
        elif o == "--crpix":
            crpix = parse_tuple_float(a)
        elif o == "--cdelt":
            cdelt = parse_tuple_float(a)
        elif o == "--crval":
            crval = parse_tuple_float(a)
        # masking parameters
        elif o == "--obj_rmin":
            mask_opt["obj_rmin"] = float(a)
        elif o == "--obj_rmax":
            mask_opt["obj_rmax"] = float(a)
        elif o == "--data_rmin":
            mask_opt["data_rmin"] = float(a)
        elif o == "--data_rmax":
            mask_opt["data_rmax"] = float(a)
        elif o in ("-n", "--negtative"):
            mask_opt["mask_negative"] = True
        # optimization parameters
        elif o == "--hyperparameters":
            hypers = parse_tuple_float(a)
        elif o == "--maxiter":
            maxiter = int(a)
        elif o == "--tol":
            tol = float(a)
        # other parameters
        elif o in ("-o", "--output"):
            output = a
        else:
            assert False, "unhandled option"

    # data
    data = list()
    for instr in instrume:
        for obs in obsrvtry:
            for det in detector:
                data.append(secchi.read_data(path,
                                             instrume=instr,
                                             obsrvtry=obs,
                                             detector=det,
                                             **data_params))
    data = secchi.concatenate(data)
    data = secchi.sort_data_array(data)
    # enforce 64 bits data XXX
    data.header["BITPIX"][:] = -64
    # cube
    object_header = make_object_header(naxis, crpix, cdelt, crval)
    obj = fa.fitsarray_from_header(object_header)
    # model
    P, D, obj_mask, data_mask = models.srt(data, obj, **mask_opt)
    # inversion
    b = data[data_mask == 0]
    sol = lo.quadratic_optimization(P, b, D, hypers, maxiter=maxiter, tol=tol)
    # reshape result
    fsol = fa.zeros(obj.shape, header=object_header)
    fsol[obj_mask == 0] = sol.flatten()
    fsol.tofits(output)
    return None

def parse_tuple(my_str):
    """
    Parse input parameters which can be tuples.
    """
    # remove parenthesis
    my_str = my_str.rstrip(")")
    my_str = my_str.lstrip("(")
    my_str = my_str.rstrip("]")
    my_str = my_str.lstrip("[")
    # split tuple elements if any
    str_list = my_str.split(",")
    # remove trailing whitespaces
    str_list = [s.rstrip() for s in str_list]
    str_list = [s.lstrip() for s in str_list]
    return str_list

def parse_tuple_int(my_str):
    """
    Parse tuple and convert to int.
    """
    return [int(s) for s in parse_tuple(my_str)]

def parse_tuple_float(my_str):
    """
    Parse tuple and convert to float.
    """
    return [float(s) for s in parse_tuple(my_str)]

def make_object_header(naxis, crpix, cdelt, crval):
    header = dict()
    header["NAXIS"] = len(naxis)
    header["BITPIX"] = -64
    for i in xrange(len(naxis)):
        header["NAXIS" + str(i + 1)] = naxis[i]
        header["CRPIX" + str(i + 1)] = crpix[i]
        header["CDELT" + str(i + 1)] = cdelt[i]
        header["CRVAL" + str(i + 1)] = crval[i]
    return header

if __name__ == "__main__":
    main()
