/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.regression;

import cc.mallet.regression.LinearRegression;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.InvertedIndex;
import java.io.File;
import java.text.NumberFormat;

public class CoordinateDescent {
    LinearRegression regression;
    double[] parameters;
    InstanceList trainingData;
    double[] scaledResiduals;
    double tuningConstant;
    double[] sumSquaredX;
    double[] scaledThresholds;
    InvertedIndex featureIndex;
    int interceptIndex;
    int precisionIndex;
    int dimension;
    NumberFormat formatter;

    public CoordinateDescent(InstanceList data, double l1Weight) {
        int index;
        this.tuningConstant = l1Weight;
        this.trainingData = data;
        this.regression = new LinearRegression(this.trainingData.getDataAlphabet());
        this.parameters = this.regression.getParameters();
        this.interceptIndex = this.parameters.length - 2;
        this.precisionIndex = this.parameters.length - 1;
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(3);
        this.dimension = this.parameters.length - 1;
        this.scaledResiduals = new double[this.dimension];
        this.sumSquaredX = new double[this.dimension];
        this.scaledThresholds = new double[this.dimension];
        this.featureIndex = new InvertedIndex(data);
        for (Instance instance : data) {
            FeatureVector predictors = (FeatureVector)instance.getData();
            double y = (Double)instance.getTarget();
            int n = this.interceptIndex;
            this.scaledResiduals[n] = this.scaledResiduals[n] + y;
            for (int i = 0; i < predictors.numLocations(); ++i) {
                index = predictors.indexAtLocation(i);
                double value = predictors.valueAtLocation(i);
                int n2 = index;
                this.scaledResiduals[n2] = this.scaledResiduals[n2] + y * value;
                int n3 = index;
                this.sumSquaredX[n3] = this.sumSquaredX[n3] + value * value;
            }
        }
        int n = this.interceptIndex;
        this.scaledResiduals[n] = this.scaledResiduals[n] / (double)data.size();
        for (int index2 = 0; index2 < this.dimension - 1; ++index2) {
            int n4 = index2;
            this.scaledResiduals[n4] = this.scaledResiduals[n4] / this.sumSquaredX[index2];
            this.scaledThresholds[index2] = this.tuningConstant / this.sumSquaredX[index2];
        }
        boolean converged = false;
        int iteration = 0;
        while (!converged) {
            double value;
            double totalDiff = 0.0;
            double diff = this.parameters[this.interceptIndex] - this.scaledResiduals[this.interceptIndex];
            totalDiff += Math.abs(diff);
            this.parameters[this.interceptIndex] = this.scaledResiduals[this.interceptIndex];
            for (Instance instance : data) {
                FeatureVector predictors = (FeatureVector)instance.getData();
                for (int i = 0; i < predictors.numLocations(); ++i) {
                    int index3 = predictors.indexAtLocation(i);
                    value = predictors.valueAtLocation(i);
                    int n5 = index3;
                    this.scaledResiduals[n5] = this.scaledResiduals[n5] + value * diff / this.sumSquaredX[index3];
                }
            }
            for (index = 0; index < this.dimension - 1; ++index) {
                diff = this.parameters[index];
                if (this.scaledResiduals[index] > this.tuningConstant) {
                    this.parameters[index] = this.scaledResiduals[index] - this.tuningConstant;
                } else if (this.scaledResiduals[index] < -this.tuningConstant) {
                    this.parameters[index] = this.scaledResiduals[index] + this.tuningConstant;
                }
                totalDiff += Math.abs(diff -= this.parameters[index]);
                for (Object o : this.featureIndex.getInstancesWithFeature(index)) {
                    int i;
                    Instance instance = (Instance)o;
                    FeatureVector predictors = (FeatureVector)instance.getData();
                    value = 0.0;
                    for (i = 0; i < predictors.numLocations(); ++i) {
                        if (predictors.indexAtLocation(i) != index) continue;
                        value = predictors.valueAtLocation(i);
                        break;
                    }
                    int n6 = this.interceptIndex;
                    this.scaledResiduals[n6] = this.scaledResiduals[n6] + value * diff / (double)data.size();
                    for (i = 0; i < predictors.numLocations(); ++i) {
                        int otherIndex = predictors.indexAtLocation(i);
                        double otherValue = predictors.valueAtLocation(i);
                        if (otherIndex == index) continue;
                        int n7 = otherIndex;
                        this.scaledResiduals[n7] = this.scaledResiduals[n7] + value * otherValue * diff / this.sumSquaredX[otherIndex];
                    }
                }
            }
            if (totalDiff < 1.0E-4) {
                converged = true;
                continue;
            }
            if (++iteration % 100 != 0) continue;
            System.out.println(totalDiff);
        }
    }

    public String toString() {
        double sumSquaredError = 0.0;
        for (int i = 0; i < this.trainingData.size(); ++i) {
            Instance instance = (Instance)this.trainingData.get(i);
            double prediction = this.regression.predict(instance);
            double y = (Double)instance.getTarget();
            double residual = y - prediction;
            sumSquaredError += residual * residual;
        }
        StringBuilder out2 = new StringBuilder();
        out2.append("(Int)\t" + this.formatter.format(this.parameters[this.interceptIndex]) + "\n");
        for (int index = 0; index < this.dimension - 1; ++index) {
            out2.append(this.trainingData.getDataAlphabet().lookupObject(index) + "\t");
            out2.append(this.formatter.format(this.parameters[index]) + "\n");
        }
        out2.append("SSE: " + this.formatter.format(sumSquaredError) + "\n");
        return out2.toString();
    }

    public static void main(String[] args) throws Exception {
        InstanceList data = InstanceList.load(new File(args[0]));
        CoordinateDescent trainer = new CoordinateDescent(data, Double.parseDouble(args[1]));
        System.out.println(trainer);
    }
}

