/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.mfs;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.Value;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.features.weighting.AbstractWeighting;
import com.rapidminer.operator.mfs.Util;
import com.rapidminer.operator.mrmr.MRMRFunctions;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.RandomGenerator;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.Random;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.SingularMatrixException;

public class RecursiveConditionalCorrelationWeighting
extends AbstractWeighting {
    public static final String PARAMETER_BLOCKSIZE = "blocksize";
    public static final String PARAMETER_USE_ENSEMBLE_CORRELATION = "use_ensemble_correlation";
    public static final String PARAMETER_ENSEMBLE_SIZE = "ensemble_size";
    public static final String PARAMETER_RANDOM_REPETITIONS = "repetitions";
    public static final String PARAMETER_RESULT_COMBINATION = "recursive_result_combination";
    public static final String PARAMETER_THRESHOLD = "threshold";
    public static final String PARAMETER_ELIMINATION = "elimination";
    public static final String PARAMETER_K = "k";
    private double iteration;

    public RecursiveConditionalCorrelationWeighting(OperatorDescription description) {
        super(description);
        this.addValue((Value)new ValueDouble("iteration", "The number of the current iteration."){

            public double getDoubleValue() {
                return RecursiveConditionalCorrelationWeighting.this.iteration;
            }
        });
    }

    public AttributeWeights calculateWeights(ExampleSet exampleSet) throws OperatorException {
        AttributeWeights weights;
        block11: {
            int id;
            Iterator it;
            LinkedList<Integer> nextQueue;
            boolean recursiveResultCombination;
            boolean elimination;
            int ensembleSize;
            boolean useEnsemble;
            int repetitions;
            int blocksize;
            LinkedList<Integer> queue;
            double[] CorWithLabel;
            String[] attributeNames;
            Attributes attributes;
            int i;
            int k;
            block10: {
                if (this.getParameterAsDouble(PARAMETER_THRESHOLD) > 0.0) {
                    return this.graphEstimate(exampleSet);
                }
                int p = exampleSet.getAttributes().size();
                k = this.getParameterAsInt(PARAMETER_K);
                i = 0;
                attributes = exampleSet.getAttributes();
                Attribute label = attributes.getLabel();
                attributeNames = new String[p];
                CorWithLabel = new double[p];
                Iterator attribute = attributes.iterator();
                i = 0;
                while (attribute.hasNext()) {
                    Attribute att = (Attribute)attribute.next();
                    attributeNames[i] = att.getName();
                    CorWithLabel[i++] = MRMRFunctions.Correlation(exampleSet, att, label);
                }
                queue = new LinkedList<Integer>();
                for (i = 0; i < p; ++i) {
                    queue.add(i);
                }
                weights = new AttributeWeights(exampleSet);
                for (i = 0; i < p; ++i) {
                    weights.setWeight(attributeNames[i], 0.0);
                }
                blocksize = this.getParameterAsInt(PARAMETER_BLOCKSIZE);
                repetitions = this.getParameterAsInt(PARAMETER_RANDOM_REPETITIONS);
                useEnsemble = this.getParameterAsBoolean(PARAMETER_USE_ENSEMBLE_CORRELATION);
                ensembleSize = this.getParameterAsInt(PARAMETER_ENSEMBLE_SIZE);
                elimination = this.getParameterAsBoolean(PARAMETER_ELIMINATION);
                recursiveResultCombination = this.getParameterAsBoolean(PARAMETER_RESULT_COMBINATION);
                if (repetitions != 1) break block10;
                this.recursivelyShrinkQueue(queue, exampleSet, k, blocksize, attributes, attributeNames, CorWithLabel, useEnsemble, ensembleSize, elimination);
                while (!queue.isEmpty()) {
                    weights.setWeight(attributeNames[queue.poll()], 1.0);
                }
                break block11;
            }
            for (i = 0; i < repetitions; ++i) {
                this.iteration += 1.0;
                nextQueue = new LinkedList();
                it = queue.iterator();
                while (it.hasNext()) {
                    nextQueue.add((Integer)it.next());
                }
                this.recursivelyShrinkQueue(nextQueue, exampleSet, k, blocksize, attributes, attributeNames, CorWithLabel, useEnsemble, ensembleSize, elimination);
                while (!nextQueue.isEmpty()) {
                    id = nextQueue.poll();
                    weights.setWeight(attributeNames[id], weights.getWeight(attributeNames[id]) + 1.0);
                }
                Collections.shuffle(queue, (Random)RandomGenerator.getRandomGenerator((Operator)this));
            }
            if (!recursiveResultCombination) break block11;
            nextQueue = new LinkedList<Integer>();
            it = queue.iterator();
            while (it.hasNext()) {
                id = (Integer)it.next();
                if (!(weights.getWeight(attributeNames[id]) > 0.0)) continue;
                nextQueue.add(id);
                weights.setWeight(attributeNames[id], 0.0);
            }
            this.recursivelyShrinkQueue(nextQueue, exampleSet, k, blocksize, attributes, attributeNames, CorWithLabel, useEnsemble, ensembleSize, elimination);
            while (!nextQueue.isEmpty()) {
                weights.setWeight(attributeNames[nextQueue.poll()], 1.0);
            }
        }
        return weights;
    }

    private void recursivelyShrinkQueue(Queue<Integer> queue, ExampleSet exampleSet, int k, int blocksize, Attributes attributes, String[] attributeNames, double[] CorWithLabel, boolean useEnsemble, int ensembleSize, boolean elimination) throws OperatorException {
        if (queue.size() == 1) {
            return;
        }
        if (queue.size() < blocksize) {
            blocksize = queue.size();
        }
        int singularMatricesCounter = 0;
        RealMatrix corrmat = MatrixUtils.createRealIdentityMatrix(blocksize + 1);
        int[] ii = new int[blocksize + 1];
        double temp = 0.0;
        while (queue.size() > k) {
            int i;
            RealMatrix s;
            int i2;
            if (blocksize > queue.size()) {
                blocksize = queue.size();
                corrmat = MatrixUtils.createRealIdentityMatrix(blocksize + 1);
            }
            for (i2 = 1; i2 <= blocksize; ++i2) {
                ii[i2] = queue.remove();
            }
            for (i2 = 1; i2 <= blocksize; ++i2) {
                corrmat.setEntry(0, i2, CorWithLabel[ii[i2]]);
                corrmat.setEntry(i2, 0, CorWithLabel[ii[i2]]);
                for (int j = i2 + 1; j <= blocksize; ++j) {
                    temp = useEnsemble ? MRMRFunctions.Correlation(exampleSet, attributes.get(attributeNames[ii[i2]]), attributes.get(attributeNames[ii[j]]), ensembleSize) : MRMRFunctions.Correlation(exampleSet, attributes.get(attributeNames[ii[i2]]), attributes.get(attributeNames[ii[j]]));
                    corrmat.setEntry(i2, j, temp);
                    corrmat.setEntry(j, i2, temp);
                }
            }
            try {
                s = new LUDecompositionImpl(corrmat).getSolver().getInverse();
            }
            catch (SingularMatrixException ex) {
                for (i = 1; i <= blocksize; ++i) {
                    queue.add(ii[i]);
                }
                this.getLogger().warning("The Correlation-Matrix was singular. All " + blocksize + " features were put back into the queue.");
                if (++singularMatricesCounter <= attributes.size()) continue;
                throw new OperatorException("Too many (" + singularMatricesCounter + ") singular matrices occured. Killing process.");
            }
            double[] cond_corr = new double[blocksize];
            for (i = 1; i <= blocksize; ++i) {
                cond_corr[i - 1] = Math.abs(s.getEntry(0, i) / Math.sqrt(s.getEntry(0, 0) * s.getEntry(i, i)));
            }
            if (elimination) {
                int imin = Util.minIndex(cond_corr) + 1;
                for (int i3 = 1; i3 <= blocksize; ++i3) {
                    if (i3 == imin) continue;
                    queue.add(ii[i3]);
                }
                continue;
            }
            int imax = Util.maxIndex(cond_corr) + 1;
            queue.add(ii[imax]);
        }
    }

    public AttributeWeights graphEstimate(ExampleSet exampleSet) throws OperatorException {
        int p = exampleSet.getAttributes().size();
        int i = 0;
        Attributes attributes = exampleSet.getAttributes();
        Attribute label = attributes.getLabel();
        String[] attributeNames = new String[p];
        double[] CorWithLabel = new double[p];
        Iterator attribute = attributes.iterator();
        i = 0;
        while (attribute.hasNext()) {
            Attribute att = (Attribute)attribute.next();
            attributeNames[i] = att.getName();
            CorWithLabel[i++] = MRMRFunctions.Correlation(exampleSet, att, label);
        }
        LinkedList<Integer> queue = new LinkedList<Integer>();
        for (i = 0; i < p; ++i) {
            queue.add(i);
        }
        AttributeWeights weights = new AttributeWeights(exampleSet);
        for (i = 0; i < p; ++i) {
            weights.setWeight(attributeNames[i], 0.0);
        }
        int blocksize = this.getParameterAsInt(PARAMETER_BLOCKSIZE);
        boolean useEnsemble = this.getParameterAsBoolean(PARAMETER_USE_ENSEMBLE_CORRELATION);
        int ensembleSize = this.getParameterAsInt(PARAMETER_ENSEMBLE_SIZE);
        double threshold = this.getParameterAsDouble(PARAMETER_THRESHOLD);
        int singularMatricesCounter = 0;
        RealMatrix corrmat = MatrixUtils.createRealIdentityMatrix(blocksize + 1);
        int[] ii = new int[blocksize + 1];
        double temp = 0.0;
        int roundsWithoutDeletion = 0;
        do {
            RealMatrix s;
            if (blocksize > queue.size()) {
                blocksize = queue.size();
                corrmat = MatrixUtils.createRealIdentityMatrix(blocksize + 1);
            }
            for (i = 1; i <= blocksize; ++i) {
                ii[i] = (Integer)queue.remove();
            }
            for (i = 1; i <= blocksize; ++i) {
                corrmat.setEntry(0, i, CorWithLabel[ii[i]]);
                corrmat.setEntry(i, 0, CorWithLabel[ii[i]]);
                for (int j = i + 1; j <= blocksize; ++j) {
                    temp = useEnsemble ? MRMRFunctions.Correlation(exampleSet, attributes.get(attributeNames[ii[i]]), attributes.get(attributeNames[ii[j]]), ensembleSize) : MRMRFunctions.Correlation(exampleSet, attributes.get(attributeNames[ii[i]]), attributes.get(attributeNames[ii[j]]));
                    corrmat.setEntry(i, j, temp);
                    corrmat.setEntry(j, i, temp);
                }
            }
            try {
                s = new LUDecompositionImpl(corrmat).getSolver().getInverse();
            }
            catch (SingularMatrixException ex) {
                for (i = 1; i <= blocksize; ++i) {
                    queue.add(ii[i]);
                }
                this.getLogger().warning("The Correlation-Matrix was singular. All " + blocksize + " features were put back into the queue.");
                if (++singularMatricesCounter <= attributes.size()) continue;
                throw new OperatorException("Too many (" + singularMatricesCounter + ") singular matrices occured. Killing process.");
            }
            double[] cond_corr = new double[blocksize];
            for (i = 1; i <= blocksize; ++i) {
                cond_corr[i - 1] = Math.abs(s.getEntry(0, i) / Math.sqrt(s.getEntry(0, 0) * s.getEntry(i, i)));
            }
            int imin = Util.minIndex(cond_corr) + 1;
            if (Math.abs(cond_corr[imin - 1]) > threshold) {
                imin = -1;
                ++roundsWithoutDeletion;
            } else {
                roundsWithoutDeletion = 0;
            }
            for (i = 1; i <= blocksize; ++i) {
                if (i == imin) continue;
                queue.add(ii[i]);
            }
        } while (roundsWithoutDeletion < queue.size());
        while (!queue.isEmpty()) {
            weights.setWeight(attributeNames[(Integer)queue.poll()], 1.0);
        }
        return weights;
    }

    public List<ParameterType> getParameterTypes() {
        List list = super.getParameterTypes();
        list.add(new ParameterTypeInt(PARAMETER_K, "Number of features to select", 1, Integer.MAX_VALUE, 10));
        list.add(new ParameterTypeInt(PARAMETER_BLOCKSIZE, "Number of features for conditional covariance estimation.", 2, Integer.MAX_VALUE, 3));
        list.add(new ParameterTypeInt(PARAMETER_RANDOM_REPETITIONS, "Number of randomised repetitions", 1, Integer.MAX_VALUE, 1));
        list.add(new ParameterTypeBoolean(PARAMETER_RESULT_COMBINATION, "Recursive result combination instead of averaged sets. Only active for random repetitions", true));
        list.add(new ParameterTypeBoolean(PARAMETER_ELIMINATION, "elimination: only the features with the smallest cond_corr per block is removed. Otherwise only the feature with the highest is kept.", true));
        list.add(new ParameterTypeBoolean(PARAMETER_USE_ENSEMBLE_CORRELATION, "Stabilize correlation computation with ensemble", false));
        list.add(new ParameterTypeInt(PARAMETER_ENSEMBLE_SIZE, "Size of the correlation ensemble (not the number of randomly initialized repetitions)", 1, Integer.MAX_VALUE, 10));
        list.add(new ParameterTypeDouble(PARAMETER_THRESHOLD, "threshold for graph estimation. Performs neighbourhood estimation if > 0.", 0.0, Double.MAX_VALUE, 0.0));
        return list;
    }

    public boolean supportsCapability(OperatorCapability capability) {
        if (capability == OperatorCapability.BINOMINAL_LABEL || capability == OperatorCapability.NUMERICAL_ATTRIBUTES || capability == OperatorCapability.NUMERICAL_LABEL) {
            return true;
        }
        return capability != OperatorCapability.BINOMINAL_ATTRIBUTES && capability != OperatorCapability.POLYNOMINAL_LABEL && capability != OperatorCapability.POLYNOMINAL_ATTRIBUTES;
    }
}

