/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import com.rapidminer.kobra.topicmodels.Corpus;
import java.io.PrintWriter;

public class LDA {
    int sNumDocs = 0;
    double sAlpha = 0.0;
    int numTopics;
    int numTerms;
    Corpus corpus;
    double alpha;
    double[][] log_prob_w;
    double[][] var_gamma;
    double[][] phi;
    double[][] class_word;
    double[] class_total;
    int EM_MAX_ITER = 1000;
    int VAR_MAX_ITER = 20;
    double EM_CONVERGED = 1.0E-4;
    boolean ESTIMATE_ALPHA = true;
    private double NEWTON_THRESH = 1.0E-5;
    private int MAX_ALPHA_ITER = 1000;
    double VAR_CONVERGED = 1.0E-6;

    public static void main(String[] args) {
        Corpus c = new Corpus("ap.dat", "vocab.txt");
        LDA lda = new LDA(4, c, 0.5);
        lda.runEM();
        lda.storeBeta();
    }

    public LDA(int nT, Corpus c, double initialAlpha) {
        int j;
        int i;
        System.out.println("topics\t\t: " + nT);
        System.out.println("initial alpha\t: " + initialAlpha);
        this.corpus = c;
        this.numTopics = nT;
        this.alpha = initialAlpha;
        this.sNumDocs = 0;
        int N = c.getNumDocs();
        this.var_gamma = new double[N][];
        for (int i2 = 0; i2 < N; ++i2) {
            this.var_gamma[i2] = new double[this.numTopics];
        }
        int max_length = this.corpus.maxCorpusLength();
        this.phi = new double[max_length][];
        for (i = 0; i < max_length; ++i) {
            this.phi[i] = new double[this.numTopics];
        }
        this.numTerms = this.corpus.getNumTerms();
        this.log_prob_w = new double[this.numTopics][];
        for (i = 0; i < this.numTopics; ++i) {
            this.log_prob_w[i] = new double[this.numTerms];
            for (j = 0; j < this.numTerms; ++j) {
                this.log_prob_w[i][j] = 0.0;
            }
        }
        this.class_total = new double[this.numTopics];
        this.class_word = new double[this.numTopics][];
        for (i = 0; i < this.numTopics; ++i) {
            this.class_total[i] = 0.0;
            this.class_word[i] = new double[this.numTerms];
            for (j = 0; j < this.numTerms; ++j) {
                this.class_word[i][j] = 1.0 / (double)this.numTerms + this.ldaRand();
                int n = i;
                this.class_total[n] = this.class_total[n] + this.class_word[i][j];
            }
        }
        this.ldaMLE(false);
    }

    void zeroInitialize() {
        for (int i = 0; i < this.numTopics; ++i) {
            this.class_total[i] = 0.0;
            this.class_word[i] = new double[this.numTerms];
            for (int j = 0; j < this.numTerms; ++j) {
                this.class_word[i][j] = 0.0;
            }
        }
        this.sNumDocs = 0;
        this.sAlpha = 0.0;
    }

    public void runEM() {
        int i = 0;
        double likelihood_old = 0.0;
        double converged = 1.0;
        while ((converged < 0.0 || converged > this.EM_CONVERGED || i <= 2) && i <= this.EM_MAX_ITER) {
            System.out.println("**** em iteration " + ++i + " ****");
            double likelihood = 0.0;
            this.zeroInitialize();
            for (int d = 0; d < this.corpus.getNumDocs(); ++d) {
                if (d % 1000 == 0) {
                    System.out.println("document " + d);
                }
                likelihood += this.docEstep(d);
            }
            this.ldaMLE(this.ESTIMATE_ALPHA);
            converged = (likelihood_old - likelihood) / likelihood_old;
            if (converged < 0.0) {
                this.VAR_MAX_ITER *= 2;
            }
            likelihood_old = likelihood;
            System.out.println("Likelihood = " + likelihood + ", delta = " + converged);
        }
    }

