/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.optimization;

import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffUpdateFunction;
import edu.stanford.nlp.optimization.Evaluator;
import edu.stanford.nlp.optimization.Function;
import edu.stanford.nlp.optimization.HasEvaluators;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.util.Timing;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Random;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class StochasticInPlaceMinimizer<T extends Function>
implements Minimizer<Function>,
HasEvaluators {
    protected double xscale;
    protected double xnorm;
    protected double[] x;
    protected int t0;
    protected double sigma = 1.0;
    protected double lambda;
    protected boolean quiet = false;
    protected int numPasses = 50;
    protected int bSize = 1;
    protected int tuningSamples = 1000;
    protected Random gen = new Random(1L);
    protected long maxTime = Long.MAX_VALUE;
    private int evaluateIters = 0;
    private Evaluator[] evaluators;
    private static NumberFormat nf = new DecimalFormat("0.000E0");

    public StochasticInPlaceMinimizer() {
    }

    public StochasticInPlaceMinimizer(double sigma, int numPasses) {
        this(sigma, numPasses, -1);
    }

    public StochasticInPlaceMinimizer(double sigma, int numPasses, int tuningSamples) {
        this.sigma = sigma;
        if (numPasses >= 0) {
            this.numPasses = numPasses;
        } else {
            this.sayln("  StochasticInPlaceMinimizer: numPasses=" + numPasses + ", defaulting to " + this.numPasses);
        }
        if (tuningSamples > 0) {
            this.tuningSamples = tuningSamples;
        } else {
            this.sayln("  StochasticInPlaceMinimizer: tuneSampleSize=" + tuningSamples + ", defaulting to " + this.tuningSamples);
        }
    }

    public StochasticInPlaceMinimizer(LogPrior prior, int numPasses, int batchSize, int tuningSamples) {
        if (!LogPrior.LogPriorType.QUADRATIC.equals((Object)prior.getType())) {
            throw new RuntimeException("Unsupport prior type " + (Object)((Object)prior.getType()));
        }
        this.sigma = prior.getSigma();
        if (numPasses >= 0) {
            this.numPasses = numPasses;
        } else {
            this.sayln("  StochasticInPlaceMinimizer: numPasses=" + numPasses + ", defaulting to " + this.numPasses);
        }
        this.bSize = batchSize;
        if (tuningSamples > 0) {
            this.tuningSamples = tuningSamples;
        } else {
            this.sayln("  StochasticInPlaceMinimizer: tuneSampleSize=" + tuningSamples + ", defaulting to " + this.tuningSamples);
        }
    }

    public void shutUp() {
        this.quiet = true;
    }

    protected String getName() {
        return "SGD_InPlace_b" + this.bSize + "_lambda" + nf.format(this.lambda);
    }

    @Override
    public void setEvaluators(int iters, Evaluator[] evaluators) {
        this.evaluateIters = iters;
        this.evaluators = evaluators;
    }

    private void ensureFinite(double[] vect, String name) throws InvalidElementException {
        for (int i = 0; i < vect.length; ++i) {
            if (Double.isNaN(vect[i])) {
                throw new InvalidElementException("NAN found in " + name + " element " + i);
            }
            if (!Double.isInfinite(vect[i])) continue;
            throw new InvalidElementException("Infinity found in " + name + " element " + i);
        }
    }

    protected void init(AbstractStochasticCachingDiffUpdateFunction func) {
    }

    public double getObjective(AbstractStochasticCachingDiffUpdateFunction function, double[] w, double wscale, int[] sample) {
        double wnorm = this.getNorm(w) * wscale * wscale;
        double obj = function.valueAt(w, wscale, sample);
        return obj + 0.5 * (double)sample.length * this.lambda * wnorm;
    }

    public double tryEta(AbstractStochasticCachingDiffUpdateFunction function, double[] initial, int[] sample, double eta) {
        int numBatches = sample.length / this.bSize;
        double[] w = new double[initial.length];
        double wscale = 1.0;
        double obj = 0.0;
        System.arraycopy(initial, 0, w, 0, w.length);
        int[] sampleBatch = new int[this.bSize];
        int sampleIndex = 0;
        for (int batch = 0; batch < numBatches; ++batch) {
            for (int i = 0; i < this.bSize; ++i) {
                sampleBatch[i] = sample[(sampleIndex + i) % sample.length];
            }
            sampleIndex += this.bSize;
            double gain = eta / wscale;
            function.calculateStochasticUpdate(w, wscale, sampleBatch, gain);
            wscale *= 1.0 - eta * this.lambda * (double)this.bSize;
        }
        obj = this.getObjective(function, w, wscale, sample);
        return obj;
    }

    public double tune(AbstractStochasticCachingDiffUpdateFunction function, double[] initial, int sampleSize, double seta) {
        Timing timer = new Timing();
        int[] sample = function.getSample(sampleSize);
        double sobj = this.getObjective(function, initial, 1.0, sample);
        double besteta = 1.0;
        double bestobj = sobj;
        double eta = seta;
        int totest = 10;
        double factor = 2.0;
        boolean phase2 = false;
        while (totest > 0 || !phase2) {
            double obj = this.tryEta(function, initial, sample, eta);
            boolean okay = obj < sobj;
            this.sayln("  Trying eta=" + eta + "  obj=" + obj + (okay ? "(possible)" : "(too large)"));
            if (okay) {
                --totest;
                if (obj < bestobj) {
                    bestobj = obj;
                    besteta = eta;
                }
            }
            if (!phase2) {
                if (okay) {
                    eta *= factor;
                } else {
                    phase2 = true;
                    eta = seta;
                }
            }
            if (!phase2) continue;
            eta /= factor;
        }
        this.t0 = (int)(1.0 / ((besteta /= factor) * this.lambda));
        this.sayln("  Taking eta=" + besteta + " t0=" + this.t0);
        this.sayln("  Tuning completed in: " + Timing.toSecondsString(timer.report()) + " s");
        return besteta;
    }

    private double getNorm(double[] w) {
        double norm = 0.0;
        for (int i = 0; i < w.length; ++i) {
            norm += w[i] * w[i];
        }
        return norm;
    }

    private void rescale() {
        if (this.xscale == 1.0) {
            return;
        }
        int i = 0;
        while (i < this.x.length) {
            int n = i++;
            this.x[n] = this.x[n] * this.xscale;
        }
        this.xscale = 1.0;
    }

    private void doEvaluation(double[] x) {
        if (this.evaluators == null) {
            return;
        }
        for (Evaluator eval : this.evaluators) {
            this.sayln("  Evaluating: " + eval.toString());
            eval.evaluate(x);
        }
    }

    @Override
    public double[] minimize(Function function, double functionTolerance, double[] initial) {
        return this.minimize(function, functionTolerance, initial, -1);
    }

    @Override
    public double[] minimize(Function f, double functionTolerance, double[] initial, int maxIterations) {
        boolean have_max;
        if (!(f instanceof AbstractStochasticCachingDiffUpdateFunction)) {
            throw new UnsupportedOperationException();
        }
        AbstractStochasticCachingDiffUpdateFunction function = (AbstractStochasticCachingDiffUpdateFunction)f;
        int totalSamples = function.dataDimension();
        int tuneSampleSize = Math.min(totalSamples, this.tuningSamples);
        if (tuneSampleSize < this.tuningSamples) {
            System.err.println("WARNING: Total number of samples=" + totalSamples + " is smaller than requested tuning sample size=" + this.tuningSamples + "!!!");
        }
        this.lambda = 1.0 / (this.sigma * (double)totalSamples);
        this.sayln("Using sigma=" + this.sigma + " lambda=" + this.lambda + " tuning sample size " + tuneSampleSize);
        this.tune(function, initial, tuneSampleSize, 0.1);
        this.x = new double[initial.length];
        System.arraycopy(initial, 0, this.x, 0, this.x.length);
        this.xscale = 1.0;
        this.xnorm = this.getNorm(this.x);
        int numBatches = totalSamples / this.bSize;
        this.init(function);
        boolean bl = have_max = maxIterations > 0 || this.numPasses > 0;
        if (!have_max) {
            throw new UnsupportedOperationException("No maximum number of iterations has been specified.");
        }
        maxIterations = Math.max(maxIterations, this.numPasses) * numBatches;
        this.sayln("       Batchsize of: " + this.bSize);
        this.sayln("       Data dimension of: " + totalSamples);
        this.sayln("       Batches per pass through data:  " + numBatches);
        this.sayln("       Number of passes is = " + this.numPasses);
        this.sayln("       Max iterations is = " + maxIterations);
        boolean doEval = false;
        Timing total = new Timing();
        Timing current = new Timing();
        total.start();
        current.start();
        int t = this.t0;
        int iters = 0;
        for (int pass = 0; pass < this.numPasses; ++pass) {
            boolean bl2 = doEval = pass > 0 && this.evaluateIters > 0 && pass % this.evaluateIters == 0;
            if (doEval) {
                this.rescale();
                this.doEvaluation(this.x);
            }
            double totalValue = 0.0;
            double lastValue = 0.0;
            this.say("Iter: " + iters + " pass " + pass + " batch 1 ... ");
            for (int batch = 0; batch < numBatches; ++batch) {
                ++iters;
                double eta = 1.0 / (this.lambda * (double)t);
                double gain = eta / this.xscale;
                lastValue = function.calculateStochasticUpdate(this.x, this.xscale, this.bSize, gain);
                totalValue += lastValue;
                this.xscale *= 1.0 - eta * this.lambda * (double)this.bSize;
                t += this.bSize;
            }
            if (this.xscale < 1.0E-6) {
                this.rescale();
            }
            try {
                this.ensureFinite(this.x, "x");
            }
            catch (InvalidElementException e) {
                System.err.println(e.toString());
                for (int i = 0; i < this.x.length; ++i) {
                    this.x[i] = Double.NaN;
                }
                break;
            }
            this.xnorm = this.getNorm(this.x) * this.xscale * this.xscale;
            double loss = totalValue + 0.5 * this.xnorm * this.lambda * (double)totalSamples;
            this.say("" + numBatches);
            this.say("[" + (double)total.report() / 1000.0 + " s ");
            this.say("{" + (double)current.restart() / 1000.0 + " s}] ");
            this.sayln(" " + lastValue + " " + totalValue + " " + loss);
            if (iters >= maxIterations) {
                this.sayln("Stochastic Optimization complete.  Stopped after max iterations");
                break;
            }
            if (total.report() < this.maxTime) continue;
            this.sayln("Stochastic Optimization complete.  Stopped after max time");
            break;
        }
        this.rescale();
        if (this.evaluateIters > 0) {
            this.doEvaluation(this.x);
        }
        this.sayln("Completed in: " + Timing.toSecondsString(total.report()) + " s");
        return this.x;
    }

    protected void sayln(String s) {
        if (!this.quiet) {
            System.err.println(s);
        }
    }

    protected void say(String s) {
        if (!this.quiet) {
            System.err.print(s);
        }
    }

    public static class InvalidElementException
    extends Throwable {
        private static final long serialVersionUID = 1647150702529757545L;

        public InvalidElementException(String s) {
            super(s);
        }
    }
}

