/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.Optimizable;
import gnu.trove.map.hash.TDoubleIntHashMap;
import org.apache.commons.math3.special.Beta;
import org.apache.commons.math3.special.Gamma;

public class MyHashBetaOptimizable
implements Optimizable.ByGradientValue {
    double alpha = 0.0;
    double beta = 0.0;
    double a = 0.0;
    double b = 0.0;
    double[] values;
    TDoubleIntHashMap vals = null;

    @Override
    public int getNumParameters() {
        return 2;
    }

    @Override
    public void getParameters(double[] buffer) {
        buffer[0] = this.a;
        buffer[1] = this.b;
    }

    @Override
    public double getParameter(int index) {
        return index == 0 ? this.a : this.b;
    }

    @Override
    public void setParameters(double[] params) {
        this.a = params[0];
        this.b = params[1];
    }

    @Override
    public void setParameter(int index, double value) {
        if (index == 0) {
            this.a = value;
        } else {
            this.b = value;
        }
    }

    @Override
    public void getValueGradient(double[] buffer) {
        int n = 0;
        this.alpha = Math.exp(this.a);
        this.beta = Math.exp(this.b);
        double s2 = 0.0;
        double s3 = 0.0;
        int nv = 0;
        n = 0;
        double gab = -Gamma.digamma(this.alpha + this.beta);
        double ga = Gamma.digamma(this.alpha);
        double gb = Gamma.digamma(this.beta);
        for (double k : this.vals.keys()) {
            nv = this.vals.get(k);
            n += nv;
            s2 -= (double)nv * Math.log(k);
            s3 += (double)nv * Math.log(1.0 - k);
        }
        buffer[0] = s2 + (double)n * (gab + ga);
        buffer[1] = s3 + (double)n * (gab + gb);
    }

    @Override
    public double getValue() {
        this.alpha = Math.exp(this.a);
        this.beta = Math.exp(this.b);
        int n = 0;
        double s1 = 0.0;
        double s2 = 0.0;
        double s3 = 0.0;
        int nv = 0;
        n = 0;
        for (double k : this.vals.keys()) {
            nv = this.vals.get(k);
            n += nv;
            s1 += (double)nv * Math.log(k);
            s2 += (double)nv * Math.log(1.0 - k);
        }
        return (double)(-n) * Beta.logBeta(this.alpha, this.beta) + (this.alpha - 1.0) * s1 + (this.beta - 1.0) * s2;
    }
}

