#! /usr/bin/env python
"""
"""
import os
import re
import sys
import shutil
import tempfile
import argparse
import subprocess
from glob import glob
import numpy as np

def main():
    
    # Create a temp directory to operate in
    tempdir = tempfile.mkdtemp()
    if args.verbose:
        print "Temp Directory: ", tempdir

    # Save the current dir and move to the temp dir
    origdir = os.getcwd()

    # Figure out what the final gif will be called
    outfile = args.out
    if outfile is None:
        outfile = timeseries + ".gif"
    else:
        outfile = os.path.abspath(outfile)

    try:
        # Get an absolute path to the input timeseries and parfile
        timeseries = os.path.abspath(args.timeseries)
        if args.verbose:
            print "Input timeseries: ", timeseries
        if args.plot is not None:
            parfile = os.path.abspath(args.plot)
            if args.verbose:
                print "Input parfile: ", parfile
        
        # Move into the temp dir
        os.chdir(tempdir)
        
        # Skullstrip the timeseries
        if args.bet:
            runcmd(["bet", timeseries, "ts_bet", "-F"])
        else:
            runcmd(["cp", timeseries, "ts_bet.nii.gz"])

        # Get a robust threshold for the timeseries
        thresh = runcmd(["fslstats", "ts_bet", "-r"], return_stdout=True).strip()

        # Split the timeseries into frames
        runcmd(["fslsplit", "ts_bet", "ts-"])

        # Get a list of the split timeseries volumes
        raw_frames = glob("ts-*")
        raw_frames.sort()

        # Make a title png
        title = timeseries
        titlepng = "title.png"
        runcmd(["convert -size 750x30 xc:black -font helvetica -pointsize 18",
                "-gravity Center -fill white -draw \"text 0,0 '%s'\" title.png"%title])

        # Make pngs of each raw frame and titles
        raw_pngs = []
        for frame in raw_frames:
            rawpng = frame + ".png"
            runcmd(["slicer", frame, "-i", thresh, "-A 750", rawpng])
            raw_pngs.append(rawpng)

        if args.ref is not None:
            # Write a stdev image
            runcmd(["fslmaths", "ts_bet", "-Tstd", "std"]) 

            # Get the reference image command
            if args.ref in ["mean", "median"]:
                cmd = ["fslmaths", "ts_bet", "-T%s"%args.ref, "ref"]
            else:
                if args.ref == "first":
                    idx = "0"
                elif args.ref == "middle":
                    length = runcmd(["fslval", timeseries, "dim4"], return_stdout=True)
                    idx = str((int(length)-1)/2)
                else:
                    idx = args.ref
                cmd = ["fslroi", timeseries, "ref", idx, "1"]
            
            # Write the reference image
            runcmd(cmd)
            
            # Create the art images, overlay, and slice them
            movie_pngs = []
            for i, frame in enumerate(raw_frames):
                runcmd(["fslmaths", frame, "-sub", "ref", "-div", "std", "art%04d"%i])
                runcmd(["overlay", "0", "0", "art%04d"%i, frame, "-a",
                        "art%04d"%i, args.min, args.max,
                        "art%04d"%i, "-"+args.min, "-"+args.max, 
                        "art_overlay%04d"%i])
                artpng = "art_overlay%04d.png"%i
                runcmd(["slicer", "art_overlay%04d"%i, "-A", "750", artpng])
                movie_png = "movie_%04d.png"%i
                runcmd(["pngappend", "title.png", "-", raw_pngs[i], "-", artpng, movie_png])
                movie_pngs.append(movie_png)
        
        else:
            # Otherwise, just title each raw clip
            movie_pngs = []
            for i, raw_png in enumerate(raw_pngs):
                movie_png = "movie_%04d.png"%i
                runcmd(["pngappend", "title.png", "-", raw_png, movie_png])
                movie_pngs.append(movie_png)

        # Plot a timeseries
        if args.plot is not None:
            # Get an array of the full timeseries
            full_ts = np.genfromtxt(parfile)
            # Normalize (maybe)
            if args.normplot:
                full_ts = (full_ts - full_ts.mean(axis=0))/full_ts.std(axis=0)
            # Pad with zeros in case the text array is shorter
            # than the movie (e.g. it's relative displacement)
            if full_ts.shape[0] < len(movie_pngs):
                if len(full_ts.shape) > 1:
                    full_ts = np.vstack((
                        np.zeros(((len(movie_pngs)-full_ts.shape[0]), full_ts.shape[1])),
                        full_ts))
                else:
                    full_ts = np.hstack((
                        np.zeros((len(movie_pngs)-full_ts.shape[0])), full_ts))
            # Find the min and max for a stationary plot
            min = np.abs(full_ts).min()
            max = np.abs(full_ts).max()
            for i, movie_png in enumerate(movie_pngs):
                # Initialize an emtpy timeseries
                sofar_ts = np.zeros(full_ts.shape)
                f = 0
                # Add in the timeseries so far
                while f <= i:
                    sofar_ts[f] = full_ts[f]
                    f += 1
                # Write the timeseries text file for this frame
                sofar_par = "parfile_%04d.txt"%i
                np.savetxt(sofar_par, sofar_ts)
                
                # Write the timeseries plot for this frame
                plot_png = "plot_%04d.png"%i
                runcmd(["fsl_tsplot -i",sofar_par,"-t",args.plot,
                        "--ymin=%.3f"%min, "--ymax=%.3f"%max ,"-w 750","-o",plot_png])
                
                # Append the plot to the movie frame
                runcmd(["pngappend", movie_png, "-", plot_png, movie_png])

        # Convert the movie frames to .gif
        movie_gifs = []
        for png in movie_pngs:
            runcmd(["convert", png, png[:-4]+".gif"])
            movie_gifs.append(png[:-4] + ".gif")
        
        # Turn the frames into an animated .gif
        outfile = os.path.join(origdir, outfile)
        cmd = ["whirlgif", "-o", outfile, "-loop", "100"]
        cmd.extend(movie_gifs)
        runcmd(cmd)

        # Open the gif in a browser (does this work? making impossible for now)
        if 0 and args.browser:
            browsercmd = "firefox %s &"%outfile
            if args.verbose:
                print browsercmd
            os.system(browsercmd)

    except RuntimeError, e:
        # Catch a runtime error and exit politely
        sys.exit(e)
    finally:
        # Regardless, go back to the original directory
        os.chdir(origdir)
        # Probably clean up the temp dir, unless we're debugging
        if args.clean:
            shutil.rmtree(tempdir)
        else:
            print "Tempdir %s not removed"%tempdir

