/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.util.Random;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

public class SamplersLDAWordRegularize
extends SamplersLDAWordFeatures {
    TDoubleArrayList[] graphWeights;
    public double nu = 1.0;
    int reg_iter = 100;
    Array2DRowRealMatrix S = new Array2DRowRealMatrix();
    double[][] phi = null;
    double[][] theta = null;
    int numStats = 0;

    public void init(int[] docIds, int[] wordIds, int numTopics, int numWords, int numDocs, int iter, double beta, double alpha) {
        int i;
        this.maxIter = iter;
        this.BETA = beta;
        this.ALPHA = alpha;
        int topic = 0;
        this.numTokens = wordIds.length;
        this.numTopics = numTopics;
        this.numDocs = numDocs;
        this.numWords = numWords;
        this.topics = new int[this.numTokens];
        this.wordtopiccounts = new int[numWords * numTopics];
        this.doctopiccounts = new int[numDocs * numTopics];
        this.topiccounts = new int[numTopics];
        int[] wordcounts = new int[numWords];
        int[] doccounts = new int[numDocs];
        this.words = wordIds;
        this.docs = docIds;
        this.b = new double[numTopics * numWords];
        this.paras = new double[numTopics * numWords];
        this.parameters = new double[numTopics][numWords];
        if (this.p_v == null) {
            this.p_v = new double[numWords];
            for (i = 0; i < numWords; ++i) {
                this.p_v[i] = this.BETA;
            }
        }
        for (i = 0; i < numWords; ++i) {
            for (int j = 0; j < numTopics; ++j) {
                this.b[i * numTopics + j] = this.BETA;
                this.paras[i * numTopics + j] = 2.0 * Math.random() * this.LAMBDA - this.LAMBDA;
                this.parameters[j][i] = 2.0 * Math.random() * this.LAMBDA - this.LAMBDA;
                this.b[i * numTopics + j] = Math.exp(this.parameters[j][i]) * this.p_v[j];
            }
        }
        for (i = 0; i < wordIds.length; ++i) {
            int wi = this.words[i];
            int di = this.docs[i];
            this.topics[i] = topic = new Random().nextInt(numTopics);
            int n = wi * numTopics + topic;
            this.wordtopiccounts[n] = this.wordtopiccounts[n] + 1;
            int n2 = di * numTopics + topic;
            this.doctopiccounts[n2] = this.doctopiccounts[n2] + 1;
            int n3 = topic;
            this.topiccounts[n3] = this.topiccounts[n3] + 1;
        }
    }

    @Override
    public int[] getTokenToTopic() {
        return this.tokenToTopic;
    }

    public void train() {
    }

    @Override
    public void setReg(boolean r) {
        this.reg = r;
    }

    public double sum(double[] v) {
        double res = 0.0;
        for (int i = 0; i < v.length; ++i) {
            res += v[i];
        }
        return res;
    }

    @Override
    public void GibbsSampling() {
        int burnIn = (int)((double)this.maxIter * 0.9);
        int topic = 0;
        this.WBETA = (double)this.numWords * this.BETA;
        this.WBETAs = new double[this.numTopics];
        for (int i = 0; i < this.numTopics; ++i) {
            this.WBETAs[i] = 0.0;
            for (int j = 0; j < this.numWords; ++j) {
                int n = i;
                this.WBETAs[n] = this.WBETAs[n] + this.b[j * this.numTopics + i];
            }
        }
        this.probs = new double[this.numTopics];
        this.tokenToTopic = new int[this.numTokens];
        TIntArrayList ll = new TIntArrayList(this.numTokens);
        for (int i = 0; i < this.numTokens; ++i) {
            ll.add(i);
        }
        ll.shuffle(new Random(2000L));
        double[][] R = new double[this.numTopics][this.numWords];
        for (int t = 0; t < this.numTopics; ++t) {
            int i;
            double N = 0.0;
            R[t] = new double[this.numWords];
            for (i = 0; i < this.numWords; ++i) {
                R[t][i] = (double)this.wordtopiccounts[i * this.numTopics + t] + 2.0 * this.nu;
                N += R[t][i];
            }
            i = 0;
            while (i < this.numWords) {
                double[] dArray = R[t];
                int n = i++;
                dArray[n] = dArray[n] / N;
            }
        }
        this.S = new Array2DRowRealMatrix(this.numWords, this.numWords);
        for (int j = 0; j < this.numWords; ++j) {
            for (int k = 0; k < this.Phi[j].size(); ++k) {
                int w = this.Phi[j].get(k);
                double d = this.graphWeights[j].get(k);
                this.S.setEntry(j, w, d);
            }
        }
        Array2DRowRealMatrix N_wt = new Array2DRowRealMatrix(this.numWords, this.numTopics);
        RealMatrix Phi_wt = new Array2DRowRealMatrix(this.numWords, this.numTopics);
        for (int i = 0; i < this.numTopics; ++i) {
            Phi_wt.setColumnVector(i, N_wt.getColumnVector(i));
            Phi_wt.getColumnVector(i).mapDivideToSelf(Phi_wt.getColumnVector(i).getL1Norm());
        }
        Phi_wt = Phi_wt.scalarAdd(this.BETA);
        ArrayRealVector psi_val0 = new ArrayRealVector(this.numWords);
        psi_val0.mapAdd(this.BETA);
        boolean conv = false;
        for (int iter = 0; iter < this.maxIter; ++iter) {
            int i;
            if (iter > this.maxIter / 5 && iter % 10 == 0 && iter < burnIn) {
                if (!conv) {
                    int t;
                    N_wt = new Array2DRowRealMatrix(this.numWords, this.numTopics);
                    for (t = 0; t < this.numTopics; ++t) {
                        for (int i2 = 0; i2 < this.numWords; ++i2) {
                            N_wt.setEntry(i2, t, this.wordtopiccounts[i2 * this.numTopics + t]);
                        }
                    }
                    Phi_wt = new Array2DRowRealMatrix(this.numWords, this.numTopics);
                    for (i = 0; i < this.numTopics; ++i) {
                        Phi_wt.setColumnVector(i, N_wt.getColumnVector(i));
                        Phi_wt.setColumnVector(i, Phi_wt.getColumnVector(i).mapDivideToSelf(Phi_wt.getColumnVector(i).getL1Norm()));
                    }
                    Phi_wt = Phi_wt.scalarAdd(this.BETA);
                    for (t = 0; t < this.numTopics; ++t) {
                        Array2DRowRealMatrix phi_t = new Array2DRowRealMatrix(this.numWords, 1);
                        phi_t.setColumnVector(0, N_wt.getColumnVector(t));
                        phi_t.setColumnVector(0, phi_t.getColumnVector(0).mapAddToSelf(0.001));
                        phi_t.setColumnVector(0, phi_t.getColumnVector(0).mapDivideToSelf(phi_t.getColumnVector(0).getL1Norm()));
                        for (int iters = 0; iters < this.reg_iter; ++iters) {
                            RealMatrix num = this.S.multiply((RealMatrix)phi_t);
                            RealMatrix den = phi_t.transpose().multiply(num);
                            double den_d = den.getEntry(0, 0);
                            phi_t.setColumnVector(0, N_wt.getColumnVector(t).add(phi_t.getColumnVector(0).ebeMultiply(num.getColumnVector(0)).mapMultiplyToSelf(2.0 * this.nu).mapDivideToSelf(den_d)));
                            phi_t.setColumnVector(0, phi_t.getColumnVector(0).mapDivideToSelf(phi_t.getColumnVector(0).getL1Norm()));
                        }
                        Phi_wt.setColumnVector(t, phi_t.getColumnVector(0));
                        Phi_wt.setColumnVector(t, Phi_wt.getColumnVector(t).mapDivideToSelf(Phi_wt.getColumnVector(t).getL1Norm()));
                    }
                } else {
                    RealVector psi_val = psi_val0.mapDivide(((RealVector)psi_val0).getL1Norm());
                    for (int iters = 0; iters < this.reg_iter; ++iters) {
                    }
                }
            }
            for (int ii = 0; ii < this.numTokens; ++ii) {
                int i3 = ll.get(ii);
                int wi = this.words[i3];
                int di = this.docs[i3];
                int n = topic = this.topics[i3];
                this.topiccounts[n] = this.topiccounts[n] - 1;
                int wioffset = wi * this.numTopics;
                int dioffset = di * this.numTopics;
                int n2 = wioffset + topic;
                this.wordtopiccounts[n2] = this.wordtopiccounts[n2] - 1;
                int n3 = dioffset + topic;
                this.doctopiccounts[n3] = this.doctopiccounts[n3] - 1;
                double totprob = 0.0;
                for (int j = 0; j < this.numTopics; ++j) {
                    this.probs[j] = Phi_wt.getEntry(wi, j) * ((double)this.doctopiccounts[dioffset + j] + this.ALPHA);
                    totprob += this.probs[j];
                }
                double r = totprob * Math.random();
                topic = 0;
                for (double max = this.probs[0]; r > max; max += this.probs[++topic]) {
                }
                this.topics[i3] = topic;
                int n4 = wioffset + topic;
                this.wordtopiccounts[n4] = this.wordtopiccounts[n4] + 1;
                int n5 = dioffset + topic;
                this.doctopiccounts[n5] = this.doctopiccounts[n5] + 1;
                int n6 = topic;
                this.topiccounts[n6] = this.topiccounts[n6] + 1;
                this.tokenToTopic[ii] = topic;
            }
            for (i = 0; i < this.numWords * this.numTopics; ++i) {
                if (this.wordtopiccounts[i] >= 0) continue;
                this.wordtopiccounts[i] = 0;
            }
            for (i = 0; i < this.numDocs * this.numTopics; ++i) {
                if (this.doctopiccounts[i] >= 0) continue;
                this.doctopiccounts[i] = 0;
            }
            if (iter < burnIn || !(Math.random() < 0.5)) continue;
            this.updateDistributions();
        }
    }

    @Override
    public double[][] getBetas() {
        double[][] betas = new double[this.numTopics][this.numWords];
        for (int i = 0; i < this.numTopics; ++i) {
            for (int j = 0; j < this.numWords; ++j) {
                betas[i][j] = this.paras[i * this.numWords + j];
            }
        }
        return betas;
    }

    @Override
    public int[] assignedTopicsToWords() {
        int[] res = new int[this.numWords];
        double max = 0.0;
        for (int i = 0; i < this.numWords; ++i) {
            max = 0.0;
            for (int j = 0; j < this.numTopics; ++j) {
                if (!(max < (double)this.wordtopiccounts[i * this.numTopics + j])) continue;
                max = this.wordtopiccounts[i * this.numTopics + j];
                res[i] = j;
            }
        }
        return res;
    }

    @Override
    public double[] assignedTopicsToWordsProbs() {
        double[] res = new double[this.numWords];
        double max = 0.0;
        for (int i = 0; i < this.numWords; ++i) {
            max = 0.0;
            double s = 0.0;
            for (int j = 0; j < this.numTopics; ++j) {
                s += (double)this.wordtopiccounts[i * this.numTopics + j];
                if (!(max < (double)this.wordtopiccounts[i * this.numTopics + j])) continue;
                max = this.wordtopiccounts[i * this.numTopics + j];
                res[i] = this.wordtopiccounts[i * this.numTopics + j];
            }
        }
        return res;
    }

    @Override
    public int[] assignedTopicsToDocs() {
        int[] res = new int[this.numDocs];
        double max = 0.0;
        for (int i = 0; i < this.numDocs; ++i) {
            max = 0.0;
            for (int j = 0; j < this.numTopics; ++j) {
                if (!(max < (double)this.doctopiccounts[i * this.numTopics + j])) continue;
                max = this.doctopiccounts[i * this.numTopics + j];
                res[i] = j;
            }
        }
        return res;
    }

    @Override
    public void updateDistributions() {
        ++this.numStats;
        this.updateWordDistribution();
        this.updateDocumentDistribution();
    }

    @Override
    public void updateWordDistribution() {
        if (this.theta == null) {
            this.theta = new double[this.numTopics][this.numWords];
        }
        for (int i = 0; i < this.numWords; ++i) {
            for (int j = 0; j < this.numTopics; ++j) {
                double[] dArray = this.theta[j];
                int n = i;
                dArray[n] = dArray[n] + ((double)this.wordtopiccounts[i * this.numTopics + j] + this.b[i * this.numTopics + j]) / ((double)this.topiccounts[j] + this.WBETAs[j]);
            }
        }
    }

    @Override
    public void updateDocumentDistribution() {
        if (this.phi == null) {
            this.phi = new double[this.numTopics][this.numDocs];
        }
        for (int i = 0; i < this.numDocs; ++i) {
            int j;
            int docCounts = 0;
            for (j = 0; j < this.numTopics; ++j) {
                docCounts += this.doctopiccounts[i * this.numTopics + j];
            }
            for (j = 0; j < this.numTopics; ++j) {
                double[] dArray = this.phi[j];
                int n = i;
                dArray[n] = dArray[n] + ((double)this.doctopiccounts[i * this.numTopics + j] + this.ALPHA) / ((double)docCounts + (double)this.numTopics * this.ALPHA);
            }
        }
    }

    @Override
    public double[][] wordDistribution() {
        double[][] res = new double[this.numTopics][this.numWords];
        if (this.numStats > 0) {
            for (int i = 0; i < this.numWords; ++i) {
                for (int j = 0; j < this.numTopics; ++j) {
                    res[j][i] = this.theta[j][i] / (double)this.numStats;
                }
            }
            return res;
        }
        for (int i = 0; i < this.numWords; ++i) {
            for (int j = 0; j < this.numTopics; ++j) {
                res[j][i] = ((double)this.wordtopiccounts[i * this.numTopics + j] + this.b[i * this.numTopics + j]) / ((double)this.topiccounts[j] + this.WBETAs[j]);
            }
        }
        return res;
    }

    @Override
    public double[][] documentDistribution() {
        double[][] res = new double[this.numTopics][this.numDocs];
        if (this.numStats > 0) {
            for (int i = 0; i < this.numDocs; ++i) {
                for (int j = 0; j < this.numTopics; ++j) {
                    res[j][i] = this.phi[j][i] / (double)this.numStats;
                }
            }
            return res;
        }
        for (int i = 0; i < this.numDocs; ++i) {
            int j;
            int docCounts = 0;
            for (j = 0; j < this.numTopics; ++j) {
                docCounts += this.doctopiccounts[i * this.numTopics + j];
            }
            for (j = 0; j < this.numTopics; ++j) {
                res[j][i] = ((double)this.doctopiccounts[i * this.numTopics + j] + this.ALPHA) / ((double)docCounts + (double)this.numTopics * this.ALPHA);
            }
        }
        return res;
    }

    public static void main(String[] args) {
    }
}

