/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import com.rapidminer.kobra.topicmodels.SamplersLDA;
import com.rapidminer.tools.RandomGenerator;
import gnu.trove.TIntArrayList;
import gnu.trove.map.hash.TDoubleIntHashMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TObjectDoubleHashMap;
import java.util.ArrayList;
import java.util.Random;

public class SamplersHLDA
extends SamplersLDA {
    int totalNodes = 0;
    int numLevels = 3;
    double gamma = 1.0;
    int[][] levels;
    NCRPNode rootNode;
    NCRPNode node;
    NCRPNode[] documentLeaves;
    double alpha = 0.1;
    double eta = 0.1;
    double etaSum = 0.0;
    int globalNew = 0;
    double globalSum = 0.0;
    int[][] sequences = null;
    TIntIntHashMap NodeIdToTopic = new TIntIntHashMap();
    int next = 0;
    public String tree = "";
    public int pruneLevel = 3;
    int size = 0;

    public void calculateNCRP(TObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight) {
        for (NCRPNode child : node.children) {
            this.calculateNCRP(nodeWeights, child, weight + Math.log((double)child.customers / ((double)node.customers + this.gamma)));
        }
        nodeWeights.put(node, weight + Math.log(this.gamma / ((double)node.customers + this.gamma)));
    }

    public void calculateWordLikelihood(TObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight, TIntIntHashMap[] typeCounts, double[] newTopicWeights, int level, int iteration) {
        double nodeWeight = 0.0;
        int[] types = typeCounts[level].keys();
        int totalTokens = 0;
        for (int type : types) {
            for (int i = 0; i < typeCounts[level].get(type); ++i) {
                nodeWeight += Math.log((this.eta + (double)node.typeCounts[type] + (double)i) / (this.etaSum + (double)node.totalTokens + (double)totalTokens));
                ++totalTokens;
            }
        }
        for (NCRPNode child : node.children) {
            this.calculateWordLikelihood(nodeWeights, child, weight + nodeWeight, typeCounts, newTopicWeights, level + 1, iteration);
        }
        ++level;
        while (level < this.numLevels) {
            nodeWeight += newTopicWeights[level];
            ++level;
        }
        nodeWeights.adjustValue(node, nodeWeight);
    }

    public void propagateTopicWeight(TObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight) {
        if (!nodeWeights.containsKey(node)) {
            return;
        }
        for (NCRPNode child : node.children) {
            this.propagateTopicWeight(nodeWeights, child, weight);
        }
        nodeWeights.adjustValue(node, weight);
    }

    public void samplePath(int doc, int iteration) {
        int i;
        int level;
        NCRPNode[] path = new NCRPNode[this.numLevels];
        NCRPNode node = this.documentLeaves[doc];
        for (level = this.numLevels - 1; level >= 0; --level) {
            path[level] = node;
            node = node.parent;
        }
        this.documentLeaves[doc].dropPath();
        TObjectDoubleHashMap<NCRPNode> nodeWeights = new TObjectDoubleHashMap<NCRPNode>();
        this.calculateNCRP(nodeWeights, this.rootNode, 0.0);
        TIntIntHashMap[] typeCounts = new TIntIntHashMap[this.numLevels];
        for (level = 0; level < this.numLevels; ++level) {
            typeCounts[level] = new TIntIntHashMap();
        }
        int[] docLevels = this.levels[doc];
        for (int token = 0; token < docLevels.length; ++token) {
            level = docLevels[token];
            int type = this.sequences[doc][token];
            if (!typeCounts[level].containsKey(type)) {
                typeCounts[level].put(type, 1);
            } else {
                typeCounts[level].increment(type);
            }
            int n = type;
            path[level].typeCounts[n] = path[level].typeCounts[n] - 1;
            assert (path[level].typeCounts[type] >= 0);
            --path[level].totalTokens;
            assert (path[level].totalTokens >= 0);
        }
        double[] newTopicWeights = new double[this.numLevels];
        for (level = 1; level < this.numLevels; ++level) {
            int[] types = typeCounts[level].keys();
            int totalTokens = 0;
            for (int t : types) {
                for (i = 0; i < typeCounts[level].get(t); ++i) {
                    int n = level;
                    newTopicWeights[n] = newTopicWeights[n] + Math.log((this.eta + (double)i) / (this.etaSum + (double)totalTokens));
                    ++totalTokens;
                }
            }
        }
        this.calculateWordLikelihood(nodeWeights, this.rootNode, 0.0, typeCounts, newTopicWeights, 0, iteration);
        NCRPNode[] nodes = nodeWeights.keys((NCRPNode[])new NCRPNode[0]);
        double[] weights = new double[nodes.length];
        double sum = 0.0;
        double max = Double.NEGATIVE_INFINITY;
        for (i = 0; i < nodes.length; ++i) {
            if (!(nodeWeights.get(nodes[i]) > max)) continue;
            max = nodeWeights.get(nodes[i]);
        }
        for (i = 0; i < nodes.length; ++i) {
            weights[i] = Math.exp(nodeWeights.get(nodes[i]) - max);
            sum += weights[i];
        }
        for (i = 0; i < nodes.length; ++i) {
            nodeWeights.put(nodes[i], weights[i]);
        }
        node = nodes[RandomGenerator.getGlobalRandomGenerator().randomIndex(weights)];
        this.globalSum = 0.0;
        if (!node.isLeaf()) {
            ++this.globalNew;
            node = node.getNewLeaf();
        }
        node.addPath();
        this.documentLeaves[doc] = node;
        for (level = this.numLevels - 1; level >= 0; --level) {
            int[] types;
            int[] arr$ = types = typeCounts[level].keys();
            int len$ = arr$.length;
            for (int i$ = 0; i$ < len$; ++i$) {
                int t;
                int n = t = arr$[i$];
                node.typeCounts[n] = node.typeCounts[n] + typeCounts[level].get(t);
                node.totalTokens += typeCounts[level].get(t);
            }
            node = node.parent;
        }
    }

    public NCRPNode sampleNode(NCRPNode start, TObjectDoubleHashMap<NCRPNode> nodeWeights, double r) {
        this.globalSum += nodeWeights.get(start);
        if (r <= this.globalSum && start.level > 0) {
            return start;
        }
        for (NCRPNode child : start.children) {
            NCRPNode res = this.sampleNode(child, nodeWeights, r);
            if (res == null) continue;
            return res;
        }
        return null;
    }

    public void sampleTopics(int doc) {
        int token;
        int level;
        int seqLen = this.sequences[doc].length;
        int[] docLevels = this.levels[doc];
        NCRPNode[] path = new NCRPNode[this.numLevels];
        int[] levelCounts = new int[this.numLevels];
        NCRPNode node = this.documentLeaves[doc];
        for (level = this.numLevels - 1; level >= 0; --level) {
            path[level] = node;
            node = node.parent;
        }
        double[] levelWeights = new double[this.numLevels];
        for (token = 0; token < seqLen; ++token) {
            int n = docLevels[token];
            levelCounts[n] = levelCounts[n] + 1;
        }
        for (token = 0; token < seqLen; ++token) {
            int type = this.sequences[doc][token];
            int n = docLevels[token];
            levelCounts[n] = levelCounts[n] - 1;
            node = path[docLevels[token]];
            int n2 = type;
            node.typeCounts[n2] = node.typeCounts[n2] - 1;
            --node.totalTokens;
            double sum = 0.0;
            for (level = 0; level < this.numLevels; ++level) {
                levelWeights[level] = (this.alpha + (double)levelCounts[level]) * (this.eta + (double)path[level].typeCounts[type]) / (this.etaSum + (double)path[level].totalTokens);
                sum += levelWeights[level];
            }
            level = 0;
            while (level < this.numLevels) {
                int n3 = level++;
                levelWeights[n3] = levelWeights[n3] / sum;
            }
            docLevels[token] = level = RandomGenerator.getGlobalRandomGenerator().randomIndex(levelWeights);
            int n4 = docLevels[token];
            levelCounts[n4] = levelCounts[n4] + 1;
            node = path[level];
            int n5 = type;
            node.typeCounts[n5] = node.typeCounts[n5] + 1;
            ++node.totalTokens;
        }
    }

    @Override
    public void init(int[] docIds, int[] wordIds, int numTopics, int numWords, int numDocs, int iter, double beta, double alpha, boolean locSeed, int seed) {
        int i;
        this.maxIter = iter;
        this.BETA = beta;
        this.ALPHA = alpha;
        this.alpha = alpha;
        this.eta = beta;
        this.etaSum = (double)numWords * beta;
        boolean topic = false;
        this.numTokens = wordIds.length;
        this.numTopics = numTopics;
        this.numDocs = numDocs;
        this.numWords = numWords;
        this.topics = new int[this.numTokens];
        this.wordtopiccounts = new int[numWords * numTopics];
        this.doctopiccounts = new int[numDocs * numTopics];
        this.topiccounts = new int[numTopics];
        this.words = wordIds;
        this.docs = docIds;
        if (locSeed) {
            this.seed = seed;
            this.rn = new Random(seed);
        } else {
            this.rn = new Random();
        }
        this.levels = new int[numDocs][];
        this.documentLeaves = new NCRPNode[numDocs];
        this.sequences = new int[numDocs][];
        NCRPNode[] path = new NCRPNode[this.numLevels];
        this.rootNode = new NCRPNode(numWords);
        TIntArrayList[] tmpTokens = new TIntArrayList[numDocs];
        for (i = 0; i < numDocs; ++i) {
            tmpTokens[i] = new TIntArrayList();
        }
        for (i = 0; i < wordIds.length; ++i) {
            int wi = this.words[i];
            int di = this.docs[i];
            tmpTokens[di].add(wi);
        }
        for (i = 0; i < numDocs; ++i) {
            tmpTokens[i].shuffle(this.rn);
            this.sequences[i] = tmpTokens[i].toNativeArray();
        }
        this.documentLeaves = new NCRPNode[numDocs];
        for (int doc = 0; doc < numDocs; ++doc) {
            int seqLen = this.sequences[doc].length;
            path[0] = this.rootNode;
            ++this.rootNode.customers;
            for (int level = 1; level < this.numLevels; ++level) {
                path[level] = path[level - 1].select();
                ++path[level].customers;
            }
            this.node = path[this.numLevels - 1];
            this.levels[doc] = new int[seqLen];
            this.documentLeaves[doc] = this.node;
            for (int token = 0; token < seqLen; ++token) {
                int type = this.sequences[doc][token];
                this.levels[doc][token] = this.rn.nextInt(this.numLevels);
                this.node = path[this.levels[doc][token]];
                ++this.node.totalTokens;
                int n = type;
                this.node.typeCounts[n] = this.node.typeCounts[n] + 1;
            }
        }
        System.out.println(this.totalNodes);
    }

    @Override
    public double[][] wordDistribution() {
        int i;
        this.numerateTopics(this.rootNode);
        this.numTopics = this.NodeIdToTopic.size();
        double[][] res = new double[this.numTopics][this.numWords];
        this.topiccounts = new int[this.numTopics];
        this.wordtopiccounts = new int[this.numWords * this.numTopics];
        for (i = 0; i < this.numDocs; ++i) {
            NCRPNode node = this.documentLeaves[i];
            NCRPNode[] nodes = new NCRPNode[this.numLevels];
            int[] ids = new int[this.numLevels];
            for (int level = this.numLevels - 1; level >= 0; --level) {
                nodes[level] = node;
                ids[level] = node.nodeID;
                node = node.parent;
            }
            int[] docLevels = this.levels[i];
            int seqLength = this.sequences[i].length;
            for (int j = 0; j < seqLength; ++j) {
                NCRPNode currNode = nodes[docLevels[j]];
                int n = this.sequences[i][j] * this.numTopics + this.NodeIdToTopic.get(currNode.nodeID);
                this.wordtopiccounts[n] = this.wordtopiccounts[n] + 1;
                int n2 = this.NodeIdToTopic.get(currNode.nodeID);
                this.topiccounts[n2] = this.topiccounts[n2] + 1;
            }
        }
        for (i = 0; i < this.numWords; ++i) {
            for (int j = 0; j < this.numTopics; ++j) {
                res[j][i] = ((double)this.wordtopiccounts[i * this.numTopics + j] + this.BETA) / ((double)this.topiccounts[j] + (double)this.numWords * this.BETA);
            }
        }
        return res;
    }

    public void printNodes() {
        this.printNode(this.rootNode, 0);
    }

    public void printNode(NCRPNode node, int indent) {
        StringBuffer out2 = new StringBuffer();
        for (int i = 0; i < indent; ++i) {
            out2.append("  ");
        }
        out2.append(node.totalTokens + "/" + node.customers + " ");
        out2.append(this.NodeIdToTopic.get(node.nodeID));
        this.tree = this.tree + out2.toString() + "\n";
        for (NCRPNode child : node.children) {
            this.printNode(child, indent + 1);
        }
    }

    public void pruneTree(NCRPNode node) {
        for (NCRPNode child : node.children) {
            if (child.level >= this.pruneLevel) {
                child.nodeID = node.nodeID;
            }
            this.pruneTree(child);
        }
    }

    public void removeUnusedNodes(NCRPNode node) {
        ArrayList<NCRPNode> children = node.children;
        int s = children.size();
        NCRPNode[] child_tmp = children.toArray(new NCRPNode[0]);
        for (int i = 0; i < s; ++i) {
            NCRPNode child = child_tmp[i];
            if (child.totalTokens <= 1) {
                node.remove(child);
                continue;
            }
            this.removeUnusedNodes(child);
        }
    }

    public void numerateTopics(NCRPNode node) {
        if (!this.NodeIdToTopic.contains(node.nodeID)) {
            this.NodeIdToTopic.put(node.nodeID, this.next);
            ++this.next;
        }
        for (NCRPNode child : node.children) {
            this.numerateTopics(child);
        }
    }

    @Override
    public double[][] documentDistribution() {
        int i;
        this.pruneTree(this.rootNode);
        this.numerateTopics(this.rootNode);
        this.printNodes();
        this.numTopics = this.NodeIdToTopic.size();
        double[][] res = new double[this.numTopics][this.numDocs];
        this.doctopiccounts = new int[this.numDocs * this.numTopics];
        for (i = 0; i < this.numDocs; ++i) {
            NCRPNode node = this.documentLeaves[i];
            NCRPNode[] nodes = new NCRPNode[this.numLevels];
            int[] ids = new int[this.numLevels];
            for (int level = this.numLevels - 1; level >= 0; --level) {
                nodes[level] = node;
                ids[level] = node.nodeID;
                node = node.parent;
            }
            int[] docLevels = this.levels[i];
            int seqLength = this.sequences[i].length;
            for (int j = 0; j < seqLength; ++j) {
                NCRPNode currNode = nodes[docLevels[j]];
                int n = i * this.numTopics + this.NodeIdToTopic.get(currNode.nodeID);
                this.doctopiccounts[n] = this.doctopiccounts[n] + 1;
            }
        }
        for (i = 0; i < this.numDocs; ++i) {
            int j;
            int docCounts = 0;
            for (j = 0; j < this.numTopics; ++j) {
                docCounts += this.doctopiccounts[i * this.numTopics + j];
            }
            for (j = 0; j < this.numTopics; ++j) {
                res[j][i] = ((double)this.doctopiccounts[i * this.numTopics + j] + this.ALPHA) / ((double)docCounts + (double)this.numTopics * this.ALPHA);
            }
        }
        return res;
    }

    public void countNodes(NCRPNode node) {
        ++this.size;
        for (NCRPNode child : node.children) {
            this.countNodes(child);
        }
    }

    @Override
    public void GibbsSampling() {
        for (int iter = 0; iter < this.maxIter; ++iter) {
            int doc;
            for (doc = 0; doc < this.numDocs; ++doc) {
                this.samplePath(doc, iter);
            }
            for (doc = 0; doc < this.numDocs; ++doc) {
                this.sampleTopics(doc);
            }
            this.removeUnusedNodes(this.rootNode);
            this.size = 0;
            this.countNodes(this.rootNode);
            System.out.println(iter + ":" + this.size + ":" + this.globalNew);
            this.globalNew = 0;
        }
    }

    class NCRPNode {
        int type = 0;
        public int[] changes = new int[2];
        int customers = 0;
        ArrayList<NCRPNode> children;
        NCRPNode parent;
        int level;
        int totalTokens;
        int[] typeCounts;
        public int nodeID;
        TDoubleIntHashMap hsValues = null;
        double[] pGompertz;
        public Random rn = null;

        public NCRPNode(NCRPNode parent, int dimensions, int level) {
            this.type = (int)Math.round(RandomGenerator.getGlobalRandomGenerator().nextDouble());
            this.parent = parent;
            this.children = new ArrayList();
            this.level = level;
            this.totalTokens = 0;
            this.typeCounts = new int[dimensions];
            this.nodeID = SamplersHLDA.this.totalNodes++;
            this.hsValues = new TDoubleIntHashMap();
            this.pGompertz = new double[]{RandomGenerator.getGlobalRandomGenerator().nextDouble(), RandomGenerator.getGlobalRandomGenerator().nextDouble()};
        }

        public NCRPNode(int dimensions) {
            this(null, dimensions, 0);
        }

        public double dist(double x, double a, double b) {
            return b * Math.exp(-(b * x + a * Math.exp(-b * x))) * (1.0 + a * (1.0 - Math.exp(-b * x)));
        }

        public double getTimeWeight(double t) {
            return this.dist(t, this.pGompertz[0], this.pGompertz[1]);
        }

        public void addTime(double t) {
            NCRPNode node = this;
            int n = 0;
            if (node.hsValues.contains(t)) {
                n = node.hsValues.get(t);
            }
            node.hsValues.put(t, n + 1);
            while ((node = node.parent) != null) {
                n = 0;
                if (node.hsValues.contains(t)) {
                    n = node.hsValues.get(t);
                }
                node.hsValues.put(t, n + 1);
            }
        }

        public NCRPNode addChild() {
            NCRPNode node = new NCRPNode(this, this.typeCounts.length, this.level + 1);
            this.children.add(node);
            return node;
        }

        public boolean isLeaf() {
            return this.level == SamplersHLDA.this.numLevels - 1;
        }

        public NCRPNode getNewLeaf() {
            NCRPNode node = this;
            for (int l = this.level; l < SamplersHLDA.this.numLevels - 1; ++l) {
                node = node.addChild();
            }
            return node;
        }

        public void dropPath() {
            NCRPNode node = this;
            --node.customers;
            if (node.customers == 0) {
                node.parent.remove(node);
            }
            for (int l = 1; l < SamplersHLDA.this.numLevels; ++l) {
                node = node.parent;
                --node.customers;
                if (node.customers != 0) continue;
                node.parent.remove(node);
            }
        }

        public void remove(NCRPNode node) {
            this.children.remove(node);
        }

        public void addPath() {
            NCRPNode node = this;
            ++node.customers;
            for (int l = 1; l < SamplersHLDA.this.numLevels; ++l) {
                node = node.parent;
                ++node.customers;
            }
        }

        public NCRPNode selectExisting() {
            double[] weights = new double[this.children.size()];
            int i = 0;
            for (NCRPNode child : this.children) {
                weights[i] = (double)child.customers / (SamplersHLDA.this.gamma + (double)this.customers);
                ++i;
            }
            int choice = RandomGenerator.getGlobalRandomGenerator().randomIndex(weights);
            if (choice < 0) {
                return null;
            }
            return this.children.get(choice);
        }

        public NCRPNode select() {
            double[] weights = new double[this.children.size() + 1];
            weights[0] = SamplersHLDA.this.gamma / (SamplersHLDA.this.gamma + (double)this.customers);
            int i = 1;
            double sum = weights[0];
            for (NCRPNode child : this.children) {
                weights[i] = (double)child.customers / (SamplersHLDA.this.gamma + (double)this.customers);
                sum += weights[i];
                ++i;
            }
            int j = 0;
            while (j < weights.length) {
                int n = j++;
                weights[n] = weights[n] / sum;
            }
            int choice = RandomGenerator.getGlobalRandomGenerator().randomIndex(weights);
            if (choice == 0) {
                return this.addChild();
            }
            return this.children.get(choice - 1);
        }
    }
}

