/*
 * Decompiled with CFR 0.152.
 */
package stallone.hmm.pmm;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import stallone.api.API;
import stallone.api.cluster.IClustering;
import stallone.api.datasequence.IDataInput;
import stallone.api.datasequence.IDataSequence;
import stallone.api.datasequence.IDataSequenceLoader;
import stallone.api.doubles.IDoubleArray;
import stallone.api.doubles.IMetric;
import stallone.api.hmm.ParameterEstimationException;
import stallone.api.ints.IIntArray;
import stallone.api.ints.IIntList;
import stallone.coordinates.MinimalRMSDistance3D;
import stallone.doubles.EuclideanDistance;
import stallone.doubles.PrimitiveDoubleTools;
import stallone.hmm.pmm.MultiClusteringSplitMerge;
import stallone.hmm.pmm.NinjaEstimator;
import stallone.util.CommandLineParser;

public class AdaptiveDiscretization {
    private List<String> inputFiles;
    private IDataInput data;
    private int ninitclusters = 0;
    private IDataSequence initcenters;
    private IMetric metric;
    private int nsplit;
    private int tau = 1;
    private int timeshift = 1;
    private int nhidden;
    private int maxrefinementsteps = 100;
    private double requestedError = 0.01;
    private String outdir;
    private MultiClusteringSplitMerge mc;
    NinjaEstimator ninja;
    IDoubleArray msmC;
    IDoubleArray msmT;
    IDoubleArray msmpi;
    IDoubleArray msmPi;
    IDoubleArray msmCorr;
    IDoubleArray msmChi;
    IDoubleArray msmTC;
    IDoubleArray msmTimescales;
    IDoubleArray hmmChi;
    IDoubleArray hmmTC;
    IDoubleArray hmmpiC;
    IDoubleArray hmmTimescales;
    IDoubleArray errors;
    double etot;

    private IDoubleArray getStateMixing(IDoubleArray chi, IDoubleArray piObs, IDoubleArray piHidden) {
        IDoubleArray mix = chi.copy();
        int i = 0;
        while (i < chi.rows()) {
            int j = 0;
            while (j < chi.columns()) {
                mix.set(i, j, chi.get(i, j) * piHidden.get(j) / piObs.get(i));
                ++j;
            }
            ++i;
        }
        API.alg.normalizeRows(mix, 1);
        return mix;
    }

    private IIntArray statesWithOverlap(IDoubleArray mix, double requestedQuality) {
        IIntList res = API.intsNew.list(0);
        int i = 0;
        while (i < mix.rows()) {
            IDoubleArray pfrom = mix.viewRow(i);
            if (API.doubles.max(pfrom) < requestedQuality) {
                res.append(i);
            }
            ++i;
        }
        return res;
    }

    private IIntArray selectStatesToSplit(IDoubleArray errors, double etot) {
        IIntList splitStates = API.intsNew.list(0);
        if (etot <= this.requestedError) {
            return splitStates;
        }
        int nselect = Math.max(1, (int)Math.sqrt(errors.size()));
        IIntArray sortedIndexes = API.doubles.sortedIndexes(errors);
        return API.ints.subToNew(sortedIndexes, sortedIndexes.size() - nselect, sortedIndexes.size());
    }

    private void printStateQualities(IDoubleArray mix, IDoubleArray pi) {
        System.out.println("State qualities");
        int i = 0;
        while (i < mix.rows()) {
            IDoubleArray pfrom = mix.viewRow(i);
            double q = API.doubles.max(pfrom);
            System.out.println(String.valueOf(i) + "\t" + q + "\t" + pi.get(i));
            ++i;
        }
    }

