/*
 * Decompiled with CFR 0.152.
 */
package net.maizegenetics.stats.EMMA;

import java.util.ArrayList;
import java.util.Arrays;
import net.maizegenetics.matrixalgebra.Matrix.DoubleMatrix;
import net.maizegenetics.matrixalgebra.Matrix.DoubleMatrixFactory;
import net.maizegenetics.matrixalgebra.decomposition.EigenvalueDecomposition;
import net.maizegenetics.stats.linearmodels.LinearModelUtils;
import org.apache.log4j.Logger;

public class EMMAforDoubleMatrix {
    private static final Logger myLogger = Logger.getLogger(EMMAforDoubleMatrix.class);
    protected DoubleMatrix y;
    protected double[] lambda;
    protected double[] eta2;
    protected double c;
    protected int N;
    protected int q;
    protected int Nran;
    protected int dfMarker = 0;
    protected DoubleMatrix X;
    protected DoubleMatrix Z = null;
    protected DoubleMatrix K;
    protected EigenvalueDecomposition eig;
    protected EigenvalueDecomposition eigA;
    protected DoubleMatrix U;
    protected DoubleMatrix invH;
    protected DoubleMatrix invXHX;
    protected DoubleMatrix beta;
    protected DoubleMatrix Xbeta;
    protected double ssModel;
    protected double ssError;
    protected double SST;
    protected double Rsq;
    protected int dfModel;
    protected int dfError;
    protected double delta;
    protected double varResidual;
    protected double varRandomEffect;
    protected DoubleMatrix blup;
    protected DoubleMatrix pred;
    protected DoubleMatrix res;
    protected double lnLikelihood;
    protected boolean findDelta = true;
    protected double lowerlimit = 1.0E-5;
    protected double upperlimit = 100000.0;
    protected int nregions = 100;
    protected double convergence = 1.0E-10;
    protected int maxiter = 50;
    protected int subintervalCount = 0;

    public EMMAforDoubleMatrix(DoubleMatrix y, DoubleMatrix fixed, DoubleMatrix kin, int nAlleles) {
        this(y, fixed, kin, nAlleles, Double.NaN);
    }

    public EMMAforDoubleMatrix(DoubleMatrix data, DoubleMatrix fixed, DoubleMatrix kin, int nAlleles, double delta) {
        this.dfModel = fixed.numberOfColumns();
        int rank = fixed.columnRank();
        if (rank < this.dfModel) {
            throw new IllegalArgumentException("The fixed effect design matrix has less than full column rank. The analysis will not be run.");
        }
        if (!Double.isNaN(delta)) {
            this.delta = delta;
            this.findDelta = false;
        }
        this.y = data;
        if (this.y.numberOfColumns() > 1 && this.y.numberOfRows() == 1) {
            this.y = this.y.transpose();
        }
        this.N = this.y.numberOfRows();
        this.X = fixed;
        this.q = this.X.numberOfColumns();
        this.K = kin;
        this.Nran = this.K.numberOfRows();
        this.Z = DoubleMatrixFactory.DEFAULT.identity(this.Nran);
        this.dfMarker = nAlleles - 1;
        this.init();
    }

    public EMMAforDoubleMatrix(DoubleMatrix data, DoubleMatrix fixed, DoubleMatrix kin, DoubleMatrix inZ, int nAlleles, double delta) {
        this.dfModel = fixed.numberOfColumns();
        int rank = fixed.columnRank();
        if (rank < this.dfModel) {
            throw new IllegalArgumentException("The fixed effect design matrix has less than full column rank. The analysis will not be run.");
        }
        if (!Double.isNaN(delta)) {
            this.delta = delta;
            this.findDelta = false;
        }
        this.y = data;
        if (this.y.numberOfColumns() > 1 && this.y.numberOfRows() == 1) {
            this.y = this.y.transpose();
        }
        this.N = this.y.numberOfRows();
        this.X = fixed;
        this.q = this.X.numberOfColumns();
        this.Z = inZ;
        this.K = kin;
        this.Nran = this.Z.numberOfRows();
        this.dfMarker = nAlleles - 1;
        this.init();
    }

