/*
 * Decompiled with CFR 0.152.
 */
package net.aclib.fanova.model;

import ca.ubc.cs.beta.aclib.algorithmrun.AlgorithmRun;
import ca.ubc.cs.beta.aclib.configspace.NormalizedRange;
import ca.ubc.cs.beta.aclib.configspace.ParamConfigurationSpace;
import ca.ubc.cs.beta.models.fastrf.RandomForest;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import net.aclib.fanova.eval.ModelEvaluation;
import net.aclib.fanova.model.RandomForestPreprocessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FunctionalANOVARunner {
    private static final Logger log = LoggerFactory.getLogger(FunctionalANOVARunner.class);

    public static ArrayList<Map.Entry<HashSet, Double>> sortedKeysByValue(HashMap<HashSet<Integer>, Double> varianceContributions) {
        ArrayList<Map.Entry<HashSet, Double>> l = new ArrayList<Map.Entry<HashSet, Double>>(varianceContributions.entrySet());
        Collections.sort(l, new Comparator<Map.Entry<HashSet, Double>>(){

            @Override
            public int compare(Map.Entry<HashSet, Double> o1, Map.Entry<HashSet, Double> o2) {
                return o2.getValue().compareTo(o1.getValue());
            }
        });
        return l;
    }

    public static void decomposeVariance(RandomForest existingForest, List<AlgorithmRun> testRuns, ParamConfigurationSpace configSpace, Random rand, boolean compareToDef, double quantileToCompareTo, boolean computePairwiseInteraction, String outputDir, boolean logModel, boolean plotMarginals) throws IOException, InterruptedException {
        String s;
        RandomForest forest = ModelEvaluation.extractMarginalForest(existingForest, testRuns, configSpace, rand, compareToDef, quantileToCompareTo);
        RandomForestPreprocessor.preprocessRandomForest(forest, configSpace);
        DecimalFormat decim = new DecimalFormat("#");
        DecimalFormat decim2 = new DecimalFormat("##.##");
        double timeForComputingMainEffects = 0.0;
        double timeForComputingBinaryEffects = 0.0;
        HashMap<HashSet<Integer>, Double> totalFractionsExplained = new HashMap<HashSet<Integer>, Double>();
        double sumOfFractionsOfBinaries = 0.0;
        int numDim = configSpace.getCategoricalSize().length;
        double[][] allObservations = new double[numDim][];
        double[][] allIntervalSizes = new double[numDim][];
        int numTree = 0;
        while (numTree < forest.Trees.length) {
            HashSet<Integer> allVariableIndices = new HashSet<Integer>();
            int j = 0;
            while (j < configSpace.getCategoricalSize().length) {
                allVariableIndices.add(new Integer(j));
                ++j;
            }
            double thisTreeTotalVariance = RandomForestPreprocessor.computeTotalVarianceOfRegressionTree(forest.Trees[numTree], configSpace);
            if (thisTreeTotalVariance == 0.0) {
                s = "Tree " + numTree + " has no variance -> skipping.";
                log.info(s);
            } else {
                double marg;
                s = "Tree " + numTree + ": Total variance of predictor: " + thisTreeTotalVariance;
                log.info(s);
                double thisTreeFractionOfVarianceExplainedByMarginals = 0.0;
                HashMap thisTreeVarianceContributions = new HashMap();
                HashMap<Integer, Double> singleVarianceContributions = new HashMap<Integer, Double>();
                long start = System.nanoTime();
                int dim = 0;
                while (dim < numDim) {
                    int numVals = configSpace.getCategoricalSize()[dim];
                    if (numVals > 0) {
                        allObservations[dim] = new double[numVals];
                        allIntervalSizes[dim] = new double[numVals];
                        int valIndex = 0;
                        while (valIndex < numVals) {
                            allObservations[dim][valIndex] = valIndex;
                            allIntervalSizes[dim][valIndex] = 1.0 / (double)numVals;
                            ++valIndex;
                        }
                    } else {
                        ArrayList<Double> splitPoints = new ArrayList<Double>();
                        int node_index = 0;
                        while (node_index < forest.Trees[numTree].var.length) {
                            if (forest.Trees[numTree].var[node_index] == dim + 1) {
                                splitPoints.add(forest.Trees[numTree].cut[node_index]);
                            }
                            ++node_index;
                        }
                        splitPoints.add(0.0);
                        splitPoints.add(1.0);
                        Collections.sort(splitPoints);
                        if (splitPoints.size() == 2) {
                            allObservations[dim] = new double[0];
                            allIntervalSizes[dim] = new double[0];
                        } else {
                            allObservations[dim] = new double[splitPoints.size() - 1];
                            allIntervalSizes[dim] = new double[splitPoints.size() - 1];
                            int lowerIntervalId = 0;
                            while (lowerIntervalId < splitPoints.size() - 1) {
                                allObservations[dim][lowerIntervalId] = ((Double)splitPoints.get(lowerIntervalId) + (Double)splitPoints.get(lowerIntervalId + 1)) / 2.0;
                                allIntervalSizes[dim][lowerIntervalId] = (Double)splitPoints.get(lowerIntervalId + 1) - (Double)splitPoints.get(lowerIntervalId);
                                ++lowerIntervalId;
                            }
                        }
                    }
                    ++dim;
                }
                dim = 0;
                while (dim < numDim) {
                    int[] indicesOfObservations = new int[]{dim};
                    ArrayList<Double> as = new ArrayList<Double>();
                    double weightedSum = 0.0;
                    double weightedSumOfSquares = 0.0;
                    int valIndex = 0;
                    while (valIndex < allObservations[dim].length) {
                        double[] observations = new double[]{allObservations[dim][valIndex]};
                        log.debug("Observations of Tree: " + numTree + " : " + dim + " : " + observations[0]);
                        marg = forest.Trees[numTree].marginalPerformance(indicesOfObservations, observations);
                        log.debug("Marg of Tree: " + numTree + " : " + dim + " : " + marg);
                        as.add(marg);
                        double intervalSize = allIntervalSizes[dim][valIndex];
                        weightedSum += marg * intervalSize;
                        weightedSumOfSquares += marg * marg * intervalSize;
                        ++valIndex;
                    }
                    double thisMarginalVarianceContribution = weightedSumOfSquares - weightedSum * weightedSum;
                    log.debug("MarginalVarianceContribution for tree " + numTree + " : " + thisMarginalVarianceContribution);
                    double thisMarginalFractionOfVarianceExplained = thisMarginalVarianceContribution / thisTreeTotalVariance * 100.0;
                    if (Double.isNaN(thisMarginalFractionOfVarianceExplained)) {
                        throw new RuntimeException("ERROR - variance contributions is NaN.");
                    }
                    thisTreeFractionOfVarianceExplainedByMarginals += thisMarginalFractionOfVarianceExplained;
                    s = "Tree " + numTree + ": " + decim.format(thisMarginalFractionOfVarianceExplained) + "% of variance explained by parameter " + (String)configSpace.getParameterNames().get(dim);
                    log.info(s);
                    HashSet<Integer> set = new HashSet<Integer>();
                    set.add(dim);
                    thisTreeVarianceContributions.put(set, thisMarginalVarianceContribution);
                    singleVarianceContributions.put(dim, thisMarginalVarianceContribution);
                    ++dim;
                }
                double thisTreeTimeForComputingMainEffects = (double)(System.nanoTime() - start) * 1.0E-9;
                timeForComputingMainEffects += thisTreeTimeForComputingMainEffects;
                s = "\nTree " + numTree + ": " + "Fraction of variance explained by main effects in this tree: " + thisTreeFractionOfVarianceExplainedByMarginals + "%. Took a total of " + thisTreeTimeForComputingMainEffects + " seconds.";
                log.info(s);
                double thisTreeFractionOfVarianceExplainedByBinaries = 0.0;
                double thisTreeTimeForComputingBinaryEffects = 0.0;
                if (computePairwiseInteraction) {
                    int dim1 = 0;
                    while (dim1 < numDim) {
                        int dim2 = dim1 + 1;
                        while (dim2 < numDim) {
                            int[] indicesOfObservations = new int[2];
                            indicesOfObservations[0] = dim1;
                            ArrayList<Double> as = new ArrayList<Double>();
                            indicesOfObservations[1] = dim2;
                            double weightedSum = 0.0;
                            double weightedSumOfSquares = 0.0;
                            int valIndex1 = 0;
                            while (valIndex1 < allObservations[dim1].length) {
                                int valIndex2 = 0;
                                while (valIndex2 < allObservations[dim2].length) {
                                    double[] observations = new double[]{allObservations[dim1][valIndex1], allObservations[dim2][valIndex2]};
                                    double intervalSize1 = allIntervalSizes[dim1][valIndex1];
                                    double intervalSize2 = allIntervalSizes[dim2][valIndex2];
                                    marg = forest.Trees[numTree].marginalPerformance(indicesOfObservations, observations);
                                    weightedSum += marg * intervalSize1 * intervalSize2;
                                    weightedSumOfSquares += marg * marg * intervalSize1 * intervalSize2;
                                    s = "Marginal for parameter " + (String)configSpace.getParameterNames().get(dim1) + " set to value " + observations[0] + " and parameter " + (String)configSpace.getParameterNames().get(dim2) + " set to value " + observations[1] + ": " + marg;
                                    as.add(marg);
                                    ++valIndex2;
                                }
                                ++valIndex1;
                            }
                            double thisBinaryVarianceContribution = weightedSumOfSquares - weightedSum * weightedSum;
                            thisBinaryVarianceContribution -= ((Double)singleVarianceContributions.get(dim1)).doubleValue();
                            log.debug("SingleVarianceContribtion" + dim1 + " : " + singleVarianceContributions.get(dim1));
                            log.debug("SingleVarianceContribtion" + dim2 + " : " + singleVarianceContributions.get(dim2));
                            double thisBinaryFractionOfVarianceExplained = (thisBinaryVarianceContribution -= ((Double)singleVarianceContributions.get(dim2)).doubleValue()) / thisTreeTotalVariance * 100.0;
                            thisTreeFractionOfVarianceExplainedByBinaries += thisBinaryFractionOfVarianceExplained;
                            s = String.valueOf(decim.format(thisBinaryFractionOfVarianceExplained)) + "% for contribution of parameters " + (String)configSpace.getParameterNames().get(dim1) + " & " + (String)configSpace.getParameterNames().get(dim2);
                            log.debug(s);
                            HashSet<Integer> set = new HashSet<Integer>();
                            set.add(dim1);
                            set.add(dim2);
                            thisTreeVarianceContributions.put(set, thisBinaryVarianceContribution);
                            ++dim2;
                        }
                        ++dim1;
                    }
                    thisTreeTimeForComputingBinaryEffects = (double)(System.nanoTime() - start) * 1.0E-9 - thisTreeTimeForComputingMainEffects;
                    timeForComputingBinaryEffects += thisTreeTimeForComputingBinaryEffects;
                    s = "Tree " + numTree + ": " + "Fraction of variance explained by binary interaction effects this tree: " + thisTreeFractionOfVarianceExplainedByBinaries + "%. Took " + thisTreeTimeForComputingBinaryEffects + " seconds.";
                    log.info(s);
                }
                double tmpThisTreeFractionExplained = 0.0;
                for (HashSet indexSet : thisTreeVarianceContributions.keySet()) {
                    double previousFractionExplained = 0.0;
                    if (numTree > 0) {
                        previousFractionExplained = totalFractionsExplained.get(indexSet);
                    }
                    double thisFractionExplained = (Double)thisTreeVarianceContributions.get(indexSet) / thisTreeTotalVariance * 100.0;
                    log.debug("ThisTreeVarianceContributions of index" + indexSet.toString() + " for Tree" + numTree + " : " + thisTreeVarianceContributions.get(indexSet));
                    log.debug("ThisTreeTotalVariance for Tree" + numTree + " : " + thisTreeTotalVariance);
                    tmpThisTreeFractionExplained += thisFractionExplained;
                    totalFractionsExplained.put(indexSet, previousFractionExplained + 1.0 / (double)forest.Trees.length * thisFractionExplained);
                    log.debug("TotalFractionExplained for Tree" + numTree + " : " + totalFractionsExplained.get(indexSet));
                }
                log.debug(String.valueOf(tmpThisTreeFractionExplained) + "%");
            }
            ++numTree;
        }
        double tmpFractionExplained = 0.0;
        for (HashSet indexSet : totalFractionsExplained.keySet()) {
            tmpFractionExplained += ((Double)totalFractionsExplained.get(indexSet)).doubleValue();
        }
        log.debug(String.valueOf(tmpFractionExplained) + "%");
        double sumOfFractionsOfMarginals = 0.0;
        for (HashSet indexSet : totalFractionsExplained.keySet()) {
            if (indexSet.size() != 1) continue;
            sumOfFractionsOfMarginals += ((Double)totalFractionsExplained.get(indexSet)).doubleValue();
        }
        s = "\nSum of fractions of marginals: " + sumOfFractionsOfMarginals + "%";
        log.info(s);
        if (computePairwiseInteraction) {
            sumOfFractionsOfBinaries = 0.0;
            for (HashSet indexSet : totalFractionsExplained.keySet()) {
                if (indexSet.size() != 2) continue;
                sumOfFractionsOfBinaries += ((Double)totalFractionsExplained.get(indexSet)).doubleValue();
            }
            s = "Sum of fractions of binaries: " + sumOfFractionsOfBinaries + "%";
            log.info(s);
        }
        s = "Results for paper:   & " + decim.format(sumOfFractionsOfMarginals) + "\\% (" + decim.format(timeForComputingMainEffects) + "s) & " + decim.format(sumOfFractionsOfBinaries) + "\\% (" + decim.format(timeForComputingBinaryEffects) + "s)";
        log.info(s);
        ArrayList<String> paramNamesOrderByMarginalVarianceExplained = new ArrayList<String>();
        boolean idx = false;
        ArrayList<Map.Entry<HashSet, Double>> list = FunctionalANOVARunner.sortedKeysByValue(totalFractionsExplained);
        int numMaxEffectsToPrint = 30;
        s = "\n" + numMaxEffectsToPrint + " most important effects (out of main and binary interaction effects):";
        log.info(s);
        for (Map.Entry<HashSet, Double> entry : list) {
            HashSet set = entry.getKey();
            --numMaxEffectsToPrint;
            if (set.size() == 1) {
                Integer varIndex = (Integer)set.toArray()[0];
                String parameterName = (String)configSpace.getParameterNames().get(varIndex);
                paramNamesOrderByMarginalVarianceExplained.add(parameterName);
                if (numMaxEffectsToPrint <= 0) continue;
                s = String.valueOf(decim2.format(entry.getValue())) + "% due to main effect: " + parameterName;
                log.info(s);
                continue;
            }
            if (set.size() != 2) continue;
            Integer varIndex1 = (Integer)set.toArray()[0];
            Integer varIndex2 = (Integer)set.toArray()[1];
            String varName1 = (String)configSpace.getParameterNames().get(varIndex1);
            String varName2 = (String)configSpace.getParameterNames().get(varIndex2);
            if (numMaxEffectsToPrint <= 0) continue;
            s = String.valueOf(decim2.format(entry.getValue())) + "% due to interaction: " + varName1 + " x " + varName2;
            log.info(s);
        }
        String allSingleMarginalsOutputFile = String.valueOf(outputDir) + "/allSingleMarginals.txt";
        PrintWriter writer = null;
        try {
            writer = new PrintWriter(allSingleMarginalsOutputFile, "UTF-8");
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        if (plotMarginals) {
            log.info("Collecting and plotting marginals ...");
            int dim = 0;
            while (dim < numDim) {
                double[] valuesToPredictFor;
                String parameterName = (String)configSpace.getParameterNames().get(dim);
                String parameterFileName = parameterName.replace('/', '_');
                if (configSpace.getCategoricalSize()[dim] > 0) {
                    valuesToPredictFor = allObservations[dim];
                } else {
                    int numVals = 11;
                    valuesToPredictFor = new double[numVals];
                    int valIndex = 0;
                    while (valIndex < numVals) {
                        valuesToPredictFor[valIndex] = ((double)valIndex + 0.0) / (double)(numVals - 1);
                        ++valIndex;
                    }
                }
                String singleMarginalOutputFile = String.valueOf(outputDir) + "/" + parameterFileName + ".marginals";
                PrintWriter singleWriter = null;
                try {
                    singleWriter = new PrintWriter(singleMarginalOutputFile, "UTF-8");
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                int[] indicesOfObservations = new int[]{dim};
                int valIndex = 0;
                while (valIndex < valuesToPredictFor.length) {
                    double[] observations = new double[]{valuesToPredictFor[valIndex]};
                    ArrayList<Double> margs = new ArrayList<Double>();
                    int numTree2 = 0;
                    while (numTree2 < forest.Trees.length) {
                        margs.add(forest.Trees[numTree2].marginalPerformance(indicesOfObservations, observations));
                        ++numTree2;
                    }
                    double avg = 0.0;
                    for (Double marg : margs) {
                        avg += marg / (double)margs.size();
                    }
                    double std = 0.0;
                    for (Double marg : margs) {
                        std += Math.pow(marg - avg, 2.0) / (double)margs.size();
                    }
                    std = Math.sqrt(std);
                    if (configSpace.getCategoricalSize()[dim] > 0) {
                        writer.println(String.valueOf((String)((List)configSpace.getValuesMap().get(parameterName)).get(valIndex)) + ": " + avg + " +/- " + std);
                        singleWriter.println(String.valueOf(valIndex) + " " + (String)((List)configSpace.getValuesMap().get(parameterName)).get(valIndex) + " " + avg + " " + std);
                    } else {
                        writer.println(String.valueOf(valuesToPredictFor[valIndex]) + ": " + avg + " +/- " + std);
                        Map normalizeRangeMap = configSpace.getNormalizedRangeMap();
                        NormalizedRange normalizedRange = (NormalizedRange)normalizeRangeMap.get(parameterName);
                        singleWriter.println(String.valueOf(valIndex) + " " + normalizedRange.unnormalizeValue(valuesToPredictFor[valIndex]) + " " + avg + " " + std);
                    }
                    ++valIndex;
                }
                writer.println();
                singleWriter.close();
                String gnuplotFile = String.valueOf(outputDir) + "/" + parameterFileName + ".gnuplot";
                singleWriter = null;
                try {
                    singleWriter = new PrintWriter(gnuplotFile, "UTF-8");
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                singleWriter.println("set xtic rotate by -45");
                singleWriter.println("set terminal pdf");
                String pdfFilename = String.valueOf(outputDir) + "/" + parameterFileName + "-marginals.pdf";
                singleWriter.println("set output '" + pdfFilename + "'");
                singleWriter.println("set boxwidth 0.75");
                singleWriter.println("set style fill solid");
                singleWriter.println("plot '" + singleMarginalOutputFile + "' using 3:xtic(2) with boxes fillstyle solid notitle, \\");
                singleWriter.println("'" + singleMarginalOutputFile + "' using 1:3:4 with boxerrorbars fillstyle empty lc rgb 'black' notitle");
                singleWriter.close();
                String cmd = "gnuplot " + gnuplotFile;
                Process p = Runtime.getRuntime().exec(cmd);
                p.waitFor();
                ++dim;
            }
            writer.close();
            log.info("\nSingle marginal predictions written to file " + allSingleMarginalsOutputFile);
            log.info("Generating .tex file ...");
            String texFile = String.valueOf(outputDir) + "/allSingleMarginals.tex";
            writer = null;
            try {
                writer = new PrintWriter(texFile, "UTF-8");
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            writer.println("\\documentclass[letterpaper]{article}");
            writer.println("\\usepackage{times}");
            writer.println("\\usepackage{graphicx}");
            writer.println("\\usepackage{epsfig}");
            writer.println("\\usepackage{subfigure}");
            writer.println("\\usepackage{lscape}");
            writer.println("\\begin{document}");
            writer.println("\\title{Functional ANOVA Analysis}");
            writer.println("\\maketitle");
            writer.println("\\textbf{When using parts of this document, please cite the functional ANOVA paper (the reference will be available soon; in the meantime please ask Frank for details).}");
            int count = 0;
            for (String parameterName : paramNamesOrderByMarginalVarianceExplained) {
                String parameterFileName = parameterName.replace('/', '_');
                int indexOfParameter = configSpace.getParameterNamesInAuthorativeOrder().indexOf(parameterName);
                HashSet<Integer> set = new HashSet<Integer>();
                set.add(indexOfParameter);
                double fractionExplainedByThisMarginal = totalFractionsExplained.get(set);
                writer.println("\\begin{figure}[tbp]");
                writer.println("\\begin{center}");
                String singleMarginalOutputFile = String.valueOf(outputDir) + "/" + parameterFileName + "-marginals.pdf";
                writer.println("\\includegraphics{" + singleMarginalOutputFile + "}");
                writer.println("\\caption{Marginal predictions for parameter " + parameterName.replace("_", "\\_") + ". This marginal explains " + decim2.format(fractionExplainedByThisMarginal) + "\\% of the predictor's total variance.  \\label{fig:" + parameterName + "}}");
                writer.println("\\end{center}");
                writer.println("\\end{figure}");
                if (count++ % 4 != 0) continue;
                writer.println("\\clearpage");
            }
            writer.println("\\end{document}");
            writer.close();
            String cmd = "pdflatex " + texFile + " > pdflatex-output.txt";
            log.info("Need to call: " + cmd);
        }
        log.info("Functional ANOVA finished successfully - exiting.");
    }
}

