/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.Optimizable;
import gnu.trove.list.array.TIntArrayList;
import java.util.Random;
import org.apache.commons.math3.special.Gamma;

public class MySparseGroupWordFeatOptimizable
implements Optimizable.ByGradientValue {
    int currentGroup = -1;
    int currentK = 0;
    int[][] groups = null;
    int[] all = null;
    int numTokens = 1;
    int k = 10;
    int v = 1928;
    double[][] parameters = null;
    double[] m = null;
    double[] b = null;
    double a = 0.5;
    double[] lastBuffer;
    double[][] documentFeatures = null;
    public TIntArrayList[] Phi;
    public double[] p_v;
    public double lambda = 1.0;
    public double sigma = 2.0;
    public int[] n_k;
    public int[] n_kv;
    double last = 0.0;

    public MySparseGroupWordFeatOptimizable(int k, int v) {
        this.k = k;
        this.v = v;
        this.parameters = new double[k][v];
        this.p_v = new double[v];
        this.all = new int[v];
        for (int j = 0; j < v; ++j) {
            this.p_v[j] = 1.0 / (double)v;
            this.all[j] = j;
        }
        for (int i = 0; i < k; ++i) {
            for (int j = 0; j < v; ++j) {
                this.parameters[i][j] = 2.0 * Math.random() * this.sigma - this.sigma;
            }
        }
    }

    public MySparseGroupWordFeatOptimizable(int k, int v, int num) {
        this.numTokens = num;
        this.k = k;
        this.v = v;
        this.parameters = new double[k][v];
        this.p_v = new double[v];
        this.all = new int[v];
        for (int j = 0; j < v; ++j) {
            this.p_v[j] = 1.0 / (double)v;
            this.all[j] = j;
        }
        for (int i = 0; i < k; ++i) {
            for (int j = 0; j < v; ++j) {
                this.parameters[i][j] = 2.0 * Math.random() * this.sigma - this.sigma;
            }
        }
    }

    public MySparseGroupWordFeatOptimizable(int k, int v, int num, Random rn) {
        this.numTokens = num;
        this.k = k;
        this.v = v;
        this.parameters = new double[k][v];
        this.p_v = new double[v];
        this.all = new int[v];
        for (int j = 0; j < v; ++j) {
            this.p_v[j] = 1.0 / (double)v;
            this.all[j] = j;
        }
        for (int i = 0; i < k; ++i) {
            for (int j = 0; j < v; ++j) {
                this.parameters[i][j] = 2.0 * rn.nextDouble() * this.sigma - this.sigma;
            }
        }
    }

    public void init(int group, int topic) {
        int[] ids = this.groups[group];
        for (int i = 0; i < ids.length; ++i) {
            this.parameters[topic][ids[i]] = 2.0 * Math.random() * this.sigma - this.sigma;
        }
    }

    public void init(int group, int topic, Random rn) {
        int[] ids = this.groups[group];
        for (int i = 0; i < ids.length; ++i) {
            this.parameters[topic][ids[i]] = 2.0 * rn.nextDouble() * this.sigma - this.sigma;
        }
    }

    @Override
    public int getNumParameters() {
        return this.k * this.v;
    }

    @Override
    public void getParameters(double[] buffer) {
        for (int i = 0; i < this.k; ++i) {
            for (int j = 0; j < this.v; ++j) {
                buffer[i * this.v + j] = this.parameters[i][j];
            }
        }
    }

    @Override
    public double getParameter(int index) {
        return this.parameters[index / this.v][index % this.v];
    }

    @Override
    public void setParameters(double[] params) {
        for (int i = 0; i < this.k; ++i) {
            for (int j = 0; j < this.v; ++j) {
                this.parameters[i][j] = params[i * this.v + j];
            }
        }
    }

    @Override
    public void setParameter(int index, double value) {
        this.parameters[index / this.v][index % this.v] = value;
    }

    public void getValueGradient(double[] buffer, int group, int kLoc) {
        int i;
        int j;
        int i2;
        double[][] parameters = new double[this.k][this.v];
        for (int j2 = 0; j2 < this.v; ++j2) {
            for (i2 = 0; i2 < this.k; ++i2) {
                parameters[i2][j2] = this.parameters[i2][j2];
            }
        }
        int[] ids = this.groups[group];
        for (i2 = 0; i2 < ids.length; ++i2) {
            parameters[kLoc][ids[i2]] = 0.0;
        }
        for (i2 = 0; i2 < buffer.length; ++i2) {
            buffer[i2] = 0.0;
        }
        for (i2 = 0; i2 < this.k; ++i2) {
            for (j = 0; j < this.v; ++j) {
                buffer[i2 * this.v + j] = 2.0 * parameters[i2][j];
            }
        }
        ids = this.groups[group];
        double[] sumBetas = new double[this.k];
        for (i = 0; i < this.k; ++i) {
            for (int j3 = 0; j3 < this.v; ++j3) {
                int n = i;
                sumBetas[n] = sumBetas[n] + Math.exp(parameters[i][j3]) * this.p_v[j3];
            }
        }
        for (j = 0; j < this.v; ++j) {
            for (int i3 = 0; i3 < this.k; ++i3) {
                int n = i3 * this.v + j;
                buffer[n] = buffer[n] + (Gamma.digamma(sumBetas[i3] + (double)this.n_k[i3]) - Gamma.digamma(sumBetas[i3]) * Math.exp(parameters[i3][j])) * this.p_v[j];
                if (this.n_kv[j * this.k + i3] <= 0) continue;
                int n2 = i3 * this.v + j;
                buffer[n2] = buffer[n2] + (Gamma.digamma(Math.exp(parameters[i3][j]) * this.p_v[j]) - Gamma.digamma(Math.exp(parameters[i3][j]) * this.p_v[j] + (double)this.n_kv[j * this.k + i3])) * Math.exp(parameters[i3][j]) * this.p_v[j];
            }
        }
        for (i = 0; i < buffer.length; ++i) {
            buffer[i] = -buffer[i] / (double)(this.k * this.v);
        }
    }

    @Override
    public void getValueGradient(double[] buffer) {
        int i;
        int j;
        double tmp;
        double L = 0.0;
        int[] ids = null;
        ids = this.currentGroup == -1 ? this.all : this.groups[this.currentGroup];
        for (int i2 = 0; i2 < buffer.length; ++i2) {
            buffer[i2] = 0.0;
        }
        double[] sumBetas = new double[this.k];
        for (int i3 = 0; i3 < this.k; ++i3) {
            for (int j2 = 0; j2 < this.v; ++j2) {
                int n = i3;
                sumBetas[n] = sumBetas[n] + Math.exp(this.parameters[i3][j2]) * this.p_v[j2];
            }
        }
        if (this.currentGroup == -1) {
            tmp = 0.0;
            int i4 = 0;
            if (i4 < this.k) {
                // empty if block
            }
            for (j = 0; j < this.v; ++j) {
                for (i = 0; i < this.k; ++i) {
                    int n = i * this.v + j;
                    buffer[n] = buffer[n] + 2.0 * this.parameters[i][j];
                }
            }
            for (j = 0; j < this.v; ++j) {
                for (i = 0; i < this.k; ++i) {
                    int n = i * this.v + j;
                    buffer[n] = buffer[n] + (Gamma.digamma(sumBetas[i] + (double)this.n_k[i]) - Gamma.digamma(sumBetas[i])) * Math.exp(this.parameters[i][j]) * this.p_v[j];
                    if (this.n_kv[j * this.k + i] <= 0) continue;
                    int n2 = i * this.v + j;
                    buffer[n2] = buffer[n2] + (Gamma.digamma(Math.exp(this.parameters[i][j]) * this.p_v[j]) - Gamma.digamma(Math.exp(this.parameters[i][j]) * this.p_v[j] + (double)this.n_kv[j * this.k + i])) * Math.exp(this.parameters[i][j]) * this.p_v[j];
                }
            }
        } else {
            tmp = 0.0;
            for (j = 0; j < ids.length; ++j) {
                tmp += this.parameters[this.currentK][ids[j]] * this.parameters[this.currentK][ids[j]];
            }
            if (tmp == 0.0) {
                buffer = this.lastBuffer;
                return;
            }
            j = 0;
            if (j >= ids.length || this.currentGroup != -1) {
                // empty if block
            }
            for (j = 0; j < ids.length; ++j) {
                i = this.currentK;
                int n = i * this.v + ids[j];
                buffer[n] = buffer[n] + 2.0 * this.parameters[i][ids[j]];
            }
            for (j = 0; j < ids.length; ++j) {
                i = this.currentK;
                int n = i * this.v + ids[j];
                buffer[n] = buffer[n] + (Gamma.digamma(sumBetas[i] + (double)this.n_k[i]) - Gamma.digamma(sumBetas[i])) * Math.exp(this.parameters[i][ids[j]]) * this.p_v[ids[j]];
                if (this.n_kv[ids[j] * this.k + i] <= 0) continue;
                int n3 = i * this.v + ids[j];
                buffer[n3] = buffer[n3] + (Gamma.digamma(Math.exp(this.parameters[i][ids[j]]) * this.p_v[ids[j]]) - Gamma.digamma(Math.exp(this.parameters[i][ids[j]]) * this.p_v[ids[j]] + (double)this.n_kv[ids[j] * this.k + i])) * Math.exp(this.parameters[i][ids[j]]) * this.p_v[ids[j]];
            }
        }
        this.lastBuffer = new double[buffer.length];
        for (int i5 = 0; i5 < buffer.length; ++i5) {
            buffer[i5] = -buffer[i5] / (double)this.numTokens;
            this.lastBuffer[i5] = buffer[i5];
        }
    }

    @Override
    public double getValue() {
        int j;
        double res = 0.0;
        for (int j2 = 0; j2 < this.k; ++j2) {
            for (int i = 0; i < this.v; ++i) {
                res = Math.pow(this.parameters[j2][i], 2.0);
            }
        }
        if (res == 0.0 || Double.isNaN(res) || Double.isInfinite(res)) {
            return this.last;
        }
        double gl = 0.0;
        int t = 0;
        if (t < this.k) {
            // empty if block
        }
        double L = 0.0;
        for (int i = 0; i < this.k; ++i) {
            double tmp = 0.0;
            for (j = 0; j < this.v; ++j) {
                tmp += Math.exp(this.parameters[i][j]) * this.p_v[j];
            }
            L += Gamma.logGamma(tmp + (double)this.n_k[i]) - Gamma.logGamma(tmp);
        }
        double sumProd = 0.0;
        for (int i = 0; i < this.k; ++i) {
            for (j = 0; j < this.v; ++j) {
                double tmp = 0.0;
                if (this.n_kv[j * this.k + i] > 0) {
                    tmp += Gamma.logGamma(Math.exp(this.parameters[i][j]) * this.p_v[j]) - Gamma.logGamma(Math.exp(this.parameters[i][j]) * this.p_v[j] + (double)this.n_kv[j * this.k + i]);
                }
                sumProd += tmp;
            }
        }
        this.last = -1.0 * (L + sumProd + res + gl);
        this.last /= (double)this.numTokens;
        return -1.0 / (double)this.numTokens * (L + sumProd + res + gl);
    }
}