    private void estimate() throws ParameterEstimationException {
        IIntArray dtraj = this.mc.getCurrentDiscreteTrajectory();
        ArrayList<IIntArray> dtrajs = new ArrayList<IIntArray>();
        dtrajs.add(dtraj);
        this.ninja = new NinjaEstimator(dtrajs);
        this.ninja.setNHiddenStates(this.nhidden);
        this.ninja.setTau(this.tau);
        this.ninja.setTimeshift(this.timeshift);
        this.ninja.estimate();
        this.msmT = this.ninja.getMSMTransitionMatrix();
        this.msmpi = this.ninja.getMSMStationaryDistribution();
        this.msmPi = API.doublesNew.diag(this.msmpi);
        this.msmCorr = API.alg.product(this.msmPi, this.msmT);
        this.msmTimescales = this.ninja.getMSMTimescales();
        this.hmmTC = this.ninja.getHMMTransitionMatrix();
        this.hmmpiC = this.ninja.getHMMStationaryDistribution();
        this.hmmChi = this.ninja.getHMMOutputProbabilities();
        this.hmmTimescales = this.ninja.getHMMTimescales();
    }

    private void estimateErrorByPureness() {
        IDoubleArray chinorm = this.hmmChi.copy();
        API.alg.normalizeRows(chinorm, 1);
        this.errors = API.doublesNew.array(chinorm.rows());
        int i = 0;
        while (i < chinorm.rows()) {
            IDoubleArray row = chinorm.viewRow(i);
            int imax = API.doubles.maxIndex(row);
            row.set(imax, row.get(imax) - 1.0);
            this.errors.set(i, this.msmPi.get(i) * API.alg.norm(row));
            ++i;
        }
        this.etot = API.doubles.sum(this.errors);
    }

    private void estimateErrorByDetectability() {
        this.errors = API.doublesNew.array(this.hmmChi.rows());
        double qtot = 0.0;
        int i = 0;
        while (i < this.errors.size()) {
            double q1 = 0.0;
            double q2 = 0.0;
            int j = 0;
            while (j < this.hmmChi.columns()) {
                q1 += Math.pow(this.hmmpiC.get(j) * this.hmmChi.get(i, j), 2.0);
                q2 += this.hmmpiC.get(j) * this.hmmChi.get(i, j);
                ++j;
            }
            qtot += q1 / q2;
            this.errors.set(i, this.msmpi.get(i) * (1.0 - q1 / (q2 * this.msmpi.get(i))));
            ++i;
        }
        this.etot = 1.0 - qtot;
    }

    private void estimateError() {
        this.estimateErrorByDetectability();
        int i = 0;
        while (i < this.errors.size()) {
            System.out.println(String.valueOf(i) + "\t" + this.msmpi.get(i) + "\t" + this.errors.get(i));
            ++i;
        }
        System.out.println("n = " + this.msmpi.size() + "\t etot = " + this.etot);
    }

    private boolean split() {
        IIntArray badStates = this.selectStatesToSplit(this.errors, this.etot);
        if (badStates.size() == 0) {
            System.out.println("No overlapping states left. DONE!");
            return false;
        }
        System.out.println("splitting states: " + badStates);
        System.out.println();
        boolean couldsplit = this.mc.considerSplit(badStates);
        if (!couldsplit) {
            System.out.println("A split was requested but couldn not be executed. DONE!");
            return false;
        }
        this.mc.accept();
        return true;
    }

    private ArrayList<IIntArray> mergeGroupsByPurity() {
        IIntArray mergeCandidates = API.intsNew.arrayRange(this.hmmChi.rows());
        IDoubleArray pComeFrom = this.hmmChi.copy();
        API.alg.normalizeRows(pComeFrom, 1);
        System.out.println(" COME-FROM array:");
        API.doubles.print(pComeFrom, "\t", "\n");
        ArrayList<IIntArray> groups = new ArrayList<IIntArray>();
        int i = 0;
        while (i < this.nhidden) {
            groups.add(API.intsNew.list(0));
            ++i;
        }
        i = 0;
        while (i < mergeCandidates.size()) {
            int s = mergeCandidates.get(i);
            int g = PrimitiveDoubleTools.maxIndex(pComeFrom.getRow(s));
            if (pComeFrom.get(i, g) > 0.99) {
                ((IIntList)groups.get(g)).append(s);
            }
            ++i;
        }
        i = groups.size() - 1;
        while (i >= 0) {
            if (groups.get(i).size() <= 1) {
                System.out.println("removing merge group " + i + " with size " + groups.get(i).size());
                groups.remove(i);
            } else {
                int j = 0;
                while (j < groups.size()) {
                    IIntArray group = groups.get(j);
                    System.out.println("- merge group " + j + ": " + API.ints.toString(group, "", ","));
                    IDoubleArray groupChi = pComeFrom.view(group.getArray(), API.intsNew.arrayRange(0, this.nhidden).getArray());
                    API.doubles.print(groupChi, "\t", "\n");
                    System.out.println();
                    ++j;
                }
            }
            --i;
        }
        return groups;
    }