def runcmd(cmdline, return_stdout=False):
    """Use the subprocess.Popen class to run a command"""
    cmd = " ".join(cmdline)
    if args.verbose:
        print cmd

    proc = subprocess.Popen(cmd, 
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE,
                            env=os.environ,
                            shell=True,
                            cwd=os.getcwd())
               

    stdout, stderr = proc.communicate()

    # Some of these programs are annoying about stderr
    if stderr and not "WARNING" in stderr and not stderr.startswith("whirlgif"):
        raise RuntimeError(stderr)

    if args.verbose:
        print stdout
    if return_stdout:
        return stdout

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-ts", dest="timeseries", help="functional timeseries")
    parser.add_argument("-ref", metavar="REF_TYPE",help="reference type for artifact visualization")
    parser.add_argument("-min", default="2",
                        help="minimum threshold for artifact visualization (def: 2)")
    parser.add_argument("-max", default="5",
                        help="saturation point for artifact visualization (def: 5)")
    parser.add_argument("-plot", metavar="TEXT_FILE", help="text file to generate plot with")
    parser.add_argument("-normplot", action="store_true", help="normalize plot file columns")
    parser.add_argument("-nobet", dest="bet", action="store_false",
                        help="do not run BET on the timeseries")
    parser.add_argument("-out", help="gif file to write")
    parser.add_argument("-verbose",action="store_true",help="verbose output")
    parser.add_argument("-noclean", dest="clean", action="store_false",
                        help="do not delete temporary directory")
    if len(sys.argv) == 1:
        sys.argv.append("-h")
    args = parser.parse_args()
    main()
