/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffFunction;
import edu.stanford.nlp.optimization.Function;
import edu.stanford.nlp.optimization.StochasticCalculateMethods;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class StochasticDiffFunctionTester {
    private static double EPS = 1.0E-8;
    private static boolean quiet = false;
    protected int testBatchSize;
    protected int numBatches;
    protected AbstractStochasticCachingDiffFunction thisFunc;
    double[] approxGrad;
    double[] fullGrad;
    double[] diff;
    double[] Hv;
    double[] HvFD;
    double[] v;
    double[] curGrad;
    double[] gradFD;
    double diffNorm;
    double diffValue;
    double fullValue;
    double approxValue;
    double diffGrad;
    double maxGradDiff = 0.0;
    double maxHvDiff = 0.0;
    Random generator;
    private static NumberFormat nf = new DecimalFormat("00.0");

    public StochasticDiffFunctionTester(Function function) {
        if (!(function instanceof AbstractStochasticCachingDiffFunction)) {
            System.err.println("Attempt to test non stochastic function using StochasticDiffFunctionTester");
            throw new UnsupportedOperationException();
        }
        this.thisFunc = (AbstractStochasticCachingDiffFunction)function;
        this.generator = new Random(System.currentTimeMillis());
        this.testBatchSize = (int)StochasticDiffFunctionTester.getTestBatchSize(this.thisFunc.dataDimension());
        if (this.testBatchSize < 0 || this.testBatchSize > this.thisFunc.dataDimension() || this.thisFunc.dataDimension() % this.testBatchSize != 0) {
            System.err.println("Invalid testBatchSize found, testing aborted.  Data size: " + this.thisFunc.dataDimension() + " batchSize: " + this.testBatchSize);
            System.exit(1);
        }
        this.numBatches = this.thisFunc.dataDimension() / this.testBatchSize;
        this.sayln("StochasticDiffFunctionTester created with:");
        this.sayln("   data dimension  = " + this.thisFunc.dataDimension());
        this.sayln("   batch size = " + this.testBatchSize);
        this.sayln("   number of batches = " + this.numBatches);
    }

    private void sayln(String s) {
        if (!quiet) {
            System.err.println(s);
        }
    }

    private static long[] primeFactors(long N) {
        long[] fctr = new long[64];
        long n = Math.abs(N);
        int fctrIndex = 0;
        if (n > 0L) {
            while (n % 2L == 0L) {
                fctrIndex = (short)(fctrIndex + 1);
                fctr[fctrIndex] = 2L;
                n /= 2L;
            }
            while (n % 3L == 0L) {
                fctrIndex = (short)(fctrIndex + 1);
                fctr[fctrIndex] = 3L;
                n /= 3L;
            }
            int k = 5;
            while ((long)(k * k) <= n) {
                for (int dvsr = k; dvsr <= k + 2; dvsr += 2) {
                    while (n % (long)dvsr == 0L) {
                        fctrIndex = (short)(fctrIndex + 1);
                        fctr[fctrIndex] = dvsr;
                        n /= (long)dvsr;
                    }
                }
                k += 6;
            }
            if (n > 1L) {
                fctrIndex = (short)(fctrIndex + 1);
                fctr[fctrIndex] = n;
            }
        }
        fctr[0] = fctrIndex;
        return fctr;
    }

    private static long getTestBatchSize(long size) {
        long testBatchSize = 1L;
        long[] factors = StochasticDiffFunctionTester.primeFactors(size);
        long factorCount = factors[0];
        if (factorCount == 0L) {
            System.err.println("Attempt to test function on data of prime dimension.  This would involve a batchSize of 1 and may take a very long time.");
            System.exit(1);
        } else if (factorCount == 2L) {
            testBatchSize = (int)factors[1];
        } else {
            int f = 1;
            while ((long)f < factorCount) {
                testBatchSize *= factors[f];
                ++f;
            }
        }
        return testBatchSize;
    }

    public boolean testSumOfBatches(double[] x, double functionTolerance) {
        double[] dArray;
        boolean ret = false;
        System.err.println("Making sure that the sum of stochastic gradients equals the full gradient");
        AbstractStochasticCachingDiffFunction.SamplingMethod tmpSampleMethod = this.thisFunc.sampleMethod;
        StochasticCalculateMethods tmpMethod = this.thisFunc.method;
        this.thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Ordered;
        if (this.thisFunc.method == StochasticCalculateMethods.NoneSpecified) {
            System.err.println("No calculate method has been specified");
        }
        this.approxValue = 0.0;
        this.approxGrad = new double[x.length];
        this.curGrad = new double[x.length];
        this.fullGrad = new double[x.length];
        double percent = 0.0;
        for (int i = 0; i < this.numBatches; ++i) {
            percent = 100.0 * (double)i / (double)this.numBatches;
            this.approxValue += this.thisFunc.valueAt(x, this.v, this.testBatchSize);
            this.thisFunc.returnPreviousValues = true;
            System.arraycopy(this.thisFunc.derivativeAt(x, this.v, this.testBatchSize), 0, this.curGrad, 0, this.curGrad.length);
            this.approxGrad = ArrayMath.pairwiseAdd(this.approxGrad, this.curGrad);
            double norm = ArrayMath.norm(this.approxGrad);
            System.err.printf("%5.1f percent complete  %6.2f \n", percent, norm);
        }
        System.err.println("About to calculate the full derivative and value");
        System.arraycopy(this.thisFunc.derivativeAt(x), 0, this.fullGrad, 0, this.fullGrad.length);
        this.thisFunc.returnPreviousValues = true;
        this.fullValue = this.thisFunc.valueAt(x);
        this.diff = new double[x.length];
        this.diff = ArrayMath.pairwiseSubtract(this.fullGrad, this.approxGrad);
        if (ArrayMath.norm_inf(dArray) < functionTolerance) {
            this.sayln("");
            this.sayln("Success: sum of batch gradients equals full gradient");
            ret = true;
        } else {
            this.diffNorm = ArrayMath.norm(this.diff);
            this.sayln("");
            this.sayln("Failure: sum of batch gradients minus full gradient has norm " + this.diffNorm);
            ret = false;
        }
        if (Math.abs(this.approxValue - this.fullValue) < functionTolerance) {
            this.sayln("");
            this.sayln("Success: sum of batch values equals full value");
            ret = true;
        } else {
            this.sayln("");
            this.sayln("Failure: sum of batch values minus full value has norm " + Math.abs(this.approxValue - this.fullValue));
            ret = false;
        }
        this.thisFunc.sampleMethod = tmpSampleMethod;
        this.thisFunc.method = tmpMethod;
        return ret;
    }

    public boolean testDerivatives(double[] x, double functionTolerance) {
        boolean ret = false;
        boolean compareHess = true;
        System.err.println("Making sure that the stochastic derivatives are ok.");
        AbstractStochasticCachingDiffFunction.SamplingMethod tmpSampleMethod = this.thisFunc.sampleMethod;
        StochasticCalculateMethods tmpMethod = this.thisFunc.method;
        this.thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Ordered;
        if (this.thisFunc.method == StochasticCalculateMethods.NoneSpecified) {
            System.err.println("No calculate method has been specified");
        } else if (!this.thisFunc.method.calculatesHessianVectorProduct()) {
            compareHess = false;
        }
        this.approxValue = 0.0;
        this.approxGrad = new double[x.length];
        this.curGrad = new double[x.length];
        this.Hv = new double[x.length];
        double percent = 0.0;
        for (int i = 0; i < this.numBatches; ++i) {
            percent = 100.0 * (double)i / (double)this.numBatches;
            System.err.printf("%5.1f percent complete\n", percent);
            this.thisFunc.method = tmpMethod;
            System.arraycopy(this.thisFunc.HdotVAt(x, this.v, this.testBatchSize), 0, this.Hv, 0, this.Hv.length);
            this.thisFunc.method = StochasticCalculateMethods.ExternalFiniteDifference;
            System.arraycopy(this.thisFunc.derivativeAt(x, this.v, this.testBatchSize), 0, this.gradFD, 0, this.gradFD.length);
            this.thisFunc.recalculatePrevBatch = true;
            System.arraycopy(this.thisFunc.HdotVAt(x, this.v, this.gradFD, this.testBatchSize), 0, this.HvFD, 0, this.HvFD.length);
            double DiffHv = ArrayMath.norm_inf(ArrayMath.pairwiseSubtract(this.Hv, this.HvFD));
            if (!(DiffHv > this.maxHvDiff)) continue;
            this.maxHvDiff = DiffHv;
        }
        if (this.maxHvDiff < functionTolerance) {
            this.sayln("");
            this.sayln("Success: Hessian approximations lined up");
            ret = true;
        } else {
            this.sayln("");
            this.sayln("Failure: Hessian approximation at somepoint was off by " + this.maxHvDiff);
            ret = false;
        }
        this.thisFunc.sampleMethod = tmpSampleMethod;
        this.thisFunc.method = tmpMethod;
        return ret;
    }

    public double testConditionNumber(int samples) {
        double maxSeen = 0.0;
        double minSeen = 0.0;
        double[] thisV = new double[this.thisFunc.domainDimension()];
        double[] thisX = new double[thisV.length];
        this.gradFD = new double[thisV.length];
        this.HvFD = new double[thisV.length];
        boolean isNeg = false;
        boolean isPos = false;
        boolean isSemi = false;
        this.thisFunc.method = StochasticCalculateMethods.ExternalFiniteDifference;
        for (int j = 0; j < samples; ++j) {
            int i;
            for (i = 0; i < thisV.length; ++i) {
                thisV[i] = this.generator.nextDouble();
            }
            for (i = 0; i < thisX.length; ++i) {
                thisX[i] = this.generator.nextDouble();
            }
            System.err.println("Evaluating Hessian Product");
            System.arraycopy(this.thisFunc.derivativeAt(thisX, thisV, this.testBatchSize), 0, this.gradFD, 0, this.gradFD.length);
            this.thisFunc.recalculatePrevBatch = true;
            System.arraycopy(this.thisFunc.HdotVAt(thisX, thisV, this.gradFD, this.testBatchSize), 0, this.HvFD, 0, this.HvFD.length);
            double thisVHV = ArrayMath.innerProduct(thisV, this.HvFD);
            if (Math.abs(thisVHV) > maxSeen) {
                maxSeen = Math.abs(thisVHV);
            }
            if (Math.abs(thisVHV) < minSeen) {
                minSeen = Math.abs(thisVHV);
            }
            if (thisVHV < 0.0) {
                isNeg = true;
            }
            if (thisVHV > 0.0) {
                isPos = true;
            }
            if (thisVHV == 0.0) {
                isSemi = true;
            }
            System.err.println("It:" + j + "  C:" + maxSeen / minSeen + "N:" + isNeg + "P:" + isPos + "S:" + isSemi);
        }
        System.out.println("Condition Number of: " + maxSeen / minSeen);
        System.out.println("Is negative: " + isNeg);
        System.out.println("Is positive: " + isPos);
        System.out.println("Is semi:     " + isSemi);
        return maxSeen / minSeen;
    }

    public double[] getVariance(double[] x) {
        return this.getVariance(x, this.testBatchSize);
    }

    public double[] getVariance(double[] x, int batchSize) {
        double[] ret = new double[4];
        double[] fullHx = new double[this.thisFunc.domainDimension()];
        double[] thisHx = new double[x.length];
        double[] thisGrad = new double[x.length];
        ArrayList HxList = new ArrayList();
        this.thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Ordered;
        System.arraycopy(this.thisFunc.derivativeAt(x, x, this.thisFunc.dataDimension()), 0, thisGrad, 0, thisGrad.length);
        System.arraycopy(this.thisFunc.HdotVAt(x, x, thisGrad, this.thisFunc.dataDimension()), 0, fullHx, 0, fullHx.length);
        double fullNorm = ArrayMath.norm(fullHx);
        double hessScale = (double)this.thisFunc.dataDimension() / (double)batchSize;
        this.thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.RandomWithReplacement;
        int n = 100;
        double simMean = 0.0;
        double ratMean = 0.0;
        double simS = 0.0;
        double ratS = 0.0;
        int k = 0;
        System.err.println(fullHx[4] + "  " + x[4]);
        for (int i = 0; i < n; ++i) {
            System.arraycopy(this.thisFunc.derivativeAt(x, x, batchSize), 0, thisGrad, 0, thisGrad.length);
            System.arraycopy(this.thisFunc.HdotVAt(x, x, thisGrad, batchSize), 0, thisHx, 0, thisHx.length);
            ArrayMath.multiplyInPlace(thisHx, hessScale);
            double thisNorm = ArrayMath.norm(thisHx);
            double sim = ArrayMath.innerProduct(thisHx, fullHx) / (thisNorm * fullNorm);
            double rat = thisNorm / fullNorm;
            double simDelta = sim - simMean;
            simS += simDelta * (sim - (simMean += simDelta / (double)(++k)));
            double ratDelta = rat - ratMean;
            ratS += ratDelta * (rat - (ratMean += ratDelta / (double)k));
        }
        double simVar = simS / (double)(k - 1);
        double ratVar = ratS / (double)(k - 1);
        ret[0] = simMean;
        ret[1] = simVar;
        ret[2] = ratMean;
        ret[3] = ratVar;
        return ret;
    }

    public void testVariance(double[] x) {
        int[] batchSizes = new int[]{10, 20, 35, 50, 75, 150, 300, 500, 750, 1000, 5000, 10000};
        PrintWriter file = null;
        DecimalFormat nf = new DecimalFormat("0.000E0");
        try {
            file = new PrintWriter(new FileOutputStream("var.out"), true);
        }
        catch (IOException e) {
            System.err.println("Caught IOException outputing List to file: " + e.getMessage());
            System.exit(1);
        }
        for (int bSize : batchSizes) {
            double[] varResult = this.getVariance(x, bSize);
            file.println(bSize + "," + nf.format(varResult[0]) + "," + nf.format(varResult[1]) + "," + nf.format(varResult[2]) + "," + nf.format(varResult[3]));
            System.err.println("Batch size of: " + bSize + "   " + varResult[0] + "," + nf.format(varResult[1]) + "," + nf.format(varResult[2]) + "," + nf.format(varResult[3]));
        }
        file.close();
    }

    public void listToFile(List<double[]> thisList, String fileName) {
        PrintWriter file = null;
        DecimalFormat nf = new DecimalFormat("0.000E0");
        try {
            file = new PrintWriter(new FileOutputStream(fileName), true);
        }
        catch (IOException e) {
            System.err.println("Caught IOException outputing List to file: " + e.getMessage());
            System.exit(1);
        }
        for (double[] element : thisList) {
            for (double val : element) {
                file.print(nf.format(val) + "  ");
            }
            file.println("");
        }
        file.close();
    }

    public void arrayToFile(double[] thisArray, String fileName) {
        PrintWriter file = null;
        DecimalFormat nf = new DecimalFormat("0.000E0");
        try {
            file = new PrintWriter(new FileOutputStream(fileName), true);
        }
        catch (IOException e) {
            System.err.println("Caught IOException outputing List to file: " + e.getMessage());
            System.exit(1);
        }
        for (double element : thisArray) {
            file.print(nf.format(element) + "  ");
        }
        file.close();
    }
}

