/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.extract;

import cc.mallet.extract.ExactMatchComparator;
import cc.mallet.extract.Extraction;
import cc.mallet.extract.ExtractionEvaluator;
import cc.mallet.extract.Field;
import cc.mallet.extract.FieldComparator;
import cc.mallet.extract.Record;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.MatrixOps;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Writer;
import java.text.DecimalFormat;
import java.util.Iterator;

public class PerDocumentF1Evaluator
implements ExtractionEvaluator {
    private FieldComparator comparator = new ExactMatchComparator();
    private PrintStream errorOutputStream = null;

    public FieldComparator getComparator() {
        return this.comparator;
    }

    public void setComparator(FieldComparator comparator) {
        this.comparator = comparator;
    }

    public PrintStream getErrorOutputStream() {
        return this.errorOutputStream;
    }

    public void setErrorOutputStream(OutputStream errorOutputStream) {
        this.errorOutputStream = errorOutputStream instanceof PrintStream ? (PrintStream)errorOutputStream : new PrintStream(errorOutputStream);
    }

    @Override
    public void evaluate(Extraction extraction) {
        this.evaluate(extraction, System.out);
    }

    public void evaluate(Extraction extraction, PrintStream out2) {
        this.evaluate("", extraction, new PrintWriter((Writer)new OutputStreamWriter(out2), true));
    }

    public void evaluate(Extraction extraction, PrintWriter out2) {
        this.evaluate("", extraction, out2);
    }

    public void evaluate(String description, Extraction extraction, PrintWriter out2) {
        int numDocs = extraction.getNumDocuments();
        assert (numDocs == extraction.getNumRecords());
        LabelAlphabet dict = extraction.getLabelAlphabet();
        int numLabels = dict.size();
        int[] numCorr = new int[numLabels];
        int[] numPred = new int[numLabels];
        int[] numTrue = new int[numLabels];
        for (int docnum = 0; docnum < numDocs; ++docnum) {
            Label name;
            Record extracted = extraction.getRecord(docnum);
            Record target = extraction.getTargetRecord(docnum);
            Iterator it = extracted.fieldsIterator();
            while (it.hasNext()) {
                int idx;
                Field predField = (Field)it.next();
                name = predField.getName();
                Field trueField = target.getField(name);
                int n = idx = name.getIndex();
                numPred[n] = numPred[n] + 1;
                if (predField.numValues() > 1) {
                    System.err.println("Warning: Field " + predField + " has more than one extracted value. Picking arbitrarily...");
                }
                if (trueField != null && trueField.isValue(predField.value(0), this.comparator)) {
                    int n2 = idx;
                    numCorr[n2] = numCorr[n2] + 1;
                    continue;
                }
                if (this.errorOutputStream == null) continue;
                this.errorOutputStream.println("Error in extraction! Document " + extraction.getDocumentExtraction(docnum).getName());
                this.errorOutputStream.println("Predicted " + predField);
                this.errorOutputStream.println("True " + trueField);
                this.errorOutputStream.println();
            }
            it = target.fieldsIterator();
            while (it.hasNext()) {
                Field trueField = (Field)it.next();
                name = trueField.getName();
                int n = name.getIndex();
                numTrue[n] = numTrue[n] + 1;
            }
        }
        DecimalFormat f = new DecimalFormat("0.####");
        double totalF1 = 0.0;
        int totalFields = 0;
        out2.println(description + " per-document F1");
        out2.println("Name\tP\tR\tF1");
        for (int i = 0; i < numLabels; ++i) {
            double F1;
            double P = numPred[i] == 0 ? 0.0 : (double)numCorr[i] / (double)numPred[i];
            double R = numTrue[i] == 0 ? 1.0 : (double)numCorr[i] / (double)numTrue[i];
            double d = F1 = P + R == 0.0 ? 0.0 : 2.0 * P * R / (P + R);
            if (numPred[i] > 0 || numTrue[i] > 0) {
                totalF1 += F1;
                ++totalFields;
            }
            Label name = dict.lookupLabel(i);
            out2.println(name + "\t" + f.format(P) + "\t" + f.format(R) + "\t" + f.format(F1));
        }
        int totalCorr = MatrixOps.sum(numCorr);
        int totalPred = MatrixOps.sum(numPred);
        int totalTrue = MatrixOps.sum(numTrue);
        double P = (double)totalCorr / (double)totalPred;
        double R = (double)totalCorr / (double)totalTrue;
        double F1 = 2.0 * P * R / (P + R);
        out2.println("OVERALL (micro-averaged) P=" + f.format(P) + " R=" + f.format(R) + " F1=" + f.format(F1));
        out2.println("OVERALL (macro-averaged) F1=" + f.format(totalF1 / (double)totalFields));
        out2.println();
    }
}

