/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.optimize.Optimizer;
import com.rapidminer.kobra.opt.MyOrthantWiseLimitedMemoryBFGS;
import com.rapidminer.kobra.topicmodels.MySparseGroupOptimizable;
import com.rapidminer.kobra.topicmodels.MySparseGroupWordFeatOptimizable;
import com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures;
import gnu.trove.list.array.TIntArrayList;
import java.util.Random;

public class SamplersLDAGroupWordFeatures
extends SamplersLDAWordFeatures {
    double[] y_v;
    double[] y_kv;
    double[] probs;
    int[] tokenToTopic = null;
    int[][] groups = null;
    double a = 0.01;
    boolean prox = false;
    double[][] phi = null;
    double[][] theta = null;
    int numStats = 0;

    @Override
    public int[] getTokenToTopic() {
        return this.tokenToTopic;
    }

    public void train() {
    }

    public double s(double l, double p) {
        double tmp = Math.abs(l) - p;
        tmp = tmp < 0.0 ? 0.0 : tmp;
        return Math.signum(l) * tmp;
    }

    public double[] S(double[] dl, double p) {
        double[] tmp = new double[dl.length];
        for (int i = 0; i < tmp.length; ++i) {
            tmp[i] = this.s(dl[i], p);
        }
        return tmp;
    }

    @Override
    public void GibbsSampling() {
        int burnIn = (int)((double)this.maxIter * 0.9);
        int topic = 0;
        this.WBETA = (double)this.numWords * this.BETA;
        this.WBETAs = new double[this.numTopics];
        for (int i = 0; i < this.numTopics; ++i) {
            this.WBETAs[i] = 0.0;
            for (int j = 0; j < this.numWords; ++j) {
                int n = i;
                this.WBETAs[n] = this.WBETAs[n] + this.b[j * this.numTopics + i];
            }
        }
        this.probs = new double[this.numTopics];
        this.tokenToTopic = new int[this.numTokens];
        TIntArrayList ll = new TIntArrayList(this.numTokens);
        for (int i = 0; i < this.numTokens; ++i) {
            ll.add(i);
        }
        ll.shuffle(new Random(2000L));
        for (int iter = 0; iter < this.maxIter; ++iter) {
            int i;
            if (iter > this.maxIter / 5 && iter % 10 == 0 && iter < burnIn) {
                MySparseGroupWordFeatOptimizable optimizable = new MySparseGroupWordFeatOptimizable(this.numTopics, this.numWords, ll.size(), this.rn);
                if (this.prox) {
                    optimizable = new MySparseGroupOptimizable(this.numTopics, this.numWords);
                }
                optimizable.Phi = this.Phi;
                optimizable.a = this.a;
                optimizable.groups = this.groups;
                if (this.p_v != null) {
                    optimizable.p_v = this.p_v;
                }
                optimizable.n_kv = this.wordtopiccounts;
                optimizable.n_k = this.topiccounts;
                optimizable.lambda = this.LAMBDA;
                optimizable.sigma = this.LAMBDA;
                this.parameters = optimizable.parameters;
                Optimizer optimizer = null;
                boolean converged = false;
                try {
                    TIntArrayList tt = new TIntArrayList();
                    for (int t = 0; t < this.numTopics; ++t) {
                        tt.add(t);
                    }
                    TIntArrayList gg = new TIntArrayList();
                    for (int i2 = 0; i2 < this.groups.length; ++i2) {
                        gg.add(i2);
                    }
                    for (int outer = 0; outer < 1; ++outer) {
                        for (int ti = 0; ti < this.numTopics; ++ti) {
                            int t = tt.get(ti);
                            gg.shuffle(new Random());
                            for (int id = 0; id < this.groups.length; ++id) {
                                int k;
                                int i3 = gg.get(id);
                                if (this.prox) {
                                    optimizable.currentGroup = i3;
                                    optimizable.currentK = t;
                                    try {
                                        optimizable.init(i3, t, this.rn);
                                        optimizer = new MyOrthantWiseLimitedMemoryBFGS(optimizable, this.GAMMA);
                                        ((MyOrthantWiseLimitedMemoryBFGS)optimizer).prox = true;
                                        ((MyOrthantWiseLimitedMemoryBFGS)optimizer).gl = true;
                                        ((MyOrthantWiseLimitedMemoryBFGS)optimizer).lambda = this.a;
                                        converged = optimizer.optimize(1);
                                    }
                                    catch (OptimizationException e) {
                                        e.printStackTrace();
                                    }
                                    continue;
                                }
                                double[] dl = new double[this.numTopics * this.numWords];
                                optimizable.getValueGradient(dl, i3, t);
                                double norm = 0.0;
                                for (k = 0; k < dl.length; ++k) {
                                    norm = dl[k] * dl[k];
                                }
                                if ((norm = Math.sqrt(norm)) < this.a * Math.sqrt(this.groups[i3].length) && norm >= 0.0) {
                                    this.paras = new double[this.numTopics * this.numWords];
                                    optimizable.getParameters(this.paras);
                                    for (k = 0; k < this.numTopics; ++k) {
                                        for (int j = 0; j < this.numWords; ++j) {
                                            this.parameters[k][j] = this.paras[k * this.numWords + j];
                                        }
                                    }
                                    for (int j = 0; j < this.groups[i3].length; ++j) {
                                        this.parameters[t][this.groups[i3][j]] = 0.0;
                                    }
                                    optimizable.parameters = this.parameters;
                                    continue;
                                }
                                optimizable.currentGroup = i3;
                                optimizable.currentK = t;
                                try {
                                    optimizable.init(i3, t, this.rn);
                                    optimizer = !this.reg ? new LimitedMemoryBFGS(optimizable) : new MyOrthantWiseLimitedMemoryBFGS(optimizable, this.GAMMA);
                                    converged = optimizer.optimize(1);
                                    continue;
                                }
                                catch (OptimizationException e) {
                                    e.printStackTrace();
                                }
                            }
                        }
                    }
                }
                catch (OptimizationException e) {
                    e.printStackTrace();
                }
                this.paras = new double[this.numTopics * this.numWords];
                this.parameters = optimizable.parameters;
                this.WBETA = 0.0;
                for (int i4 = 0; i4 < this.numTopics; ++i4) {
                    this.WBETAs[i4] = 0.0;
                }
                int numFeatures = this.numWords;
                for (int i5 = 0; i5 < this.numTopics; ++i5) {
                    for (int j = 0; j < this.numWords; ++j) {
                        this.paras[i5 * this.numWords + j] = this.parameters[i5][j];
                        this.b[j * this.numTopics + i5] = this.p_v != null ? Math.exp(this.parameters[i5][j]) * this.p_v[j] : Math.exp(this.parameters[i5][j]) * 1.0;
                        int n = i5;
                        this.WBETAs[n] = this.WBETAs[n] + this.b[j * this.numTopics + i5];
                    }
                }
                System.out.println("Convergence: " + converged);
            }
            for (int ii = 0; ii < this.numTokens; ++ii) {
                int i6 = ll.get(ii);
                int wi = this.words[i6];
                int di = this.docs[i6];
                int n = topic = this.topics[i6];
                this.topiccounts[n] = this.topiccounts[n] - 1;
                int wioffset = wi * this.numTopics;
                int dioffset = di * this.numTopics;
                int n2 = wioffset + topic;
                this.wordtopiccounts[n2] = this.wordtopiccounts[n2] - 1;
                int n3 = dioffset + topic;
                this.doctopiccounts[n3] = this.doctopiccounts[n3] - 1;
                double totprob = 0.0;
                for (int j = 0; j < this.numTopics; ++j) {
                    this.probs[j] = ((double)this.wordtopiccounts[wioffset + j] + this.b[wi * this.numTopics + j]) / ((double)this.topiccounts[j] + this.WBETAs[j]) * ((double)this.doctopiccounts[dioffset + j] + this.ALPHA);
                    totprob += this.probs[j];
                }
                double r = totprob * this.rn.nextDouble();
                topic = 0;
                for (double max = this.probs[0]; r > max; max += this.probs[++topic]) {
                }
                this.topics[i6] = topic;
                int n4 = wioffset + topic;
                this.wordtopiccounts[n4] = this.wordtopiccounts[n4] + 1;
                int n5 = dioffset + topic;
                this.doctopiccounts[n5] = this.doctopiccounts[n5] + 1;
                int n6 = topic;
                this.topiccounts[n6] = this.topiccounts[n6] + 1;
                this.tokenToTopic[ii] = topic;
            }
            for (i = 0; i < this.numWords * this.numTopics; ++i) {
                if (this.wordtopiccounts[i] >= 0) continue;
                this.wordtopiccounts[i] = 0;
            }
            for (i = 0; i < this.numDocs * this.numTopics; ++i) {
                if (this.doctopiccounts[i] >= 0) continue;
                this.doctopiccounts[i] = 0;
            }
            if (iter < burnIn || !(this.rn.nextDouble() < 0.5)) continue;
            this.updateDistributions();
        }
    }

    @Override
    public double[][] getBetas() {
        double[][] betas = new double[this.numTopics][this.numWords];
        for (int i = 0; i < this.numTopics; ++i) {
            for (int j = 0; j < this.numWords; ++j) {
                betas[i][j] = this.paras[i * this.numWords + j];
            }
        }
        return betas;
    }

    @Override
    public int[] assignedTopicsToWords() {
        int[] res = new int[this.numWords];
        double max = 0.0;
        for (int i = 0; i < this.numWords; ++i) {
            max = 0.0;
            for (int j = 0; j < this.numTopics; ++j) {
                if (!(max < (double)this.wordtopiccounts[i * this.numTopics + j])) continue;
                max = this.wordtopiccounts[i * this.numTopics + j];
                res[i] = j;
            }
        }
        return res;
    }

    @Override
    public double[] assignedTopicsToWordsProbs() {
        double[] res = new double[this.numWords];
        double max = 0.0;
        for (int i = 0; i < this.numWords; ++i) {
            max = 0.0;
            double s = 0.0;
            for (int j = 0; j < this.numTopics; ++j) {
                s += (double)this.wordtopiccounts[i * this.numTopics + j];
                if (!(max < (double)this.wordtopiccounts[i * this.numTopics + j])) continue;
                max = this.wordtopiccounts[i * this.numTopics + j];
                res[i] = this.wordtopiccounts[i * this.numTopics + j];
            }
        }
        return res;
    }

    @Override
    public int[] assignedTopicsToDocs() {
        int[] res = new int[this.numDocs];
        double max = 0.0;
        for (int i = 0; i < this.numDocs; ++i) {
            max = 0.0;
            for (int j = 0; j < this.numTopics; ++j) {
                if (!(max < (double)this.doctopiccounts[i * this.numTopics + j])) continue;
                max = this.doctopiccounts[i * this.numTopics + j];
                res[i] = j;
            }
        }
        return res;
    }

    @Override
    public void updateDistributions() {
        ++this.numStats;
        this.updateWordDistribution();
        this.updateDocumentDistribution();
    }

    @Override
    public void updateWordDistribution() {
        if (this.theta == null) {
            this.theta = new double[this.numTopics][this.numWords];
        }
        for (int i = 0; i < this.numWords; ++i) {
            for (int j = 0; j < this.numTopics; ++j) {
                double[] dArray = this.theta[j];
                int n = i;
                dArray[n] = dArray[n] + ((double)this.wordtopiccounts[i * this.numTopics + j] + this.b[i * this.numTopics + j]) / ((double)this.topiccounts[j] + this.WBETAs[j]);
            }
        }
    }

    @Override
    public void updateDocumentDistribution() {
        if (this.phi == null) {
            this.phi = new double[this.numTopics][this.numDocs];
        }
        for (int i = 0; i < this.numDocs; ++i) {
            int j;
            int docCounts = 0;
            for (j = 0; j < this.numTopics; ++j) {
                docCounts += this.doctopiccounts[i * this.numTopics + j];
            }
            for (j = 0; j < this.numTopics; ++j) {
                double[] dArray = this.phi[j];
                int n = i;
                dArray[n] = dArray[n] + ((double)this.doctopiccounts[i * this.numTopics + j] + this.ALPHA) / ((double)docCounts + (double)this.numTopics * this.ALPHA);
            }
        }
    }

    @Override
    public double[][] wordDistribution() {
        double[][] res = new double[this.numTopics][this.numWords];
        if (this.numStats > 0) {
            for (int i = 0; i < this.numWords; ++i) {
                for (int j = 0; j < this.numTopics; ++j) {
                    res[j][i] = this.theta[j][i] / (double)this.numStats;
                }
            }
            return res;
        }
        for (int i = 0; i < this.numWords; ++i) {
            for (int j = 0; j < this.numTopics; ++j) {
                res[j][i] = ((double)this.wordtopiccounts[i * this.numTopics + j] + this.b[i * this.numTopics + j]) / ((double)this.topiccounts[j] + this.WBETAs[j]);
            }
        }
        return res;
    }

    @Override
    public double[][] documentDistribution() {
        double[][] res = new double[this.numTopics][this.numDocs];
        if (this.numStats > 0) {
            for (int i = 0; i < this.numDocs; ++i) {
                for (int j = 0; j < this.numTopics; ++j) {
                    res[j][i] = this.phi[j][i] / (double)this.numStats;
                }
            }
            return res;
        }
        for (int i = 0; i < this.numDocs; ++i) {
            int j;
            int docCounts = 0;
            for (j = 0; j < this.numTopics; ++j) {
                docCounts += this.doctopiccounts[i * this.numTopics + j];
            }
            for (j = 0; j < this.numTopics; ++j) {
                res[j][i] = ((double)this.doctopiccounts[i * this.numTopics + j] + this.ALPHA) / ((double)docCounts + (double)this.numTopics * this.ALPHA);
            }
        }
        return res;
    }

    public static void main(String[] args) {
    }
}

