/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

public class BalancedWinnow
extends Classifier
implements Serializable {
    double[][] m_weights;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;

    public BalancedWinnow(Pipe dataPipe, double[][] weights) {
        super(dataPipe);
        this.m_weights = new double[weights.length][weights[0].length];
        for (int i = 0; i < weights.length; ++i) {
            for (int j = 0; j < weights[0].length; ++j) {
                this.m_weights[i][j] = weights[i][j];
            }
        }
    }

    public double[][] getWeights() {
        int numCols = this.m_weights[0].length;
        double[][] ret = new double[this.m_weights.length][numCols];
        for (int i = 0; i < ret.length; ++i) {
            System.arraycopy(this.m_weights[i], 0, ret[i], 0, numCols);
        }
        return ret;
    }

    @Override
    public Classification classify(Instance instance) {
        int numClasses = this.getLabelAlphabet().size();
        int numFeats = this.getAlphabet().size();
        double[] scores = new double[numClasses];
        FeatureVector fv = (FeatureVector)instance.getData();
        assert (this.instancePipe == null || fv.getAlphabet() == this.instancePipe.getDataAlphabet());
        int fvisize = fv.numLocations();
        double sum = 0.0;
        for (int ci = 0; ci < numClasses; ++ci) {
            for (int fvi = 0; fvi < fvisize; ++fvi) {
                int fi = fv.indexAtLocation(fvi);
                double vi = fv.valueAtLocation(fvi);
                if (this.m_weights[ci].length <= fi) continue;
                int n = ci;
                scores[n] = scores[n] + vi * this.m_weights[ci][fi];
                sum += vi * this.m_weights[ci][fi];
            }
            int n = ci;
            scores[n] = scores[n] + this.m_weights[ci][numFeats];
            sum += this.m_weights[ci][numFeats];
        }
        MatrixOps.timesEquals(scores, 1.0 / sum);
        return new Classification(instance, this, new LabelVector(this.getLabelAlphabet(), scores));
    }

    private void writeObject(ObjectOutputStream out2) throws IOException {
        out2.writeInt(1);
        out2.writeObject(this.getInstancePipe());
        out2.writeObject(this.m_weights);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        if (version != 1) {
            throw new ClassNotFoundException("Mismatched BalancedWinnow versions: wanted 1, got " + version);
        }
        this.instancePipe = (Pipe)in.readObject();
        this.m_weights = (double[][])in.readObject();
    }
}