    double docEstep(int doc) {
        int k;
        int n = 0;
        double likelihood = this.lda_inference(doc);
        double gamma_sum = 0.0;
        for (k = 0; k < this.numTopics; ++k) {
            gamma_sum += this.var_gamma[doc][k];
            this.sAlpha += this.digamma(this.var_gamma[doc][k]);
        }
        this.sAlpha -= (double)this.numTopics * this.digamma(gamma_sum);
        for (int word : this.corpus.getWordIdsInDoc(doc)) {
            for (k = 0; k < this.numTopics; ++k) {
                double[] dArray = this.class_word[k];
                int n2 = word;
                dArray[n2] = dArray[n2] + (double)this.corpus.getWordFreqInDoc(doc, word) * this.phi[n][k];
                int n3 = k;
                this.class_total[n3] = this.class_total[n3] + (double)this.corpus.getWordFreqInDoc(doc, word) * this.phi[n][k];
            }
            ++n;
        }
        ++this.sNumDocs;
        return likelihood;
    }

    private double opt_alpha(double ss, int D, int K) {
        double df;
        double init_a = 100.0;
        int iter = 0;
        double log_a = Math.log(init_a);
        do {
            ++iter;
            double a = Math.exp(log_a);
            if (a == Double.NaN) {
                System.out.println("warning : alpha is nan; new init = " + (init_a *= 10.0));
                a = init_a;
                log_a = Math.log(a);
            }
            double f = this.alhood(a, ss, D, K);
            df = this.d_alhood(a, ss, D, K);
            double d2f = this.d2_alhood(a, D, K);
            log_a -= 1.0 / (d2f * a + df) * df;
            System.out.println("alpha maximization : " + f + "   " + df);
        } while (Math.abs(df) > this.NEWTON_THRESH && iter < this.MAX_ALPHA_ITER);
        return Math.exp(log_a);
    }

    public void ldaMLE(boolean estimate_alpha) {
        for (int k = 0; k < this.numTopics; ++k) {
            for (int w = 0; w < this.numTerms; ++w) {
                this.log_prob_w[k][w] = this.class_word[k][w] > 0.0 ? Math.log(this.class_word[k][w]) - Math.log(this.class_total[k]) : -100.0;
            }
        }
        if (estimate_alpha) {
            this.alpha = this.opt_alpha(this.sAlpha, this.sNumDocs, this.numTopics);
            System.out.println("new alpha = " + this.alpha);
        }
    }

    private double ldaRand() {
        return Math.random();
    }

    public void storeBeta() {
        try {
            PrintWriter writer = new PrintWriter("beta", "UTF-8");
            for (int i = 0; i < this.numTopics; ++i) {
                for (int j = 0; j < this.numTerms; ++j) {
                    writer.print(" " + this.log_prob_w[i][j]);
                }
                writer.println();
            }
            writer.close();
        }
        catch (Exception e) {
            System.out.println(e);
        }
    }

    double log_sum(double log_a, double log_b) {
        double v = log_a < log_b ? log_b + Math.log(1.0 + Math.exp(log_a - log_b)) : log_a + Math.log(1.0 + Math.exp(log_b - log_a));
        return v;
    }

    double lda_inference(int doc) {
        int n;
        int k;
        double converged = 1.0;
        double phisum = 0.0;
        double likelihood = 0.0;
        double likelihood_old = 0.0;
        double[] oldphi = new double[this.numTopics];
        double[] digamma_gam = new double[this.numTopics];
        for (k = 0; k < this.numTopics; ++k) {
            this.var_gamma[doc][k] = this.alpha + (double)this.corpus.getTotal(doc) / (double)this.numTopics;
            digamma_gam[k] = this.digamma(this.var_gamma[doc][k]);
            for (n = 0; n < this.corpus.getDocumentLength(doc); ++n) {
                this.phi[n][k] = 1.0 / (double)this.numTopics;
            }
        }
        int var_iter = 0;
        while (converged > this.VAR_CONVERGED && (var_iter < this.VAR_MAX_ITER || this.VAR_MAX_ITER == -1)) {
            ++var_iter;
            n = 0;
            for (int word : this.corpus.getWordIdsInDoc(doc)) {
                phisum = 0.0;
                for (k = 0; k < this.numTopics; ++k) {
                    oldphi[k] = this.phi[n][k];
                    this.phi[n][k] = digamma_gam[k] + this.log_prob_w[k][word];
                    phisum = k > 0 ? this.log_sum(phisum, this.phi[n][k]) : this.phi[n][k];
                }
                for (k = 0; k < this.numTopics; ++k) {
                    this.phi[n][k] = Math.exp(this.phi[n][k] - phisum);
                    double[] dArray = this.var_gamma[doc];
                    int n2 = k;
                    dArray[n2] = dArray[n2] + (double)this.corpus.getWordFreqInDoc(doc, word) * (this.phi[n][k] - oldphi[k]);
                    digamma_gam[k] = this.digamma(this.var_gamma[doc][k]);
                }
                ++n;
            }
            likelihood = this.compute_likelihood(doc);
            converged = (likelihood_old - likelihood) / likelihood_old;
            likelihood_old = likelihood;
        }
        return likelihood;
    }

