/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.lasso;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.Tools;
import java.util.TreeSet;
import java.util.Vector;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;

public class LARSModel
extends PredictionModel {
    private static final long serialVersionUID = -6112829333480866927L;
    protected AttributeWeights weights;
    protected Vector<RealVector> beta;
    protected Vector<Double> gamma_A;
    protected Vector<Double> C;
    protected double avgY;
    protected double threshold;
    protected boolean lasso;
    private int solution;

    LARSModel(ExampleSet exampleSet) throws OperatorException {
        this(exampleSet, false, 1.0, 1.0E-4);
    }

    LARSModel(ExampleSet exampleSet, boolean doLasso, double threshold, double epsilon) throws OperatorException {
        super(exampleSet);
        double gamma_hat;
        int i;
        this.threshold = threshold;
        this.lasso = doLasso;
        int m = exampleSet.getAttributes().size();
        int n = exampleSet.size();
        RealMatrix X = MatrixUtils.createRealMatrix(n, m);
        int j = 0;
        for (Attribute a : exampleSet.getAttributes()) {
            i = 0;
            for (Example e : exampleSet) {
                X.setEntry(i++, j, e.getNumericalValue(a));
            }
            ++j;
        }
        RealMatrix Y = MatrixUtils.createRealMatrix(n, 1);
        i = 0;
        if (exampleSet.getAttributes().getLabel().isNumerical()) {
            for (Example e : exampleSet) {
                Y.setEntry(i++, 0, e.getLabel());
            }
        } else {
            for (Example e : exampleSet) {
                Y.setEntry(i++, 0, 2.0 * (e.getLabel() - 0.5));
            }
        }
        double[] s = new double[m];
        TreeSet<Integer> Aleph = new TreeSet<Integer>();
        int iterations = 0;
        int j_tilde = -1;
        int asize = 0;
        double maxL1 = 0.0;
        Vector<RealMatrix> nu_hat = new Vector<RealMatrix>();
        this.beta = new Vector();
        Vector<RealMatrix> c = new Vector<RealMatrix>();
        this.C = new Vector();
        this.gamma_A = new Vector();
        RealMatrix G_A_inv = null;
        nu_hat.add(MatrixUtils.createRealMatrix(n, 1));
        RealVector beta_prev = MatrixUtils.createRealVector(new double[m]);
        this.beta.add(beta_prev);
        do {
            RealMatrix c_hat = X.transpose().multiply(Y.subtract((RealMatrix)nu_hat.get(iterations)));
            c.add(c_hat);
            double C_hat = 0.0;
            for (j = 0; j < m; ++j) {
                if (!(Math.abs(c_hat.getEntry(j, 0)) > C_hat)) continue;
                C_hat = Math.abs(c_hat.getEntry(j, 0));
            }
            this.C.add(C_hat);
            if (doLasso && j_tilde >= 0) {
                Aleph.remove(j_tilde);
            } else {
                for (j = 0; j < m; ++j) {
                    if (!(Math.abs(c_hat.getEntry(j, 0)) >= C_hat - epsilon)) continue;
                    Aleph.add(j);
                }
            }
            asize = Aleph.size();
            for (Integer J : Aleph) {
                s[J.intValue()] = Math.signum(c_hat.getEntry(J, 0));
            }
            if (Aleph.isEmpty()) break;
            RealMatrix X_A = MatrixUtils.createRealMatrix(n, asize);
            j = 0;
            for (Integer j_a : Aleph) {
                for (i = 0; i < n; ++i) {
                    X_A.setEntry(i, j, X.getEntry(i, j_a) * s[j_a]);
                }
                ++j;
            }
            RealMatrix G_A = X_A.transpose().multiply(X_A);
            G_A_inv = new LUDecompositionImpl(G_A).getSolver().getInverse();
            double A_A = 0.0;
            for (i = 0; i < asize; ++i) {
                for (j = 0; j < asize; ++j) {
                    A_A += G_A_inv.getEntry(i, j);
                }
            }
            if (A_A <= 0.0) {
                throw new OperatorException("Sum of inverse G_A matrix ist not positive.");
            }
            A_A = 1.0 / Math.sqrt(A_A);
            RealMatrix w_A = MatrixUtils.createRealMatrix(asize, 1);
            for (i = 0; i < asize; ++i) {
                double w_a_tmp = 0.0;
                for (j = 0; j < asize; ++j) {
                    w_a_tmp += G_A_inv.getEntry(i, j);
                }
                w_A.setEntry(i, 0, w_a_tmp *= A_A);
            }
            RealMatrix u_A = X_A.multiply(w_A);
            RealMatrix a = X.transpose().multiply(u_A);
            gamma_hat = Double.POSITIVE_INFINITY;
            for (j = 0; j < m; ++j) {
                if (Aleph.contains(j) && Aleph.size() != m) continue;
                double t1 = (C_hat - c_hat.getEntry(j, 0)) / (A_A - a.getEntry(j, 0));
                double t2 = (C_hat + c_hat.getEntry(j, 0)) / (A_A + a.getEntry(j, 0));
                if (t1 > 0.0 && t1 < gamma_hat) {
                    gamma_hat = t1;
                }
                if (!(t2 > 0.0) || !(t2 < gamma_hat)) continue;
                gamma_hat = t2;
            }
            if (doLasso) {
                Vector<Double> gamma = new Vector<Double>(m);
                i = 0;
                for (j = 0; j < m; ++j) {
                    if (Aleph.contains(j)) {
                        gamma.add(-1.0 * beta_prev.getEntry(j) / (s[j] * w_A.getEntry(i++, 0)));
                        continue;
                    }
                    gamma.add(0.0);
                }
                double gamma_tilde = Double.POSITIVE_INFINITY;
                j_tilde = -1;
                for (j = 0; j < m; ++j) {
                    double gamma_tmp = (Double)gamma.get(j);
                    if (!(gamma_tmp > 0.0) || !(gamma_tmp < gamma_tilde)) continue;
                    gamma_tilde = gamma_tmp;
                    j_tilde = j;
                }
                if (gamma_tilde < gamma_hat && j_tilde >= 0) {
                    gamma_hat = gamma_tilde;
                } else {
                    j_tilde = -1;
                }
            }
            this.gamma_A.add(gamma_hat);
            nu_hat.add(((RealMatrix)nu_hat.get(iterations)).add(u_A.scalarMultiply(gamma_hat)));
            ArrayRealVector beta_A = new ArrayRealVector(m);
            int j_a = 0;
            for (j = 0; j < m; ++j) {
                if (!Aleph.contains(j)) continue;
                beta_A.setEntry(j, beta_prev.getEntry(j) + w_A.getEntry(j_a, 0) * gamma_hat * s[j]);
                ++j_a;
            }
            this.beta.add(beta_A);
            beta_prev = beta_A;
            maxL1 = beta_A.getL1Norm();
        } while (gamma_hat > epsilon && !Double.isInfinite(gamma_hat) && (this.lasso || ++iterations <= n && iterations <= m) && (!(threshold > 0.0) || !(threshold < maxL1)));
        this.solution = iterations;
        if (threshold > 0.0 && threshold < maxL1) {
            this.solution = 0;
            for (i = 1; i < iterations; ++i) {
                if (!(this.beta.get(i).getL1Norm() <= threshold)) continue;
                this.solution = i;
            }
            RealVector beta_diff = this.beta.get(this.solution + 1).subtract(this.beta.get(this.solution));
            double fac = (threshold - this.beta.get(this.solution).getL1Norm()) / (this.beta.get(this.solution + 1).getL1Norm() - this.beta.get(this.solution).getL1Norm());
            RealVector beta_t = this.beta.get(this.solution).add(beta_diff.mapMultiply(fac));
            this.beta.add(this.solution + 1, beta_t);
            ++this.solution;
        }
        this.weights = new AttributeWeights(exampleSet);
        j = 0;
        for (Attribute att : exampleSet.getAttributes()) {
            this.weights.setWeight(att.getName(), this.beta.get(this.solution).getEntry(j));
            ++j;
        }
    }

    public void changeModel(double threshold, int numFeatures) {
        int i;
        int solution = this.beta.size();
        if (numFeatures > 0) {
            solution = 0;
            for (i = 1; i < this.beta.size(); ++i) {
                if (!(this.beta.get(i).mapSignum().getL1Norm() <= (double)numFeatures)) continue;
                solution = i;
                this.threshold = this.beta.get(i).getL1Norm();
            }
        } else {
            this.threshold = threshold;
            if (threshold > 0.0) {
                solution = 0;
                for (i = 1; i < this.beta.size(); ++i) {
                    if (!(this.beta.get(i).getL1Norm() <= threshold)) continue;
                    solution = i;
                }
            }
        }
        this.weights = (AttributeWeights)this.weights.clone();
        int j = 0;
        for (String att : this.weights.getAttributeNames()) {
            this.weights.setWeight(att, this.beta.get(solution).getEntry(j));
            ++j;
        }
    }

    public String toString() {
        StringBuffer result = new StringBuffer();
        result.append(super.toString() + Tools.getLineSeparator() + Tools.getLineSeparator());
        result.append(this.beta.size() - 1 + " iterations in total, using BETA of iteration no. " + this.solution + Tools.getLineSeparator() + Tools.getLineSeparator());
        for (String s : this.weights.getAttributeNames()) {
            result.append(s + " * " + this.weights.getWeight(s) + " + " + Tools.getLineSeparator());
        }
        result.append("0 (bias)");
        return result.toString();
    }

    public AttributeWeights getWeights() {
        return this.weights;
    }

    public int size() {
        return this.weights == null ? 0 : this.weights.getSize();
    }

    public Vector<RealVector> getBeta() {
        return this.beta;
    }

    public Vector<Double> getGamma() {
        return this.gamma_A;
    }

    public Vector<Double> getC() {
        return this.C;
    }

    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        boolean nominalLabel = !exampleSet.getAttributes().getLabel().isNumerical();
        for (Example e : exampleSet) {
            double pred = 0.0;
            for (Attribute a : exampleSet.getAttributes()) {
                pred += e.getValue(a) * this.weights.getWeight(a.getName());
            }
            if (nominalLabel) {
                pred = pred > 0.0 ? 1.0 : 0.0;
            }
            e.setValue(predictedLabel, pred);
        }
        return exampleSet;
    }
}