    protected void init() {
        int nreml = this.N - this.q;
        this.c = (double)nreml * Math.log((double)(nreml / 2) / Math.PI) - (double)nreml;
        this.lambda = new double[nreml];
        DoubleMatrix A = this.Z.mult(this.K).tcrossproduct(this.Z);
        this.eigA = A.getEigenvalueDecomposition();
        double[] eigenvalA = this.eigA.getEigenvalues();
        int n = eigenvalA.length;
        double min = eigenvalA[0];
        for (int i = 1; i < n; ++i) {
            min = Math.min(min, eigenvalA[i]);
        }
        double bend = 0.0;
        if (min < 0.01) {
            bend = -1.0 * min + 0.5;
        }
        DoubleMatrix[] XtXGM = this.X.getXtXGM();
        DoubleMatrix XtX = XtXGM[0];
        DoubleMatrix S = XtXGM[2];
        DoubleMatrix G = XtXGM[1];
        n = A.numberOfRows();
        for (int i = 0; i < n; ++i) {
            A.set(i, i, A.get(i, i) + bend);
        }
        DoubleMatrix SAS = S.mult(A.mult(S));
        this.eig = SAS.getEigenvalueDecomposition();
        double[] eigenval = this.eig.getEigenvalues();
        int[] ndx = this.getSortedIndexofAbsoluteValues(eigenval);
        int[] eigndx = new int[nreml];
        for (int i = 0; i < nreml; ++i) {
            eigndx[i] = ndx[i];
        }
        DoubleMatrix V = this.eig.getEigenvectors();
        this.U = V.getSelection(null, ndx);
        for (int i = 0; i < nreml; ++i) {
            this.lambda[i] = eigenval[eigndx[i]] - bend;
        }
    }

    private int[] getSortedIndexofAbsoluteValues(double[] values) {
        int i;
        int n = values.length;
        int[] index = new int[n];
        class Pair
        implements Comparable<Pair> {
            int order;
            double absvalue;

            Pair(int order, double value) {
                this.order = order;
                this.absvalue = Math.abs(value);
            }

            @Override
            public int compareTo(Pair other) {
                if (this.absvalue < other.absvalue) {
                    return 1;
                }
                if (this.absvalue > other.absvalue) {
                    return -1;
                }
                return 0;
            }
        }
        Object[] valuePairs = new Pair[n];
        for (i = 0; i < n; ++i) {
            valuePairs[i] = new Pair(i, values[i]);
        }
        Arrays.sort(valuePairs);
        for (i = 0; i < n; ++i) {
            index[i] = ((Pair)valuePairs[i]).order;
        }
        return index;
    }

    public void solve() {
        DoubleMatrix eta = this.U.crossproduct(this.y);
        int nrows = eta.numberOfRows();
        this.eta2 = new double[nrows];
        for (int i = 0; i < nrows; ++i) {
            this.eta2[i] = eta.get(i, 0) * eta.get(i, 0);
        }
        if (this.findDelta) {
            double[] interval = new double[]{this.lowerlimit, this.upperlimit};
            this.delta = this.findDeltaInInterval(interval);
        }
        this.lnLikelihood = this.lnlk(this.delta);
        this.invH = this.inverseH(this.delta);
        this.beta = this.calculateBeta();
        double genvar = this.getGenvar(this.beta);
        this.dfModel = this.q - 1;
        this.dfError = this.N - this.q;
        this.varResidual = genvar * this.delta;
        this.varRandomEffect = genvar;
    }

    public void calculateBlupsPredictedResiduals() {
        this.blup = this.calculateBLUP();
        this.pred = this.calculatePred();
        this.res = this.calculateRes();
    }

