/*
 * Decompiled with CFR 0.152.
 */
package net.aclib.fanova.model;

import ca.ubc.cs.beta.aclib.algorithmrun.AlgorithmRun;
import ca.ubc.cs.beta.aclib.configspace.ParamConfigurationSpace;
import ca.ubc.cs.beta.models.fastrf.RandomForest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import net.aclib.fanova.eval.ModelEvaluation;
import net.aclib.fanova.model.RandomForestPreprocessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FunctionalANOVAVarianceDecompose {
    private static final Logger log = LoggerFactory.getLogger(FunctionalANOVAVarianceDecompose.class);
    private double[][][] allObservations;
    private double[][][] allIntervalSizes;
    private RandomForest forest;
    private double[] thisTreeTotalVariance;
    private Vector<HashMap<Integer, Double>> singleVarianceContributions;
    private HashMap<HashSet<Integer>, Double> thisTreeVarianceContributions = new HashMap();
    private HashMap<HashSet<Integer>, Double> totalFractionsExplained = new HashMap();

    public FunctionalANOVAVarianceDecompose(RandomForest existingForest, List<AlgorithmRun> testRuns, ParamConfigurationSpace configSpace, Random rand, boolean compareToDef, double quantileToCompareTo, boolean logModel) throws IOException, InterruptedException {
        this.forest = ModelEvaluation.extractMarginalForest(existingForest, testRuns, configSpace, rand, compareToDef, quantileToCompareTo);
        RandomForestPreprocessor.preprocessRandomForest(this.forest, configSpace);
        this.singleVarianceContributions = new Vector();
        int i = 0;
        while (i < this.forest.numTrees) {
            this.singleVarianceContributions.add(new HashMap());
            ++i;
        }
        int numDim = configSpace.getCategoricalSize().length;
        this.allObservations = new double[this.forest.Trees.length][numDim][];
        this.allIntervalSizes = new double[this.forest.Trees.length][numDim][];
        this.thisTreeTotalVariance = new double[this.forest.Trees.length];
        int numTree = 0;
        while (numTree < this.forest.Trees.length) {
            String s;
            HashSet<Integer> allVariableIndices = new HashSet<Integer>();
            int j = 0;
            while (j < configSpace.getCategoricalSize().length) {
                allVariableIndices.add(new Integer(j));
                ++j;
            }
            this.thisTreeTotalVariance[numTree] = RandomForestPreprocessor.computeTotalVarianceOfRegressionTree(this.forest.Trees[numTree], configSpace);
            if (this.thisTreeTotalVariance[numTree] == 0.0) {
                s = "Tree " + numTree + " has no variance -> skipping.";
                log.info(s);
            } else {
                s = "Tree " + numTree + ": Total variance of predictor: " + this.thisTreeTotalVariance[numTree];
                log.info(s);
                int dim = 0;
                while (dim < numDim) {
                    int numVals = configSpace.getCategoricalSize()[dim];
                    if (numVals > 0) {
                        this.allObservations[numTree][dim] = new double[numVals];
                        this.allIntervalSizes[numTree][dim] = new double[numVals];
                        int valIndex = 0;
                        while (valIndex < numVals) {
                            this.allObservations[numTree][dim][valIndex] = valIndex;
                            this.allIntervalSizes[numTree][dim][valIndex] = 1.0 / (double)numVals;
                            ++valIndex;
                        }
                    } else {
                        ArrayList<Double> splitPoints = new ArrayList<Double>();
                        int node_index = 0;
                        while (node_index < this.forest.Trees[numTree].var.length) {
                            if (this.forest.Trees[numTree].var[node_index] == dim + 1) {
                                splitPoints.add(this.forest.Trees[numTree].cut[node_index]);
                            }
                            ++node_index;
                        }
                        splitPoints.add(0.0);
                        splitPoints.add(1.0);
                        Collections.sort(splitPoints);
                        if (splitPoints.size() == 2) {
                            this.allObservations[numTree][dim] = new double[0];
                            this.allIntervalSizes[numTree][dim] = new double[0];
                        } else {
                            this.allObservations[numTree][dim] = new double[splitPoints.size() - 1];
                            this.allIntervalSizes[numTree][dim] = new double[splitPoints.size() - 1];
                            int lowerIntervalId = 0;
                            while (lowerIntervalId < splitPoints.size() - 1) {
                                this.allObservations[numTree][dim][lowerIntervalId] = ((Double)splitPoints.get(lowerIntervalId) + (Double)splitPoints.get(lowerIntervalId + 1)) / 2.0;
                                this.allIntervalSizes[numTree][dim][lowerIntervalId] = (Double)splitPoints.get(lowerIntervalId + 1) - (Double)splitPoints.get(lowerIntervalId);
                                ++lowerIntervalId;
                            }
                        }
                    }
                    ++dim;
                }
            }
            ++numTree;
        }
    }

    public double getMarginal(int dim) {
        HashSet<Integer> set = new HashSet<Integer>();
        set.add(dim);
        if (this.totalFractionsExplained.containsKey(set)) {
            return this.totalFractionsExplained.get(set);
        }
        int numTree = 0;
        while (numTree < this.forest.Trees.length) {
            int[] indicesOfObservations = new int[]{dim};
            ArrayList<Double> as = new ArrayList<Double>();
            double weightedSum = 0.0;
            double weightedSumOfSquares = 0.0;
            int valIndex = 0;
            while (valIndex < this.allObservations[numTree][dim].length) {
                double[] observations = new double[]{this.allObservations[numTree][dim][valIndex]};
                double marg = this.forest.Trees[numTree].marginalPerformance(indicesOfObservations, observations);
                as.add(marg);
                double intervalSize = this.allIntervalSizes[numTree][dim][valIndex];
                weightedSum += marg * intervalSize;
                weightedSumOfSquares += marg * marg * intervalSize;
                ++valIndex;
            }
            double thisMarginalVarianceContribution = weightedSumOfSquares - weightedSum * weightedSum;
            double thisMarginalFractionOfVarianceExplained = thisMarginalVarianceContribution / this.thisTreeTotalVariance[numTree] * 100.0;
            if (Double.isNaN(thisMarginalFractionOfVarianceExplained)) {
                throw new RuntimeException("ERROR - variance contributions is NaN.");
            }
            this.thisTreeVarianceContributions.put(set, thisMarginalVarianceContribution);
            this.singleVarianceContributions.get(numTree).put(dim, thisMarginalVarianceContribution);
            double previousFractionExplained = 0.0;
            if (numTree > 0) {
                previousFractionExplained = this.totalFractionsExplained.get(set);
            }
            double thisFractionExplained = this.thisTreeVarianceContributions.get(set) / this.thisTreeTotalVariance[numTree] * 100.0;
            this.totalFractionsExplained.put(set, previousFractionExplained + 1.0 / (double)this.forest.Trees.length * thisFractionExplained);
            ++numTree;
        }
        return this.totalFractionsExplained.get(set);
    }

    public double getPairwiseMarginal(int dim1, int dim2) {
        HashSet<Integer> set = new HashSet<Integer>();
        set.add(dim1);
        set.add(dim2);
        if (this.totalFractionsExplained.containsKey(set)) {
            return this.totalFractionsExplained.get(set);
        }
        int indexOfFirstTree = 0;
        if (!this.singleVarianceContributions.get(indexOfFirstTree).containsKey(dim1)) {
            this.getMarginal(dim1);
        }
        if (!this.singleVarianceContributions.get(indexOfFirstTree).containsKey(dim2)) {
            this.getMarginal(dim2);
        }
        int numTree = 0;
        while (numTree < this.forest.Trees.length) {
            int[] indicesOfObservations = new int[2];
            indicesOfObservations[0] = dim1;
            ArrayList<Double> as = new ArrayList<Double>();
            indicesOfObservations[1] = dim2;
            double weightedSum = 0.0;
            double weightedSumOfSquares = 0.0;
            int valIndex1 = 0;
            while (valIndex1 < this.allObservations[numTree][dim1].length) {
                int valIndex2 = 0;
                while (valIndex2 < this.allObservations[numTree][dim2].length) {
                    double[] observations = new double[]{this.allObservations[numTree][dim1][valIndex1], this.allObservations[numTree][dim2][valIndex2]};
                    double intervalSize1 = this.allIntervalSizes[numTree][dim1][valIndex1];
                    double intervalSize2 = this.allIntervalSizes[numTree][dim2][valIndex2];
                    double marg = this.forest.Trees[numTree].marginalPerformance(indicesOfObservations, observations);
                    weightedSum += marg * intervalSize1 * intervalSize2;
                    weightedSumOfSquares += marg * marg * intervalSize1 * intervalSize2;
                    as.add(marg);
                    ++valIndex2;
                }
                ++valIndex1;
            }
            double thisBinaryVarianceContribution = weightedSumOfSquares - weightedSum * weightedSum;
            thisBinaryVarianceContribution -= this.singleVarianceContributions.get(numTree).get(dim1).doubleValue();
            this.thisTreeVarianceContributions.put(set, thisBinaryVarianceContribution -= this.singleVarianceContributions.get(numTree).get(dim2).doubleValue());
            double previousFractionExplained = 0.0;
            if (numTree > 0) {
                previousFractionExplained = this.totalFractionsExplained.get(set);
            }
            double thisFractionExplained = this.thisTreeVarianceContributions.get(set) / this.thisTreeTotalVariance[numTree] * 100.0;
            this.totalFractionsExplained.put(set, previousFractionExplained + 1.0 / (double)this.forest.Trees.length * thisFractionExplained);
            ++numTree;
        }
        return this.totalFractionsExplained.get(set);
    }

    public double[] getMarginalForValue(int dim, double valueToPredict) {
        int[] indicesOfObservations = new int[]{dim};
        double[] observations = new double[]{valueToPredict};
        ArrayList<Double> margs = new ArrayList<Double>();
        int numTree = 0;
        while (numTree < this.forest.Trees.length) {
            margs.add(this.forest.Trees[numTree].marginalPerformance(indicesOfObservations, observations));
            ++numTree;
        }
        double avg = 0.0;
        for (Double marg : margs) {
            avg += marg / (double)margs.size();
        }
        double std = 0.0;
        for (Double marg : margs) {
            std += Math.pow(marg - avg, 2.0) / (double)margs.size();
        }
        std = Math.sqrt(std);
        double[] ret = new double[]{avg, std};
        return ret;
    }

    public double[] getMarginalForValuePair(int dim1, int dim2, double valueToPredict1, double valueToPredict2) {
        int[] indicesOfObservations = new int[]{dim1, dim2};
        double[] observations = new double[]{valueToPredict1, valueToPredict2};
        ArrayList<Double> margs = new ArrayList<Double>();
        int numTree = 0;
        while (numTree < this.forest.Trees.length) {
            margs.add(this.forest.Trees[numTree].marginalPerformance(indicesOfObservations, observations));
            ++numTree;
        }
        double avg = 0.0;
        for (Double marg : margs) {
            avg += marg / (double)margs.size();
        }
        double std = 0.0;
        for (Double marg : margs) {
            std += Math.pow(marg - avg, 2.0) / (double)margs.size();
        }
        std = Math.sqrt(std);
        double[] ret = new double[]{avg, std};
        return ret;
    }
}

