/*
 * Decompiled with CFR 0.152.
 */
package stallone.hmm;

import java.util.Arrays;
import java.util.List;
import stallone.api.datasequence.IDataSequence;
import stallone.api.doubles.IDoubleArray;
import stallone.api.function.IParametricFunction;
import stallone.api.hmm.IExpectationMaximization;
import stallone.api.hmm.IHMM;
import stallone.api.hmm.IHMMHiddenVariables;
import stallone.api.hmm.IHMMParameters;
import stallone.api.hmm.ParameterEstimationException;
import stallone.api.ints.IIntArray;
import stallone.api.io.IO;
import stallone.api.mc.MarkovModel;
import stallone.api.stat.IParameterEstimator;
import stallone.doubles.PrimitiveDoubleTools;
import stallone.hmm.ForwardBackward;
import stallone.hmm.HMMCountMatrixEstimator;
import stallone.hmm.HMMForwardModel;
import stallone.hmm.HMMHiddenVariables;
import stallone.hmm.Viterbi;

public class EM
implements IExpectationMaximization,
IHMM {
    private List<IDataSequence> obs;
    private HMMForwardModel model;
    private HMMCountMatrixEstimator countMatrixEstimator;
    private IParameterEstimator[] outputModelEstimators;
    private ForwardBackward trajEstimator;
    double logLikelihood = Double.NEGATIVE_INFINITY;
    private List<IIntArray> initPaths;
    private HMMHiddenVariables[] hidden = null;
    private boolean saveMemory = false;
    private int nStepsMax = 1;
    private double dectol = 0.1;
    private double[] likelihoods;

    public EM(List<IDataSequence> _obs, boolean eventBased, int nstates, boolean reversible, IParametricFunction _fOut, IParameterEstimator _outputModelEstimator, boolean _saveMemory) {
        int i;
        if (_obs == null) {
            throw new IllegalArgumentException("Observation is " + null);
        }
        if (_obs.size() == 0) {
            throw new IllegalArgumentException("Observation has zero Elements");
        }
        this.obs = _obs;
        this.model = new HMMForwardModel(_obs, eventBased, nstates, reversible, _fOut);
        this.saveMemory = _saveMemory;
        this.trajEstimator = new ForwardBackward(this.model);
        if (this.saveMemory) {
            this.hidden = new HMMHiddenVariables[1];
            int maxsize = 0;
            int i2 = 0;
            while (i2 < this.hidden.length) {
                if (this.obs.get(i2).size() > maxsize) {
                    maxsize = this.obs.get(i2).size();
                }
                ++i2;
            }
            this.hidden[0] = new HMMHiddenVariables(maxsize, this.model.getNStates());
        } else {
            this.hidden = new HMMHiddenVariables[this.obs.size()];
            i = 0;
            while (i < this.hidden.length) {
                this.hidden[i] = new HMMHiddenVariables(this.obs.get(i).size(), this.model.getNStates());
                ++i;
            }
        }
        this.countMatrixEstimator = new HMMCountMatrixEstimator(eventBased, this.model);
        this.outputModelEstimators = new IParameterEstimator[this.model.getNStates()];
        i = 0;
        while (i < this.outputModelEstimators.length) {
            this.outputModelEstimators[i] = _outputModelEstimator.copy();
            ++i;
        }
    }

    @Override
    public void setInitialParameters(IHMMParameters _par) {
        this.model.setParameters(_par);
    }

    @Override
    public void setInitialPaths(List<IIntArray> _initPaths) {
        this.initPaths = _initPaths;
        this.model.setTransitionCounts(MarkovModel.util.estimateC(this.initPaths, 1));
    }

    public void setSaveMemory(boolean sm) {
        this.saveMemory = sm;
    }

    public void setCountMode(int mode) {
        this.countMatrixEstimator.setCountMode(mode);
    }

    @Override
    public void setMaximumNumberOfStep(int nsteps) {
        this.nStepsMax = nsteps;
    }

    @Override
    public void setLikelihoodDecreaseTolerance(double _dectol) {
        this.dectol = _dectol;
    }

    @Override
    public void run() throws ParameterEstimationException {
        double[] res = new double[this.nStepsMax];
        this.logLikelihood = Double.NEGATIVE_INFINITY;
        if (this.initPaths != null) {
            this.emStep(this.initPaths);
        }
        int n = 0;
        while (n < this.nStepsMax) {
            res[n] = this.emStep(null);
            if (res[n] + this.dectol < this.logLikelihood) {
                System.out.println(" next likelihood = " + res[n] + " exiting.");
                this.likelihoods = PrimitiveDoubleTools.subarray(res, 0, n);
                return;
            }
            if (res[n] > this.logLikelihood) {
                this.logLikelihood = res[n];
            }
            if (Double.isNaN(res[n])) {
                System.out.println("NaN in likelihood from E-Step");
                Arrays.fill(res, Double.NEGATIVE_INFINITY);
                this.likelihoods = res;
                return;
            }
            ++n;
        }
        this.likelihoods = res;
    }

    private double emStep(List<IIntArray> _initPaths) throws ParameterEstimationException {
        double res = 0.0;
        this.countMatrixEstimator.initialize();
        int s = 0;
        while (s < this.outputModelEstimators.length) {
            this.outputModelEstimators[s].initialize();
            ++s;
        }
        int i = 0;
        while (i < this.obs.size()) {
            HMMHiddenVariables hiddenCur = null;
            if (this.saveMemory) {
                this.hidden[0].setLength(this.obs.get(i).size());
                hiddenCur = this.hidden[0];
            } else {
                hiddenCur = this.hidden[i];
            }
            if (_initPaths == null) {
                this.trajEstimator.computePath(i, hiddenCur);
            } else if (this.model.isEventBased()) {
                hiddenCur.setPath(this.obs.get(i), _initPaths.get(i));
            } else {
                hiddenCur.setPath(_initPaths.get(i));
            }
            if (Double.isNaN(res += hiddenCur.logLikelihood())) break;
            this.countMatrixEstimator.addToEstimate(this.obs.get(i), i, hiddenCur);
            int s2 = 0;
            while (s2 < this.outputModelEstimators.length) {
                this.outputModelEstimators[s2].addToEstimate(this.obs.get(i), hiddenCur.getGammaByState(s2));
                ++s2;
            }
            ++i;
        }
        if (!(res < this.logLikelihood - this.dectol) && !Double.isNaN(res)) {
            IDoubleArray C = this.countMatrixEstimator.getEstimate();
            this.model.setTransitionCounts(C);
            int s3 = 0;
            while (s3 < this.outputModelEstimators.length) {
                this.model.setOutputParameters(s3, this.outputModelEstimators[s3].getEstimate());
                ++s3;
            }
        }
        return res;
    }

    @Override
    public IHMMHiddenVariables getHidden(int itraj) {
        if (this.saveMemory) {
            this.hidden[0].setLength(this.obs.get(itraj).size());
            try {
                this.trajEstimator.computePath(itraj, this.hidden[0]);
            }
            catch (ParameterEstimationException ex) {
                IO.util.error(ex.getMessage());
            }
            return this.hidden[0];
        }
        return this.hidden[itraj];
    }

    @Override
    public IHMMParameters getParameters() {
        return this.model.getParameters();
    }

    @Override
    public double getLogLikelihood() {
        return this.logLikelihood;
    }

    @Override
    public List<IIntArray> viterbi() {
        Viterbi v = new Viterbi(this.model);
        return v.getPaths();
    }

    @Override
    public IHMM getHMM() {
        return this;
    }

    @Override
    public double[] getLogLikelihoodHistory() {
        return this.likelihoods;
    }

    @Override
    public IDoubleArray getTransitionMatrix() {
        return this.model.getParameters().getTransitionMatrix();
    }

    @Override
    public IDoubleArray getOutputParameters() {
        return this.model.getParameters().getOutputParameterMatrix();
    }
}

