/*
 * Decompiled with CFR 0.152.
 */
package stallone.api.mc;

import java.util.ArrayList;
import java.util.List;
import stallone.api.algebra.Algebra;
import stallone.api.algebra.IEigenvalueDecomposition;
import stallone.api.doubles.Doubles;
import stallone.api.doubles.IDoubleArray;
import stallone.api.doubles.IDoubleIterator;
import stallone.api.ints.IIntArray;
import stallone.api.ints.Ints;
import stallone.api.mc.ICountMatrixEstimator;
import stallone.api.mc.IDynamicalExpectations;
import stallone.api.mc.IDynamicalExpectationsSpectral;
import stallone.api.mc.ITransitionMatrixEstimator;
import stallone.api.mc.MarkovModel;
import stallone.api.mc.tpt.ICommittor;
import stallone.cluster.MilestoningFilter;
import stallone.graph.MatrixIntGraph;
import stallone.graph.connectivity.IntStrongConnectivity;
import stallone.mc.MarkovChain;
import stallone.mc.StationaryDistribution;
import stallone.mc.estimator.TransitionMatrixLikelihood;
import stallone.mc.pcca.MetastableSubspace;
import stallone.mc.pcca.PCCA;

public class MarkovModelUtilities {
    public boolean isConnected(IDoubleArray P) {
        return this.connectedComponents(P).size() == 1;
    }

    public List<IIntArray> connectedComponents(IDoubleArray P) {
        MatrixIntGraph g = new MatrixIntGraph(P);
        IntStrongConnectivity connectivity = new IntStrongConnectivity(g);
        connectivity.perform();
        List<IIntArray> C = connectivity.getStrongComponents();
        return C;
    }

    public IIntArray giantComponent(IDoubleArray P) {
        List<IIntArray> C = this.connectedComponents(P);
        IIntArray largest = C.get(0);
        int size = C.get(0).size();
        int i = 1;
        while (i < C.size()) {
            if (C.get(i).size() > size) {
                largest = C.get(i);
                size = largest.size();
            }
            ++i;
        }
        return largest;
    }

    public IDoubleArray estimateC(Iterable<IIntArray> trajs, int lag) {
        ICountMatrixEstimator est = MarkovModel.create.createCountMatrixEstimatorSliding(trajs, lag);
        return est.estimate();
    }

    public IDoubleArray estimateCmilestoning(Iterable<IIntArray> trajs, Iterable<IIntArray> cores, int lag) {
        MilestoningFilter filter = new MilestoningFilter(cores);
        ArrayList<IIntArray> filteredList = new ArrayList<IIntArray>();
        for (IIntArray traj : trajs) {
            filteredList.add(filter.filter(traj));
        }
        return this.estimateC(filteredList, lag);
    }

    public IDoubleArray estimateCmilestoning(Iterable<IIntArray> trajs, int lag) {
        int max = 0;
        for (IIntArray traj : trajs) {
            max = Math.max(max, Ints.util.max(traj));
        }
        ArrayList<IIntArray> cores = new ArrayList<IIntArray>();
        int i = 0;
        while (i < max + 1) {
            cores.add(Ints.create.arrayFrom(i));
            ++i;
        }
        return this.estimateCmilestoning(trajs, cores, lag);
    }

    public IDoubleArray estimateC(IIntArray traj, int lag) {
        ICountMatrixEstimator est = MarkovModel.create.createCountMatrixEstimatorSliding(traj, lag);
        return est.estimate();
    }

    public IDoubleArray estimateCmilestoning(IIntArray traj, Iterable<IIntArray> cores, int lag) {
        MilestoningFilter filter = new MilestoningFilter(cores);
        IIntArray filteredTraj = filter.filter(traj);
        return this.estimateC(filteredTraj, lag);
    }

    public IDoubleArray estimateCmilestoning(IIntArray traj, int lag) {
        int max = Ints.util.max(traj);
        ArrayList<IIntArray> cores = new ArrayList<IIntArray>();
        int i = 0;
        while (i < max + 1) {
            cores.add(Ints.create.arrayFrom(i));
            ++i;
        }
        return this.estimateCmilestoning(traj, (Iterable<IIntArray>)cores, lag);
    }

    public IDoubleArray estimateCstepping(Iterable<IIntArray> trajs, int lag) {
        ICountMatrixEstimator est = MarkovModel.create.createCountMatrixEstimatorStepping(trajs, lag);
        return est.estimate();
    }

