/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.mfs;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.math.similarity.DistanceMeasure;
import com.rapidminer.tools.math.similarity.numerical.EuclideanDistance;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

public class SAMModel
extends PredictionModel {
    private static final long serialVersionUID = -424389057824068747L;
    protected Example C1;
    protected Example C2;
    protected DistanceMeasure distFunc = new EuclideanDistance();
    protected Map<Attribute, Double> M1;
    protected Map<Attribute, Double> M2;
    protected String posLabelString = "";
    protected String negLabelString = "";
    protected double posLabelId = 0.0;
    protected double negLabelId = 1.0;

    public SAMModel(ExampleSet exampleSet) {
        super(exampleSet);
        int p = exampleSet.getAttributes().size();
        int n = exampleSet.size();
        int n1 = 0;
        int n2 = 0;
        this.posLabelId = exampleSet.getAttributes().getLabel().getMapping().getPositiveIndex();
        this.negLabelId = exampleSet.getAttributes().getLabel().getMapping().getNegativeIndex();
        this.posLabelString = exampleSet.getAttributes().getLabel().getMapping().getPositiveString();
        this.negLabelString = exampleSet.getAttributes().getLabel().getMapping().getNegativeString();
        double[] labels = new double[n];
        this.M1 = new HashMap<Attribute, Double>(p);
        this.M2 = new HashMap<Attribute, Double>(p);
        Iterator attIter = exampleSet.getAttributes().allAttributes();
        while (attIter.hasNext()) {
            Attribute attribute = (Attribute)attIter.next();
            this.M1.put(attribute, 0.0);
            this.M2.put(attribute, 0.0);
        }
        Iterator nit = exampleSet.iterator();
        int i = 0;
        double tmpLabel = 0.0;
        while (nit.hasNext()) {
            labels[i] = tmpLabel = ((Example)nit.next()).getLabel();
            if (tmpLabel == this.posLabelId) {
                ++n1;
            } else if (tmpLabel == this.negLabelId) {
                ++n2;
            }
            ++i;
        }
        this.log("n1=" + n1 + ", n2=" + n2);
        if (n1 + n2 != n) {
            this.log("Summe der Klassengr\u00c3\u00b6\u00c3\u0178en ungleich Anzahl Beispiele.");
        }
        if (n1 == 0 || n2 == 0) {
            this.log("Eine Klasse mit 0 Elementen");
        }
        Map<Attribute, Double> M = null;
        for (Example e : exampleSet) {
            M = e.getLabel() == this.posLabelId ? this.M1 : this.M2;
            for (Attribute attribute : exampleSet.getAttributes()) {
                M.put(attribute, M.get(attribute) + e.getValue(attribute));
            }
        }
        for (Attribute attribute : exampleSet.getAttributes()) {
            this.M1.put(attribute, this.M1.get(attribute) / (double)n1);
            this.M2.put(attribute, this.M2.get(attribute) / (double)n2);
        }
    }

    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) {
        Attributes attributes = exampleSet.getAttributes();
        for (Example example : exampleSet) {
            double d1 = 0.0;
            double d2 = 0.0;
            for (Attribute a : attributes) {
                double val = example.getNumericalValue(a);
                double temp = val - this.M1.get(a);
                d1 += temp * temp;
                temp = val - this.M2.get(a);
                d2 += temp * temp;
            }
            if (d1 < d2) {
                example.setPredictedLabel(this.posLabelId);
            } else {
                example.setPredictedLabel(this.negLabelId);
            }
            example.setConfidence(this.posLabelString, d1 / (d1 + d2));
            example.setConfidence(this.negLabelString, d2 / (d1 + d2));
        }
        return exampleSet;
    }

    public boolean isUpdatable() {
        return false;
    }

    public String toString() {
        StringBuffer buffer = new StringBuffer();
        buffer.append("SAM model for two classes");
        buffer.append(Tools.getLineSeparator());
        buffer.append(Tools.getLineSeparator());
        buffer.append("Centroid of positive class: ");
        buffer.append(Tools.getLineSeparator());
        buffer.append(this.M1.toString());
        buffer.append(Tools.getLineSeparator());
        buffer.append(Tools.getLineSeparator());
        buffer.append("Centroid of negative class: ");
        buffer.append(Tools.getLineSeparator());
        buffer.append(this.M2.toString());
        return buffer.toString();
    }
}