    private double findDeltaInInterval(double[] interval) {
        double[][] d = this.scanlnlk(interval[0], interval[1]);
        double[][] sgnchange = this.findSignChanges(d);
        int nchanges = sgnchange.length;
        double[] bestd = new double[]{Double.NaN, Double.NaN, Double.NaN};
        int n = d.length;
        for (int i = 0; i < n; ++i) {
            if (Double.isNaN(bestd[1])) {
                bestd = d[i];
                continue;
            }
            if (Double.isNaN(d[i][1]) || !(d[i][1] > bestd[1])) continue;
            bestd = d[i];
        }
        double bestdelta = bestd[0];
        double lkDelta = bestd[1];
        for (int i = 0; i < nchanges; ++i) {
            double newlk;
            double newdelta = this.findMaximum(sgnchange[i]);
            if (Double.isNaN(newdelta) || Double.isNaN(newlk = this.lnlk(newdelta)) || !(newlk > lkDelta)) continue;
            bestdelta = newdelta;
            lkDelta = newlk;
        }
        return bestdelta;
    }

    private double lnlk(double delta) {
        double term1 = 0.0;
        double term2 = 0.0;
        int n = this.N - this.q;
        for (int i = 0; i < n; ++i) {
            double val = this.lambda[i] + delta;
            if (val < 0.0) {
                return Double.NaN;
            }
            term1 += this.eta2[i] / val;
            term2 += Math.log(val);
        }
        return (this.c - (double)n * Math.log(term1) - term2) / 2.0;
    }

    private double d1lnlk(double delta) {
        double term1 = 0.0;
        double term2 = 0.0;
        double term3 = 0.0;
        int n = this.N - this.q;
        for (int i = 0; i < n; ++i) {
            double val = 1.0 / (this.lambda[i] + delta);
            double val2 = this.eta2[i] * val;
            term1 += val2;
            term2 += val2 * val;
            term3 += val;
        }
        return (double)n * term2 / term1 / 2.0 - term3 / 2.0;
    }

    private double[][] scanlnlk(double lower, double upper) {
        double[][] result = new double[this.nregions][3];
        upper = Math.log10(upper);
        lower = Math.log10(lower);
        double incr = (upper - lower) / (double)(this.nregions - 1);
        for (int i = 0; i < this.nregions; ++i) {
            double delta;
            result[i][0] = delta = Math.pow(10.0, lower + (double)i * incr);
            result[i][1] = this.lnlk(delta);
            result[i][2] = this.d1lnlk(delta);
        }
        return result;
    }

    private double[][] findSignChanges(double[][] scan) {
        ArrayList<Double[]> changes = new ArrayList<Double[]>();
        int n = scan.length;
        for (int i = 0; i < n - 1; ++i) {
            if (!(scan[i][2] > 0.0) || !(scan[i + 1][2] <= 0.0) || Double.isNaN(scan[i][1])) continue;
            changes.add(new Double[]{scan[i][0], scan[i + 1][0]});
        }
        n = changes.size();
        double[][] result = new double[n][2];
        for (int i = 0; i < n; ++i) {
            result[i][0] = ((Double[])changes.get(i))[0];
            result[i][1] = ((Double[])changes.get(i))[1];
        }
        return result;
    }

    private double findMaximum(double[] interval) {
        double delta = interval[0];
        boolean end = false;
        int n = this.N - this.q;
        for (int nIterations = 0; !end && nIterations < this.maxiter; ++nIterations) {
            double A = 0.0;
            double B = 0.0;
            double C = 0.0;
            double D = 0.0;
            double E = 0.0;
            for (int i = 0; i < n; ++i) {
                double val = this.lambda[i] + delta;
                double val2 = val * val;
                double val3 = val2 * val;
                A += this.eta2[i] / val;
                B += this.eta2[i] / val2;
                C += this.eta2[i] / val3;
                D += 1.0 / val;
                E += 1.0 / val2;
            }
            double d1 = (double)n * B / A - D;
            if (Math.abs(d1) < this.convergence) {
                end = true;
            } else {
                double d2 = E + (double)n * (B * B - 2.0 * A * C) / A / A;
                delta -= d1 / d2;
            }
            if (!(delta < interval[0]) && !(delta > interval[1])) continue;
            ++this.subintervalCount;
            if (this.subintervalCount > 3) {
                this.subintervalCount = 0;
                return Double.NaN;
            }
            delta = this.findDeltaInInterval(interval);
            end = true;
        }
        this.subintervalCount = 0;
        return delta;
    }

