/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.pam;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.IOObjectCollection;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.Tools;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Vector;
import org.jfree.data.statistics.Statistics;

public class PAMModel
extends PredictionModel {
    private static final long serialVersionUID = 7383427652043162098L;
    protected int exampleSetSize;
    protected int attributeSize;
    protected int classSize;
    protected double shrinkage;
    protected double medianSD;
    protected ExampleSet exampleSet;
    protected AttributeWeights weights;
    protected Map<String, AttributeWeights> classWeights;
    protected Vector<Double> overallCentroid;
    protected Vector<Double> classSD;
    protected Map<String, Double> standardErrorComponent;
    protected Map<String, Integer> classFrequency;
    protected Map<String, Vector<Double>> classCentroid;
    protected Map<String, Vector<Double>> discriminantScore;

    public PAMModel(ExampleSet exampleSet, double shrinkage) throws OperatorException {
        super(exampleSet);
        Vector<Object> classCentroidVector;
        String labelString;
        int i;
        this.exampleSet = exampleSet;
        this.shrinkage = shrinkage;
        this.exampleSetSize = exampleSet.size();
        this.attributeSize = exampleSet.getAttributes().size();
        Attribute labelAttribute = exampleSet.getAttributes().getLabel();
        this.overallCentroid = new Vector();
        for (i = 0; i < this.attributeSize; ++i) {
            this.overallCentroid.add(0.0);
        }
        for (Example example : exampleSet) {
            i = 0;
            for (Attribute attribute : exampleSet.getAttributes()) {
                this.overallCentroid.set(i, this.overallCentroid.get(i) + example.getNumericalValue(attribute));
                ++i;
            }
        }
        for (i = 0; i < this.attributeSize; ++i) {
            this.overallCentroid.set(i, this.overallCentroid.get(i) / (double)this.exampleSetSize);
        }
        this.classFrequency = new HashMap<String, Integer>();
        this.classCentroid = new HashMap<String, Vector<Double>>();
        this.classSize = 0;
        for (Example example : exampleSet) {
            labelString = example.getNominalValue(labelAttribute);
            if (!this.classCentroid.containsKey(labelString)) {
                classCentroidVector = new Vector();
                for (i = 0; i < this.attributeSize; ++i) {
                    classCentroidVector.add(i, 0.0);
                }
                this.classCentroid.put(labelString, classCentroidVector);
                this.classFrequency.put(labelString, 0);
                ++this.classSize;
            }
            classCentroidVector = this.classCentroid.get(labelString);
            i = 0;
            for (Attribute attribute : exampleSet.getAttributes()) {
                classCentroidVector.set(i, (Double)classCentroidVector.get(i) + example.getNumericalValue(attribute));
                ++i;
            }
            this.classCentroid.put(labelString, classCentroidVector);
            this.classFrequency.put(labelString, this.classFrequency.get(labelString) + 1);
        }
        for (Map.Entry<String, Vector<Double>> entry : this.classCentroid.entrySet()) {
            labelString = entry.getKey();
            classCentroidVector = entry.getValue();
            for (i = 0; i < this.attributeSize; ++i) {
                classCentroidVector.set(i, (Double)classCentroidVector.get(i) / (double)this.classFrequency.get(labelString).intValue());
            }
            this.classCentroid.put(labelString, classCentroidVector);
        }
        this.standardErrorComponent = new HashMap<String, Double>();
        for (Map.Entry<String, Serializable> entry : this.classFrequency.entrySet()) {
            labelString = entry.getKey();
            int frequency = (Integer)entry.getValue();
            this.standardErrorComponent.put(labelString, Math.sqrt(1.0 / (double)frequency + 1.0 / (double)this.exampleSetSize));
        }
        this.classSD = new Vector();
        i = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            double sd = 0.0;
            for (Example example : exampleSet) {
                String labelString2 = example.getNominalValue(labelAttribute);
                double value = example.getNumericalValue(attribute);
                double mean = this.classCentroid.get(labelString2).get(i);
                sd += Math.pow(value - mean, 2.0);
            }
            if (this.exampleSetSize <= this.classSize) {
                throw new OperatorException("Number of examples in ExampleSet is not greater than number of different classes");
            }
            this.classSD.add(i, Math.sqrt(sd / (double)(this.exampleSetSize - this.classSize)));
            ++i;
        }
        this.medianSD = Statistics.calculateMedian(this.classSD);
        this.discriminantScore = new HashMap<String, Vector<Double>>();
        for (Map.Entry<String, Serializable> entry : this.classCentroid.entrySet()) {
            String labelString3 = entry.getKey();
            Vector classCentroidVector2 = (Vector)entry.getValue();
            double errorComponent = this.standardErrorComponent.get(labelString3);
            Vector<Double> discriminantVector = new Vector<Double>();
            for (i = 0; i < this.attributeSize; ++i) {
                double numerator = (Double)classCentroidVector2.get(i) - this.overallCentroid.get(i);
                double denominator = errorComponent * (this.classSD.get(i) + this.medianSD);
                discriminantVector.add(i, numerator / denominator);
            }
            this.discriminantScore.put(labelString3, discriminantVector);
        }
        this.classCentroid = this.shrinkCentroids(this.classCentroid);
        this.weights = new AttributeWeights(exampleSet);
        this.classWeights = new HashMap<String, AttributeWeights>();
        List labelValues = exampleSet.getAttributes().getLabel().getMapping().getValues();
        for (String label : labelValues) {
            this.classWeights.put(label, new AttributeWeights(exampleSet));
            this.classWeights.get(label).setSource("PAM - " + label);
        }
        i = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            int count = 0;
            for (Map.Entry<String, Vector<Double>> entry : this.discriminantScore.entrySet()) {
                if (entry.getValue().get(i) > 0.0) {
                    ++count;
                }
                this.classWeights.get(entry.getKey()).setWeight(attribute.getName(), entry.getValue().get(i).doubleValue());
            }
            this.weights.setWeight(attribute.getName(), (double)count);
            ++i;
        }
    }

    public Map<String, Vector<Double>> shrinkCentroids(Map<String, Vector<Double>> centroids) {
        this.discriminantScore = this.shrinkScores(this.discriminantScore);
        for (Map.Entry<String, Vector<Double>> entry : centroids.entrySet()) {
            String labelString = entry.getKey();
            Vector<Double> centroidsVector = entry.getValue();
            Vector<Double> scoresVector = this.discriminantScore.get(labelString);
            double errorComponent = this.standardErrorComponent.get(labelString);
            for (int i = 0; i < this.attributeSize; ++i) {
                double value = this.overallCentroid.get(i) + errorComponent * (this.classSD.get(i) + this.medianSD) * scoresVector.get(i);
                centroidsVector.set(i, value);
            }
            centroids.put(labelString, centroidsVector);
        }
        return centroids;
    }

    public Map<String, Vector<Double>> shrinkScores(Map<String, Vector<Double>> scores) {
        for (Map.Entry<String, Vector<Double>> entry : scores.entrySet()) {
            String labelString = entry.getKey();
            Vector<Double> scoresVector = entry.getValue();
            for (int i = 0; i < this.attributeSize; ++i) {
                double value = scoresVector.get(i);
                int sign = value >= 0.0 ? 1 : -1;
                value = Math.abs(value) - this.shrinkage;
                value = value < 0.0 ? 0.0 : (double)sign * value;
                scoresVector.set(i, value);
            }
            scores.put(labelString, scoresVector);
        }
        return scores;
    }

    public AttributeWeights getWeights() {
        return this.weights;
    }

    public IOObjectCollection<AttributeWeights> getClassWeights() {
        AttributeWeights[] a = new AttributeWeights[this.classWeights.values().size()];
        a = this.classWeights.values().toArray(a);
        IOObjectCollection result = new IOObjectCollection((IOObject[])a);
        return result;
    }

    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        for (Example example : exampleSet) {
            String predictedLabelString = null;
            double minScore = 0.0;
            int k = 0;
            for (Map.Entry<String, Vector<Double>> entry : this.classCentroid.entrySet()) {
                String labelString = entry.getKey();
                Vector<Double> classCentroidVector = entry.getValue();
                double score = 0.0;
                int i = 0;
                for (Attribute attribute : exampleSet.getAttributes()) {
                    double numerator = Math.pow(example.getNumericalValue(attribute) - classCentroidVector.get(i), 2.0);
                    double denominator = Math.pow(this.classSD.get(i) + this.medianSD, 2.0);
                    score += numerator / denominator;
                    ++i;
                }
                double classProbability = (double)this.classFrequency.get(labelString).intValue() / (double)this.exampleSetSize;
                score -= 2.0 * Math.log(classProbability);
                if (k == 0 || score < minScore) {
                    minScore = score;
                    predictedLabelString = labelString;
                }
                ++k;
            }
            example.setValue(predictedLabel, predictedLabelString);
        }
        return exampleSet;
    }

    public String toString() {
        StringBuffer result = new StringBuffer();
        result.append(super.toString() + Tools.getLineSeparator() + Tools.getLineSeparator());
        result.append("Used attributes:" + Tools.getLineSeparator());
        result.append("Attribute\tNumber of classes to which relevant" + Tools.getLineSeparator());
        int count = 0;
        for (Attribute attribute : this.exampleSet.getAttributes()) {
            double weight = this.weights.getWeight(attribute.getName());
            if (!(weight > 0.0)) continue;
            result.append(attribute.getName() + '\t' + weight + Tools.getLineSeparator());
            ++count;
        }
        result.append("Total number of attributes used: " + count + Tools.getLineSeparator());
        return result.toString();
    }
}

