/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.factorAnalysis;

import dr.evomodel.continuous.GibbsSampleFromTreeInterface;
import dr.inference.model.LatentFactorModel;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.PathDependent;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.SymmetricMatrix;

public class FactorTreeGibbsOperator
extends SimpleMCMCOperator
implements PathDependent,
GibbsOperator {
    private final LatentFactorModel lfm;
    private double pathParameter = 1.0;
    private final GibbsSampleFromTreeInterface tree;
    private final GibbsSampleFromTreeInterface workingTree;
    private final MatrixParameterInterface factors;
    private final MatrixParameterInterface errorPrec;
    private final boolean randomScan;
    private final Parameter missingIndicator;

    public FactorTreeGibbsOperator(double d, LatentFactorModel latentFactorModel, GibbsSampleFromTreeInterface gibbsSampleFromTreeInterface, Boolean bl) {
        this.setWeight(d);
        this.tree = gibbsSampleFromTreeInterface;
        this.lfm = latentFactorModel;
        this.factors = latentFactorModel.getFactors();
        this.errorPrec = latentFactorModel.getColumnPrecision();
        this.randomScan = bl;
        this.workingTree = null;
        this.missingIndicator = latentFactorModel.getMissingIndicator();
    }

    @Override
    public String getOperatorName() {
        return "Factor Tree Gibbs Operator";
    }

    @Override
    public double doOperation() {
        if (this.randomScan) {
            int n = MathUtils.nextInt(this.factors.getColumnDimension());
            MultivariateNormalDistribution multivariateNormalDistribution = this.getMVN(n);
            double[] dArray = (double[])multivariateNormalDistribution.nextRandom();
            for (int i = 0; i < this.factors.getRowDimension(); ++i) {
                this.factors.setParameterValue(i, n, dArray[i]);
            }
        } else {
            for (int i = 0; i < this.factors.getColumnDimension(); ++i) {
                MultivariateNormalDistribution multivariateNormalDistribution = this.getMVN(i);
                double[] dArray = (double[])multivariateNormalDistribution.nextRandom();
                for (int j = 0; j < this.factors.getRowDimension(); ++j) {
                    this.factors.setParameterValue(j, i, dArray[j]);
                }
            }
        }
        return 0.0;
    }

    MultivariateNormalDistribution getMVN(int n) {
        double[][] dArray = this.getPrecision(n);
        double[] dArray2 = this.getMean(n, dArray);
        return new MultivariateNormalDistribution(dArray2, dArray);
    }

    double[][] getPrecision(int n) {
        double[][] dArray = this.getTreePrec(n);
        for (int i = 0; i < this.lfm.getLoadings().getColumnDimension(); ++i) {
            for (int j = i; j < this.lfm.getLoadings().getColumnDimension(); ++j) {
                for (int k = 0; k < this.lfm.getLoadings().getRowDimension(); ++k) {
                    if (this.missingIndicator != null && this.missingIndicator.getParameterValue(n * this.lfm.getLoadings().getRowDimension() + k) == 1.0) continue;
                    double[] dArray2 = dArray[i];
                    int n2 = j;
                    dArray2[n2] = dArray2[n2] + this.lfm.getLoadings().getParameterValue(k, i) * this.errorPrec.getParameterValue(k, k) * this.lfm.getLoadings().getParameterValue(k, j) * this.pathParameter;
                }
                dArray[j][i] = dArray[i][j];
            }
        }
        return dArray;
    }

    double[] getMean(int n, double[][] dArray) {
        int n2;
        int n3;
        SymmetricMatrix symmetricMatrix = new SymmetricMatrix(dArray).inverse();
        double[] dArray2 = new double[this.lfm.getLoadings().getColumnDimension()];
        double[] dArray3 = this.getTreeMean(n);
        double[][] dArray4 = this.getTreePrec(n);
        for (n3 = 0; n3 < dArray2.length; ++n3) {
            int n4 = n3;
            dArray2[n4] = dArray2[n4] + dArray4[n3][n3] * dArray3[n3];
        }
        for (n3 = 0; n3 < this.lfm.getLoadings().getRowDimension(); ++n3) {
            for (n2 = 0; n2 < this.lfm.getLoadings().getColumnDimension(); ++n2) {
                if (this.missingIndicator != null && this.missingIndicator.getParameterValue(n * this.lfm.getScaledData().getRowDimension() + n3) == 1.0) continue;
                int n5 = n2;
                dArray2[n5] = dArray2[n5] + this.lfm.getScaledData().getParameterValue(n3, n) * this.errorPrec.getParameterValue(n3, n3) * this.lfm.getLoadings().getParameterValue(n3, n2) * this.pathParameter;
            }
        }
        double[] dArray5 = new double[dArray2.length];
        for (n2 = 0; n2 < dArray5.length; ++n2) {
            for (int i = 0; i < dArray5.length; ++i) {
                int n6 = n2;
                dArray5[n6] = dArray5[n6] + symmetricMatrix.component(n2, i) * dArray2[i];
            }
        }
        return dArray5;
    }

    public double[][] getTreePrec(int n) {
        double d = this.tree.getPrecisionFactor(n);
        double[][] dArray = new double[this.factors.getRowDimension()][this.factors.getRowDimension()];
        for (int i = 0; i < this.factors.getRowDimension(); ++i) {
            dArray[i][i] = d;
        }
        if (this.workingTree != null) {
            double[][] dArray2 = this.workingTree.getConditionalPrecision(n);
            for (int i = 0; i < dArray.length; ++i) {
                for (int j = 0; j < dArray.length; ++j) {
                    dArray[i][j] = dArray[i][j] * this.pathParameter + dArray2[i][j] * (1.0 - this.pathParameter);
                }
            }
        }
        return dArray;
    }

    public double[] getTreeMean(int n) {
        double[] dArray = this.tree.getConditionalMean(n);
        if (this.workingTree != null) {
            double[] dArray2 = this.workingTree.getConditionalMean(n);
            for (int i = 0; i < dArray.length; ++i) {
                dArray[i] = dArray[i] * this.pathParameter + dArray2[i] * (1.0 - this.pathParameter);
            }
        }
        return dArray;
    }

    @Override
    public void setPathParameter(double d) {
        this.pathParameter = d;
    }
}