    public IDoubleArray estimateCstepping(IIntArray traj, int lag) {
        ICountMatrixEstimator est = MarkovModel.create.createCountMatrixEstimatorStepping(traj, lag);
        return est.estimate();
    }

    public double logLikelihood(IDoubleArray T, IDoubleArray C) {
        return TransitionMatrixLikelihood.logLikelihood(T, C);
    }

    public double logLikelihoodCorrelationMatrix(IDoubleArray corr, IDoubleArray C) {
        return TransitionMatrixLikelihood.logLikelihoodCorrelationMatrix(corr, C);
    }

    public boolean isTransitionMatrix(IDoubleArray T) {
        IDoubleIterator it = T.nonzeroIterator();
        while (it.hasNext()) {
            double Tij = it.get();
            if (Tij < 0.0 || Tij > 1.0) {
                System.out.println("Invalid Element: " + it.row() + ", " + it.column());
                return false;
            }
            it.advance();
        }
        int i = 0;
        while (i < T.rows()) {
            double diff = Math.abs(Doubles.util.sum(T.viewRow(i)) - 1.0);
            if (diff > 1.0E-6) {
                System.out.println("Invalid Row sum difference from 1 in row " + i + ": " + diff);
                return false;
            }
            ++i;
        }
        return true;
    }

    public boolean isRateMatrix(IDoubleArray K) {
        IDoubleIterator it = K.nonzeroIterator();
        while (it.hasNext()) {
            int i = it.row();
            int j = it.column();
            double kij = it.get();
            if (i == j && kij > 0.0) {
                return false;
            }
            if (i != j && kij < 0.0) {
                return false;
            }
            it.advance();
        }
        int i = 0;
        while (i < K.rows()) {
            if (Math.abs(Doubles.util.sum(K.viewRow(i))) > 1.0E-6) {
                return false;
            }
            ++i;
        }
        return true;
    }

    public boolean isReversible(IDoubleArray T) {
        return this.isReversible(T, this.stationaryDistribution(T));
    }

    public boolean isReversible(IDoubleArray T, IDoubleArray pi) {
        IDoubleIterator it = T.nonzeroIterator();
        while (it.hasNext()) {
            double fji;
            int i = it.row();
            int j = it.column();
            double fij = pi.get(i) * it.get();
            if (fij + (fji = pi.get(j) * T.get(j, i)) > 1.0E-10 && Math.abs((fij - fji) / (fij + fji)) > 1.0E-6) {
                return false;
            }
            it.advance();
        }
        return true;
    }

    public IDoubleArray estimateT(IDoubleArray counts) {
        ITransitionMatrixEstimator est = MarkovModel.create.createTransitionMatrixEstimatorNonrev();
        est.setCounts(counts);
        est.estimate();
        return est.getTransitionMatrix();
    }

    public IDoubleArray estimateTrev(IDoubleArray counts) {
        ITransitionMatrixEstimator est = MarkovModel.create.createTransitionMatrixEstimatorRev();
        est.setCounts(counts);
        est.estimate();
        return est.getTransitionMatrix();
    }

    public IDoubleArray estimateTrev(IDoubleArray counts, IDoubleArray piFixed) {
        ITransitionMatrixEstimator est = MarkovModel.create.createTransitionMatrixEstimatorRev(piFixed);
        est.setCounts(counts);
        est.estimate();
        return est.getTransitionMatrix();
    }

    public IDoubleArray stationaryDistribution(IDoubleArray T) {
        return StationaryDistribution.calculate(T);
    }

    public IDoubleArray stationaryDistributionRevQuick(IDoubleArray T) {
        return StationaryDistribution.calculateReversible(T);
    }

    public IDoubleArray timescales(IDoubleArray T, double tau) {
        if (!this.isTransitionMatrix(T)) {
            throw new IllegalArgumentException("Trying to calculate timescales of a matrix that is not a transition matrix");
        }
        IEigenvalueDecomposition evd = Algebra.util.evd(T);
        IDoubleArray ev = evd.getEvalNorm();
        IDoubleArray timescales = Doubles.create.array(ev.size() - 1);
        int i = 0;
        while (i < timescales.size()) {
            timescales.set(i, -tau / Math.log(ev.get(i + 1)));
            ++i;
        }
        return timescales;
    }

