/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import com.rapidminer.kobra.topicmodels.MyWordFeatOptimizable;
import org.apache.commons.math3.special.Gamma;

public class MyWordFeatRegOptimizable
extends MyWordFeatOptimizable {
    public MyWordFeatRegOptimizable(int k, int v) {
        super(k, v);
    }

    @Override
    public void getValueGradient(double[] buffer) {
        int i;
        double L = 0.0;
        for (int i2 = 0; i2 < buffer.length; ++i2) {
            buffer[i2] = 0.0;
        }
        double[] sumBetas = new double[this.k];
        for (i = 0; i < this.k; ++i) {
            for (int j = 0; j < this.v; ++j) {
                int n = i;
                sumBetas[n] = sumBetas[n] + Math.exp(this.parameters[i][j]) * this.p_v[j];
            }
        }
        for (int j = 0; j < this.v; ++j) {
            for (int i3 = 0; i3 < this.k; ++i3) {
                int n = i3 * this.v + j;
                buffer[n] = buffer[n] + (Gamma.digamma(sumBetas[i3] + (double)this.n_k[i3]) - Gamma.digamma(sumBetas[i3]) * Math.exp(this.parameters[i3][j]) * this.p_v[j]);
                if (this.n_kv[j * this.k + i3] <= 0) continue;
                int n2 = i3 * this.v + j;
                buffer[n2] = buffer[n2] + (Gamma.digamma(Math.exp(this.parameters[i3][j]) * this.p_v[j]) - Gamma.digamma(Math.exp(this.parameters[i3][j]) * this.p_v[j] + (double)this.n_kv[j * this.k + i3]) * Math.exp(this.parameters[i3][j]) * this.p_v[j]);
            }
        }
        for (i = 0; i < buffer.length; ++i) {
            buffer[i] = -buffer[i];
        }
    }

    @Override
    public double getValue() {
        int j;
        double res = 0.0;
        double L = 0.0;
        for (int i = 0; i < this.k; ++i) {
            double tmp = 0.0;
            for (j = 0; j < this.v; ++j) {
                tmp += Math.exp(this.parameters[i][j]) * this.p_v[j];
            }
            L += Gamma.logGamma(tmp + (double)this.n_k[i]) - Gamma.logGamma(tmp);
        }
        double sumProd = 0.0;
        for (int i = 0; i < this.k; ++i) {
            for (j = 0; j < this.v; ++j) {
                double tmp = 0.0;
                if (this.n_kv[j * this.k + i] > 0) {
                    tmp += Gamma.logGamma(Math.exp(this.parameters[i][j]) * this.p_v[j]) - Gamma.logGamma(Math.exp(this.parameters[i][j]) * this.p_v[j] + (double)this.n_kv[j * this.k + i]);
                }
                sumProd += tmp;
            }
        }
        return -1.0 * (L + sumProd + res);
    }
}

