/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.test;

import cc.mallet.grmm.types.AbstractTableFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.PottsTableFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

public class TestPottsFactor
extends TestCase {
    private PottsTableFactor factor;
    private Variable alpha;
    private VarSet vars;

    public TestPottsFactor(String name) {
        super(name);
    }

    public static TestSuite suite() {
        return new TestSuite((Class<? extends TestCase>)TestPottsFactor.class);
    }

    @Override
    protected void setUp() throws Exception {
        this.alpha = new Variable(-1);
        Variable v1 = new Variable(2);
        Variable v2 = new Variable(2);
        this.vars = new HashVarSet(new Variable[]{v1, v2});
        this.factor = new PottsTableFactor(this.vars, this.alpha);
    }

    public void testSlice() {
        Assignment assn = new Assignment(this.alpha, 1.0);
        Factor sliced = this.factor.slice(assn);
        TestPottsFactor.assertTrue(sliced instanceof AbstractTableFactor);
        TestPottsFactor.assertTrue(sliced.varSet().equals(this.vars));
        TableFactor expected = new TableFactor(this.vars, new double[]{1.0, Math.exp(-1.0), Math.exp(-1.0), 1.0});
        TestPottsFactor.assertTrue(sliced.almostEquals(expected));
    }

    public void testSumGradLog() {
        Assignment alphaAssn = new Assignment(this.alpha, 1.0);
        double[] values = new double[]{0.4, 0.1, 0.3, 0.2};
        TableFactor q = new TableFactor(this.vars, values);
        double grad = this.factor.sumGradLog(q, this.alpha, alphaAssn);
        TestPottsFactor.assertEquals(-0.4, grad, 1.0E-5);
    }

    public void testSumGradLog2() {
        Assignment alphaAssn = new Assignment(this.alpha, 1.0);
        double[] values = new double[]{0.4, 0.1, 0.3, 0.2};
        TableFactor q1 = new TableFactor(this.vars, values);
        TableFactor q2 = new TableFactor(new Variable(2), new double[]{0.7, 0.3});
        Factor q = q1.multiply(q2);
        double grad = this.factor.sumGradLog(q, this.alpha, alphaAssn);
        TestPottsFactor.assertEquals(-0.4, grad, 1.0E-5);
    }

    public static void main(String[] args) {
        TestSuite theSuite;
        if (args.length > 0) {
            theSuite = new TestSuite();
            for (int i = 0; i < args.length; ++i) {
                theSuite.addTest(new TestPottsFactor(args[i]));
            }
        } else {
            theSuite = TestPottsFactor.suite();
        }
        TestRunner.run(theSuite);
    }
}

