/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.OptimizationException;
import com.rapidminer.kobra.opt.MyOrthantWiseLimitedMemoryBFGS;
import com.rapidminer.kobra.topicmodels.MyHashGompertzOptimizable;
import com.rapidminer.kobra.topicmodels.SamplersHLDA;
import com.rapidminer.tools.RandomGenerator;
import gnu.trove.TIntArrayList;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TObjectDoubleHashMap;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;

public class SamplersGHLDA
extends SamplersHLDA {
    double[] times = null;
    double maxTime = 0.0;
    ArrayList<SamplersHLDA.NCRPNode> allCurrentNodes = null;
    double[][] pGompertz = null;
    double[] predictions = null;

    public void addProbs(double prob, SamplersHLDA.NCRPNode start, TObjectDoubleHashMap<SamplersHLDA.NCRPNode> nodeWeights) {
        SamplersHLDA.NCRPNode node = start;
        nodeWeights.adjustValue(node, prob);
    }

    @Override
    public void samplePath(int doc, int iteration) {
        int i;
        int level;
        SamplersHLDA.NCRPNode[] path = new SamplersHLDA.NCRPNode[this.numLevels];
        SamplersHLDA.NCRPNode node = this.documentLeaves[doc];
        for (level = this.numLevels - 1; level >= 0; --level) {
            path[level] = node;
            node = node.parent;
        }
        this.documentLeaves[doc].dropPath();
        TObjectDoubleHashMap<SamplersHLDA.NCRPNode> nodeWeights = new TObjectDoubleHashMap<SamplersHLDA.NCRPNode>();
        this.calculateNCRP(nodeWeights, this.rootNode, 0.0);
        TIntIntHashMap[] typeCounts = new TIntIntHashMap[this.numLevels];
        for (level = 0; level < this.numLevels; ++level) {
            typeCounts[level] = new TIntIntHashMap();
        }
        int[] docLevels = this.levels[doc];
        for (int token = 0; token < docLevels.length; ++token) {
            level = docLevels[token];
            int type = this.sequences[doc][token];
            if (!typeCounts[level].containsKey(type)) {
                typeCounts[level].put(type, 1);
            } else {
                typeCounts[level].increment(type);
            }
            int n = type;
            path[level].typeCounts[n] = path[level].typeCounts[n] - 1;
            assert (path[level].typeCounts[type] >= 0);
            --path[level].totalTokens;
            assert (path[level].totalTokens >= 0);
        }
        double[] newTopicWeights = new double[this.numLevels];
        for (level = 1; level < this.numLevels; ++level) {
            int[] types = typeCounts[level].keys();
            int totalTokens = 0;
            for (int t : types) {
                for (i = 0; i < typeCounts[level].get(t); ++i) {
                    int n = level;
                    newTopicWeights[n] = newTopicWeights[n] + Math.log((this.eta + (double)i) / (this.etaSum + (double)totalTokens));
                    ++totalTokens;
                }
            }
        }
        this.calculateWordLikelihood(nodeWeights, this.rootNode, 0.0, typeCounts, newTopicWeights, 0, iteration);
        SamplersHLDA.NCRPNode[] nodes = nodeWeights.keys((SamplersHLDA.NCRPNode[])new SamplersHLDA.NCRPNode[0]);
        double[] weights = new double[nodes.length];
        double sum = 0.0;
        double max = Double.NEGATIVE_INFINITY;
        for (i = 0; i < nodes.length; ++i) {
            if (nodes[i].type == 1) {
                double tp = nodes[i].getTimeWeight(this.times[doc]);
                if (tp == 0.0) {
                    this.addProbs(Math.log(1.0 / this.maxTime), nodes[i], nodeWeights);
                    continue;
                }
                this.addProbs(Math.log(nodes[i].getTimeWeight(this.times[doc])), nodes[i], nodeWeights);
                continue;
            }
            this.addProbs(Math.log(1.0 / this.maxTime), nodes[i], nodeWeights);
        }
        for (i = 0; i < nodes.length; ++i) {
            if (!(nodeWeights.get(nodes[i]) > max)) continue;
            max = nodeWeights.get(nodes[i]);
        }
        for (i = 0; i < nodes.length; ++i) {
            weights[i] = Math.exp(nodeWeights.get(nodes[i]) - max);
            sum += weights[i];
        }
        TObjectDoubleHashMap<SamplersHLDA.NCRPNode> t2 = new TObjectDoubleHashMap<SamplersHLDA.NCRPNode>(nodeWeights);
        for (int i2 = 0; i2 < nodes.length; ++i2) {
            nodeWeights.put(nodes[i2], weights[i2]);
        }
        node = nodes[RandomGenerator.getGlobalRandomGenerator().randomIndex(weights)];
        this.globalSum = 0.0;
        if (!node.isLeaf()) {
            ++this.globalNew;
            node = node.getNewLeaf();
        }
        node.addPath();
        this.documentLeaves[doc] = node;
        for (level = this.numLevels - 1; level >= 0; --level) {
            int[] types;
            node.addTime(this.times[doc]);
            int[] arr$ = types = typeCounts[level].keys();
            int len$ = arr$.length;
            for (int i$ = 0; i$ < len$; ++i$) {
                int t;
                int n = t = arr$[i$];
                node.typeCounts[n] = node.typeCounts[n] + typeCounts[level].get(t);
                node.totalTokens += typeCounts[level].get(t);
            }
            node = node.parent;
        }
    }

    @Override
    public void sampleTopics(int doc) {
        int token;
        int level;
        int seqLen = this.sequences[doc].length;
        int[] docLevels = this.levels[doc];
        SamplersHLDA.NCRPNode[] path = new SamplersHLDA.NCRPNode[this.numLevels];
        int[] levelCounts = new int[this.numLevels];
        SamplersHLDA.NCRPNode node = this.documentLeaves[doc];
        for (level = this.numLevels - 1; level >= 0; --level) {
            path[level] = node;
            node = node.parent;
        }
        double[] levelWeights = new double[this.numLevels];
        for (token = 0; token < seqLen; ++token) {
            int n = docLevels[token];
            levelCounts[n] = levelCounts[n] + 1;
        }
        for (token = 0; token < seqLen; ++token) {
            int type = this.sequences[doc][token];
            int n = docLevels[token];
            levelCounts[n] = levelCounts[n] - 1;
            node = path[docLevels[token]];
            int n2 = type;
            node.typeCounts[n2] = node.typeCounts[n2] - 1;
            --node.totalTokens;
            double sum = 0.0;
            for (level = 0; level < this.numLevels; ++level) {
                double tp;
                double pp = 1.0;
                pp = path[level].type == 1 ? ((tp = path[level].getTimeWeight(this.times[doc])) == 0.0 ? 1.0 / this.maxTime : path[level].getTimeWeight(this.times[doc])) : 1.0 / this.maxTime;
                levelWeights[level] = pp * (this.alpha + (double)levelCounts[level]) * (this.eta + (double)path[level].typeCounts[type]) / (this.etaSum + (double)path[level].totalTokens);
                sum += levelWeights[level];
            }
            level = 0;
            while (level < this.numLevels) {
                int n3 = level++;
                levelWeights[n3] = levelWeights[n3] / sum;
            }
            docLevels[token] = level = RandomGenerator.getGlobalRandomGenerator().randomIndex(levelWeights);
            int n4 = docLevels[token];
            levelCounts[n4] = levelCounts[n4] + 1;
            node = path[level];
            node.addTime(this.times[doc]);
            int n5 = type;
            node.typeCounts[n5] = node.typeCounts[n5] + 1;
            ++node.totalTokens;
        }
    }

    public void init(int[] docIds, int[] wordIds, double[] times, int numTopics, int numWords, int numDocs, int iter, double beta, double alpha, boolean locSeed, int seed) {
        int i;
        this.times = times;
        this.maxIter = iter;
        this.BETA = beta;
        this.ALPHA = alpha;
        this.alpha = alpha;
        this.eta = beta;
        this.etaSum = (double)numWords * beta;
        boolean topic = false;
        this.numTokens = wordIds.length;
        this.numTopics = numTopics;
        this.numDocs = numDocs;
        this.numWords = numWords;
        this.topics = new int[this.numTokens];
        this.wordtopiccounts = new int[numWords * numTopics];
        this.doctopiccounts = new int[numDocs * numTopics];
        this.topiccounts = new int[numTopics];
        this.words = wordIds;
        this.docs = docIds;
        if (locSeed) {
            this.seed = seed;
            this.rn = new Random(seed);
        } else {
            this.rn = new Random();
        }
        this.levels = new int[numDocs][];
        this.documentLeaves = new SamplersHLDA.NCRPNode[numDocs];
        this.sequences = new int[numDocs][];
        SamplersHLDA.NCRPNode[] path = new SamplersHLDA.NCRPNode[this.numLevels];
        this.rootNode = new SamplersHLDA.NCRPNode(this, numWords);
        TIntArrayList[] tmpTokens = new TIntArrayList[numDocs];
        for (i = 0; i < numDocs; ++i) {
            tmpTokens[i] = new TIntArrayList();
        }
        for (i = 0; i < wordIds.length; ++i) {
            int wi = this.words[i];
            int di = this.docs[i];
            tmpTokens[di].add(wi);
        }
        for (i = 0; i < numDocs; ++i) {
            this.sequences[i] = tmpTokens[i].toNativeArray();
        }
        this.documentLeaves = new SamplersHLDA.NCRPNode[numDocs];
        for (int doc = 0; doc < numDocs; ++doc) {
            int seqLen = this.sequences[doc].length;
            path[0] = this.rootNode;
            ++this.rootNode.customers;
            for (int level = 1; level < this.numLevels; ++level) {
                path[level] = path[level - 1].select();
                ++path[level].customers;
            }
            this.node = path[this.numLevels - 1];
            this.levels[doc] = new int[seqLen];
            this.documentLeaves[doc] = this.node;
            for (int token = 0; token < seqLen; ++token) {
                int type = this.sequences[doc][token];
                this.levels[doc][token] = this.rn.nextInt(this.numLevels);
                this.node = path[this.levels[doc][token]];
                this.node.addTime(times[doc]);
                if (times[doc] > this.maxTime) {
                    this.maxTime = times[doc];
                }
                ++this.node.totalTokens;
                int n = type;
                this.node.typeCounts[n] = this.node.typeCounts[n] + 1;
            }
        }
        System.out.println(this.totalNodes);
    }

    public void numerateNodes(SamplersHLDA.NCRPNode node) {
        this.allCurrentNodes.add(node);
        for (SamplersHLDA.NCRPNode child : node.children) {
            this.numerateNodes(child);
        }
    }

    public SamplersHLDA.NCRPNode sampleNode(SamplersHLDA.NCRPNode node, double r, Double sum, TObjectDoubleHashMap<SamplersHLDA.NCRPNode> weights) {
        if (r <= (sum = Double.valueOf(sum + weights.get(node)))) {
            return node;
        }
        Iterator<SamplersHLDA.NCRPNode> i$ = node.children.iterator();
        if (i$.hasNext()) {
            SamplersHLDA.NCRPNode child = i$.next();
            return this.sampleNode(child, r, sum, weights);
        }
        return null;
    }

    @Override
    public void GibbsSampling() {
        for (int iter = 0; iter < this.maxIter; ++iter) {
            int doc;
            for (doc = 0; doc < this.numDocs; ++doc) {
                this.samplePath(doc, iter);
            }
            for (doc = 0; doc < this.numDocs; ++doc) {
                this.sampleTopics(doc);
            }
            this.allCurrentNodes = new ArrayList();
            this.removeUnusedNodes(this.rootNode);
            this.numerateNodes(this.rootNode);
            this.numTopics = this.allCurrentNodes.size();
            this.pGompertz = new double[this.numTopics][2];
            int s = 0;
            int i = 0;
            while (i < this.numTopics & iter < this.maxIter - 1) {
                SamplersHLDA.NCRPNode node = this.allCurrentNodes.get(i);
                MyHashGompertzOptimizable opt = new MyHashGompertzOptimizable();
                opt.vals = node.hsValues;
                s += opt.vals.size();
                opt.alpha = this.rn.nextDouble();
                opt.beta = this.rn.nextDouble();
                opt.a = this.rn.nextDouble();
                opt.b = this.rn.nextDouble();
                MyOrthantWiseLimitedMemoryBFGS optimizer = new MyOrthantWiseLimitedMemoryBFGS(opt);
                boolean converged = false;
                double[] paras = new double[2];
                try {
                    converged = optimizer.optimize(100);
                }
                catch (OptimizationException e) {
                    e.printStackTrace();
                }
                opt.getParameters(paras);
                node.pGompertz[0] = Math.exp(paras[0]);
                node.pGompertz[1] = Math.exp(paras[1]);
                this.pGompertz[i] = node.pGompertz;
                double optGomp = opt.getValue2();
                int nv = 0;
                int n = 0;
                double p = 1.0 / this.maxTime;
                for (double k : node.hsValues.keys()) {
                    nv = node.hsValues.get(k);
                    n += nv;
                }
                double lL = (double)(-n) * Math.log(this.maxTime);
                int oldType = node.type;
                if (lL > optGomp) {
                    node.pGompertz[0] = 0.0;
                    node.pGompertz[1] = 0.0;
                    node.type = 0;
                    this.pGompertz[i] = node.pGompertz;
                    node.type = 0;
                    node.changes[0] = node.changes[0] + 1;
                } else {
                    node.type = 1;
                    node.changes[1] = node.changes[1] + 1;
                }
                opt.vals.clear();
                ++i;
            }
            this.size = 0;
            this.countNodes(this.rootNode);
            System.out.println(iter + ":" + this.size + ":" + this.globalNew);
            this.globalNew = 0;
        }
    }

    @Override
    public void printNode(SamplersHLDA.NCRPNode node, int indent) {
        StringBuffer out2 = new StringBuffer();
        for (int i = 0; i < indent; ++i) {
            out2.append("  ");
        }
        out2.append(node.totalTokens + "/" + node.customers + " " + node.changes[0] + ":" + node.changes[1]);
        out2.append(" topic: " + (this.NodeIdToTopic.get(node.nodeID) + 1));
        this.tree = this.tree + out2.toString() + "\n";
        for (SamplersHLDA.NCRPNode child : node.children) {
            this.printNode(child, indent + 1);
        }
    }

    public double[][] getPi() {
        this.pruneTree(this.rootNode);
        this.numerateTopics(this.rootNode);
        this.numTopics = this.allCurrentNodes.size();
        ArrayList<double[]> tmp_res = new ArrayList<double[]>();
        HashMap<Integer, double[]> map = new HashMap<Integer, double[]>();
        for (int i = 0; i < this.numTopics; ++i) {
            SamplersHLDA.NCRPNode child = this.allCurrentNodes.get(i);
            int next = this.NodeIdToTopic.get(child.nodeID);
            if (child.level >= this.pruneLevel) continue;
            tmp_res.add(child.pGompertz);
            if (map.containsKey(next)) continue;
            map.put(next, child.pGompertz);
        }
        double[][] res = new double[map.size()][];
        for (Integer k : map.keySet()) {
            res[k.intValue()] = (double[])map.get(k);
        }
        return res;
    }

    public double[] getPredictions() {
        if (this.predictions == null) {
            this.predictions = new double[this.numDocs];
        }
        return this.predictions;
    }

    public TDoubleArrayList[] getAssignedTimes() {
        this.pruneTree(this.rootNode);
        this.numerateTopics(this.rootNode);
        HashMap<Integer, TDoubleArrayList> map = new HashMap<Integer, TDoubleArrayList>();
        for (int i = 0; i < this.numTopics; ++i) {
            TDoubleArrayList ls = new TDoubleArrayList();
            int nv = 0;
            for (double k : this.allCurrentNodes.get((int)i).hsValues.keys()) {
                nv = this.allCurrentNodes.get((int)i).hsValues.get(k);
                for (int kn = 0; kn < nv; ++kn) {
                    ls.add(k);
                }
            }
            SamplersHLDA.NCRPNode child = this.allCurrentNodes.get(i);
            int next = this.NodeIdToTopic.get(child.nodeID);
            if (!map.containsKey(next)) {
                TDoubleArrayList ls2 = new TDoubleArrayList();
                map.put(next, ls2);
            }
            TDoubleArrayList ls2 = (TDoubleArrayList)map.get(next);
            ls2.addAll(ls);
        }
        TDoubleArrayList[] res = new TDoubleArrayList[map.size()];
        for (Integer k : map.keySet()) {
            res[k.intValue()] = (TDoubleArrayList)map.get(k);
        }
        return res;
    }
}

