/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.OptimizationException;
import com.rapidminer.kobra.opt.MyOrthantWiseLimitedMemoryBFGS;
import com.rapidminer.kobra.topicmodels.MyHashGompertzOptimizable;
import com.rapidminer.kobra.topicmodels.SamplersHLDA;
import com.rapidminer.tools.RandomGenerator;
import gnu.trove.TIntArrayList;
import gnu.trove.map.hash.TDoubleIntHashMap;
import gnu.trove.map.hash.TIntIntHashMap;
import java.util.ArrayList;
import java.util.Random;

public class SamplersGompertzHLDA
extends SamplersHLDA {
    double[] times;
    double[][] pGompertz;
    double[][] pi;
    ArrayList<SamplersHLDA.NCRPNode> allCurrentNodes = null;
    int size = 0;
    double[] predictions = null;

    public static double dist(double x, double a, double b) {
        return b * Math.exp(-(b * x + a * Math.exp(-b * x))) * (1.0 + a * (1.0 - Math.exp(-b * x)));
    }

    @Override
    public void sampleTopics(int doc) {
        int token;
        int level;
        int seqLen = this.sequences[doc].length;
        int[] docLevels = this.levels[doc];
        SamplersHLDA.NCRPNode[] path = new SamplersHLDA.NCRPNode[this.numLevels];
        int[] levelCounts = new int[this.numLevels];
        SamplersHLDA.NCRPNode node = this.documentLeaves[doc];
        for (level = this.numLevels - 1; level >= 0; --level) {
            path[level] = node;
            node = node.parent;
        }
        double[] levelWeights = new double[this.numLevels];
        for (token = 0; token < seqLen; ++token) {
            int n = docLevels[token];
            levelCounts[n] = levelCounts[n] + 1;
        }
        for (token = 0; token < seqLen; ++token) {
            int topic;
            int type = this.sequences[doc][token];
            int n = docLevels[token];
            levelCounts[n] = levelCounts[n] - 1;
            node = path[docLevels[token]];
            int n2 = type;
            node.typeCounts[n2] = node.typeCounts[n2] - 1;
            --node.totalTokens;
            double sum = 0.0;
            for (level = 0; level < this.numLevels; ++level) {
                topic = this.NodeIdToTopic.get(path[level].nodeID);
                double[] pGompertz = path[level].pGompertz;
                levelWeights[level] = (this.alpha + (double)levelCounts[level]) * (this.eta + (double)path[level].typeCounts[type]) / (this.etaSum + (double)path[level].totalTokens) * SamplersGompertzHLDA.dist(this.times[doc], pGompertz[0], pGompertz[1]);
                sum += levelWeights[level];
            }
            level = 0;
            while (level < this.numLevels) {
                int n3 = level++;
                levelWeights[n3] = levelWeights[n3] / sum;
            }
            docLevels[token] = level = RandomGenerator.getGlobalRandomGenerator().randomIndex(levelWeights);
            int n4 = docLevels[token];
            levelCounts[n4] = levelCounts[n4] + 1;
            node = path[level];
            int n5 = type;
            node.typeCounts[n5] = node.typeCounts[n5] + 1;
            ++node.totalTokens;
            topic = this.NodeIdToTopic.get(node.nodeID);
            TDoubleIntHashMap hsValues = node.hsValues;
            int n6 = 1;
            if (hsValues.contains(this.times[doc])) {
                n6 += hsValues.get(this.times[doc]);
            }
            hsValues.put(this.times[doc], n6);
        }
    }

    public void init(int[] docIds, int[] wordIds, double[] times, int numTopics, int numWords, int numDocs, int iter, double beta, double alpha, boolean locSeed, int seed) {
        int i;
        this.times = times;
        this.maxIter = iter;
        this.BETA = beta;
        this.ALPHA = alpha;
        this.alpha = alpha;
        this.eta = beta;
        this.etaSum = (double)numWords * beta;
        int topic = 0;
        this.numTokens = wordIds.length;
        this.numTopics = numTopics;
        this.numDocs = numDocs;
        this.numWords = numWords;
        this.topics = new int[this.numTokens];
        this.wordtopiccounts = new int[numWords * numTopics];
        this.doctopiccounts = new int[numDocs * numTopics];
        this.topiccounts = new int[numTopics];
        this.words = wordIds;
        this.docs = docIds;
        if (locSeed) {
            this.seed = seed;
            this.rn = new Random(seed);
        } else {
            this.rn = new Random();
        }
        this.levels = new int[numDocs][];
        this.documentLeaves = new SamplersHLDA.NCRPNode[numDocs];
        this.sequences = new int[numDocs][];
        SamplersHLDA.NCRPNode[] path = new SamplersHLDA.NCRPNode[this.numLevels];
        this.rootNode = new SamplersHLDA.NCRPNode(this, numWords);
        TIntArrayList[] tmpTokens = new TIntArrayList[numDocs];
        for (i = 0; i < numDocs; ++i) {
            tmpTokens[i] = new TIntArrayList();
        }
        for (i = 0; i < wordIds.length; ++i) {
            int wi = this.words[i];
            int di = this.docs[i];
            tmpTokens[di].add(wi);
            this.topics[i] = topic = this.rn.nextInt(numTopics);
            int n = wi * numTopics + topic;
            this.wordtopiccounts[n] = this.wordtopiccounts[n] + 1;
            int n2 = di * numTopics + topic;
            this.doctopiccounts[n2] = this.doctopiccounts[n2] + 1;
            int n3 = topic;
            this.topiccounts[n3] = this.topiccounts[n3] + 1;
        }
        for (i = 0; i < numDocs; ++i) {
            tmpTokens[i].shuffle(this.rn);
            this.sequences[i] = tmpTokens[i].toNativeArray();
        }
        this.documentLeaves = new SamplersHLDA.NCRPNode[numDocs];
        for (int doc = 0; doc < numDocs; ++doc) {
            int seqLen = this.sequences[doc].length;
            path[0] = this.rootNode;
            ++this.rootNode.customers;
            for (int level = 1; level < this.numLevels; ++level) {
                path[level] = path[level - 1].select();
                ++path[level].customers;
            }
            this.node = path[this.numLevels - 1];
            this.levels[doc] = new int[seqLen];
            this.documentLeaves[doc] = this.node;
            for (int token = 0; token < seqLen; ++token) {
                int type = this.sequences[doc][token];
                this.levels[doc][token] = this.rn.nextInt(this.numLevels);
                this.node = path[this.levels[doc][token]];
                ++this.node.totalTokens;
                int n = type;
                this.node.typeCounts[n] = this.node.typeCounts[n] + 1;
            }
        }
        this.allCurrentNodes = new ArrayList();
        this.NodeIdToTopic = new TIntIntHashMap();
        System.out.println(this.totalNodes);
    }

    @Override
    public void countNodes(SamplersHLDA.NCRPNode node) {
        ++this.size;
        for (SamplersHLDA.NCRPNode child : node.children) {
            this.countNodes(child);
        }
    }

    public void numerateNodes(SamplersHLDA.NCRPNode node) {
        this.allCurrentNodes.add(node);
        for (SamplersHLDA.NCRPNode child : node.children) {
            this.numerateNodes(child);
            node.hsValues.putAll(child.hsValues);
        }
    }

    @Override
    public void GibbsSampling() {
        for (int iter = 0; iter < this.maxIter; ++iter) {
            int doc;
            for (doc = 0; doc < this.numDocs; ++doc) {
                this.samplePath(doc, iter);
            }
            this.next = 0;
            this.allCurrentNodes = new ArrayList();
            this.numerateNodes(this.rootNode);
            this.numTopics = this.allCurrentNodes.size();
            for (doc = 0; doc < this.numDocs; ++doc) {
                this.sampleTopics(doc);
            }
            this.pGompertz = new double[this.numTopics][2];
            int s = 0;
            for (int i = 0; i < this.numTopics; ++i) {
                SamplersHLDA.NCRPNode node = this.allCurrentNodes.get(i);
                MyHashGompertzOptimizable opt = new MyHashGompertzOptimizable();
                opt.vals = node.hsValues;
                s += opt.vals.size();
                opt.alpha = this.rn.nextDouble();
                opt.beta = this.rn.nextDouble();
                opt.a = this.rn.nextDouble();
                opt.b = this.rn.nextDouble();
                MyOrthantWiseLimitedMemoryBFGS optimizer = new MyOrthantWiseLimitedMemoryBFGS(opt);
                boolean converged = false;
                double[] paras = new double[2];
                try {
                    converged = optimizer.optimize(100);
                }
                catch (OptimizationException e) {
                    e.printStackTrace();
                }
                opt.getParameters(paras);
                node.pGompertz[0] = Math.exp(paras[0]);
                node.pGompertz[1] = Math.exp(paras[1]);
                this.pGompertz[i] = node.pGompertz;
                opt.vals.clear();
            }
            this.size = 0;
            this.countNodes(this.rootNode);
            System.out.println(this.size);
        }
    }

    public double[][] getPi() {
        return this.pGompertz;
    }

    public double[] getPredictions() {
        if (this.predictions == null) {
            this.predictions = new double[this.numDocs];
        }
        return this.predictions;
    }

    class GNCRPNode
    extends SamplersHLDA.NCRPNode {
        public GNCRPNode(SamplersHLDA.NCRPNode parent, int dimensions, int level) {
            super(SamplersGompertzHLDA.this, parent, dimensions, level);
            this.customers = 0;
            this.parent = parent;
            this.children = new ArrayList();
            this.level = level;
            this.totalTokens = 0;
            this.typeCounts = new int[dimensions];
            this.nodeID = SamplersGompertzHLDA.this.totalNodes++;
            this.hsValues = new TDoubleIntHashMap();
            this.pGompertz = new double[]{Math.random(), Math.random()};
        }

        public GNCRPNode(int dimensions) {
            super(SamplersGompertzHLDA.this, dimensions);
        }
    }
}

