/*
 * Decompiled with CFR 0.152.
 */
package stallone.coordinates;

import java.io.FileNotFoundException;
import stallone.api.API;
import stallone.api.algebra.IEigenvalueDecomposition;
import stallone.api.coordinates.ITICA;
import stallone.api.datasequence.IDataInput;
import stallone.api.datasequence.IDataSequence;
import stallone.api.doubles.IDoubleArray;
import stallone.api.dynamics.IIntegratorThermostatted;
import stallone.api.potential.IEnergyModel;
import stallone.stat.RunningMomentsMultivariate;

public class TICA
implements ITICA {
    private int lag;
    RunningMomentsMultivariate moments;
    private int dimIn;
    private IDoubleArray CovTauSym;
    private IDoubleArray evalTICA;
    private IDoubleArray evecTICA;
    private int dimOut;

    public TICA(IDataInput _source, int _lag) {
        this.lag = _lag;
        this.init(_source.dimension());
        for (IDataSequence seq : _source.sequences()) {
            this.addData(seq);
        }
        this.computeTransform();
    }

    public TICA(IDataSequence _source, int _lag) {
        this.lag = _lag;
        this.init(_source.dimension());
        this.addData(_source);
        this.computeTransform();
    }

    public TICA(int _lag) {
        this.lag = _lag;
    }

    private final void init(int _dimIn) {
        this.dimIn = _dimIn;
        if (this.dimOut == 0) {
            this.dimOut = _dimIn;
        }
        this.moments = API.statNew.runningMomentsMultivar(this.dimIn, this.lag);
    }

    @Override
    public final void addData(IDataSequence data) {
        if (this.dimIn == 0) {
            this.init(data.dimension());
        }
        this.moments.addData(data);
    }

    @Override
    public final void computeTransform() {
        IDoubleArray Cov = this.moments.getCov();
        IEigenvalueDecomposition evd = API.alg.evd(Cov);
        IDoubleArray evalPCA = evd.getEvalNorm();
        IDoubleArray evecPCA = evd.getRightEigenvectorMatrix().viewReal();
        IDoubleArray S = API.doublesNew.array(evalPCA.size());
        int i = 0;
        while (i < S.size()) {
            S.set(i, 1.0 * Math.sqrt(evalPCA.get(i)));
            ++i;
        }
        IDoubleArray evecPCAscaled = API.alg.product(evecPCA, API.doublesNew.diag(S));
        this.CovTauSym = this.moments.getCovLagged();
        this.CovTauSym = API.alg.addWeightedToNew(0.5, this.CovTauSym, 0.5, API.alg.transposeToNew(this.CovTauSym));
        IDoubleArray pcCovTau = API.alg.product(API.alg.product(API.alg.transposeToNew(evecPCAscaled), this.CovTauSym), evecPCAscaled);
        IEigenvalueDecomposition evd2 = API.alg.evd(pcCovTau);
        this.evalTICA = evd2.getEvalNorm();
        this.evecTICA = API.alg.product(evecPCAscaled, evd2.getRightEigenvectorMatrix().viewReal());
    }

    @Override
    public IDoubleArray getMeanVector() {
        return this.moments.getMean();
    }

    @Override
    public IDoubleArray getCovarianceMatrix() {
        return this.moments.getCov();
    }

    @Override
    public IDoubleArray getCovarianceMatrixLagged() {
        return this.CovTauSym;
    }

    @Override
    public void setDimension(int d) {
        this.dimOut = d;
    }

    @Override
    public IDoubleArray getEigenvalues() {
        return this.evalTICA;
    }

    @Override
    public IDoubleArray getEigenvector(int i) {
        return this.evecTICA.viewColumn(i);
    }

    @Override
    public IDoubleArray getEigenvectorMatrix() {
        return this.evecTICA;
    }

    @Override
    public IDoubleArray transform(IDoubleArray x) {
        IDoubleArray out = API.doublesNew.array(this.dimOut);
        this.transform(x, out);
        return out;
    }

    @Override
    public void transform(IDoubleArray in, IDoubleArray out) {
        IDoubleArray x;
        if (in.columns() != 1) {
            in = API.doublesNew.array(in.getArray());
        }
        if ((x = API.alg.subtract(in, this.moments.getMean())).rows() > 1) {
            x = API.alg.transposeToNew(x);
        }
        IDoubleArray y = API.alg.product(x, this.evecTICA);
        int d = Math.min(in.size(), out.size());
        int i = 0;
        while (i < d) {
            out.set(i, y.get(i));
            ++i;
        }
    }

    @Override
    public int dimension() {
        return this.dimOut;
    }

    public static void main(String[] args) throws FileNotFoundException {
        IEnergyModel pot = API.potNew.multivariateFromExpression(new String[]{"x", "y"}, "1/4 x^4 - 1/2 x^2 + 1/2 y^2", "x^3-x", "y");
        IDoubleArray masses = API.doublesNew.arrayFrom(1.0, 1.0);
        double dt = 0.1;
        double gamma = 1.0;
        double kT = 0.2;
        IIntegratorThermostatted langevin = API.dynNew.langevinLeapFrog(pot, masses, 0.1, gamma, kT);
        IDoubleArray x0 = API.doublesNew.arrayFrom(0.0, 0.0);
        int nsteps = 100000;
        int nsave = 10;
        IDataSequence seq = API.dyn.run(x0, langevin, nsteps, nsave);
        int lag = 1;
        TICA tica = new TICA(lag);
        tica.addData(seq);
        tica.computeTransform();
        System.out.println("mean: \t" + API.doubles.toString(tica.getMeanVector(), "\t"));
        System.out.println("cov: \t" + API.doubles.toString(tica.getCovarianceMatrix(), "\t", "\n"));
        System.out.println("covTau: \t" + API.doubles.toString(tica.getCovarianceMatrixLagged(), "\t", "\n"));
        System.out.println();
        System.out.println("eval: \t" + API.doubles.toString(tica.getEigenvalues(), "\t"));
        System.out.println("evec1: \t" + API.doubles.toString(tica.getEigenvector(0), "\t"));
        System.out.println("evec2: \t" + API.doubles.toString(tica.getEigenvector(1), "\t"));
        tica.setDimension(1);
        IDoubleArray y1 = tica.transform(API.doublesNew.arrayFrom(2.0, 2.0));
        System.out.println("y1 = \t" + API.doubles.toString(y1, "\t"));
        IDoubleArray y2 = tica.transform(API.doublesNew.arrayFrom(4.0, 4.0));
        System.out.println("y2 = \t" + API.doubles.toString(y2, "\t"));
    }
}