    public IDoubleArray timescales(Iterable<IIntArray> dtraj, ICountMatrixEstimator Cest, ITransitionMatrixEstimator Test3, int ntimescales, IIntArray lagtimes) {
        double[][] res = new double[lagtimes.size()][];
        Cest.addInput(dtraj);
        int i = 0;
        while (i < lagtimes.size()) {
            int tau = lagtimes.get(i);
            Cest.setLag(lagtimes.get(i));
            IDoubleArray C = Cest.estimate();
            Test3.setCounts(C);
            Test3.estimate();
            IDoubleArray T = Test3.getTransitionMatrix();
            res[i] = this.timescales(T, tau).getArray();
            ++i;
        }
        return Doubles.create.matrix(res);
    }

    public IIntArray metastableStates(IDoubleArray M, int nstates) {
        PCCA pcca = MarkovModel.create.createPCCA(M, nstates);
        return pcca.getClusters();
    }

    public IDoubleArray metastableMemberships(IDoubleArray M, int nstates) {
        PCCA pcca = MarkovModel.create.createPCCA(M, nstates);
        return pcca.getFuzzy();
    }

    public IDoubleArray[] coarseGrain(IDoubleArray T, int nstates) {
        MetastableSubspace ms = new MetastableSubspace(T);
        ms.coarseGrain(nstates);
        IDoubleArray[] res = new IDoubleArray[]{ms.getCoarseGrainedTransitionMatrix(), ms.getObservationProbabilities()};
        return res;
    }

    public IDoubleArray forwardCommittor(IDoubleArray M, IIntArray A, IIntArray B) {
        ICommittor comm = MarkovModel.create.createCommittor(M, A, B);
        return comm.forwardCommittor();
    }

    public IDoubleArray backwardCommittor(IDoubleArray M, IIntArray A, IIntArray B) {
        ICommittor comm = MarkovModel.create.createCommittor(M, A, B);
        return comm.backwardCommittor();
    }

    public IDoubleArray autocorrelation(IDoubleArray M, IDoubleArray observable, IDoubleArray timepoints) {
        return this.correlation(M, observable, observable, timepoints);
    }

    public IDoubleArray correlation(IDoubleArray M, IDoubleArray observable1, IDoubleArray observable2, IDoubleArray timepoints) {
        IDynamicalExpectations dexp = MarkovModel.create.createDynamicalExpectations(M);
        IDoubleArray res = Doubles.create.array(timepoints.size());
        int i = 0;
        while (i < res.size()) {
            res.set(i, dexp.calculateCorrelation(observable1, observable2, timepoints.get(i)));
            ++i;
        }
        return res;
    }

    public IDoubleArray perturbationExpectation(IDoubleArray M, IDoubleArray pi0, IDoubleArray observable, IDoubleArray timepoints) {
        IDynamicalExpectations dexp = MarkovModel.create.createDynamicalExpectations(M);
        IDoubleArray res = Doubles.create.array(timepoints.size());
        int i = 0;
        while (i < res.size()) {
            res.set(i, dexp.calculatePerturbationExpectation(pi0, observable, timepoints.get(i)));
            ++i;
        }
        return res;
    }

    public IDoubleArray fingerprintAutocorrelation(IDoubleArray M, IDoubleArray observable) {
        return this.fingerprintCorrelation(M, observable, observable);
    }

    public IDoubleArray fingerprintCorrelation(IDoubleArray M, IDoubleArray observable1, IDoubleArray observable2) {
        IDynamicalExpectationsSpectral dexp = MarkovModel.create.createDynamicalFingerprint(M);
        dexp.calculateCorrelation(observable1, observable2);
        IDoubleArray res = Doubles.util.mergeColumns(dexp.getTimescales(), dexp.getAmplitudes());
        return res;
    }

    public IDoubleArray fingerprintPerturbation(IDoubleArray M, IDoubleArray p0, IDoubleArray observable) {
        IDynamicalExpectationsSpectral dexp = MarkovModel.create.createDynamicalFingerprint(M);
        dexp.calculatePerturbationExpectation(p0, observable);
        IDoubleArray res = Doubles.util.mergeColumns(dexp.getTimescales(), dexp.getAmplitudes());
        return res;
    }

    public IIntArray trajectory(IDoubleArray T, int s, int length) {
        MarkovChain mc = new MarkovChain(T);
        mc.setStartingState(s);
        return mc.randomTrajectory(length);
    }

    public IIntArray trajectoryToState(IDoubleArray T, int s, int[] endStates) {
        MarkovChain mc = new MarkovChain(T);
        mc.setStartingState(s);
        return mc.randomTrajectoryToState(endStates);
    }
}

