/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.optimization;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffFunction;
import edu.stanford.nlp.optimization.Function;
import edu.stanford.nlp.optimization.StochasticMinimizer;
import edu.stanford.nlp.sequences.SeqClassifierFlags;
import edu.stanford.nlp.util.Pair;
import java.io.IOException;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ScaledSGDMinimizer
extends StochasticMinimizer {
    private static int method = 1;
    public List<double[]> yList = null;
    public List<double[]> sList = null;
    public double[] diag;
    private double fixedGain = 0.99;
    private double[] s;
    private double[] y;
    private static int pairMem = 20;
    private double aMax = 1000000.0;
    private static NumberFormat nf = new DecimalFormat("0.000E0");

    public double tuneFixedGain(Function function, double[] initial, long msPerTest, double fixedStart) {
        double[] xtest = new double[initial.length];
        double fOpt = 0.0;
        double factor = 1.2;
        double min = Double.POSITIVE_INFINITY;
        this.maxTime = msPerTest;
        double prev = Double.POSITIVE_INFINITY;
        if (!(function instanceof AbstractStochasticCachingDiffFunction)) {
            throw new UnsupportedOperationException();
        }
        AbstractStochasticCachingDiffFunction dfunction = (AbstractStochasticCachingDiffFunction)function;
        int it = 1;
        boolean toContinue = true;
        double f = fixedStart;
        do {
            System.arraycopy(initial, 0, xtest, 0, initial.length);
            System.err.println("");
            this.fixedGain = f;
            System.err.println("Testing with batchsize: " + StochasticMinimizer.bSize + "    gain:  " + StochasticMinimizer.gain + "  fixedGain:  " + nf.format(this.fixedGain));
            this.numPasses = 10000;
            this.minimize(function, 1.0E-100, xtest);
            double result = dfunction.valueAt(xtest);
            if (it == 1) {
                f /= factor;
            }
            if (result < min) {
                min = result;
                fOpt = this.fixedGain;
                f /= factor;
                prev = result;
            } else if (result < prev) {
                f /= factor;
                prev = result;
            } else if (result > prev) {
                toContinue = false;
            }
            ++it;
            System.err.println("");
            System.err.println("Final value is: " + nf.format(result));
            System.err.println("Optimal so far is:  fixedgain: " + fOpt);
        } while (toContinue);
        return fOpt;
    }

    @Override
    public Pair<Integer, Double> tune(Function function, double[] initial, long msPerTest) {
        this.quiet = true;
        for (int i = 0; i < 2; ++i) {
            this.fixedGain = this.tuneDouble(function, initial, msPerTest, new setFixedGain(this), 0.1, 1.0);
            StochasticMinimizer.gain = this.tuneGain(function, initial, msPerTest, 1.0E-7, 1.0);
            StochasticMinimizer.bSize = this.tuneBatch(function, initial, msPerTest, 1);
            System.err.println("Results:  fixedGain: " + nf.format(this.fixedGain) + "  gain: " + nf.format(StochasticMinimizer.gain) + "  batch " + StochasticMinimizer.bSize);
        }
        return new Pair<Integer, Double>(StochasticMinimizer.bSize, StochasticMinimizer.gain);
    }

    @Override
    public void shutUp() {
        this.quiet = true;
    }

    public void setBatchSize(int batchSize) {
        bSize = batchSize;
    }

    public ScaledSGDMinimizer(SeqClassifierFlags flags) {
        StochasticMinimizer.bSize = flags.stochasticBatchSize;
        StochasticMinimizer.gain = flags.initialGain;
        this.numPasses = flags.SGDPasses;
        this.outputIterationsToFile = flags.outputIterationsToFile;
    }

    public ScaledSGDMinimizer(double SGDGain, int batchSize, int sgdPasses) {
        this(SGDGain, batchSize, sgdPasses, 1, false);
    }

    public ScaledSGDMinimizer(double SGDGain, int batchSize, int sgdPasses, int method) {
        this(SGDGain, batchSize, sgdPasses, method, false);
    }

    public ScaledSGDMinimizer(double SGDGain, int batchSize, int sgdPasses, int method, boolean outputToFile) {
        StochasticMinimizer.bSize = batchSize;
        StochasticMinimizer.gain = SGDGain;
        this.numPasses = sgdPasses;
        ScaledSGDMinimizer.method = method;
        this.outputIterationsToFile = outputToFile;
    }

    public ScaledSGDMinimizer(double SGDGain, int batchSize) {
        this(SGDGain, batchSize, 50);
    }

    public void setMaxTime(Long max) {
        this.maxTime = max;
    }

    @Override
    public String getName() {
        int g = (int)(gain * 1000.0);
        int f = (int)(this.fixedGain * 1000.0);
        return "ScaledSGD" + bSize + "_g" + g + "_f" + f;
    }

    @Override
    protected void takeStep(AbstractStochasticCachingDiffFunction dfunction) {
        for (int i = 0; i < this.x.length; ++i) {
            double thisGain = this.fixedGain * this.gainSchedule(this.k, 5 * this.numBatches) / this.diag[i];
            this.newX[i] = this.x[i] - thisGain * this.grad[i];
        }
        this.say(" A ");
        if (pairMem > 0 && this.sList.size() == pairMem || this.sList.size() == pairMem) {
            this.s = this.sList.remove(0);
            this.y = this.yList.remove(0);
        } else {
            this.s = new double[this.x.length];
            this.y = new double[this.x.length];
        }
        this.s = ArrayMath.pairwiseSubtract(this.newX, this.x);
        dfunction.recalculatePrevBatch = true;
        System.arraycopy(dfunction.derivativeAt(this.newX, bSize), 0, this.y, 0, this.grad.length);
        ArrayMath.pairwiseSubtractInPlace(this.y, this.newGrad);
        double[] comp = new double[this.x.length];
        this.sList.add(this.s);
        this.yList.add(this.y);
        this.updateDiag(this.diag, this.s, this.y);
    }

    @Override
    protected void init(AbstractStochasticCachingDiffFunction func) {
        this.diag = new double[this.x.length];
        this.memory = 1;
        for (int i = 0; i < this.x.length; ++i) {
            this.diag[i] = this.fixedGain / gain;
        }
        this.sList = new ArrayList<double[]>();
        this.yList = new ArrayList<double[]>();
    }

    private void updateDiag(double[] diag, double[] s, double[] y) {
        if (method == 0) {
            this.updateDiagMinErr(diag, s, y);
        } else if (method == 1) {
            this.updateDiagBFGS(diag, s, y);
        }
    }

    private void updateDiagBFGS(double[] diag, double[] s, double[] y) {
        double sDs = 0.0;
        double sy = 0.0;
        for (int i = 0; i < s.length; ++i) {
            sDs += s[i] * diag[i] * s[i];
            sy += s[i] * y[i];
        }
        this.say("B");
        double[] newDiag = new double[s.length];
        boolean updateDiag = true;
        for (int i = 0; i < s.length; ++i) {
            newDiag[i] = (1.0 - diag[i] * s[i] * s[i] / sDs) * diag[i] + y[i] * y[i] / sy;
            if (!(newDiag[i] < 0.0)) continue;
            updateDiag = false;
            break;
        }
        if (updateDiag) {
            System.arraycopy(newDiag, 0, diag, 0, s.length);
        } else {
            this.say("!");
        }
    }

    private void updateDiagMinErr(double[] diag, double[] s, double[] y) {
        double lamStar;
        double low = 0.0;
        double high = 0.0;
        double alpha = 10.0;
        for (int i = 0; i < s.length; ++i) {
            double tmp = s[i] * (y[i] - diag[i]);
            high += tmp * tmp;
        }
        this.say("M");
        alpha = Math.sqrt(ArrayMath.norm(y) / ArrayMath.norm(s)) * Math.sqrt(50.0 / (50.0 + (double)this.k));
        this.say(" alpha " + nf.format(alpha *= Math.sqrt(ArrayMath.average(diag))));
        high = Math.sqrt(high) / (2.0 * alpha);
        lagrange func = new lagrange(s, y, diag, alpha);
        if ((Double)func.apply(low) > 0.0) {
            lamStar = this.getRoot(func, low, high);
        } else {
            lamStar = 0.0;
            this.say(" * ");
        }
        for (int i = 0; i < s.length; ++i) {
            diag[i] = (Math.abs(y[i] * s[i]) + 2.0 * lamStar * diag[i]) / (s[i] * s[i] + 1.0E-8 + 2.0 * lamStar);
            if (!(diag[i] <= 1.0 / this.aMax)) continue;
            diag[i] = 1.0 / gain;
        }
    }

    private double getRoot(edu.stanford.nlp.util.Function<Double, Double> func, double lower, double upper) {
        double mid = 0.5 * (lower + upper);
        double fval = 0.0;
        double TOL = 1.0E-8;
        double skew = 0.4;
        int count = 0;
        if (func.apply(upper) > 0.0 || func.apply(lower) < 0.0) {
            this.say("LOWER AND UPPER SUPPLIED TO GET ROOT DO NOT BOUND THE ROOT.");
        }
        fval = func.apply(mid);
        while (Math.abs(fval) > TOL) {
            ++count;
            if (fval > 0.0) {
                lower = mid;
            } else if (fval < 0.0) {
                upper = mid;
            }
            mid = skew * lower + (1.0 - skew) * upper;
            fval = func.apply(mid);
            if (count <= 100) continue;
        }
        this.say("   " + nf.format(mid) + "  f" + nf.format(fval));
        return mid;
    }

    public void serializeWeights(String serializePath, double[] weights) {
        this.serializeWeights(serializePath, weights, null);
    }

    public void serializeWeights(String serializePath, double[] weights, double[] diag) {
        System.err.println("Serializing weights to " + serializePath + "...");
        try {
            weight out2 = new weight(weights, diag);
            IOUtils.writeObjectToFile((Object)out2, serializePath);
        }
        catch (Exception e) {
            System.err.println("Error serializing to " + serializePath);
            e.printStackTrace();
        }
    }

    public double[] getWeights(String loadPath) throws IOException, ClassCastException, ClassNotFoundException {
        System.err.println("Loading weights from " + loadPath + "...");
        weight w = (weight)IOUtils.readObjectFromFile(loadPath);
        double[] wt = w.w;
        return wt;
    }

    public double[] getDiag(String loadPath) throws IOException, ClassCastException, ClassNotFoundException {
        System.err.println("Loading weights from " + loadPath + "...");
        weight w = (weight)IOUtils.readObjectFromFile(loadPath);
        double[] diag = w.d;
        return diag;
    }

    public class weight
    implements Serializable {
        public double[] w;
        public double[] d;
        private static final long serialVersionUID = 814182172645533781L;

        public weight(double[] wt) {
            this.w = wt;
        }

        public weight(double[] wt, double[] di) {
            this.w = wt;
            this.d = di;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    class lagrange
    implements edu.stanford.nlp.util.Function<Double, Double> {
        double[] s;
        double[] y;
        double[] d;
        double a;
        double tmp;

        public lagrange(double[] s, double[] y, double[] d, double a) {
            this.s = s;
            this.y = y;
            this.d = d;
            this.a = a;
        }

        @Override
        public Double apply(Double lam) {
            double val = 0.0;
            for (int i = 0; i < this.s.length; ++i) {
                this.tmp = (this.y[i] * this.s[i] + 2.0 * lam * this.d[i]) / (this.s[i] * this.s[i] + 2.0 * lam) - this.d[i];
                val += this.tmp * this.tmp;
            }
            return val -= this.a * this.a;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class setFixedGain
    implements StochasticMinimizer.PropertySetter<Double> {
        ScaledSGDMinimizer parent = null;

        public setFixedGain(ScaledSGDMinimizer min) {
            this.parent = min;
        }

        @Override
        public void set(Double in) {
            this.parent.fixedGain = in;
        }
    }
}

