/*
 * Decompiled with CFR 0.152.
 */
package stallone.mc.estimator;

import stallone.api.algebra.Algebra;
import stallone.api.doubles.Doubles;
import stallone.api.doubles.IDoubleArray;
import stallone.api.doubles.IDoubleIterator;
import stallone.api.doubles.IDoubleList;
import stallone.api.mc.ITransitionMatrixEstimator;

public final class TransitionMatrixEstimatorRevFixPi
implements ITransitionMatrixEstimator {
    private int nIterMax = 1000000;
    private int nIterPer1 = 1000;
    private IDoubleList logliks = Doubles.create.list(this.nIterMax);
    private IDoubleArray C;
    private IDoubleArray pi = null;
    private IDoubleArray X;
    private IDoubleIterator itX;
    private boolean verbose = false;

    public TransitionMatrixEstimatorRevFixPi(IDoubleArray _C, IDoubleArray _pi) {
        this.pi = _pi;
        this.setCounts(_C);
    }

    public TransitionMatrixEstimatorRevFixPi(IDoubleArray _pi) {
        this.pi = _pi;
    }

    private void initX() {
        IDoubleArray T = this.C.copy();
        int i = 0;
        while (i < T.rows()) {
            IDoubleArray r = T.viewRow(i);
            Algebra.util.scale(Doubles.util.sum(r), r);
            ++i;
        }
        this.X = this.C.create(this.C.rows(), this.C.columns());
        IDoubleIterator it = T.nonzeroIterator();
        while (it.hasNext()) {
            int j;
            int i2 = it.row();
            if (i2 != (j = it.column())) {
                this.X.set(i2, j, 0.5 * (this.pi.get(i2) * T.get(i2, j) + this.pi.get(j) * T.get(j, i2)));
            }
            if (Double.isNaN(this.X.get(i2, j))) {
                System.out.println("NaN: " + i2 + " " + j);
            }
            it.advance();
        }
        double o = 0.0;
        int i3 = 0;
        while (i3 < this.X.rows()) {
            o = Math.max(o, Doubles.util.sum(this.X.viewRow(i3)) / this.pi.get(i3));
            ++i3;
        }
        if (o > 0.9) {
            Algebra.util.scale(0.9 / o, this.X);
        }
        i3 = 0;
        while (i3 < this.X.rows()) {
            this.X.set(i3, i3, this.pi.get(i3) - Doubles.util.sum(this.X.viewRow(i3)));
            ++i3;
        }
        this.itX = this.X.nonzeroIterator();
    }

    private double logL() {
        double ll = 0.0;
        this.itX.reset();
        while (this.itX.hasNext()) {
            int j;
            int i = this.itX.row();
            if (this.X.get(i, j = this.itX.column()) > 0.0) {
                ll += this.C.get(i, j) * Math.log(this.X.get(i, j) / this.pi.get(i));
            }
            this.itX.advance();
        }
        return ll;
    }

    private boolean isConverged() {
        if (this.logliks.size() >= this.nIterMax) {
            return true;
        }
        if (this.logliks.size() <= this.nIterPer1) {
            return false;
        }
        int i2 = this.logliks.size() - 1;
        int i1 = i2 - this.nIterPer1;
        double dL = this.logliks.get(i2) - this.logliks.get(i1);
        return dL <= 1.0;
    }

    private double dLL(int i, int j, double d) {
        double dll = this.C.get(i, i) * Math.log(this.X.get(i, i) - d) + this.C.get(j, j) * Math.log(this.X.get(j, j) - d) + (this.C.get(i, j) + this.C.get(j, i)) * Math.log(this.X.get(i, j) + d);
        return dll;
    }

    private double opt(int i, int j, double dmin, double dmax) {
        double x_ii = this.X.get(i, i);
        double x_jj = this.X.get(j, j);
        double x_ij = this.X.get(i, j);
        double c_ii = this.C.get(i, i);
        double c_jj = this.C.get(j, j);
        double c_ij = this.C.get(i, j);
        double c_ji = this.C.get(j, i);
        double E = c_ij * x_ii + c_ji * x_ii + c_jj * x_ii - c_ii * x_ij - c_jj * x_ij + c_ii * x_jj + c_ij * x_jj + c_ji * x_jj;
        double A = Math.pow(-c_ij * x_ii - c_ji * x_ii - c_jj * x_ii + c_ii * x_ij + c_jj * x_ij - c_ii * x_jj - c_ij * x_jj - c_ji * x_jj, 2.0);
        double B = 4.0 * (c_ii + c_ij + c_ji + c_jj) * (-c_jj * x_ii * x_ij + c_ij * x_ii * x_jj + c_ji * x_ii * x_jj - c_ii * x_ij * x_jj);
        double D = 2.0 * (c_ii + c_ij + c_ji + c_jj);
        double d1 = (E - Math.sqrt(A - B)) / D;
        double d2 = (E + Math.sqrt(A - B)) / D;
        double lbest = this.dLL(i, j, 0.0);
        double dbest = 0.0;
        double l = this.dLL(i, j, dmin);
        if (l > lbest) {
            lbest = l;
            dbest = dmin;
        }
        if ((l = this.dLL(i, j, dmax)) > lbest) {
            lbest = l;
            dbest = dmax;
        }
        if (d1 >= dmin && d1 <= dmax && (l = this.dLL(i, j, d1)) > lbest) {
            lbest = l;
            dbest = d1;
        }
        if (d2 >= dmin && d2 <= dmax && (l = this.dLL(i, j, d2)) > lbest) {
            lbest = l;
            dbest = d2;
        }
        return dbest;
    }

    private void optimizeElement(int i, int j) {
        double dmin = -this.X.get(i, j);
        double dmax = Math.min(this.X.get(i, i), this.X.get(j, j));
        double d = this.opt(i, j, dmin, dmax);
        this.X.set(i, i, this.X.get(i, i) - d);
        this.X.set(i, j, this.X.get(i, j) + d);
        this.X.set(j, i, this.X.get(j, i) + d);
        this.X.set(j, j, this.X.get(j, j) - d);
    }

    private void step() {
        this.itX.reset();
        while (this.itX.hasNext()) {
            int j;
            int i = this.itX.row();
            if (i < (j = this.itX.column())) {
                this.optimizeElement(i, j);
            }
            this.itX.advance();
        }
        double ll = this.logL();
        if (this.verbose) {
            System.out.println(String.valueOf(this.logliks.size() + 1) + "\t" + ll);
        }
        this.logliks.append(ll);
    }

    @Override
    public void setMaxIter(int nmax) {
        this.nIterMax = nmax;
    }

    @Override
    public void setConvergence(int niter) {
        this.nIterPer1 = niter;
    }

    @Override
    public void setCounts(IDoubleArray _C) {
        this.C = _C;
        this.initX();
        double ll = this.logL();
        this.logliks = Doubles.create.list(this.nIterMax);
        this.logliks.append(ll);
    }

    @Override
    public void estimate() {
        while (!this.isConverged()) {
            this.step();
        }
    }

    @Override
    public IDoubleArray getTransitionMatrix() {
        IDoubleArray T = this.X.create(this.X.rows(), this.X.columns());
        IDoubleIterator it = this.X.nonzeroIterator();
        while (it.hasNext()) {
            int i = it.row();
            int j = it.column();
            T.set(i, j, this.X.get(i, j) / this.pi.get(i));
            T.set(j, i, this.X.get(i, j) / this.pi.get(j));
            it.advance();
        }
        return T;
    }

    @Override
    public double[] getLikelihoodHistory() {
        return this.logliks.getArray();
    }

    public int getIterations() {
        return this.logliks.size();
    }

    public void setVerbose(boolean _verbose) {
        this.verbose = _verbose;
    }
}

