/*
 * Decompiled with CFR 0.152.
 */
package stallone.hmm.pmm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import stallone.api.API;
import stallone.api.datasequence.IDataSequence;
import stallone.api.doubles.IDoubleArray;
import stallone.api.hmm.IExpectationMaximization;
import stallone.api.hmm.IHMM;
import stallone.api.hmm.IHMMParameters;
import stallone.api.hmm.ParameterEstimationException;
import stallone.api.ints.IIntArray;
import stallone.doubles.PrimitiveDoubleTools;
import stallone.hmm.pmm.DiscreteTrajectorySimpleDataSequence;

public class NinjaEstimator
implements IExpectationMaximization {
    private List<IIntArray> dtraj;
    private int tau = -1;
    private int timeshift = -1;
    private int nhidden = -1;
    private double maxHMMLInc = -0.01;
    private int nIterHMMMax = 1000;
    private IDoubleArray initChi = null;
    private IDoubleArray initTC = null;
    private IDoubleArray msmC;
    private IDoubleArray msmT;
    private IDoubleArray msmpi;
    private IDoubleArray msmPi;
    private IDoubleArray msmCorr;
    private IDoubleArray msmChi;
    private IDoubleArray msmTC;
    private IDoubleArray msmTimescales;
    private IHMM hmmEst;
    private IDoubleArray hmmChi;
    private IDoubleArray hmmTC;
    private IDoubleArray hmmpiC;
    private IDoubleArray hmmTimescales;
    private double[] hmmLikelihoodHistory;

    public NinjaEstimator(List<IIntArray> _dtraj) {
        this.dtraj = _dtraj;
    }

    public NinjaEstimator(int _tau, int _timeshift, int _nhidden) {
        this.tau = _tau;
        this.timeshift = _timeshift;
        this.nhidden = _nhidden;
    }

    public void setData(List<IIntArray> _dtraj) {
        this.dtraj = _dtraj;
    }

    public void setTau(int _tau) {
        this.tau = _tau;
    }

    public void setTimeshift(int _timeshift) {
        this.timeshift = _timeshift;
    }

    public void setNHiddenStates(int _nhidden) {
        this.nhidden = _nhidden;
    }

    public void setDiscreteTrajectory(List<IIntArray> _dtraj) {
        this.dtraj = _dtraj;
    }

    public void setNIterHMMMax(int _nIterHMMMax) {
        this.nIterHMMMax = _nIterHMMMax;
    }

    public void setHMMLikelihoodMaxIncrease(double _maxHMMLInc) {
        this.maxHMMLInc = _maxHMMLInc;
    }

    public void setInit(IDoubleArray TCInit, IDoubleArray ChiInit) {
        this.initTC = TCInit;
        this.initChi = ChiInit;
    }

    public IDoubleArray getMSMTransitionMatrix() {
        return this.msmT;
    }

    public IDoubleArray getMSMStationaryDistribution() {
        return this.msmpi;
    }

    public IDoubleArray getMSMTimescales() {
        return this.msmTimescales;
    }

    public IDoubleArray getPCCATransitionMatrix() {
        return this.msmTC;
    }

    public IDoubleArray getPCCAOutputProbabilities() {
        return this.msmChi;
    }

    public IDoubleArray getHMMTransitionMatrix() {
        return this.hmmTC;
    }

    public IDoubleArray getHMMStationaryDistribution() {
        return this.hmmpiC;
    }

    public IDoubleArray getHMMOutputProbabilities() {
        return this.hmmChi;
    }

    public IDoubleArray getHMMTimescales() {
        return this.hmmTimescales;
    }

    public double[] getHMMLikelihoodHistory() {
        return this.hmmLikelihoodHistory;
    }

    private List<IIntArray> subsamples(IIntArray dtraj, int dt) {
        int N = dtraj.size();
        ArrayList<IIntArray> res = new ArrayList<IIntArray>();
        int s = 0;
        while (s < this.tau) {
            IIntArray I = API.intsNew.arrayRange(s, N, this.tau);
            res.add(API.ints.subToNew(dtraj, I));
            s += dt;
        }
        return res;
    }

    private List<IIntArray> subsamples(List<IIntArray> dtraj, int dt) {
        ArrayList<IIntArray> res = new ArrayList<IIntArray>();
        int i = 0;
        while (i < dtraj.size()) {
            res.addAll(this.subsamples(dtraj.get(i), dt));
            ++i;
        }
        return res;
    }

    private List<IDataSequence> toObservation(List<IIntArray> dTrajectories, int nstates) {
        ArrayList<IDataSequence> obs = new ArrayList<IDataSequence>();
        int i = 0;
        while (i < dTrajectories.size()) {
            DiscreteTrajectorySimpleDataSequence seq = new DiscreteTrajectorySimpleDataSequence(dTrajectories.get(i));
            obs.add(seq);
            ++i;
        }
        return obs;
    }

    private IHMMParameters hmm(List<IDataSequence> observations, IDoubleArray TCInit, IDoubleArray ChiInit) throws ParameterEstimationException {
        System.out.println("HMM initialized with timescales: " + API.msm.timescales(TCInit, this.tau));
        IHMMParameters par0 = API.hmmNew.parameters(this.nhidden, true, true);
        IDoubleArray dense = API.doublesNew.array(TCInit.rows(), TCInit.columns());
        TCInit.copyInto(dense);
        par0.setTransitionMatrix(dense);
        int i = 0;
        while (i < this.nhidden) {
            par0.setOutputParameters(i, ChiInit.viewColumn(i));
            ++i;
        }
        int nObservableStates = ChiInit.rows();
        double[] uniformPrior = new double[nObservableStates];
        Arrays.fill(uniformPrior, 1.0 / (double)uniformPrior.length);
        IExpectationMaximization EM2 = API.hmmNew.emDiscrete(observations, par0, uniformPrior);
        EM2.setMaximumNumberOfStep(this.nIterHMMMax);
        EM2.setLikelihoodDecreaseTolerance(this.maxHMMLInc);
        System.out.println(" running hmm on " + observations.size() + " x " + observations.get(0).size() + " observations with maxIter " + this.nIterHMMMax);
        EM2.run();
        this.hmmLikelihoodHistory = EM2.getLogLikelihoodHistory();
        System.out.println(" hmm iterations: " + this.hmmLikelihoodHistory.length);
        System.out.println(" likelihood history: " + PrimitiveDoubleTools.toString(this.hmmLikelihoodHistory, "\n"));
        this.hmmEst = EM2.getHMM();
        IHMMParameters parEst = this.hmmEst.getParameters();
        return parEst;
    }

    private IDoubleArray getChi(IHMMParameters par) {
        int nObservableStates = par.getOutputParameters(0).size();
        IDoubleArray res = API.doublesNew.array(nObservableStates, this.nhidden);
        int i = 0;
        while (i < this.nhidden) {
            IDoubleArray pout = par.getOutputParameters(i);
            int k = 0;
            while (k < nObservableStates) {
                res.set(k, i, pout.get(k));
                ++k;
            }
            ++i;
        }
        return res;
    }

    public void estimateMSM() {
        this.msmC = API.msm.estimateC(this.dtraj, this.tau);
        this.msmT = API.msm.estimateTrev(this.msmC);
        this.msmpi = API.msm.stationaryDistribution(this.msmT);
        this.msmPi = API.doublesNew.diag(this.msmpi);
        this.msmCorr = API.alg.product(this.msmPi, this.msmT);
        this.msmTimescales = API.msm.timescales(this.msmT, this.tau);
        System.out.println("MSM timescales: \n" + API.doubles.subToNew(this.msmTimescales, 0, Math.min(5, this.msmTimescales.size())));
        IDoubleArray[] cg = API.msm.coarseGrain(this.msmT, this.nhidden);
        this.msmTC = cg[0];
        this.msmChi = cg[1];
        System.out.println("PCCA chi: ");
        System.out.println(this.msmChi);
        System.out.println("PCCA TC: ");
        System.out.println(this.msmTC);
        System.out.println("PCCA TC timescales: ");
        System.out.println(API.msm.timescales(this.msmTC, this.tau));
    }

    public void estimateHMM() throws ParameterEstimationException {
        List<IIntArray> dTrajectories = this.subsamples(this.dtraj, this.timeshift);
        int nstates = API.intseq.max(dTrajectories) + 1;
        List<IDataSequence> dtrajCompressed = this.toObservation(dTrajectories, nstates);
        IHMMParameters hmmParameters = this.hmm(dtrajCompressed, this.initTC, this.initChi);
        this.hmmTC = hmmParameters.getTransitionMatrix();
        this.hmmpiC = API.msm.stationaryDistribution(this.hmmTC);
        this.hmmChi = this.getChi(hmmParameters);
        this.hmmTimescales = API.msm.timescales(this.hmmTC, this.tau);
    }

    public void estimate() throws ParameterEstimationException {
        if (this.nhidden < 0) {
            throw new IllegalArgumentException("ABORTING NINJA: number of hidden states not yet set.");
        }
        if (this.tau < 0) {
            throw new IllegalArgumentException("ABORTING NINJA: tau not yet set.");
        }
        if (this.timeshift < 0) {
            throw new IllegalArgumentException("ABORTING NINJA: time shift not yet set.");
        }
        System.out.println("====================================================");
        System.out.println(" Running Ninja estimation with tau = " + this.tau);
        System.out.println("====================================================");
        System.out.println();
        System.out.println("----------------------------------------------------");
        System.out.println(" MSM reference estimation");
        this.estimateMSM();
        if (this.initChi == null && this.initTC == null) {
            this.initChi = this.msmChi;
            this.initTC = this.msmTC;
            System.out.println("HMM initialization: Using PCCA results");
        } else {
            System.out.println("HMM initialization: Using user-defined matrices:");
            System.out.println(" init chi: ");
            System.out.println(this.initChi);
            System.out.println(" init TC: ");
            System.out.println(this.initTC);
        }
        System.out.println("----------------------------------------------------");
        System.out.println();
        System.out.println("----------------------------------------------------");
        System.out.println(" HMM estimation");
        this.estimateHMM();
        System.out.println("chi: \n" + this.hmmChi);
        System.out.println("TC: \n" + this.hmmTC);
        System.out.println("HMM timescales: \n" + this.hmmTimescales);
        System.out.println("----------------------------------------------------");
        System.out.println();
    }

    @Override
    public void setInitialParameters(IHMMParameters par) {
        throw new UnsupportedOperationException("Not supported. Initialization is done by MSM");
    }

    @Override
    public void setInitialPaths(List<IIntArray> paths) {
        throw new UnsupportedOperationException("Not supported. Initialization is done by MSM");
    }

    @Override
    public void setMaximumNumberOfStep(int nsteps) {
        this.nIterHMMMax = nsteps;
    }

    @Override
    public void setLikelihoodDecreaseTolerance(double _dectol) {
        this.maxHMMLInc = _dectol;
    }

    @Override
    public double[] getLogLikelihoodHistory() {
        return this.hmmLikelihoodHistory;
    }

    @Override
    public void run() throws ParameterEstimationException {
        this.estimate();
    }

    @Override
    public IHMM getHMM() {
        return this.hmmEst;
    }
}

