/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.test;

import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.types.tests.TestSerializable;
import cc.mallet.util.ArrayUtils;
import cc.mallet.util.Maths;
import cc.mallet.util.Randoms;
import java.io.IOException;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

public class TestLogTableFactor
extends TestCase {
    public TestLogTableFactor(String name) {
        super(name);
    }

    public void testTimesTableFactor() {
        Variable var = new Variable(4);
        double[] vals = new double[]{2.0, 4.0, 6.0, 8.0};
        double[] vals2 = new double[]{0.5, 0.5, 0.5, 0.5};
        double[] vals3 = new double[]{1.0, 2.0, 3.0, 4.0};
        TableFactor ans = new TableFactor(var, vals3);
        TableFactor ptl1 = new TableFactor(var, vals);
        LogTableFactor lptl2 = LogTableFactor.makeFromValues(var, vals2);
        ptl1.multiplyBy(lptl2);
        TestLogTableFactor.assertTrue(ans.almostEquals(ptl1));
    }

    public void testTblTblPlusEquals() {
        Variable var = new Variable(4);
        double[] vals = new double[]{2.0, 4.0, 6.0, 8.0};
        double[] vals2 = new double[]{0.25, 0.5, 0.75, 1.0};
        double[] vals3 = new double[]{2.25, 4.5, 6.75, 9.0};
        LogTableFactor ans = LogTableFactor.makeFromValues(var, vals3);
        LogTableFactor ptl1 = LogTableFactor.makeFromValues(var, vals);
        LogTableFactor ptl2 = LogTableFactor.makeFromValues(var, vals2);
        ptl1.plusEquals(ptl2);
        TestLogTableFactor.assertTrue(ans.almostEquals(ptl1));
    }

    public void testMultiplyByLogSpace() {
        Variable var = new Variable(4);
        double[] vals = new double[]{2.0, 4.0, 6.0, 8.0};
        double[] vals2 = new double[]{0.5, 0.5, 0.5, 0.5};
        double[] vals3 = new double[]{1.0, 2.0, 3.0, 4.0};
        TableFactor ans = new TableFactor(var, vals3);
        TableFactor ptl1 = new TableFactor(var, vals);
        TableFactor ptl2 = new TableFactor(var, vals2);
        ptl1.multiplyBy(ptl2);
        TestLogTableFactor.assertTrue(ans.almostEquals(ptl1));
        TableFactor ptl3 = new TableFactor(var, vals);
        LogTableFactor ptl4 = LogTableFactor.makeFromValues(var, vals2);
        ptl3.multiplyBy(ptl4);
        TestLogTableFactor.assertTrue(ptl3.almostEquals(ptl1));
        TableFactor ptl5 = new TableFactor(var, vals);
        LogTableFactor ptl6 = LogTableFactor.makeFromValues(var, vals2);
        ptl6.multiplyBy(ptl5);
        TestLogTableFactor.assertTrue(ptl6.almostEquals(ans));
        LogTableFactor ptl7 = LogTableFactor.makeFromValues(var, vals);
        LogTableFactor ptl8 = LogTableFactor.makeFromValues(var, vals2);
        ptl8.multiplyBy(ptl7);
        TestLogTableFactor.assertTrue(ptl8.almostEquals(ans));
    }

    public void testDivideByLogSpace() {
        Variable var = new Variable(4);
        double[] vals = new double[]{2.0, 4.0, 6.0, 8.0};
        double[] vals2 = new double[]{0.5, 0.5, 0.5, 0.5};
        double[] vals3 = new double[]{4.0, 8.0, 12.0, 16.0};
        TableFactor ans = new TableFactor(var, vals3);
        TableFactor ptl1 = new TableFactor(var, vals);
        TableFactor ptl2 = new TableFactor(var, vals2);
        ptl1.divideBy(ptl2);
        TestLogTableFactor.assertTrue(ans.almostEquals(ptl1));
        TableFactor ptl3 = new TableFactor(var, vals);
        LogTableFactor ptl4 = LogTableFactor.makeFromValues(var, vals2);
        ptl3.divideBy(ptl4);
        TestLogTableFactor.assertTrue(ptl3.almostEquals(ans));
        LogTableFactor ptl5 = LogTableFactor.makeFromValues(var, vals);
        TableFactor ptl6 = new TableFactor(var, vals2);
        ptl5.divideBy(ptl6);
        TestLogTableFactor.assertTrue(ptl5.almostEquals(ans));
        LogTableFactor ptl7 = LogTableFactor.makeFromValues(var, vals);
        LogTableFactor ptl8 = LogTableFactor.makeFromValues(var, vals2);
        ptl7.divideBy(ptl8);
        TestLogTableFactor.assertTrue(ptl7.almostEquals(ans));
    }

    public void testEntropyLogSpace() {
        Variable v1 = new Variable(2);
        TableFactor ptl = new TableFactor(v1, new double[]{0.3, 0.7});
        double entropy = ptl.entropy();
        TestLogTableFactor.assertEquals(0.61086, entropy, 0.001);
        LogTableFactor ptl2 = LogTableFactor.makeFromValues(v1, new double[]{0.3, 0.7});
        double entropy2 = ptl2.entropy();
        TestLogTableFactor.assertEquals(0.61086, entropy2, 0.001);
    }

    public void ignoreTestSerialization() throws IOException, ClassNotFoundException {
        VarSet varset2;
        Variable v1 = new Variable(2);
        Variable v2 = new Variable(3);
        Variable[] vars = new Variable[]{v1, v2};
        double[] vals = new double[]{2.0, 4.0, 6.0, 3.0, 5.0, 7.0};
        LogTableFactor ptl = LogTableFactor.makeFromLogValues(vars, vals);
        LogTableFactor ptl2 = (LogTableFactor)TestSerializable.cloneViaSerialization(ptl);
        VarSet varset1 = ptl.varSet();
        TestLogTableFactor.assertTrue(!varset1.contains(varset2 = ptl2.varSet()));
        this.comparePotentialValues(ptl, ptl2);
        LogTableFactor marg1 = (LogTableFactor)ptl.marginalize(v1);
        LogTableFactor marg2 = (LogTableFactor)ptl2.marginalize(ptl2.findVariable(v1.getLabel()));
        this.comparePotentialValues(marg1, marg2);
    }

    private void comparePotentialValues(LogTableFactor ptl, LogTableFactor ptl2) {
        AssignmentIterator it1 = ptl.assignmentIterator();
        AssignmentIterator it2 = ptl2.assignmentIterator();
        while (it1.hasNext()) {
            TestLogTableFactor.assertTrue(ptl.value(it1) == ptl.value(it2));
            it1.advance();
            it2.advance();
        }
    }

    public void testExtractMaxLogSpace() {
        Variable[] vars = new Variable[]{new Variable(2), new Variable(2)};
        LogTableFactor ptl = LogTableFactor.makeFromValues(vars, new double[]{1.0, 2.0, 3.0, 4.0});
        LogTableFactor ptl2 = (LogTableFactor)ptl.extractMax(vars[1]);
        TestLogTableFactor.assertEquals("FAILURE: Potential has too many vars.\n  " + ptl2, 1, ptl2.varSet().size());
        TestLogTableFactor.assertTrue("FAILURE: Potential does not contain " + vars[1] + ":\n  " + ptl2, ptl2.varSet().contains(vars[1]));
        double[] expected = new double[]{3.0, 4.0};
        TestLogTableFactor.assertTrue("FAILURE: Potential has incorrect values.  Expected " + ArrayUtils.toString(expected) + "was " + ptl2, Maths.almostEquals(ptl2.toValueArray(), expected, 1.0E-5));
    }

    public void testLogValue() {
        Variable[] vars = new Variable[]{new Variable(2), new Variable(2)};
        LogTableFactor ptl = LogTableFactor.makeFromValues(vars, new double[]{1.0, 2.0, 3.0, 4.0});
        Assignment assn = new Assignment(vars, new int[vars.length]);
        TestLogTableFactor.assertEquals(0.0, ptl.logValue(assn), 1.0E-5);
        TestLogTableFactor.assertEquals(0.0, ptl.logValue(ptl.assignmentIterator()), 1.0E-5);
        TestLogTableFactor.assertEquals(0.0, ptl.logValue(0), 1.0E-5);
        TestLogTableFactor.assertEquals(1.0, ptl.value(assn), 1.0E-5);
        TestLogTableFactor.assertEquals(1.0, ptl.value(ptl.assignmentIterator()), 1.0E-5);
        TestLogTableFactor.assertEquals(1.0, ptl.value(0), 1.0E-5);
        LogTableFactor ptl2 = LogTableFactor.makeFromLogValues(vars, new double[]{0.0, Math.log(2.0), Math.log(3.0), Math.log(4.0)});
        Assignment assn2 = new Assignment(vars, new int[vars.length]);
        TestLogTableFactor.assertEquals(0.0, ptl2.logValue(assn2), 1.0E-5);
        TestLogTableFactor.assertEquals(0.0, ptl2.logValue(ptl2.assignmentIterator()), 1.0E-5);
        TestLogTableFactor.assertEquals(0.0, ptl2.logValue(0), 1.0E-5);
        TestLogTableFactor.assertEquals(1.0, ptl2.value(assn2), 1.0E-5);
        TestLogTableFactor.assertEquals(1.0, ptl2.value(ptl2.assignmentIterator()), 1.0E-5);
        TestLogTableFactor.assertEquals(1.0, ptl2.value(0), 1.0E-5);
    }

    public void testOneVarSlice() {
        double[] vals = new double[]{0.0, 1.3862943611198906, 0.6931471805599453, 1.791759469228055};
        Variable v1 = new Variable(2);
        Variable v2 = new Variable(2);
        Variable[] vars = new Variable[]{v1, v2};
        LogTableFactor ptl = LogTableFactor.makeFromLogValues(vars, vals);
        Assignment assn = new Assignment(v1, 0);
        LogTableFactor sliced = (LogTableFactor)ptl.slice(assn);
        LogTableFactor expected = LogTableFactor.makeFromValues(v2, new double[]{1.0, 4.0});
        this.comparePotentialValues(sliced, expected);
        TestLogTableFactor.assertEquals(1, assn.varSet().size());
    }

    public void testTwoVarSlice() {
        double[] vals = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
        Variable v1 = new Variable(2);
        Variable v2 = new Variable(2);
        Variable v3 = new Variable(2);
        Variable[] vars = new Variable[]{v1, v2, v3};
        LogTableFactor ptl = LogTableFactor.makeFromValues(vars, vals);
        Assignment assn = new Assignment(v3, 0);
        LogTableFactor sliced = (LogTableFactor)ptl.slice(assn);
        LogTableFactor expected = LogTableFactor.makeFromValues(new Variable[]{v1, v2}, new double[]{0.0, 2.0, 4.0, 6.0});
        this.comparePotentialValues(sliced, expected);
    }

    public void testMultiVarSlice() {
        double[] vals = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0};
        Variable v1 = new Variable(2);
        Variable v2 = new Variable(2);
        Variable v3 = new Variable(2);
        Variable v4 = new Variable(2);
        Variable[] vars = new Variable[]{v1, v2, v3, v4};
        LogTableFactor ptl = LogTableFactor.makeFromValues(vars, vals);
        System.out.println(ptl);
        Assignment assn = new Assignment(v4, 0);
        LogTableFactor sliced = (LogTableFactor)ptl.slice(assn);
        LogTableFactor expected = LogTableFactor.makeFromValues(new Variable[]{v1, v2, v3}, new double[]{0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0});
        this.comparePotentialValues(sliced, expected);
    }

    public void testSparseValueAndLogValue() {
        Variable[] vars = new Variable[]{new Variable(2), new Variable(2)};
        int[] szs = new int[]{2, 2};
        int[] idxs1 = new int[]{1, 3};
        double[] vals1 = new double[]{4.0, 8.0};
        LogTableFactor ptl1 = LogTableFactor.makeFromMatrix(vars, new SparseMatrixn(szs, idxs1, vals1));
        AssignmentIterator it = ptl1.assignmentIterator();
        TestLogTableFactor.assertEquals(1, it.indexOfCurrentAssn());
        TestLogTableFactor.assertEquals(Math.log(4.0), ptl1.logValue(it), 1.0E-5);
        TestLogTableFactor.assertEquals(Math.log(4.0), ptl1.logValue(it.assignment()), 1.0E-5);
        TestLogTableFactor.assertEquals(4.0, ptl1.value(it), 1.0E-5);
        TestLogTableFactor.assertEquals(4.0, ptl1.value(it.assignment()), 1.0E-5);
        it = ptl1.varSet().assignmentIterator();
        TestLogTableFactor.assertEquals(0, it.indexOfCurrentAssn());
        TestLogTableFactor.assertEquals(Double.NEGATIVE_INFINITY, ptl1.logValue(it), 1.0E-5);
        TestLogTableFactor.assertEquals(Double.NEGATIVE_INFINITY, ptl1.logValue(it.assignment()), 1.0E-5);
        TestLogTableFactor.assertEquals(0.0, ptl1.value(it), 1.0E-5);
        TestLogTableFactor.assertEquals(0.0, ptl1.value(it.assignment()), 1.0E-5);
    }

    public void testSparseMultiplyLogSpace() {
        Variable[] vars = new Variable[]{new Variable(2), new Variable(2)};
        int[] szs = new int[]{2, 2};
        int[] idxs1 = new int[]{0, 1, 3};
        double[] vals1 = new double[]{2.0, 4.0, 8.0};
        int[] idxs2 = new int[]{0, 3};
        double[] vals2 = new double[]{0.5, 0.5};
        double[] vals3 = new double[]{1.0, 0.0, 4.0};
        LogTableFactor ptl1 = LogTableFactor.makeFromMatrix(vars, new SparseMatrixn(szs, idxs1, vals1));
        LogTableFactor ptl2 = LogTableFactor.makeFromMatrix(vars, new SparseMatrixn(szs, idxs2, vals2));
        LogTableFactor ans = LogTableFactor.makeFromMatrix(vars, new SparseMatrixn(szs, idxs1, vals3));
        Factor ptl3 = ptl1.multiply(ptl2);
        TestLogTableFactor.assertTrue("Tast failed! Expected: " + ans + " Actual: " + ptl3, ans.almostEquals(ptl3));
    }

    public void testSparseDivideLogSpace() {
        Variable[] vars = new Variable[]{new Variable(2), new Variable(2)};
        int[] szs = new int[]{2, 2};
        int[] idxs1 = new int[]{0, 1, 3};
        double[] vals1 = new double[]{2.0, 4.0, 8.0};
        int[] idxs2 = new int[]{0, 3};
        double[] vals2 = new double[]{0.5, 0.5};
        double[] vals3 = new double[]{4.0, 0.0, 16.0};
        LogTableFactor ptl1 = LogTableFactor.makeFromMatrix(vars, new SparseMatrixn(szs, idxs1, vals1));
        LogTableFactor ptl2 = LogTableFactor.makeFromMatrix(vars, new SparseMatrixn(szs, idxs2, vals2));
        LogTableFactor ans = LogTableFactor.makeFromMatrix(vars, new SparseMatrixn(szs, idxs1, vals3));
        ptl1.divideBy(ptl2);
        TestLogTableFactor.assertTrue("Tast failed! Expected: " + ans + " Actual: " + ptl1, ans.almostEquals(ptl1));
    }

    public void testSparseMarginalizeLogSpace() {
        Variable[] vars = new Variable[]{new Variable(2), new Variable(2)};
        int[] szs = new int[]{2, 2};
        int[] idxs1 = new int[]{0, 1, 3};
        double[] vals1 = new double[]{2.0, 4.0, 8.0};
        LogTableFactor ptl1 = LogTableFactor.makeFromMatrix(vars, new SparseMatrixn(szs, idxs1, vals1));
        LogTableFactor ans = LogTableFactor.makeFromValues(vars[0], new double[]{6.0, 8.0});
        Factor ptl2 = ptl1.marginalize(vars[0]);
        TestLogTableFactor.assertTrue("Tast failed! Expected: " + ans + " Actual: " + ptl2 + " Orig: " + ptl1, ans.almostEquals(ptl2));
    }

    public void testLogSample() {
        Variable v = new Variable(2);
        double[] vals = new double[]{-30.0, 0.0};
        LogTableFactor ptl = LogTableFactor.makeFromLogValues(v, vals);
        int idx = ptl.sampleLocation(new Randoms(43));
        TestLogTableFactor.assertEquals(1, idx);
    }

    public void testPlusEquals() {
        Variable var = new Variable(4);
        double[] vals = new double[]{Double.NEGATIVE_INFINITY, 0.0, 0.6931471805599453, 1.0986122886681098};
        LogTableFactor factor = LogTableFactor.makeFromLogValues(var, vals);
        factor.plusEquals(0.1);
        double[] expected = new double[]{-2.3025850929940455, 0.09531017980432493, 0.7419373447293773, 1.1314021114911006};
        LogTableFactor ans = LogTableFactor.makeFromLogValues(var, expected);
        TestLogTableFactor.assertTrue("Error: expected " + ans.dumpToString() + " but was " + factor.dumpToString(), factor.almostEquals(ans));
    }

    public void testRecenter() {
        Variable var = new Variable(4);
        double[] vals = new double[]{2.0, 4.0, 6.0, 8.0};
        LogTableFactor ltbl1 = LogTableFactor.makeFromValues(var, vals);
        ltbl1.recenter();
        double[] expected = new double[]{Math.log(0.25), Math.log(0.5), Math.log(0.75), 0.0};
        LogTableFactor ans = LogTableFactor.makeFromLogValues(var, expected);
        TestLogTableFactor.assertTrue("Error: expected " + ans.dumpToString() + "but was " + ltbl1.dumpToString(), ans.almostEquals(ltbl1));
    }

    public void testRecenter2() {
        Variable var = new Variable(4);
        double[] vals = new double[]{0.0, 1.4, 1.4, 0.0};
        LogTableFactor ltbl1 = LogTableFactor.makeFromLogValues(var, vals);
        ltbl1.recenter();
        double[] expected = new double[]{-1.4, 0.0, 0.0, -1.4};
        LogTableFactor ans = LogTableFactor.makeFromLogValues(var, expected);
        TestLogTableFactor.assertTrue(!ltbl1.isNaN());
        TestLogTableFactor.assertTrue("Error: expected " + ans.dumpToString() + "but was " + ltbl1.dumpToString(), ans.almostEquals(ltbl1));
    }

    public static Test suite() {
        return new TestSuite((Class<? extends TestCase>)TestLogTableFactor.class);
    }

    public static void main(String[] args) throws Throwable {
        TestSuite theSuite;
        if (args.length > 0) {
            theSuite = new TestSuite();
            for (int i = 0; i < args.length; ++i) {
                theSuite.addTest(new TestLogTableFactor(args[i]));
            }
        } else {
            theSuite = (TestSuite)TestLogTableFactor.suite();
        }
        TestRunner.run(theSuite);
    }
}

