/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.test;

import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.BetaFactor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.ModelReader;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.Randoms;
import gnu.trove.TDoubleArrayList;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.StringReader;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

public class TestBetaFactor
extends TestCase {
    static String mdlstr = "VAR u1 u2 : continuous\nu1 ~ Beta 0.2 0.7\nu2 ~ Beta 1.0 0.3\n";

    public TestBetaFactor(String name) {
        super(name);
    }

    public void testVarSet() {
        Variable var = new Variable(-1);
        BetaFactor f = new BetaFactor(var, 0.5, 0.5);
        TestBetaFactor.assertEquals(1, f.varSet().size());
        TestBetaFactor.assertTrue(f.varSet().contains(var));
    }

    public void testValue() {
        Variable var = new Variable(-1);
        BetaFactor f = new BetaFactor(var, 1.0, 1.2);
        Assignment assn = new Assignment(var, 0.7);
        TestBetaFactor.assertEquals(0.94321, f.value(assn), 1.0E-5);
    }

    public void testSample() {
        Variable var = new Variable(-1);
        Randoms r = new Randoms(2343);
        BetaFactor f = new BetaFactor(var, 0.7, 0.5);
        TDoubleArrayList lst = new TDoubleArrayList();
        for (int i = 0; i < 100000; ++i) {
            Assignment assn = f.sample(r);
            lst.add(assn.getDouble(var));
        }
        double[] vals = lst.toNativeArray();
        double mean = MatrixOps.mean(vals);
        TestBetaFactor.assertEquals(0.5833333333333334, mean, 0.01);
    }

    public void testSample2() {
        Variable var = new Variable(-1);
        Randoms r = new Randoms(2343);
        BetaFactor f = new BetaFactor(var, 0.7, 0.5, 3.0, 8.0);
        TDoubleArrayList lst = new TDoubleArrayList();
        for (int i = 0; i < 100000; ++i) {
            Assignment assn = f.sample(r);
            lst.add(assn.getDouble(var));
        }
        double[] vals = lst.toNativeArray();
        double mean = MatrixOps.mean(vals);
        TestBetaFactor.assertEquals(5.92, mean, 0.01);
    }

    public void testSliceInFg() throws IOException {
        ModelReader reader = new ModelReader();
        FactorGraph fg = reader.readModel(new BufferedReader(new StringReader(mdlstr)));
        Variable u1 = fg.findVariable("u1");
        Variable u2 = fg.findVariable("u2");
        Assignment assn = new Assignment(new Variable[]{u1, u2}, new double[]{0.25, 0.85});
        FactorGraph fg2 = (FactorGraph)fg.slice(assn);
        TestBetaFactor.assertEquals(2, fg2.factors().size());
        TestBetaFactor.assertEquals(0.6708463722, fg2.value(new Assignment()), 1.0E-5);
    }

    public static TestSuite suite() {
        return new TestSuite((Class<? extends TestCase>)TestBetaFactor.class);
    }

    public static void main(String[] args) {
        TestSuite theSuite;
        if (args.length > 0) {
            theSuite = new TestSuite();
            for (int i = 0; i < args.length; ++i) {
                theSuite.addTest(new TestBetaFactor(args[i]));
            }
        } else {
            theSuite = TestBetaFactor.suite();
        }
        TestRunner.run(theSuite);
    }
}

