/*
 * Decompiled with CFR 0.152.
 */
package ca.ubc.cs.beta.aclib.model.builder;

import ca.ubc.cs.beta.aclib.misc.model.SMACRandomForestHelper;
import ca.ubc.cs.beta.aclib.misc.watch.StopWatch;
import ca.ubc.cs.beta.aclib.model.builder.ModelBuilder;
import ca.ubc.cs.beta.aclib.model.data.SanitizedModelData;
import ca.ubc.cs.beta.aclib.options.RandomForestOptions;
import ca.ubc.cs.beta.models.fastrf.RandomForest;
import ca.ubc.cs.beta.models.fastrf.RegtreeBuildParams;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BasicModelBuilder
implements ModelBuilder {
    protected final RandomForest forest;
    protected final RandomForest preprocessedForest;
    private final Logger log = LoggerFactory.getLogger(this.getClass());

    public BasicModelBuilder(SanitizedModelData smd, RandomForestOptions rfConfig, Random rand) {
        this(smd, rfConfig, 1.0, rand);
    }

    public BasicModelBuilder(SanitizedModelData smd, RandomForestOptions rfConfig, double subsamplePercentage, Random rand) {
        double[][] features = smd.getPCAFeatures();
        double[][] configs = smd.getConfigs();
        double[] responseValues = smd.getResponseValues();
        int[] categoricalSize = smd.getCategoricalSize();
        int[][] condParents = smd.getCondParents();
        int[][][] condParentVals = smd.getCondParentVals();
        int numTrees = rfConfig.numTrees;
        int[][] theta_inst_idxs = smd.getThetaInstIdxs();
        for (int i = 0; i < theta_inst_idxs.length; ++i) {
            int[] nArray = theta_inst_idxs[i];
            nArray[0] = nArray[0] - 1;
            int[] nArray2 = theta_inst_idxs[i];
            nArray2[1] = nArray2[1] - 1;
        }
        RegtreeBuildParams buildParams = SMACRandomForestHelper.getRandomForestBuildParams(rfConfig, features[0].length, categoricalSize, condParents, condParentVals, rand);
        this.log.trace("Building Random Forest with {} data points ", (Object)responseValues.length);
        StopWatch sw = new StopWatch();
        if (rfConfig.fullTreeBootstrap) {
            int N = responseValues.length;
            int[][] dataIdxs = new int[numTrees][N];
            for (int i = 0; i < numTrees; ++i) {
                for (int j = 0; j < N; ++j) {
                    dataIdxs[i][j] = j;
                }
            }
            sw.start();
            this.forest = RandomForest.learnModel((int)numTrees, (double[][])configs, (double[][])features, (int[][])theta_inst_idxs, (double[])responseValues, (int[][])dataIdxs, (RegtreeBuildParams)buildParams);
        } else if (subsamplePercentage < 1.0) {
            int N = (int)(subsamplePercentage * (double)responseValues.length);
            this.log.trace("Subsampling {} points out of {} total", (Object)N, (Object)responseValues.length);
            int[][] dataIdxs = new int[numTrees][N];
            for (int i = 0; i < numTrees; ++i) {
                for (int j = 0; j < N; ++j) {
                    dataIdxs[i][j] = buildParams.random.nextInt(N);
                }
            }
            sw.start();
            this.forest = RandomForest.learnModel((int)numTrees, (double[][])configs, (double[][])features, (int[][])theta_inst_idxs, (double[])responseValues, (int[][])dataIdxs, (RegtreeBuildParams)buildParams);
        } else {
            sw.start();
            this.forest = RandomForest.learnModel((int)numTrees, (double[][])configs, (double[][])features, (int[][])theta_inst_idxs, (double[])responseValues, (RegtreeBuildParams)buildParams);
        }
        this.log.debug("Building Random Forest took {} seconds ", (Object)((double)sw.stop() / 1000.0));
        if (rfConfig.preprocessMarginal) {
            this.log.trace("Preprocessing marginal for Random Forest");
            this.preprocessedForest = RandomForest.preprocessForest((RandomForest)this.forest, (double[][])features);
        } else {
            this.preprocessedForest = null;
        }
    }

    @Override
    public RandomForest getRandomForest() {
        return this.forest;
    }

    @Override
    public RandomForest getPreparedRandomForest() {
        return this.preprocessedForest;
    }
}