    private ArrayList<IIntArray> mergeGroupsBySimilarity() {
        double maxdist = 0.01;
        IIntArray mergeCandidates = API.intsNew.arrayRange(this.hmmChi.rows());
        IDoubleArray pComeFrom = this.hmmChi.copy();
        API.alg.normalizeRows(pComeFrom, 1);
        System.out.println(" COME-FROM array:");
        API.doubles.print(pComeFrom, "\t", "\n");
        ArrayList<IIntList> groups = new ArrayList<IIntList>();
        groups.add(API.intsNew.listFrom(mergeCandidates.get(0)));
        int i = 1;
        while (i < mergeCandidates.size()) {
            int s = mergeCandidates.get(i);
            IDoubleArray scf = pComeFrom.viewRow(s);
            double[] distances = new double[groups.size()];
            int j = 0;
            while (j < distances.length) {
                IDoubleArray ocf = pComeFrom.viewRow(((IIntList)groups.get(j)).get(0));
                distances[j] = API.alg.norm(API.alg.subtract(scf, ocf));
                System.out.println(String.valueOf(i) + " " + j + ":");
                System.out.println(scf);
                System.out.println(ocf);
                System.out.println(distances[j]);
                ++j;
            }
            if (PrimitiveDoubleTools.min(distances) < maxdist) {
                ((IIntList)groups.get(PrimitiveDoubleTools.minIndex(distances))).append(s);
            } else {
                groups.add(API.intsNew.listFrom(s));
            }
            ++i;
        }
        i = groups.size() - 1;
        while (i >= 0) {
            if (((IIntList)groups.get(i)).size() <= 1) {
                System.out.println("removing merge group " + i + " with size " + ((IIntList)groups.get(i)).size());
                groups.remove(i);
            } else {
                int j = 0;
                while (j < groups.size()) {
                    IIntArray group = (IIntArray)groups.get(j);
                    System.out.println("- merge group " + j + ": " + API.ints.toString(group, "", ","));
                    IDoubleArray groupChi = pComeFrom.view(group.getArray(), API.intsNew.arrayRange(0, this.nhidden).getArray());
                    API.doubles.print(groupChi, "\t", "\n");
                    System.out.println();
                    ++j;
                }
            }
            --i;
        }
        ArrayList<IIntArray> res = new ArrayList<IIntArray>();
        res.addAll(groups);
        return res;
    }

    private void merge() {
        System.out.println("\nMERGE step");
        ArrayList<IIntArray> groups = this.mergeGroupsBySimilarity();
        if (groups.isEmpty()) {
            System.out.println("MERGE: Nothing to do!\n");
            return;
        }
        this.mc.considerMerge(groups);
        this.mc.accept();
    }

