#!/usr/bin/env python
"""Create pretty graphs for publication in a scientific journal.

This script expects data folders in the current directory.

Expected data folders:
    comparison-default_single_15
    comparison-mutate_chance_100_runs
    comparison-gather_proportion_100_runs
    comparison-strong_chance_100_runs
    comparison-complexity_4_point_mutation
    comparison-complexity_4_other_mutation
    comparison-complexity_4_sibling_distance
    comparison-complexity_4_strong_chance
    comparison-complexity_4_strong_factor
    comparison-complexity_4_sibling_distance_with_high_point_mutation

Outputs graphs as .pdf files in the current directory.
"""
import matplotlib
import math
matplotlib.use('Agg')  # for running over ssh
import matplotlib.pyplot as plt
import numpy as np
import os.path
import sys

###############################################################################
# Edit these variables to change figure sizes.
# figure sizes, in LaTeX pts
# change these to ajust the figure size
# used for single column figures
FIGURE_WIDTH = 229.5
# used for images with three figures across the whole page
FIGURE_WIDTH_SMALL = 150.08344
# change this if the legends for batch graphs get cut off
PLOT_TO_LEGEND_RATIO = 0.755
###############################################################################

# Convert pts to inches
PTS_TO_INCHES = 1.0/72.27
GOLDEN_RATIO = (math.sqrt(5)-1.0)/2.0
FIGURE_SIZE = (FIGURE_WIDTH*PTS_TO_INCHES,
               FIGURE_WIDTH*PTS_TO_INCHES*GOLDEN_RATIO)
FIGURE_SIZE_SMALL = (FIGURE_WIDTH_SMALL*PTS_TO_INCHES,
                     FIGURE_WIDTH_SMALL*PTS_TO_INCHES*GOLDEN_RATIO*1.5)

params = {
'axes.labelsize': 10,
'text.fontsize': 10,
'legend.fontsize': 8,
'xtick.labelsize': 8,
'ytick.labelsize': 8,
#'text.usetex': True,
'figure.figsize': FIGURE_SIZE
}
matplotlib.rcParams.update(params)

def graph_representitive_run(analyzer_dir):
    """Create graphs from a representitive run using default settings.

    Graphs are saved as .pdf's in the current directory.
    """
    def load_data(metric):
        return np.loadtxt(os.path.join(analyzer_dir, metric), ndmin=2)

    data = load_data('complexity_avg')
    fig = plt.figure()
    plot_single_line(data, ylabel='Complexity')
    fig.tight_layout(pad=0.12, w_pad=0.0, h_pad=0.0)
    plt.savefig('part_1_complexity.pdf')

    data = load_data('fitness_avg')
    fig = plt.figure()
    plot_single_line(data, ylabel='Fitness')
    fig.tight_layout(pad=0.12, w_pad=0.0, h_pad=0.0)
    plt.savefig('part_1_fitness.pdf')

    data = load_data('genome_length_avg')
    fig = plt.figure()
    plot_single_line(data, ylabel='Genome length')
    fig.tight_layout(pad=0.12, w_pad=0.0, h_pad=0.0)
    plt.savefig('part_1_genome_length.pdf')

    complexity_data = load_data('complexity_avg')
    irr_data = load_data('irreducible_complexity_avg')
    fig = plt.figure()
    plot_complexities(complexity_data, irr_data)
    fig.tight_layout(pad=0.12, w_pad=0.0, h_pad=0.0)
    plt.savefig('part_1_complexities.pdf')

    fig = plt.figure()
    plot_single_line(irr_data, ylabel='Interlocking complexity')
    fig.tight_layout(pad=0.12, w_pad=0.0, h_pad=0.0)
    plt.savefig('part_1_interlocking_complexity.pdf')

    # plot of small and large complexes for poster
    func_1 = load_data('percent_functional-1')
    func_6 = load_data('percent_functional-5')
    fig = plt.figure()
    plt.plot(func_1, label='Small proteins', color='#BF4040')
    plt.plot(func_6, label='Large complexes', color='#40BFBF')
    plt.ylabel('Percent functional')
    plt.xlabel('Generation number')
    plt.legend(loc='best', fancybox=True)
    fig.tight_layout()
    plt.savefig('poster_complexes2.pdf')

    matplotlib.rcParams.update({'figure.size': FIGURE_SIZE_SMALL,
                                'axes.labelsize': 9.5,
                                'ytick.labelsize': 8,})
    for x in range(1, 6):
        data = load_data('percent_functional-{}'.format(x))
        fig = plt.figure(figsize=FIGURE_SIZE_SMALL)
        plot_single_line(data,
                    ylabel='% Functional (Complexes-{}'.format(x))
        plt.ylim([0.0, 1.0])
        fig.tight_layout(pad=0.15, w_pad=0.0, h_pad=0.0)
        plt.savefig('part_1_functional_{}.pdf'.format(x))
    matplotlib.rcParams.update({'figure.size':FIGURE_SIZE,
                                'axes.labelsize':10,
                                'ytick.labelsize': 8,})

