/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.OptimizationException;
import com.rapidminer.kobra.opt.MyOrthantWiseLimitedMemoryBFGS;
import com.rapidminer.kobra.topicmodels.MyDMROptimizable;
import com.rapidminer.kobra.topicmodels.SamplersLDA;
import gnu.trove.list.array.TIntArrayList;
import java.util.Random;

public class SamplersDMRLDA
extends SamplersLDA {
    int[] docCounts = null;
    double[][] features;
    int numFeatures;
    double[] alphas = null;
    public int seed = 2000;
    double[] probs;
    int[] tokenToTopic = null;
    double sigma = 0.1;
    double lambda = 0.1;

    @Override
    public void init(int[] docIds, int[] wordIds, int numTopics, int numWords, int numDocs, int iter, double beta, double alpha, boolean locSeed, int seed) {
        this.maxIter = iter;
        this.BETA = beta;
        this.ALPHA = alpha;
        int topic = 0;
        this.numTokens = wordIds.length;
        this.numTopics = numTopics;
        this.numDocs = numDocs;
        this.numWords = numWords;
        this.topics = new int[this.numTokens];
        this.wordtopiccounts = new int[numWords * numTopics];
        this.doctopiccounts = new int[numDocs * numTopics];
        this.topiccounts = new int[numTopics];
        this.alphas = new double[numDocs];
        this.docCounts = new int[numDocs];
        this.words = wordIds;
        this.docs = docIds;
        if (locSeed) {
            this.seed = seed;
            this.rn = new Random(seed);
        } else {
            this.rn = new Random();
        }
        for (int i = 0; i < wordIds.length; ++i) {
            int di;
            int wi = this.words[i];
            int n = di = this.docs[i];
            this.docCounts[n] = this.docCounts[n] + 1;
            this.topics[i] = topic = this.rn.nextInt(numTopics);
            int n2 = wi * numTopics + topic;
            this.wordtopiccounts[n2] = this.wordtopiccounts[n2] + 1;
            int n3 = di * numTopics + topic;
            this.doctopiccounts[n3] = this.doctopiccounts[n3] + 1;
            int n4 = topic;
            this.topiccounts[n4] = this.topiccounts[n4] + 1;
        }
    }

    @Override
    public int[] getTokenToTopic() {
        return this.tokenToTopic;
    }

    @Override
    public void GibbsSampling() {
        int i;
        int topic = 0;
        this.WBETA = (double)this.numWords * this.BETA;
        this.probs = new double[this.numTopics];
        this.tokenToTopic = new int[this.numTokens];
        TIntArrayList ll = new TIntArrayList(this.numTokens);
        for (int i2 = 0; i2 < this.numTokens; ++i2) {
            ll.add(i2);
        }
        double[] paras = new double[this.numFeatures * this.numTopics];
        ll.shuffle(this.rn);
        for (int iter = 0; iter < this.maxIter; ++iter) {
            if (iter % 100 == 0) {
                System.out.println("do optimization " + iter / 100);
                MyDMROptimizable opt = new MyDMROptimizable(this.numFeatures, this.numTopics, this.numDocs);
                opt.n_d = this.docCounts;
                opt.n_td = this.doctopiccounts;
                opt.documentFeatures = this.features;
                opt.sigma = this.sigma;
                opt.init(this.rn);
                MyOrthantWiseLimitedMemoryBFGS optimizer = new MyOrthantWiseLimitedMemoryBFGS(opt, this.lambda);
                boolean converged = false;
                try {
                    converged = optimizer.optimize(1000);
                }
                catch (OptimizationException e) {
                    e.printStackTrace();
                }
                opt.getParameters(paras);
            }
            for (int ii = 0; ii < this.numTokens; ++ii) {
                int i3 = ll.get(ii);
                int wi = this.words[i3];
                int di = this.docs[i3];
                double[] f = this.features[di];
                double[] alpha = new double[this.numTopics];
                for (int k = 0; k < this.numTopics; ++k) {
                    double res = 0.0;
                    for (int l = 0; l < this.numFeatures; ++l) {
                        res = f[l] * paras[k * this.numFeatures + l];
                    }
                    alpha[k] = Math.exp(res);
                }
                int n = topic = this.topics[i3];
                this.topiccounts[n] = this.topiccounts[n] - 1;
                int wioffset = wi * this.numTopics;
                int dioffset = di * this.numTopics;
                int n2 = wioffset + topic;
                this.wordtopiccounts[n2] = this.wordtopiccounts[n2] - 1;
                int n3 = dioffset + topic;
                this.doctopiccounts[n3] = this.doctopiccounts[n3] - 1;
                double totprob = 0.0;
                for (int j = 0; j < this.numTopics; ++j) {
                    this.probs[j] = ((double)this.wordtopiccounts[wioffset + j] + this.BETA) / ((double)this.topiccounts[j] + this.WBETA) * ((double)this.doctopiccounts[dioffset + j] + alpha[j]);
                    totprob += this.probs[j];
                }
                double r = totprob * this.rn.nextDouble();
                topic = 0;
                for (double max = this.probs[0]; r > max; max += this.probs[++topic]) {
                }
                this.topics[i3] = topic;
                int n4 = wioffset + topic;
                this.wordtopiccounts[n4] = this.wordtopiccounts[n4] + 1;
                int n5 = dioffset + topic;
                this.doctopiccounts[n5] = this.doctopiccounts[n5] + 1;
                int n6 = topic;
                this.topiccounts[n6] = this.topiccounts[n6] + 1;
                this.tokenToTopic[ii] = topic;
            }
        }
        for (i = 0; i < this.numWords * this.numTopics; ++i) {
            if (this.wordtopiccounts[i] >= 0) continue;
            this.wordtopiccounts[i] = 0;
        }
        for (i = 0; i < this.numDocs * this.numTopics; ++i) {
            if (this.doctopiccounts[i] >= 0) continue;
            this.doctopiccounts[i] = 0;
        }
    }

    @Override
    public int[] assignedTopicsToWords() {
        int[] res = new int[this.numWords];
        double max = 0.0;
        for (int i = 0; i < this.numWords; ++i) {
            max = 0.0;
            for (int j = 0; j < this.numTopics; ++j) {
                if (!(max < (double)this.wordtopiccounts[i * this.numTopics + j])) continue;
                max = this.wordtopiccounts[i * this.numTopics + j];
                res[i] = j;
            }
        }
        return res;
    }

    @Override
    public double[] assignedTopicsToWordsProbs() {
        double[] res = new double[this.numWords];
        double max = 0.0;
        for (int i = 0; i < this.numWords; ++i) {
            max = 0.0;
            double s = 0.0;
            for (int j = 0; j < this.numTopics; ++j) {
                s += (double)this.wordtopiccounts[i * this.numTopics + j];
                if (!(max < (double)this.wordtopiccounts[i * this.numTopics + j])) continue;
                max = this.wordtopiccounts[i * this.numTopics + j];
                res[i] = this.wordtopiccounts[i * this.numTopics + j];
            }
        }
        return res;
    }

    @Override
    public int[] assignedTopicsToDocs() {
        int[] res = new int[this.numDocs];
        double max = 0.0;
        for (int i = 0; i < this.numDocs; ++i) {
            max = 0.0;
            for (int j = 0; j < this.numTopics; ++j) {
                if (!(max < (double)this.doctopiccounts[i * this.numTopics + j])) continue;
                max = this.doctopiccounts[i * this.numTopics + j];
                res[i] = j;
            }
        }
        return res;
    }

    public static void main(String[] args) {
    }
}