    double compute_likelihood(int doc) {
        int k;
        double likelihood = 0.0;
        double digsum = 0.0;
        double var_gamma_sum = 0.0;
        double[] dig = new double[this.numTopics];
        for (k = 0; k < this.numTopics; ++k) {
            dig[k] = this.digamma(this.var_gamma[doc][k]);
            var_gamma_sum += this.var_gamma[doc][k];
        }
        digsum = this.digamma(var_gamma_sum);
        likelihood = this.lgamma(this.alpha * (double)this.numTopics) - (double)this.numTopics * this.lgamma(this.alpha) - this.lgamma(var_gamma_sum);
        for (k = 0; k < this.numTopics; ++k) {
            likelihood += (this.alpha - 1.0) * (dig[k] - digsum) + this.lgamma(this.var_gamma[doc][k]) - (this.var_gamma[doc][k] - 1.0) * (dig[k] - digsum);
            int n = 0;
            for (int word : this.corpus.getWordIdsInDoc(doc)) {
                if (this.phi[n][k] > 0.0) {
                    likelihood += (double)this.corpus.getWordFreqInDoc(doc, word) * (this.phi[n][k] * (dig[k] - digsum - Math.log(this.phi[n][k]) + this.log_prob_w[k][word]));
                }
                ++n;
            }
        }
        return likelihood;
    }

    private double alhood(double a, double ss, int D, int K) {
        return (double)D * (this.lgamma((double)K * a) - (double)K * this.lgamma(a)) + (a - 1.0) * ss;
    }

    private double d_alhood(double a, double ss, int D, int K) {
        return (double)D * ((double)K * this.digamma((double)K * a) - (double)K * this.digamma(a)) + ss;
    }

    private double d2_alhood(double a, int D, int K) {
        return (double)D * ((double)(K * K) * this.trigamma((double)K * a) - (double)K * this.trigamma(a));
    }

    private double trigamma(double x) {
        double p = 1.0 / ((x += 6.0) * x);
        p = (((((0.075757575757576 * p - 0.033333333333333) * p + 0.0238095238095238) * p - 0.033333333333333) * p + 0.166666666666667) * p + 1.0) / x + 0.5 * p;
        for (int i = 0; i < 6; ++i) {
            p = 1.0 / ((x -= 1.0) * x) + p;
        }
        return p;
    }

    private double digamma(double x) {
        double p = 1.0 / ((x += 6.0) * x);
        p = (((0.004166666666667 * p - 0.003968253986254) * p + 0.008333333333333) * p - 0.083333333333333) * p;
        p = p + Math.log(x) - 0.5 / x - 1.0 / (x - 1.0) - 1.0 / (x - 2.0) - 1.0 / (x - 3.0) - 1.0 / (x - 4.0) - 1.0 / (x - 5.0) - 1.0 / (x - 6.0);
        return p;
    }

    private double lgamma(double x) {
        double z = 1.0 / (x * x);
        z = (((-5.95238095238E-4 * z + 7.93650793651E-4) * z - 0.002777777777778) * z + 0.083333333333333) / (x += 6.0);
        z = (x - 0.5) * Math.log(x) - x + 0.918938533204673 + z - Math.log(x - 1.0) - Math.log(x - 2.0) - Math.log(x - 3.0) - Math.log(x - 4.0) - Math.log(x - 5.0) - Math.log(x - 6.0);
        return z;
    }
}

