/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.model;

import dr.evomodel.treedatalikelihood.hmc.AbstractPrecisionGradient;
import dr.inference.model.AbstractTransformedCompoundMatrix;
import dr.inference.model.Parameter;
import dr.inference.model.TransformedMultivariateParameter;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.util.CorrelationToCholesky;
import dr.util.Transform;

public class CompoundSymmetricMatrix
extends AbstractTransformedCompoundMatrix {
    private final boolean asCorrelation;
    private final boolean isCholesky;

    public CompoundSymmetricMatrix(Parameter parameter, Parameter parameter2, boolean bl, boolean bl2) {
        super(parameter, parameter2, CompoundSymmetricMatrix.getTransformation(parameter.getDimension(), bl2), true);
        assert (bl || !bl2);
        this.asCorrelation = bl;
        this.isCholesky = bl2;
    }

    private static Transform.MultivariableTransform getTransformation(int n, Boolean bl) {
        return bl != false ? new CorrelationToCholesky(n) : null;
    }

    @Override
    public String toString() {
        return this.toStringCompoundParameter(CompoundSymmetricMatrix.getVechDimension(this.dim));
    }

    private static int getVechuDimension(int n) {
        return n * (n - 1) / 2;
    }

    private static int getVechDimension(int n) {
        return n * (n + 1) / 2;
    }

    @Override
    public double getParameterValue(int n, int n2) {
        if (n != n2) {
            if (this.asCorrelation) {
                return this.offDiagonalParameter.getParameterValue(this.getUpperTriangularIndex(n, n2)) * Math.sqrt(this.diagonalParameter.getParameterValue(n) * this.diagonalParameter.getParameterValue(n2));
            }
            return this.offDiagonalParameter.getParameterValue(this.getUpperTriangularIndex(n, n2));
        }
        if (this.isStrictlyUpperTriangular) {
            return this.diagonalParameter.getParameterValue(n);
        }
        return this.diagonalParameter.getParameterValue(n) * this.offDiagonalParameter.getParameterValue(this.getUpperTriangularIndex(n, n));
    }

    @Override
    public double[][] getParameterAsMatrix() {
        int n = this.dim;
        double[][] dArray = new double[n][n];
        for (int i = 0; i < n; ++i) {
            dArray[i][i] = this.getParameterValue(i, i);
            for (int j = i + 1; j < n; ++j) {
                double d = this.getParameterValue(i, j);
                dArray[i][j] = d;
                dArray[j][i] = d;
            }
        }
        return dArray;
    }

    @Override
    public boolean isConstrainedSymmetric() {
        return true;
    }

    public boolean isCholesky() {
        return this.isCholesky;
    }

    public boolean asCorrelation() {
        return this.asCorrelation;
    }

    private double[][] getCorrelationMatrix() {
        SymmetricMatrix symmetricMatrix = SymmetricMatrix.compoundCorrelationSymmetricMatrix(this.offDiagonalParameter.getParameterValues(), this.dim);
        if (!this.asCorrelation) {
            for (int i = 0; i < this.dim; ++i) {
                for (int j = i + 1; j < this.dim; ++j) {
                    symmetricMatrix.setSymmetric(i, j, symmetricMatrix.component(i, j) / Math.sqrt(this.diagonalParameter.getParameterValue(i) * this.diagonalParameter.getParameterValue(j)));
                }
            }
        }
        return symmetricMatrix.toComponents();
    }

    @Override
    public double[] updateGradientOffDiagonal(double[] dArray) {
        assert (dArray.length == this.dim * this.dim);
        double[] dArray2 = this.diagonalParameter.getParameterValues();
        double[] dArray3 = new double[CompoundSymmetricMatrix.getVechuDimension(this.dim)];
        int n = 0;
        for (int i = 0; i < this.dim - 1; ++i) {
            for (int j = i + 1; j < this.dim; ++j) {
                dArray3[n] = 2.0 * dArray[i * this.dim + j] * Math.sqrt(dArray2[i] * dArray2[j]);
                ++n;
            }
        }
        return this.updateGradientCorrelation(dArray3);
    }

    public double[] updateGradientFullOffDiagonal(double[] dArray) {
        assert (dArray.length == this.dim * this.dim);
        double[] dArray2 = this.diagonalParameter.getParameterValues();
        double[] dArray3 = new double[dArray.length];
        int n = 0;
        for (int i = 0; i < this.dim; ++i) {
            for (int j = 0; j < this.dim; ++j) {
                dArray3[n] = dArray[i * this.dim + j] * Math.sqrt(dArray2[i] * dArray2[j]);
                ++n;
            }
        }
        return dArray3;
    }

    public double[] updateGradientCorrelation(double[] dArray) {
        if (!this.isCholesky) {
            return dArray;
        }
        CorrelationToCholesky correlationToCholesky = new CorrelationToCholesky(this.dim);
        return correlationToCholesky.updateGradientInverseUnWeightedLogDensity(dArray, ((TransformedMultivariateParameter)this.offDiagonalParameter).getParameterUntransformedValues(), 0, dArray.length);
    }

    @Override
    public double[] updateGradientDiagonal(double[] dArray) {
        assert (dArray.length == this.dim * this.dim);
        double[] dArray2 = this.diagonalParameter.getParameterValues();
        double[] dArray3 = AbstractPrecisionGradient.flatten(this.getCorrelationMatrix());
        double[] dArray4 = new double[this.dim];
        for (int i = 0; i < this.dim; ++i) {
            double d = 0.0;
            for (int j = 0; j < this.dim; ++j) {
                d += dArray[i * this.dim + j] * Math.sqrt(dArray2[j] / dArray2[i]) * dArray3[i * this.dim + j];
            }
            dArray4[i] = d;
        }
        return dArray4;
    }

    @Override
    public String getReport() {
        return new WrappedMatrix.ArrayOfArray(this.getParameterAsMatrix()).toString();
    }
}

