/*
 * Decompiled with CFR 0.152.
 */
package stallone.mc.tpt;

import stallone.api.algebra.Algebra;
import stallone.api.doubles.Doubles;
import stallone.api.doubles.IDoubleArray;
import stallone.api.ints.IIntArray;
import stallone.api.ints.Ints;
import stallone.api.mc.MarkovModel;
import stallone.api.mc.tpt.ICommittor;
import stallone.mc.StationaryDistribution;

public class Committor
implements ICommittor {
    private IDoubleArray T;
    private IDoubleArray K;
    private IDoubleArray pi;
    private IIntArray A;
    private IIntArray B;
    private IIntArray AB;
    private IIntArray notAB;
    private IDoubleArray qforward = null;
    private IDoubleArray qbackward = null;

    public Committor(IDoubleArray M, IIntArray _A, IIntArray _B) {
        if (MarkovModel.util.isTransitionMatrix(M)) {
            this.setTransitionMatrix(M);
        } else if (MarkovModel.util.isTransitionMatrix(M)) {
            this.setRateMatrix(M);
        } else {
            throw new IllegalArgumentException("Trying to construct TPT with a matrix that is neither a transition nor a rate matrix");
        }
        this.A = _A;
        this.B = _B;
        this.AB = Ints.util.mergeToNew(this.A, this.B);
        this.notAB = Ints.util.removeValueToNew(Ints.create.arrayRange(M.rows()), this.AB);
    }

    public Committor(int nstates, IIntArray _A, IIntArray _B) {
        this.A = _A;
        this.B = _B;
        this.AB = Ints.util.mergeToNew(this.A, this.B);
        this.notAB = Ints.util.removeValueToNew(Ints.create.arrayRange(nstates), this.AB);
    }

    @Override
    public final void setTransitionMatrix(IDoubleArray _T) {
        this.qforward = null;
        this.qbackward = null;
        this.T = _T;
        this.K = _T.copy();
        int i = 0;
        while (i < this.T.rows()) {
            this.K.set(i, i, this.K.get(i, i) - 1.0);
            ++i;
        }
    }

    @Override
    public final void setRateMatrix(IDoubleArray _K) {
        this.qforward = null;
        this.qbackward = null;
        this.K = _K;
        this.T = _K.copy();
        int i = 0;
        while (i < this.T.rows()) {
            this.T.set(i, i, this.T.get(i, i) + 1.0);
            ++i;
        }
    }

    @Override
    public void setStationaryDistribution(IDoubleArray _pi) {
        this.qbackward = null;
        this.pi = _pi;
    }

    @Override
    public IDoubleArray forwardCommittor() {
        if (this.qforward != null) {
            return this.qforward;
        }
        IDoubleArray U = this.K.view(this.notAB.getArray(), this.notAB.getArray());
        IDoubleArray v = Doubles.create.array(this.notAB.size());
        int i = 0;
        while (i < v.size()) {
            int k = 0;
            while (k < this.B.size()) {
                v.set(i, v.get(i) - this.K.get(this.notAB.get(i), this.B.get(k)));
                ++k;
            }
            ++i;
        }
        IDoubleArray qI = Algebra.util.solve(U, v);
        this.qforward = Doubles.create.array(this.K.rows());
        int i2 = 0;
        while (i2 < this.A.size()) {
            this.qforward.set(this.A.get(i2), 0.0);
            ++i2;
        }
        i2 = 0;
        while (i2 < this.B.size()) {
            this.qforward.set(this.B.get(i2), 1.0);
            ++i2;
        }
        i2 = 0;
        while (i2 < this.notAB.size()) {
            this.qforward.set(this.notAB.get(i2), qI.get(i2));
            ++i2;
        }
        return this.qforward;
    }

    @Override
    public IDoubleArray backwardCommittor() {
        if (this.qbackward != null) {
            return this.qbackward;
        }
        if (this.pi == null) {
            this.pi = StationaryDistribution.calculate(this.T);
        }
        IDoubleArray U = Doubles.create.array(this.notAB.size(), this.notAB.size());
        int i = 0;
        while (i < U.rows()) {
            int j = 0;
            while (j < U.columns()) {
                U.set(i, j, this.pi.get(this.notAB.get(j)) * this.K.get(this.notAB.get(j), this.notAB.get(i)) / this.pi.get(this.notAB.get(i)));
                ++j;
            }
            ++i;
        }
        IDoubleArray v = Doubles.create.array(this.notAB.size());
        int i2 = 0;
        while (i2 < v.size()) {
            int k = 0;
            while (k < this.A.size()) {
                v.set(i2, v.get(i2) - this.pi.get(this.A.get(k)) * this.K.get(this.A.get(k), this.notAB.get(i2)) / this.pi.get(this.notAB.get(i2)));
                ++k;
            }
            ++i2;
        }
        IDoubleArray qI = Algebra.util.solve(U, v);
        this.qbackward = Doubles.create.array(this.K.rows());
        int i3 = 0;
        while (i3 < this.A.size()) {
            this.qbackward.set(this.A.get(i3), 1.0);
            ++i3;
        }
        i3 = 0;
        while (i3 < this.B.size()) {
            this.qbackward.set(this.B.get(i3), 0.0);
            ++i3;
        }
        i3 = 0;
        while (i3 < this.notAB.size()) {
            this.qbackward.set(this.notAB.get(i3), qI.get(i3));
            ++i3;
        }
        return this.qbackward;
    }
}

