/*
 * Decompiled with CFR 0.152.
 */
package stallone.mc.sampling;

import cern.jet.random.Beta;
import cern.jet.random.engine.MersenneTwister;
import stallone.api.API;
import stallone.api.doubles.IDoubleArray;
import stallone.api.doubles.IDoubleIterator;
import stallone.api.mc.IReversibleSamplingStep;
import stallone.mc.sampling.TransitionMatrixSamplingTools;
import stallone.util.MathTools;

public class Step_Rev_Row_Beta
implements IReversibleSamplingStep {
    private int n;
    private IDoubleArray C;
    private int[] dof;
    private double[] Csum;
    private IDoubleArray mu;
    private IDoubleArray u;
    private IDoubleArray T;
    private Beta[] rowDistribution;
    private double[] backupRow;
    int nprop = 0;
    int nacc = 0;

    @Override
    public void init(IDoubleArray _C, IDoubleArray _T, IDoubleArray _mu) {
        this.n = _C.rows();
        this.C = _C;
        this.T = _T;
        this.mu = _mu;
        this.backupRow = new double[this.n];
        this.u = API.doublesNew.array(this.n);
        int i = 0;
        while (i < this.u.size()) {
            this.u.set(i, -Math.log(this.mu.get(i)));
            ++i;
        }
        this.dof = new int[this.C.rows()];
        this.Csum = new double[this.C.rows()];
        IDoubleIterator it = this.C.nonzeroIterator();
        while (it.hasNext()) {
            int j;
            int i2 = it.row();
            if (this.C.get(i2, j = it.column()) >= -1.0) {
                int n = i2;
                this.dof[n] = this.dof[n] + 1;
                int n2 = i2;
                this.Csum[n2] = this.Csum[n2] + this.C.get(i2, j);
            }
            it.advance();
        }
        this.rowDistribution = new Beta[this.Csum.length];
        i = 0;
        while (i < this.Csum.length) {
            double alpha = this.Csum[i] + (double)this.dof[i] - this.C.get(i, i) - 1.0;
            double beta = this.C.get(i, i) + 1.0;
            this.rowDistribution[i] = new Beta(alpha, beta, new MersenneTwister());
            ++i;
        }
        if (!this.checkCounts()) {
            throw new IllegalArgumentException("This Matrix cannot be sampled reversibly as it has no row with positive diagonal counts and at least 2 degrees of freedom");
        }
    }

    protected final boolean checkCounts() {
        boolean valid = false;
        int i = 0;
        while (i < this.C.rows()) {
            if (this.C.get(i, i) > -1.0 && this.dof[i] >= 2) {
                valid = true;
            }
            ++i;
        }
        return valid;
    }

    private void backupRow(int row) {
        int k = 0;
        while (k < this.n) {
            this.backupRow[k] = this.T.get(row, k);
            ++k;
        }
    }

    private void restoreRow(int row) {
        int k = 0;
        while (k < this.n) {
            this.T.set(row, k, this.backupRow[k]);
            ++k;
        }
    }

    public void sampleRow(int i) {
        double x = this.rowDistribution[i].nextDouble();
        double a = x / (1.0 - this.T.get(i, i));
        this.backupRow(i);
        double sum = 0.0;
        int k = 0;
        while (k < this.T.columns()) {
            if (k != i) {
                this.T.set(i, k, this.T.get(i, k) * a);
                sum += this.T.get(i, k);
            }
            ++k;
        }
        this.T.set(i, i, 1.0 - sum);
        if (TransitionMatrixSamplingTools.isRowIn01(this.T, i)) {
            this.u.set(i, this.u.get(i) + Math.log(a));
            this.mu.set(i, Math.exp(-this.u.get(i)));
            if (Math.abs(API.doubles.min(this.u)) > 1.0) {
                API.alg.addTo(this.u, -API.doubles.min(this.u));
                k = 0;
                while (k < this.T.columns()) {
                    this.mu.set(k, Math.exp(-this.u.get(k)));
                    ++k;
                }
            }
        } else {
            this.restoreRow(i);
        }
    }

    @Override
    public boolean step() {
        int i = MathTools.randomInt(0, this.T.rows());
        while (this.C.get(i, i) <= 0.0 || this.dof[i] < 2) {
            i = MathTools.randomInt(0, this.T.rows());
        }
        this.sampleRow(i);
        ++this.nprop;
        ++this.nacc;
        return true;
    }

    public int[] getStepCount() {
        int[] count = new int[]{this.nprop, this.nacc};
        return count;
    }
}

