/*
 * Decompiled with CFR 0.152.
 */
package stallone.mc.sampling;

import stallone.api.algebra.Algebra;
import stallone.api.doubles.Doubles;
import stallone.api.doubles.IDoubleArray;
import stallone.api.doubles.IDoubleIterator;
import stallone.api.mc.ITransitionMatrixSampler;
import stallone.api.mc.MarkovModel;

public abstract class TransitionMatrixSamplerAbstract
implements ITransitionMatrixSampler {
    protected IDoubleArray T;
    protected IDoubleArray C;
    protected double logLikelihood = 0.0;

    public TransitionMatrixSamplerAbstract() {
    }

    public TransitionMatrixSamplerAbstract(IDoubleArray counts) {
        this.init(counts);
    }

    public TransitionMatrixSamplerAbstract(IDoubleArray counts, IDoubleArray Tinit) {
        this.init(counts, Tinit);
    }

    @Override
    public void init(IDoubleArray _C, IDoubleArray Tinit) {
        this.C = _C;
        this.T = Tinit == null ? MarkovModel.util.estimateT(TransitionMatrixSamplerAbstract.eraseNegatives(_C)) : Tinit;
        this.logLikelihood = MarkovModel.util.logLikelihood(this.T, this.C);
    }

    @Override
    public final void init(IDoubleArray _C) {
        this.init(_C, null);
    }

    protected static IDoubleArray eraseNegatives(IDoubleArray cin) {
        IDoubleArray cout = cin.copy();
        IDoubleIterator it = cout.nonzeroIterator();
        while (it.hasNext()) {
            if (it.get() < 0.0) {
                it.set(0.0);
            }
            it.advance();
        }
        return cout;
    }

    @Override
    public IDoubleArray sample(int steps) {
        int i = 0;
        while (i < steps) {
            this.step();
            ++i;
        }
        this.logLikelihood = MarkovModel.util.logLikelihood(this.T, this.C);
        return this.T;
    }

    protected void ensureValidElement(int i, int j) {
        if (this.T.get(i, j) < 0.0) {
            this.T.set(i, j, 0.0);
        }
        if (this.T.get(i, j) > 1.0) {
            this.T.set(i, j, 1.0);
        }
    }

    protected boolean isElementValid(int i, int j) {
        if (this.T.get(i, j) < 0.0) {
            return false;
        }
        return !(this.T.get(i, j) > 1.0);
    }

    protected double computeDetailedBalanceError(IDoubleArray mu) {
        double err = 0.0;
        int i = 0;
        while (i < this.T.rows()) {
            int j = 0;
            while (j < this.T.columns()) {
                err += Math.abs(mu.get(i) * this.T.get(i, j) - mu.get(j) * this.T.get(j, i));
                ++j;
            }
            ++i;
        }
        return err;
    }

    protected void ensureValidRow(int i) {
        IDoubleArray r = this.T.viewRow(i);
        Algebra.util.scale(1.0 / Doubles.util.sum(r), r);
    }

    protected abstract boolean step();

    @Override
    public double logLikelihood() {
        return this.logLikelihood;
    }
}

