/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.stats;

import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.ProbabilisticClassifier;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.Scorer;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Triple;
import java.text.NumberFormat;
import java.util.ArrayList;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MultiClassPrecisionRecallStats<L>
implements Scorer<L> {
    protected int[] tpCount;
    protected int[] fpCount;
    protected int[] fnCount;
    protected Index<L> labelIndex;
    protected L negLabel;
    protected int negIndex = -1;

    public <F> MultiClassPrecisionRecallStats(Classifier<L, F> classifier, GeneralDataset<L, F> data, L negLabel) {
        this.negLabel = negLabel;
        this.score(classifier, data);
    }

    public MultiClassPrecisionRecallStats(L negLabel) {
        this.negLabel = negLabel;
    }

    @Override
    public <F> double score(ProbabilisticClassifier<L, F> classifier, GeneralDataset<L, F> data) {
        return this.score((Classifier<L, F>)classifier, data);
    }

    @Override
    public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) {
        ArrayList<L> guesses = new ArrayList<L>();
        ArrayList<L> labels = new ArrayList<L>();
        for (int i = 0; i < data.size(); ++i) {
            RVFDatum<L, F> d = data.getRVFDatum(i);
            L guess = classifier.classOf(d);
            guesses.add(guess);
        }
        int[] labelsArr = data.getLabelsArray();
        this.labelIndex = data.labelIndex;
        for (int i = 0; i < data.size(); ++i) {
            labels.add(this.labelIndex.get(labelsArr[i]));
        }
        this.labelIndex = new HashIndex<L>();
        this.labelIndex.addAll(data.labelIndex().objectsList());
        this.labelIndex.addAll(classifier.labels());
        int numClasses = this.labelIndex.size();
        this.tpCount = new int[numClasses];
        this.fpCount = new int[numClasses];
        this.fnCount = new int[numClasses];
        this.negIndex = this.labelIndex.indexOf(this.negLabel);
        for (int i = 0; i < guesses.size(); ++i) {
            Object label;
            int trueIndex;
            Object guess = guesses.get(i);
            int guessIndex = this.labelIndex.indexOf(guess);
            if (guessIndex == (trueIndex = this.labelIndex.indexOf(label = labels.get(i)))) {
                if (guessIndex == this.negIndex) continue;
                int n = guessIndex;
                this.tpCount[n] = this.tpCount[n] + 1;
                continue;
            }
            if (guessIndex != this.negIndex) {
                int n = guessIndex;
                this.fpCount[n] = this.fpCount[n] + 1;
            }
            if (trueIndex == this.negIndex) continue;
            int n = trueIndex;
            this.fnCount[n] = this.fnCount[n] + 1;
        }
        return this.getFMeasure();
    }

    public Triple<Double, Integer, Integer> getPrecisionInfo(L label) {
        int i = this.labelIndex.indexOf(label);
        if (this.tpCount[i] == 0 && this.fpCount[i] == 0) {
            return new Triple<Double, Integer, Integer>(1.0, this.tpCount[i], this.fpCount[i]);
        }
        return new Triple<Double, Integer, Integer>((double)this.tpCount[i] / (double)(this.tpCount[i] + this.fpCount[i]), this.tpCount[i], this.fpCount[i]);
    }

    public double getPrecision(L label) {
        return this.getPrecisionInfo(label).first();
    }

    public Triple<Double, Integer, Integer> getPrecisionInfo() {
        int tp = 0;
        int fp = 0;
        for (int i = 0; i < this.labelIndex.size(); ++i) {
            if (i == this.negIndex) continue;
            tp += this.tpCount[i];
            fp += this.fpCount[i];
        }
        return new Triple<Double, Integer, Integer>((double)tp / (double)(tp + fp), tp, fp);
    }

    public double getPrecision() {
        return this.getPrecisionInfo().first();
    }

    public String getPrecisionDescription(int numDigits) {
        NumberFormat nf = NumberFormat.getNumberInstance();
        nf.setMaximumFractionDigits(numDigits);
        Triple<Double, Integer, Integer> prec = this.getPrecisionInfo();
        return nf.format(prec.first()) + "  (" + prec.second() + "/" + (prec.second() + prec.third()) + ")";
    }

    public String getPrecisionDescription(int numDigits, L label) {
        NumberFormat nf = NumberFormat.getNumberInstance();
        nf.setMaximumFractionDigits(numDigits);
        Triple<Double, Integer, Integer> prec = this.getPrecisionInfo(label);
        return nf.format(prec.first()) + "  (" + prec.second() + "/" + (prec.second() + prec.third()) + ")";
    }

    public Triple<Double, Integer, Integer> getRecallInfo(L label) {
        int i = this.labelIndex.indexOf(label);
        if (this.tpCount[i] == 0 && this.fnCount[i] == 0) {
            return new Triple<Double, Integer, Integer>(1.0, this.tpCount[i], this.fnCount[i]);
        }
        return new Triple<Double, Integer, Integer>((double)this.tpCount[i] / (double)(this.tpCount[i] + this.fnCount[i]), this.tpCount[i], this.fnCount[i]);
    }

    public double getRecall(L label) {
        return this.getRecallInfo(label).first();
    }

    public Triple<Double, Integer, Integer> getRecallInfo() {
        int tp = 0;
        int fn = 0;
        for (int i = 0; i < this.labelIndex.size(); ++i) {
            if (i == this.negIndex) continue;
            tp += this.tpCount[i];
            fn += this.fnCount[i];
        }
        return new Triple<Double, Integer, Integer>((double)tp / (double)(tp + fn), tp, fn);
    }

    public double getRecall() {
        return this.getRecallInfo().first();
    }

    public String getRecallDescription(int numDigits) {
        NumberFormat nf = NumberFormat.getNumberInstance();
        nf.setMaximumFractionDigits(numDigits);
        Triple<Double, Integer, Integer> recall = this.getRecallInfo();
        return nf.format(recall.first()) + "  (" + recall.second() + "/" + (recall.second() + recall.third()) + ")";
    }

    public String getRecallDescription(int numDigits, L label) {
        NumberFormat nf = NumberFormat.getNumberInstance();
        nf.setMaximumFractionDigits(numDigits);
        Triple<Double, Integer, Integer> recall = this.getRecallInfo(label);
        return nf.format(recall.first()) + "  (" + recall.second() + "/" + (recall.second() + recall.third()) + ")";
    }

    public double getFMeasure(L label) {
        double p = this.getPrecision(label);
        double r = this.getRecall(label);
        double f = 2.0 * p * r / (p + r);
        return f;
    }

    public double getFMeasure() {
        double p = this.getPrecision();
        double r = this.getRecall();
        double f = 2.0 * p * r / (p + r);
        return f;
    }

    public String getF1Description(int numDigits) {
        NumberFormat nf = NumberFormat.getNumberInstance();
        nf.setMaximumFractionDigits(numDigits);
        return nf.format(this.getFMeasure());
    }

    public String getF1Description(int numDigits, L label) {
        NumberFormat nf = NumberFormat.getNumberInstance();
        nf.setMaximumFractionDigits(numDigits);
        return nf.format(this.getFMeasure(label));
    }

    @Override
    public String getDescription(int numDigits) {
        StringBuffer sb = new StringBuffer();
        sb.append("--- PR Stats ---").append("\n");
        for (Object label : this.labelIndex) {
            if (label.equals(this.negLabel)) continue;
            sb.append("** ").append(label.toString()).append(" **\n");
            sb.append("\tPrec:   ").append(this.getPrecisionDescription(numDigits, label)).append("\n");
            sb.append("\tRecall: ").append(this.getRecallDescription(numDigits, label)).append("\n");
            sb.append("\tF1:     ").append(this.getF1Description(numDigits, label)).append("\n");
        }
        sb.append("** Overall **\n");
        sb.append("\tPrec:   ").append(this.getPrecisionDescription(numDigits)).append("\n");
        sb.append("\tRecall: ").append(this.getRecallDescription(numDigits)).append("\n");
        sb.append("\tF1:     ").append(this.getF1Description(numDigits));
        return sb.toString();
    }
}