    public boolean parseArguments(String[] args) throws FileNotFoundException, IOException {
        CommandLineParser parser = new CommandLineParser();
        parser.addStringArrayCommand("i", true);
        parser.addIntCommand("ninitcluster", false);
        parser.addStringCommand("initcenters", false);
        parser.addStringCommand("metric", false, "euclidean", new String[]{"euclidean", "minrmsd"});
        parser.addIntCommand("nsplit", true);
        parser.addCommand("tau", true);
        parser.addIntArgument("tau", true);
        parser.addIntArgument("tau", true);
        parser.addIntCommand("nhidden", true);
        parser.addDoubleCommand("requestederror", false);
        parser.addIntCommand("maxrefinementsteps", false);
        parser.addStringCommand("o", true);
        if (!parser.parse(args)) {
            return false;
        }
        String[] ifiles = parser.getStringArray("i");
        this.inputFiles = new ArrayList<String>();
        int i = 0;
        while (i < ifiles.length) {
            this.inputFiles.add(ifiles[i]);
            ++i;
        }
        System.out.println("reading input data ... ");
        IDataSequenceLoader loader = API.dataNew.multiSequenceLoader(this.inputFiles);
        this.data = loader.loadAll();
        System.out.println(" done. size: " + this.data.getSequence(0).size() + " x " + this.data.getSequence(0).dimension());
        String metricstring = parser.getString("metric");
        if (metricstring.equalsIgnoreCase("euclidean")) {
            this.metric = new EuclideanDistance();
        }
        if (metricstring.equalsIgnoreCase("minrmsd")) {
            this.metric = new MinimalRMSDistance3D(this.data.getSequence(0).dimension() / 3);
        }
        if (parser.hasCommand("ninitcluster")) {
            this.ninitclusters = parser.getInt("ninitcluster");
        } else {
            String initcenterfile = parser.getString("initcenters");
            this.initcenters = API.dataNew.reader(initcenterfile).load();
        }
        this.nsplit = parser.getInt("nsplit");
        this.tau = parser.getInt("tau", 0);
        this.timeshift = parser.getInt("tau", 1);
        this.nhidden = parser.getInt("nhidden");
        this.requestedError = parser.getDouble("requestederror");
        this.maxrefinementsteps = parser.getInt("maxrefinementsteps");
        this.outdir = parser.getString("o");
        System.out.println("read all input, continuing");
        return true;
    }

    public static String getUsageString() {
        return "\n=======================================\n AdaptiveClustering\n=======================================\nUsage: \n\nMandatory input and output options: \n -i  <trajectory>+\n -nsplit <number of new clusters per split>\n -tau <lag time> <timeshift>\n -nhidden <number of hidden states>\n\n -o <out-dir>\n\nAny of: \n -ninitcluster <number of initial clusters>\n -initcenters <initial centers>\n [-metric <minrmsd|euclidean>]\n\nOptional: \n [-requestederror <error-threshold, default = 0.01>]\n [-maxrefinementsteps <number of maximum refinement steps>]\n\n";
    }

    public static void main(String[] args) throws FileNotFoundException, IOException, ParameterEstimationException {
        AdaptiveDiscretization cmd;
        if (args.length == 0) {
            System.out.println(AdaptiveDiscretization.getUsageString());
            System.exit(0);
        }
        if (!(cmd = new AdaptiveDiscretization()).parseArguments(args)) {
            System.out.println(AdaptiveDiscretization.getUsageString());
            System.exit(0);
        }
        if (cmd.ninitclusters > 0) {
            IClustering clustering1 = API.clusterNew.kmeans(cmd.ninitclusters, 10);
            clustering1.setMetric(cmd.metric);
            IClustering clustering2 = API.clusterNew.random(cmd.nsplit);
            clustering2.setMetric(cmd.metric);
            cmd.mc = new MultiClusteringSplitMerge(cmd.data.getSequence(0), clustering1, clustering2);
        } else {
            IClustering clustering2 = API.clusterNew.randomCompact(cmd.nsplit, 10);
            clustering2.setMetric(cmd.metric);
            cmd.mc = new MultiClusteringSplitMerge(cmd.data.getSequence(0), cmd.initcenters, cmd.metric, clustering2);
        }
        cmd.estimate();
        cmd.estimateError();
        int n = 0;
        while (n < cmd.maxrefinementsteps && cmd.etot > cmd.requestedError) {
            boolean couldSplit = cmd.split();
            cmd.estimate();
            cmd.estimateError();
            cmd.merge();
            cmd.estimate();
            cmd.estimateError();
            if (n == cmd.maxrefinementsteps - 1) {
                System.out.println("Number of iterations expired. DONE!");
            }
            if (cmd.etot <= cmd.requestedError) {
                System.out.println("Prescribed error level of " + cmd.requestedError + " reached. DONE!");
            }
            ++n;
        }
        IIntArray dtrajFinal = cmd.mc.getCurrentDiscreteTrajectory();
        API.intseq.writeIntSequence(dtrajFinal, String.valueOf(cmd.outdir) + "/dtraj.dat");
    }
}

