/*
 * Decompiled with CFR 0.152.
 */
package stallone.hmm;

import stallone.api.datasequence.IDataSequence;
import stallone.api.doubles.Doubles;
import stallone.api.doubles.IDoubleArray;
import stallone.api.hmm.IHMMHiddenVariables;
import stallone.doubles.PrimitiveDoubleTools;
import stallone.hmm.HMMForwardModel;

public class HMMCountMatrixEstimator {
    public static final int MODE_MAXPATH = 1;
    public static final int MODE_VITERBI = 2;
    public static final int MODE_BAUMWELCH = 3;
    private int countmode = 3;
    private boolean countEvents = false;
    private HMMForwardModel model = null;
    private double[][] C = null;

    public HMMCountMatrixEstimator(boolean eventBased, HMMForwardModel _model) {
        this.model = _model;
        this.countEvents = eventBased;
        this.C = new double[_model.getNStates()][_model.getNStates()];
    }

    public void setCountMode(int _countmode) {
        if (_countmode != 1 && _countmode != 2 && _countmode != 3) {
            throw new IllegalArgumentException("non-existing count mode");
        }
        this.countmode = _countmode;
    }

    public void initialize() {
        this.C = new double[this.model.getNStates()][this.model.getNStates()];
    }

    public void addToEstimate(IDataSequence obs, int itraj, IHMMHiddenVariables hidden) {
        if (this.countmode == 1) {
            this.addEstimateMaxPath(obs, hidden);
        } else if (this.countmode == 2) {
            this.addEstimateViterbi(obs, hidden);
        } else if (this.countmode == 3) {
            this.addEstimateBaumWelch(obs, itraj, hidden);
        } else {
            throw new RuntimeException("Should not be here");
        }
    }

    private void addEstimateMaxPath(IDataSequence obs, IHMMHiddenVariables hidden) {
        int t = 0;
        while (t < hidden.size() - 1) {
            int s1 = hidden.mostProbableState(t);
            int s2 = hidden.mostProbableState(t + 1);
            double[] dArray = this.C[s1];
            int n = s2;
            dArray[n] = dArray[n] + 1.0;
            ++t;
        }
        if (this.countEvents) {
            t = 0;
            while (t < hidden.size() - 1) {
                int s = hidden.mostProbableState(t);
                double dt = obs.getTime(t + 1) - obs.getTime(t);
                double[] dArray = this.C[s];
                int n = s;
                dArray[n] = dArray[n] + (dt - 1.0);
                ++t;
            }
        }
    }

    private void addEstimateViterbi(IDataSequence obs, IHMMHiddenVariables hidden) {
        throw new RuntimeException("Viterbi is not implemented yet.");
    }

    public double[][] baumWelchTransition(IDataSequence obs, int itraj, int time1, IHMMHiddenVariables hidden) {
        double[][] Ct = new double[this.C.length][this.C[0].length];
        int i = 0;
        while (i < hidden.nStates()) {
            int j = 0;
            while (j < hidden.nStates()) {
                Ct[i][j] = hidden.getAlpha(time1, i) * this.model.getPtrans(itraj, time1, i, j) * hidden.getPout(time1 + 1, j) * hidden.getBeta(time1 + 1, j);
                ++j;
            }
            ++i;
        }
        double norm = PrimitiveDoubleTools.sum(Ct);
        Ct = PrimitiveDoubleTools.multiply(1.0 / norm, Ct);
        if (this.countEvents) {
            double dt = obs.getTime(time1 + 1) - obs.getTime(time1);
            double ka = (dt - 1.0) / 2.0;
            double kb = (dt - 1.0) / 2.0;
            int i2 = 0;
            while (i2 < Ct.length) {
                double[] dArray = Ct[i2];
                int n = i2;
                dArray[n] = dArray[n] + (ka * hidden.getAlpha(time1, i2) + kb * hidden.getBeta(time1 + 1, i2));
                ++i2;
            }
        }
        return Ct;
    }

    private void addEstimateBaumWelch(IDataSequence obs, int itraj, IHMMHiddenVariables hidden) {
        int t = 0;
        while (t < hidden.size() - 1) {
            this.C = PrimitiveDoubleTools.add(this.C, this.baumWelchTransition(obs, itraj, t, hidden));
            ++t;
        }
    }

    public IDoubleArray getEstimate() {
        return Doubles.create.array(this.C);
    }
}