def plot_single_line(data, xlabel='Generation number', ylabel='',
                     color='black'):
    plt.plot(data, color=color)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

def plot_complexities(complexity_data, interlocking_complexity_data):
    plt.plot(complexity_data, color='black', label='Complexity')
    plt.plot(interlocking_complexity_data, color='0.5',
             label='Interlocking complexity')
    plt.xlabel('Generation number')
    plt.legend(loc='best', fancybox=True)

def graph_batch_comparison(analyzer_dir, legend_labels, ignore=0,
                           legend_title=''):
    """Plot complexity, irr complexity, and genome length of a comparison.

    Ignores data from the last ``ignore`` simulations.
    """
    def load_data(name, ignore=ignore):
        return np.loadtxt(os.path.join(analyzer_dir, name),
                          ndmin=2)[:, :-ignore]

    fname_part = '_'.join(legend_title.lower().split())

    data = load_data('complexity_avg')
    fig = plt.figure()
    ax = plt.subplot(111)
    plot_lines(data, legend_labels, ylabel='Complexity',
               legend_title=legend_title, ax=ax)
    plt.savefig('part_2_{}_complexity.pdf'.format(fname_part))

    data = load_data('irreducible_complexity_avg')
    fig = plt.figure()
    ax = plt.subplot(111)
    plot_lines(data, legend_labels, ylabel='Interlocking complexity',
               legend_title=legend_title, ax=ax)
    plt.savefig('part_2_{}_inter_complexity.pdf'.format(fname_part))

    data = load_data('genome_length_avg')
    fig = plt.figure()
    ax = plt.subplot(111)
    plot_lines(data, legend_labels, ylabel='Genome length',
               legend_title=legend_title, ax=ax)
    plt.savefig('part_2_{}_genome_length.pdf'.format(fname_part))

def plot_lines(data, legend_labels, ylabel='', xlabel='Generation number',
               legend_title='', ax=None):
    """Plots multiple lines with a legend.
    """
    if ax is None:
        ax = plt
    colors = [str(0.0 + 0.05*x) for x in range(data.shape[1])]
    for i, column in enumerate(data.T):
        ax.plot(column, color=colors[i],
                 linestyle=["solid","dashed","dashdot","dotted"][i % 4],
                 label=str(legend_labels[i]))
    plt.ylabel(ylabel)
    plt.xlabel(xlabel)
    # create legend above figure
    plt.tight_layout(pad=0.2, w_pad=0.0, h_pad=0.0)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * PLOT_TO_LEGEND_RATIO, box.height])
    ax.legend(fancybox=True, title='',
               bbox_to_anchor=(1, 0.5), loc='center left', ncol=1)

def graph_gens_until(analyzer_dir, bar_labels, threshold_index=0, value=1.0,
                                metric='complexity_avg', mean_aggregated=0.6,
                                varied_option='N/A', num_bars=None,
                                reverse_axis=False, suffix=''):
    """Create a bar graph of the number of generations it took to reach 
    a threshold value.
    """
    metric_filename = 'gens_until_{}'.format(metric)
    data = np.loadtxt(os.path.join(analyzer_dir, metric_filename), ndmin=2)
    blanks = np.loadtxt(os.path.join(analyzer_dir,
                        '{}.blanks'.format(metric_filename)), ndmin=2)
    stddev = np.loadtxt(os.path.join(analyzer_dir,
                        '{}.stddev'.format(metric_filename)), ndmin=2)
    num_runs = np.loadtxt(os.path.join(analyzer_dir, 'num_runs'), dtype=float)
    if num_bars is not None:
        data = data[:, :num_bars]
        stddev = stddev[:, :num_bars]
        blanks = blanks[:, :num_bars]
        num_runs = num_runs[:num_bars]
    if reverse_axis == True:
        data = data[...,::-1]
        stddev = stddev[...,::-1]
        blanks = blanks[...,::-1]
        num_runs = num_runs[...,::-1]
        bar_labels = bar_labels[::-1]
    # calculate fractions of runs which did not reach threshold
    percent_blanks = np.empty(blanks.shape)
    for col_num, col in enumerate(blanks.T):
        for i in range(len(col)):
            percent_blanks[i, col_num] = col[i] / num_runs[col_num]
    percent_blanks = percent_blanks[threshold_index]
    fig = plt.figure()
    num_bars = data.shape[1]
    num_options = data.shape[0]
    bar_x_cords = np.arange(num_bars)
    width = 0.4
    plt.ylim(0.0, 1.0)
    failed_threshold = (1.0 - mean_aggregated) / 2.0
    bad_mask = blanks[threshold_index] >= (num_runs *
                                           failed_threshold).astype(int)
    good_mask = ~bad_mask
    bad_gens_until_threshold = data[threshold_index][bad_mask]
    good_gens_until_threshold = data[threshold_index][good_mask]
    bad_stddev = stddev[threshold_index][bad_mask]
    good_stddev = stddev[threshold_index][good_mask]
    bad_bar_x_cords = bar_x_cords[bad_mask]
    good_bar_x_cords = bar_x_cords[good_mask]
    if len(good_bar_x_cords):
        good = plt.bar(good_bar_x_cords, good_gens_until_threshold,
                       width, yerr=good_stddev, color='1.0',
                       ecolor='0.3', label='Threshold always reached',
                       align='center')
    plt.axis('auto')
    if len(bad_bar_x_cords):
        bad_gens_until_threshold[:] = int(plt.axis()[3])
        bad = plt.bar(bad_bar_x_cords, bad_gens_until_threshold,
                      width, color='0.5', label='Threshold not always reached',
                      align='center')
    #plt.axis('auto')
    plt.ylabel('Generations')
    plt.xlabel(varied_option)
    plt.xticks(bar_x_cords + width / 2.0, bar_labels)
    plt.subplots_adjust(bottom=0.1)
    if len(bad_bar_x_cords) and len(good_bar_x_cords):
        # only create the legend if both types of bars are present
        leg = plt.legend(loc='best', fancybox=True)
        leg.get_frame().set_alpha(0.8)
    fig.tight_layout(pad=0.12, w_pad=0.0, h_pad=0.0)
    comparison_atribute = '_'.join([x.lower() for x in varied_option.split()])
    plt.savefig('part_3_{}_{}_{}{}.pdf'.format(comparison_atribute,
                metric, value, suffix))

