/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.bahsic;

import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.kernel.Kernel;
import java.util.Vector;
import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;

public class HSICEstimator {
    Kernel kernelx;
    Kernel kernely;
    RealMatrix lMatrix;
    RealMatrix y;
    int yrows;
    int ycols;
    double ltotalSum;
    Vector<Double> lrowSums = new Vector();
    Vector<Double> lcolSums = new Vector();

    public HSICEstimator(Kernel kernelx, Kernel kernely, RealMatrix y, boolean biased) {
        int i;
        this.kernelx = kernelx;
        this.kernely = kernely;
        this.y = y;
        this.yrows = y.getRowDimension();
        this.ycols = y.getColumnDimension();
        Vector<Boolean> selectedIndices = new Vector<Boolean>();
        for (int j = 0; j < this.ycols; ++j) {
            selectedIndices.add(true);
        }
        this.lMatrix = this.computeKernelMatrix(y, kernely, selectedIndices);
        if (!biased) {
            for (i = 0; i < this.yrows; ++i) {
                this.lMatrix.setEntry(i, i, 0.0);
            }
        }
        this.ltotalSum = 0.0;
        for (i = 0; i < this.yrows; ++i) {
            double rowSum = 0.0;
            double colSum = 0.0;
            for (int j = 0; j < this.yrows; ++j) {
                rowSum += this.lMatrix.getEntry(i, j);
                colSum += this.lMatrix.getEntry(j, i);
            }
            this.ltotalSum += rowSum;
            this.lrowSums.add(rowSum);
            this.lcolSums.add(colSum);
        }
    }

    public RealMatrix computeHLH() throws OperatorException {
        int i;
        RealMatrix hMatrix = MatrixUtils.createRealMatrix(this.yrows, this.yrows);
        for (i = 0; i < this.yrows; ++i) {
            hMatrix.setEntry(i, i, 1.0 - 1.0 / (double)this.yrows);
        }
        for (i = 0; i < this.yrows; ++i) {
            for (int j = 0; j < i; ++j) {
                hMatrix.setEntry(i, j, -1.0 / (double)this.yrows);
                hMatrix.setEntry(j, i, -1.0 / (double)this.yrows);
            }
        }
        return hMatrix.multiply(this.lMatrix.multiply(hMatrix));
    }

    public RealMatrix computeHLHFast() throws OperatorException {
        RealMatrix outerSumMatrix = MatrixUtils.createRealMatrix(this.yrows, this.yrows);
        RealMatrix totalSumMatrix = MatrixUtils.createRealMatrix(this.yrows, this.yrows);
        for (int i = 0; i < this.yrows; ++i) {
            for (int j = 0; j < this.yrows; ++j) {
                outerSumMatrix.setEntry(i, j, this.lrowSums.get(i) + this.lcolSums.get(j));
                totalSumMatrix.setEntry(i, j, this.ltotalSum);
            }
        }
        outerSumMatrix = outerSumMatrix.scalarMultiply(1.0 / (double)this.yrows);
        totalSumMatrix = totalSumMatrix.scalarMultiply(1.0 / (double)(this.yrows * this.yrows));
        return this.lMatrix.subtract(outerSumMatrix).add(totalSumMatrix);
    }

    public RealMatrix computeKernelMatrix(RealMatrix x, Kernel kernel, Vector<Boolean> selectedIndices) {
        int xrows = x.getRowDimension();
        int xcols = x.getColumnDimension();
        RealMatrix kernelMatrix = MatrixUtils.createRealMatrix(xrows, xrows);
        int[] xIndex = new int[xcols];
        for (int i = 0; i < xcols; ++i) {
            xIndex[i] = i;
        }
        int[] yIndex = this.getYIndices(selectedIndices);
        for (int i = 0; i < xrows; ++i) {
            for (int j = 0; j <= i; ++j) {
                double kvalue = kernel.calculate_K(xIndex, x.getRow(i), yIndex, x.getRow(j));
                kernelMatrix.setEntry(i, j, kvalue);
                kernelMatrix.setEntry(j, i, kvalue);
            }
        }
        return kernelMatrix;
    }

    public double biasedHSIC(RealMatrix kMatrix, RealMatrix hlhMatrix) throws OperatorException {
        int m = kMatrix.getRowDimension();
        if (m != kMatrix.getColumnDimension() || m != hlhMatrix.getRowDimension() || m != hlhMatrix.getColumnDimension()) {
            throw new OperatorException("Dimensions do not match");
        }
        return kMatrix.multiply(hlhMatrix).getTrace() / (double)((m - 1) * (m - 1));
    }

    public double unbiasedHSIC(RealMatrix kMatrix) throws OperatorException {
        int m = kMatrix.getRowDimension();
        if (m != kMatrix.getColumnDimension() || m != this.yrows) {
            throw new OperatorException("Dimensions do not match");
        }
        for (int i = 0; i < m; ++i) {
            kMatrix.setEntry(i, i, 0.0);
        }
        RealMatrix klMatrix = kMatrix.multiply(this.lMatrix);
        double kltrace = klMatrix.getTrace();
        double ktotalSum = 0.0;
        double kltotalSum = 0.0;
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < m; ++j) {
                ktotalSum += kMatrix.getEntry(i, j);
                kltotalSum += klMatrix.getEntry(i, j);
            }
        }
        return (kltrace + ktotalSum * this.ltotalSum / ((double)(m - 1) * (double)(m - 2)) - 2.0 * kltotalSum / (double)(m - 2)) / (double)(m * (m - 3));
    }

    public RealMatrix outerProduct(RealMatrix x) throws OperatorException {
        if (x.getColumnDimension() != 1) {
            throw new OperatorException("Vector expected, not a matrix");
        }
        int rows = x.getRowDimension();
        RealMatrix outer = MatrixUtils.createRealMatrix(rows, rows);
        for (int i = 0; i < rows; ++i) {
            for (int j = 0; j <= i; ++j) {
                double xi = x.getEntry(i, 0);
                double xj = x.getEntry(j, 0);
                outer.setEntry(i, j, xi * xj);
                outer.setEntry(j, i, xj * xi);
            }
        }
        return outer;
    }

    public int[] getYIndices(Vector<Boolean> selectedIndices) {
        int size = selectedIndices.size();
        int[] yIndex = new int[size];
        for (int j = size - 1; j >= 0; --j) {
            yIndex[j] = selectedIndices.get(j) == true ? j : size + 1;
        }
        return yIndex;
    }
}