    private DoubleMatrix inverseH(double delta) {
        DoubleMatrix V = this.eigA.getEigenvectors();
        DoubleMatrix D = this.eigA.getEigenvalueMatrix();
        int n = D.numberOfRows();
        for (int i = 0; i < n; ++i) {
            D.set(i, i, 1.0 / (D.get(i, i) + delta));
        }
        return V.mult(D.tcrossproduct(V));
    }

    private DoubleMatrix calculateBeta() {
        DoubleMatrix XtH = this.X.crossproduct(this.invH);
        this.invXHX = XtH.mult(this.X).inverse();
        return this.invXHX.mult(XtH.mult(this.y));
    }

    private DoubleMatrix calculateBLUP() {
        this.Xbeta = this.X.mult(this.beta);
        DoubleMatrix YminusXbeta = this.y.minus(this.Xbeta);
        DoubleMatrix KtransZ = this.K.mult(this.Z.transpose());
        DoubleMatrix KtransZinvH = KtransZ.mult(this.invH);
        return KtransZinvH.mult(YminusXbeta);
    }

    private DoubleMatrix calculatePred() {
        this.Xbeta = this.X.mult(this.beta);
        DoubleMatrix Zu = this.Z.mult(this.blup);
        return this.Xbeta.plus(Zu);
    }

    private DoubleMatrix calculateRes() {
        return this.y.minus(this.pred);
    }

    private double getGenvar(DoubleMatrix beta) {
        DoubleMatrix res = this.y.copy();
        res.minusEquals(this.X.mult(beta));
        return res.crossproduct(this.invH.mult(res)).get(0, 0) / (double)(this.N - this.q);
    }

    public int getDfMarker() {
        return this.dfMarker;
    }

    public DoubleMatrix getBeta() {
        return this.beta;
    }

    public int getDfModel() {
        return this.dfModel;
    }

    public int getDfError() {
        return this.dfError;
    }

    public double getDelta() {
        return this.delta;
    }

    public DoubleMatrix getInvH() {
        return this.invH;
    }

    public double getVarRes() {
        return this.varResidual;
    }

    public double getVarRan() {
        return this.varRandomEffect;
    }

    public DoubleMatrix getBlup() {
        return this.blup;
    }

    public DoubleMatrix getPred() {
        return this.pred;
    }

    public DoubleMatrix getRes() {
        return this.res;
    }

    public double getLnLikelihood() {
        return this.lnLikelihood;
    }

    public double[] getMarkerFp() {
        double p;
        if (this.dfMarker < 1) {
            return new double[]{Double.NaN, Double.NaN, Double.NaN};
        }
        int nparm = this.beta.numberOfRows();
        int firstmarker = nparm - this.dfMarker;
        DoubleMatrix M = DoubleMatrixFactory.DEFAULT.make(this.dfMarker, nparm);
        for (int i = 0; i < this.dfMarker; ++i) {
            M.set(i, i + firstmarker, 1.0);
        }
        DoubleMatrix MB = M.mult(this.beta);
        DoubleMatrix invMiM = M.mult(this.invXHX.tcrossproduct(M));
        invMiM.invert();
        double F = MB.crossproduct(invMiM.mult(MB)).get(0, 0);
        F /= this.varRandomEffect;
        F /= (double)this.dfMarker;
        try {
            p = LinearModelUtils.Ftest(F, this.dfMarker, this.N - this.q);
        }
        catch (Exception e) {
            p = Double.NaN;
        }
        return new double[]{F, p};
    }

    public void solveWithNewData(DoubleMatrix y) {
        this.y = y;
        this.solve();
    }
}