def main():
    # check that all data is present
    for d in ['comparison-default_single_15',
    'comparison-mutate_chance_100_runs',
    'comparison-gather_proportion_100_runs',
    'comparison-strong_chance_100_runs',
    'comparison-complexity_4_point_mutation',
    'comparison-complexity_4_other_mutation',
    'comparison-complexity_4_sibling_distance',
    'comparison-complexity_4_strong_chance',
    'comparison-complexity_4_strong_factor',
    'comparison-complexity_4_sibling_distance_with_high_point_mutation']:
        if not os.path.exists(d):
            print 'Error: Missing data directory: {}'.format(d)
            print 'Make sure all simulation data is in the current directory'
            sys.exit(1)
    # graph plots for part 1
    graph_representitive_run(os.path.join('comparison-default_single_15',
                             'comparative-analyzer'))
    # graph plots for part 2
    graph_batch_comparison(os.path.join('comparison-gather_proportion_100_runs',
                           'comparative-analyzer'),
    [1, 0.3, 0.1, 0.03, 0.01, 0.003],
    legend_title='Gather speed',
    ignore=1)
    graph_batch_comparison(os.path.join('comparison-mutate_chance_100_runs',
                           'comparative-analyzer'),
                            [0.03, 0.01, 0.003, 0.001, '3e-4'],
                            ignore=2,
    legend_title='Non-point mutation rates')
    graph_batch_comparison(os.path.join('comparison-strong_chance_100_runs',
                           'comparative-analyzer'),
    [1, 0.3, 0.1, 0.03, 0.01, 0.003], ignore=1,
    legend_title='Functional probability')
    # graph plots for part 3
    graph_gens_until(os.path.join('comparison-complexity_4_point_mutation',
                               'comparative-analyzer'),
                               bar_labels=['1e-5', '3e-5', '1e-4', '3e-4',
                               '1e-3', '3e-3', '0.01', '0.03'],
                                threshold_index=3, value=4,
                                varied_option='Point mutation rate',
                                num_bars=8)
    graph_gens_until(os.path.join('comparison-complexity_4_other_mutation',
                                'comparative-analyzer'),
                                bar_labels=['1e-4', '3e-4', '1e-3',
                                '3e-3', '0.01,', '0.03'],
                                 threshold_index=3, value=4,
                                 varied_option='Non-point mutation rate')
    graph_gens_until(os.path.join('comparison-complexity_4_sibling_distance',
                                'comparative-analyzer'),
                                bar_labels=['7', '6', '5', '4', '3', '2', '1'],
                                 threshold_index=3, value=4,
                                 varied_option='Family difference threshold',
                                 reverse_axis=True)
    graph_gens_until(os.path.join('comparison-complexity_4_strong_chance',
                                'comparative-analyzer'),
                                bar_labels=['1e-3', '3e-3', '0.01', '0.03',
                                '0.1', '0.3', '1.0'],
                                 threshold_index=3, value=4,
                                 varied_option='Protein functional probability')
    graph_gens_until(os.path.join('comparison-complexity_4_strong_factor',
                                'comparative-analyzer'),
                                bar_labels=['100', '50', '20', '10', '5', '2'],
                                 threshold_index=3, value=4,
                                 varied_option='Protein functional strength',
                                 reverse_axis=True)
    graph_gens_until(os.path.join('comparison-complexity_4_sibling_distance_with_high_point_mutation',
                                'comparative-analyzer'),
                                bar_labels=['7', '6', '5', '4', '3', '2', '1'],
                                 threshold_index=3, value=4,
                                 varied_option='Family difference threshold',
                                 reverse_axis=True, suffix="_high_point_mutation")
    print 'To adjust graph sizes, edit the variables at the top of this script.'

if __name__ == '__main__':
    main()