/*
 * Decompiled with CFR 0.152.
 */
package stallone.hmm;

import stallone.api.datasequence.IDataSequence;
import stallone.api.doubles.Doubles;
import stallone.api.doubles.IDoubleArray;
import stallone.api.hmm.IHMMHiddenVariables;
import stallone.api.ints.IIntArray;
import stallone.api.io.IO;
import stallone.doubles.PrimitiveDoubleTools;

public class HMMHiddenVariables
implements IHMMHiddenVariables {
    private int capacity;
    private int length;
    private int nstates;
    private double[][] alpha;
    private double[][] beta;
    private double[][] gamma;
    private double[][] pout;
    private double[] alphanorms;
    private double[] betanorms;

    public HMMHiddenVariables(int ntimesteps, int _nstates) {
        this.capacity = ntimesteps;
        this.length = ntimesteps;
        this.nstates = _nstates;
        this.alpha = new double[this.capacity][this.nstates];
        this.beta = new double[this.capacity][this.nstates];
        this.gamma = new double[this.capacity][this.nstates];
        this.pout = new double[this.capacity][this.nstates];
        this.alphanorms = new double[this.capacity];
        this.betanorms = new double[this.capacity];
    }

    public String toString() {
        StringBuilder strbuf = new StringBuilder();
        strbuf.append(PrimitiveDoubleTools.toString(this.alpha, "\t", "\n")).append("\n\n");
        strbuf.append(PrimitiveDoubleTools.toString(this.beta, "\t", "\n")).append("\n\n");
        strbuf.append(PrimitiveDoubleTools.toString(this.gamma, "\t", "\n")).append("\n\n");
        return strbuf.toString();
    }

    public void setLength(int l) {
        if (l < 0 || l > this.capacity) {
            IO.util.error("Trying to set illegal length " + l + " in Hidden Variables. Capacity is " + this.capacity);
        }
        this.length = l;
    }

    public void setPath(IDataSequence observation, IIntArray path) {
        int L = observation != null ? observation.size() : path.size();
        this.setLength(L);
        int t = 0;
        while (t < L) {
            int ti = t;
            if (observation != null) {
                ti = (int)Math.round(observation.getTime(t));
            }
            int s = 0;
            while (s < this.nstates) {
                double x = s == path.get(ti) ? 1.0 : 0.0;
                this.setAlpha(t, s, x);
                this.setBeta(t, s, x);
                this.setPout(t, s, x);
                ++s;
            }
            ++t;
        }
        this.updateGamma();
    }

    public void setPath(IIntArray path) {
        this.setPath(null, path);
    }

    public int getCapacity() {
        return this.capacity;
    }

    @Override
    public int size() {
        return this.length;
    }

    @Override
    public int nStates() {
        return this.nstates;
    }

    private void checkIndex(int t) {
        if (t < 0 || t >= this.capacity) {
            IO.util.error("Accessing illegal time index " + t + " in Hidden Variables. We have " + this.length + " timesteps available.");
        }
        if (t >= this.length) {
            IO.util.error("Accessing illegal time index " + t + " in Hidden Variables. Capacity is " + this.capacity + ", but only " + this.length + " timesteps are accessible in this trajectory.");
        }
    }

    private void checkIndex(int t, int s) {
        if (t < 0 || t >= this.capacity) {
            IO.util.error("Accessing illegal time index " + t + " in Hidden Variables. We have " + this.length + " timesteps available.");
        }
        if (t >= this.length) {
            IO.util.error("Accessing illegal time index " + t + " in Hidden Variables. Capacity is " + this.capacity + ", but only " + this.length + " timesteps are accessible in this trajectory.");
        }
        if (s < 0 || s >= this.nstates) {
            IO.util.error("Accessing illegal state index " + s + " in Hidden Variables. We have " + this.nstates + " states available.");
        }
    }

    public void setAlpha(int t, int s, double x) {
        this.checkIndex(t, s);
        this.alpha[t][s] = x;
    }

    public void addAlpha(int t, int s, double x) {
        this.checkIndex(t, s);
        double[] dArray = this.alpha[t];
        int n = s;
        dArray[n] = dArray[n] + x;
    }

    public void setBeta(int t, int s, double x) {
        this.checkIndex(t, s);
        this.beta[t][s] = x;
    }

    public void addBeta(int t, int s, double x) {
        this.checkIndex(t, s);
        double[] dArray = this.beta[t];
        int n = s;
        dArray[n] = dArray[n] + x;
    }

    public void updateGamma() {
        int t = 0;
        while (t < this.length) {
            int i = 0;
            while (i < this.nstates) {
                this.gamma[t][i] = this.alpha[t][i] * this.beta[t][i];
                ++i;
            }
            this.gamma[t] = PrimitiveDoubleTools.multiply(1.0 / PrimitiveDoubleTools.sum(this.gamma[t]), this.gamma[t]);
            ++t;
        }
    }

    @Override
    public int[] getMaxPath() {
        int[] traj = new int[this.length];
        int t = 0;
        while (t < this.length) {
            traj[t] = PrimitiveDoubleTools.maxIndex(this.gamma[t]);
            ++t;
        }
        return traj;
    }

    public void setPout(int t, int s, double x) {
        this.checkIndex(t, s);
        this.pout[t][s] = x;
    }

    public boolean checkPout(int t) {
        block3: {
            block2: {
                if (PrimitiveDoubleTools.sum(this.pout[t]) <= 0.0) break block2;
                if (!Double.isNaN(PrimitiveDoubleTools.sum(this.pout[t]))) break block3;
            }
            return false;
        }
        return true;
    }

    @Override
    public double getAlpha(int t, int s) {
        this.checkIndex(t);
        return this.alpha[t][s];
    }

    public double[] getAlpha(int t) {
        this.checkIndex(t);
        return this.alpha[t];
    }

    public double getAlphaNorm(int t) {
        this.checkIndex(t);
        return this.alphanorms[t];
    }

    @Override
    public double getBeta(int t, int s) {
        this.checkIndex(t, s);
        return this.beta[t][s];
    }

    public double[] getGamma(int t) {
        this.checkIndex(t);
        return this.gamma[t];
    }

    @Override
    public double getGamma(int t, int s) {
        this.checkIndex(t, s);
        return this.gamma[t][s];
    }

    @Override
    public int mostProbableState(int t) {
        this.checkIndex(t);
        return PrimitiveDoubleTools.maxIndex(this.gamma[t]);
    }

    public double[] getTotalStateCounts() {
        double[] res = new double[this.nstates];
        int i = 0;
        while (i < this.nstates) {
            int t = 0;
            while (t < this.length) {
                int n = i;
                res[n] = res[n] + this.gamma[t][i];
                ++t;
            }
            ++i;
        }
        return res;
    }

    @Override
    public double getPout(int t, int s) {
        this.checkIndex(t, s);
        return this.pout[t][s];
    }

    public double[] getPout(int t) {
        this.checkIndex(t);
        return this.pout[t];
    }

    public void normalizeAlpha(int t) {
        this.checkIndex(t);
        this.alphanorms[t] = PrimitiveDoubleTools.sum(this.alpha[t]);
        if (this.alphanorms[t] == 0.0 || Double.isNaN(this.alphanorms[t])) {
            throw new RuntimeException("sum of Alpha Variables 0 or NaN. Cannot estimate pathway:\nt = " + t + "\n" + "a = " + PrimitiveDoubleTools.toString(this.alpha[t], ", ") + " \n");
        }
        this.alpha[t] = PrimitiveDoubleTools.multiply(1.0 / this.alphanorms[t], this.alpha[t]);
    }

    @Override
    public double logLikelihood() {
        double L = 0.0;
        int i = 0;
        while (i < this.alphanorms.length) {
            L += Math.log(this.alphanorms[i]);
            ++i;
        }
        return L;
    }

    public void normalizeBeta(int t) {
        this.checkIndex(t);
        this.betanorms[t] = PrimitiveDoubleTools.sum(this.beta[t]);
        if (this.betanorms[t] == 0.0 || Double.isNaN(this.betanorms[t])) {
            throw new RuntimeException("sum of Beta Variables 0 or NaN. Cannot estimate pathway");
        }
        this.beta[t] = PrimitiveDoubleTools.multiply(1.0 / this.betanorms[t], this.beta[t]);
    }

    @Override
    public IDoubleArray getGammaByState(int s) {
        return Doubles.create.array(PrimitiveDoubleTools.getColumn(this.gamma, s));
    }
}

