/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.PriorPreconditioningProvider;
import dr.inference.model.Variable;
import dr.inference.operators.hmc.MassPreconditioningOptions;
import dr.inference.operators.hmc.SecantHessian;
import dr.math.AdaptableCovariance;
import dr.math.AdaptableVector;
import dr.math.MachineAccuracy;
import dr.math.MathUtils;
import dr.math.MultivariateFunction;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.RobustEigenDecomposition;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public interface MassPreconditioner {
    public WrappedVector drawInitialMomentum();

    public double getVelocity(int var1, ReadableVector var2);

    public void storeSecant(ReadableVector var1, ReadableVector var2);

    public void updateMass();

    public WrappedVector getMass();

    public void updateVariance(WrappedVector var1);

    public ReadableVector doCollision(int[] var1, ReadableVector var2);

    public int getDimension();

    public static abstract class AbstractMassPreconditioning
    extends AbstractModel
    implements MassPreconditioner {
        protected final int dim;
        protected final Transform transform;
        protected Parameter inverseMass;
        private static final String PRECONDITIONING = "MassPreconditioning";
        protected static final String MASSNAME = "InverseMass";

        protected AbstractMassPreconditioning(int dim, Transform transform) {
            super(PRECONDITIONING);
            this.dim = dim;
            this.transform = transform;
        }

        protected abstract void initializeMass();

        protected abstract void computeInverseMass();

        @Override
        public void updateMass() {
            this.computeInverseMass();
        }

        @Override
        public int getDimension() {
            return this.dim;
        }

        @Override
        public abstract void storeSecant(ReadableVector var1, ReadableVector var2);

        protected void setInverseMassFromArray(double[] inverseMassArray) {
            int i = 0;
            while (i < inverseMassArray.length) {
                this.inverseMass.setParameterValue(i, inverseMassArray[i]);
                ++i;
            }
        }

        protected void handleModelChangedEvent(Model model, Object object, int index) {
        }

        protected void handleVariableChangedEvent(Variable variable, int index, Variable.ChangeType type) {
        }

        protected void storeState() {
        }

        protected void restoreState() {
        }

        protected void acceptState() {
        }
    }

    public static class AdaptiveDiagonalPreconditioning
    extends DiagonalPreconditioning {
        private AdaptableVector.AdaptableVariance variance;
        private final int minimumUpdates;
        private final GradientWrtParameterProvider gradient;
        private final MassPreconditioningOptions options;

        AdaptiveDiagonalPreconditioning(int dim, GradientWrtParameterProvider gradient, Transform transform, MassPreconditioningOptions options) {
            this(dim, gradient, transform, options, false);
        }

        AdaptiveDiagonalPreconditioning(int dim, GradientWrtParameterProvider gradient, Transform transform, MassPreconditioningOptions options, boolean guessInitialMass) {
            super(dim, transform);
            this.variance = new AdaptableVector.AdaptableVariance(dim);
            this.options = options;
            this.minimumUpdates = options.preconditioningDelay();
            this.gradient = gradient;
            if (guessInitialMass) {
                this.setInitialMass();
            } else {
                super.initializeMass();
            }
        }

        @Override
        protected void initializeMass() {
        }

        private void setInitialMass() {
            double[] values = this.gradient.getParameter().getParameterValues();
            double[] storedValues = (double[])values.clone();
            int i = 0;
            while (i < this.dim) {
                this.gradient.getParameter().setParameterValueQuietly(i, values[i] + MachineAccuracy.SQRT_SQRT_EPSILON);
                ++i;
            }
            this.gradient.getParameter().fireParameterChangedEvent();
            double[] gradientPlus = this.gradient.getGradientLogDensity();
            int i2 = 0;
            while (i2 < this.dim) {
                this.gradient.getParameter().setParameterValueQuietly(i2, values[i2] - MachineAccuracy.SQRT_SQRT_EPSILON);
                ++i2;
            }
            this.gradient.getParameter().fireParameterChangedEvent();
            double[] gradientMinus = this.gradient.getGradientLogDensity();
            int i3 = 0;
            while (i3 < this.dim) {
                this.gradient.getParameter().setParameterValueQuietly(i3, values[i3]);
                ++i3;
            }
            this.gradient.getParameter().fireParameterChangedEvent();
            i3 = 0;
            while (i3 < this.dim) {
                values[i3] = Math.abs((gradientPlus[i3] - gradientMinus[i3]) / (2.0 * MachineAccuracy.SQRT_SQRT_EPSILON));
                this.gradient.getParameter().setParameterValueQuietly(i3, storedValues[i3]);
                ++i3;
            }
            this.gradient.getParameter().fireParameterChangedEvent();
            this.fillZeros(values);
            this.setInverseMassFromArray(this.normalizeVector((ReadableVector)new WrappedVector.Raw(values), this.dim));
        }

        private void fillZeros(double[] positives) {
            double sum = 0.0;
            double min = Double.POSITIVE_INFINITY;
            int i = 0;
            while (i < positives.length) {
                sum += positives[i];
                if (min > positives[i] && positives[i] > 0.0) {
                    min = positives[i];
                }
                ++i;
            }
            if (sum == 0.0) {
                Arrays.fill(positives, 1.0);
            } else {
                i = 0;
                while (i < positives.length) {
                    if (positives[i] == 0.0) {
                        positives[i] = min;
                    }
                    ++i;
                }
            }
        }

        @Override
        protected void computeInverseMass() {
            if (this.variance.getUpdateCount() > this.minimumUpdates) {
                double[] newVariance = this.variance.getVariance();
                this.setInverseMassFromArray(DiagonalHessianPreconditioning.boundMassInverse(newVariance, this.options.preconditioningEigenLowerBound(), this.options.preconditioningEigenUpperBound(), this.dim, DiagonalHessianPreconditioning.VarianceConverter.VARIANCE));
            }
        }

        @Override
        public void storeSecant(ReadableVector gradient, ReadableVector position) {
            this.variance.update(position);
        }

        @Override
        public void updateVariance(WrappedVector position) {
            this.variance.update((ReadableVector)position);
        }

        @Override
        public WrappedVector getMass() {
            double[] mass = new double[this.dim];
            int i = 0;
            while (i < this.dim) {
                mass[i] = 1.0 / this.inverseMass.getParameterValue(i);
                ++i;
            }
            return new WrappedVector.Raw(mass);
        }
    }

    public static class AdaptiveFullHessianPreconditioning
    extends FullHessianPreconditioning {
        private final AdaptableCovariance adaptableCovariance;
        private final GradientWrtParameterProvider gradientProvider;
        private final AdaptableVector averageCovariance;
        private final double[] inverseMassBuffer;
        private final int minimumUpdates;
        protected MultivariateFunction numeric1 = new MultivariateFunction(){

            public double evaluate(double[] argument) {
                int i = 0;
                while (i < argument.length) {
                    gradientProvider.getParameter().setParameterValue(i, argument[i]);
                    ++i;
                }
                return gradientProvider.getLikelihood().getLogLikelihood();
            }

            public int getNumArguments() {
                return gradientProvider.getParameter().getDimension();
            }

            public double getLowerBound(int n) {
                return 0.0;
            }

            public double getUpperBound(int n) {
                return Double.POSITIVE_INFINITY;
            }
        };

        AdaptiveFullHessianPreconditioning(GradientWrtParameterProvider gradientProvider, AdaptableCovariance adaptableCovariance, Transform transform, int dim, int preconditioningDelay) {
            super(null, transform, dim);
            this.adaptableCovariance = adaptableCovariance;
            this.gradientProvider = gradientProvider;
            this.averageCovariance = new AdaptableVector.Default(dim * dim);
            this.inverseMassBuffer = new double[dim * dim];
            this.minimumUpdates = preconditioningDelay;
        }

        @Override
        protected void computeInverseMass() {
            if (this.adaptableCovariance.getUpdateCount() > this.minimumUpdates) {
                WrappedMatrix.ArrayOfArray covariance = (WrappedMatrix.ArrayOfArray)this.adaptableCovariance.getCovariance();
                double[] flatCovariance = new double[this.dim * this.dim];
                int i = 0;
                while (i < this.dim) {
                    System.arraycopy(covariance.getArrays()[i], 0, flatCovariance, i * this.dim, this.dim);
                    ++i;
                }
                this.averageCovariance.update((ReadableVector)new WrappedVector.Raw(flatCovariance));
                this.cacheAverageCovariance(this.normalizeCovariance((WrappedVector)((WrappedVector.Raw)this.averageCovariance.getMean())));
                this.setInverseMassFromArray(this.inverseMassBuffer);
            }
        }

        private ReadableVector normalizeCovariance(WrappedVector flatCovariance) {
            double sum = 0.0;
            int i = 0;
            while (i < this.dim) {
                sum += flatCovariance.get(i * this.dim + i);
                ++i;
            }
            double multiplier = (double)this.dim / sum;
            int i2 = 0;
            while (i2 < this.dim * this.dim) {
                flatCovariance.set(i2, flatCovariance.get(i2) * multiplier);
                ++i2;
            }
            return flatCovariance;
        }

        private void cacheAverageCovariance(ReadableVector mean) {
            double[][] tempVariance = new double[this.dim][this.dim];
            int i = 0;
            while (i < this.dim) {
                int j = 0;
                while (j < this.dim) {
                    tempVariance[i][j] = -mean.get(i * this.dim + j);
                    ++j;
                }
                ++i;
            }
            double[] transformedVariance = FullHessianPreconditioning.PDTransformMatrix.Default.transformMatrix(tempVariance, this.dim);
            int i2 = 0;
            while (i2 < this.dim) {
                int j = 0;
                while (j < this.dim) {
                    this.inverseMassBuffer[i2 * this.dim + j] = transformedVariance[i2 * this.dim + j];
                    ++j;
                }
                ++i2;
            }
        }

        @Override
        public void storeSecant(ReadableVector gradient, ReadableVector position) {
            this.adaptableCovariance.update(position);
        }
    }

    public static class CompoundPreconditioning
    implements MassPreconditioner {
        final int dim;
        final List<MassPreconditioner> preconditionerList;
        boolean velocityKnown = false;
        double[] velocity;

        CompoundPreconditioning(List<MassPreconditioner> preconditionerList) {
            int thisDim = 0;
            for (MassPreconditioner preconditioner : preconditionerList) {
                thisDim += preconditioner.getDimension();
            }
            this.dim = thisDim;
            this.preconditionerList = preconditionerList;
            this.velocity = new double[this.dim];
        }

        @Override
        public WrappedVector drawInitialMomentum() {
            WrappedVector.Raw initialMomentum = new WrappedVector.Raw(new double[this.dim]);
            int currentIndex = 0;
            for (MassPreconditioner preconditioner : this.preconditionerList) {
                WrappedVector currentMomentum = preconditioner.drawInitialMomentum();
                int i = 0;
                while (i < preconditioner.getDimension()) {
                    initialMomentum.set(currentIndex + i, currentMomentum.get(i));
                    ++i;
                }
                currentIndex += preconditioner.getDimension();
            }
            return initialMomentum;
        }

        @Override
        public double getVelocity(int index, ReadableVector momentum) {
            this.getVelocityVector(momentum);
            return this.velocity[index];
        }

        private void getVelocityVector(ReadableVector momentum) {
            if (!this.velocityKnown) {
                int currentIndex = 0;
                List<ReadableVector> separatedMomentum = this.separateVectors(momentum);
                int j = 0;
                while (j < this.preconditionerList.size()) {
                    MassPreconditioner preconditioner = this.preconditionerList.get(j);
                    ReadableVector currentMomentum = separatedMomentum.get(j);
                    int i = 0;
                    while (i < preconditioner.getDimension()) {
                        this.velocity[currentIndex + i] = preconditioner.getVelocity(i, currentMomentum);
                        ++i;
                    }
                    currentIndex += preconditioner.getDimension();
                    ++j;
                }
                this.velocityKnown = true;
            }
        }

        private List<ReadableVector> separateVectors(ReadableVector rawVector) {
            ArrayList<ReadableVector> vectors = new ArrayList<ReadableVector>();
            int currentIndex = 0;
            for (MassPreconditioner preconditioner : this.preconditionerList) {
                WrappedVector.Raw thisVector = new WrappedVector.Raw(new double[preconditioner.getDimension()]);
                int i = 0;
                while (i < preconditioner.getDimension()) {
                    thisVector.set(i, rawVector.get(currentIndex + i));
                    ++i;
                }
                vectors.add((ReadableVector)thisVector);
                currentIndex += preconditioner.getDimension();
            }
            return vectors;
        }

        private ReadableVector combineVectors(List<ReadableVector> vectors) {
            WrappedVector.Raw combinedVector = new WrappedVector.Raw(new double[this.dim]);
            int currentIndex = 0;
            for (ReadableVector readableVector : vectors) {
                int i = 0;
                while (i < readableVector.getDim()) {
                    combinedVector.set(currentIndex + i, readableVector.get(i));
                    ++i;
                }
            }
            return combinedVector;
        }

        @Override
        public void storeSecant(ReadableVector gradient, ReadableVector position) {
            List<ReadableVector> separatedGradient = this.separateVectors(gradient);
            List<ReadableVector> separatedPosition = this.separateVectors(position);
            int i = 0;
            while (i < this.preconditionerList.size()) {
                this.preconditionerList.get(i).storeSecant(separatedGradient.get(i), separatedPosition.get(i));
                ++i;
            }
        }

        @Override
        public void updateMass() {
            for (MassPreconditioner preconditioner : this.preconditionerList) {
                preconditioner.updateMass();
            }
            this.velocityKnown = false;
        }

        @Override
        public ReadableVector doCollision(int[] indices, ReadableVector momentum) {
            throw new RuntimeException("Not yet implemented!");
        }

        @Override
        public WrappedVector getMass() {
            throw new RuntimeException("Not yet implemented!");
        }

        @Override
        public void updateVariance(WrappedVector position) {
        }

        @Override
        public int getDimension() {
            return this.dim;
        }
    }

    public static class DiagonalHessianPreconditioning
    extends DiagonalPreconditioning {
        protected final HessianWrtParameterProvider hessian;
        private final Parameter lowerBound;
        private final Parameter upperBound;

        DiagonalHessianPreconditioning(HessianWrtParameterProvider hessian, Transform transform, int memorySize, Parameter lowerBound, Parameter upperBound) {
            super(hessian.getDimension(), transform);
            this.hessian = hessian;
            this.adaptiveDiagonal = memorySize > 0 ? new AdaptableVector.LimitedMemory(hessian.getDimension(), memorySize) : new AdaptableVector.Default(hessian.getDimension());
            this.lowerBound = lowerBound;
            this.upperBound = upperBound;
        }

        @Override
        protected void computeInverseMass() {
            double[] newDiagonalHessian = this.hessian.getDiagonalHessianLogDensity();
            if (this.transform != null) {
                double[] untransformedValues = this.hessian.getParameter().getParameterValues();
                double[] gradient = this.hessian.getGradientLogDensity();
                newDiagonalHessian = this.transform.updateDiagonalHessianLogDensity(newDiagonalHessian, gradient, untransformedValues, 0, this.dim);
            }
            this.adaptiveDiagonal.update((ReadableVector)new WrappedVector.Raw(newDiagonalHessian));
            double[] boundedDiagonal = DiagonalHessianPreconditioning.boundMassInverse(((WrappedVector)this.adaptiveDiagonal.getMean()).getBuffer(), this.lowerBound, this.upperBound, this.dim, VarianceConverter.HESSIAN);
            this.setInverseMassFromArray(boundedDiagonal);
        }

        public static double[] boundMassInverse(double[] diagonalHessian, Parameter lowerBound, Parameter upperBound, int dim, VarianceConverter varianceConverter) {
            double[] boundedMassInverse = (double[])diagonalHessian.clone();
            DiagonalHessianPreconditioning.normalizeL1(boundedMassInverse, dim);
            int i = 0;
            while (i < dim) {
                boundedMassInverse[i] = varianceConverter.convertVariance(boundedMassInverse[i]);
                if (boundedMassInverse[i] < lowerBound.getParameterValue(0)) {
                    boundedMassInverse[i] = lowerBound.getParameterValue(0);
                } else if (boundedMassInverse[i] > upperBound.getParameterValue(0)) {
                    boundedMassInverse[i] = upperBound.getParameterValue(0);
                }
                ++i;
            }
            DiagonalHessianPreconditioning.normalizeL1(boundedMassInverse, dim);
            return boundedMassInverse;
        }

        private static void normalizeL1(double[] vector, double norm) {
            double sum = 0.0;
            int i = 0;
            while (i < vector.length) {
                sum += Math.abs(vector[i]);
                ++i;
            }
            double multiplier = norm / sum;
            int i2 = 0;
            while (i2 < vector.length) {
                vector[i2] = vector[i2] * multiplier;
                ++i2;
            }
        }

        @Override
        public void storeSecant(ReadableVector gradient, ReadableVector position) {
        }

        @Override
        public void updateVariance(WrappedVector position) {
        }

        @Override
        public WrappedVector getMass() {
            throw new RuntimeException("Not yet implemented!");
        }

        static enum VarianceConverter {
            HESSIAN{

                @Override
                double convertVariance(double input) {
                    return -1.0 / input;
                }
            }
            ,
            VARIANCE{

                @Override
                double convertVariance(double input) {
                    return input;
                }
            };


            abstract double convertVariance(double var1);
        }
    }

    public static abstract class DiagonalPreconditioning
    extends AbstractMassPreconditioning {
        protected AdaptableVector adaptiveDiagonal;

        protected DiagonalPreconditioning(int dim, Transform transform) {
            super(dim, transform);
            this.adaptiveDiagonal = new AdaptableVector.Default(dim);
            this.inverseMass = new Parameter.Default("InverseMass", dim);
            this.initializeMass();
            this.addVariable((Variable)this.inverseMass);
        }

        @Override
        protected void initializeMass() {
            double[] result = new double[this.dim];
            Arrays.fill(result, 1.0);
            double[] normalizedResult = this.normalizeVector((ReadableVector)new WrappedVector.Raw(result), this.dim);
            this.setInverseMassFromArray(normalizedResult);
        }

        protected double[] normalizeVector(ReadableVector values, double targetSum) {
            double sum = 0.0;
            int i = 0;
            while (i < values.getDim()) {
                sum += values.get(i);
                ++i;
            }
            double multiplier = targetSum / sum;
            double[] normalizedValues = new double[values.getDim()];
            int i2 = 0;
            while (i2 < values.getDim()) {
                normalizedValues[i2] = values.get(i2) * multiplier;
                ++i2;
            }
            return normalizedValues;
        }

        @Override
        public WrappedVector drawInitialMomentum() {
            double[] momentum = new double[this.dim];
            int i = 0;
            while (i < this.dim) {
                momentum[i] = MathUtils.nextGaussian() * Math.sqrt(1.0 / this.inverseMass.getParameterValue(i));
                ++i;
            }
            return new WrappedVector.Raw(momentum);
        }

        @Override
        public double getVelocity(int i, ReadableVector momentum) {
            return momentum.get(i) * this.inverseMass.getParameterValue(i);
        }

        @Override
        public ReadableVector doCollision(int[] indices, ReadableVector momentum) {
            if (indices.length != 2) {
                throw new RuntimeException("Not implemented for more than two dimensions yet.");
            }
            WrappedVector.Raw updatedMomentum = new WrappedVector.Raw(new double[momentum.getDim()]);
            int i = 0;
            while (i < momentum.getDim()) {
                updatedMomentum.set(i, momentum.get(i));
                ++i;
            }
            int index1 = indices[0];
            int index2 = indices[1];
            double updatedMomentum1 = ((this.inverseMass.getParameterValue(index2) - this.inverseMass.getParameterValue(index1)) * momentum.get(index1) + 2.0 * this.inverseMass.getParameterValue(index2) * momentum.get(index2)) / (this.inverseMass.getParameterValue(index1) + this.inverseMass.getParameterValue(index2));
            double updatedMomentum2 = ((this.inverseMass.getParameterValue(index1) - this.inverseMass.getParameterValue(index2)) * momentum.get(index2) + 2.0 * this.inverseMass.getParameterValue(index1) * momentum.get(index1)) / (this.inverseMass.getParameterValue(index1) + this.inverseMass.getParameterValue(index2));
            updatedMomentum.set(index1, updatedMomentum1);
            updatedMomentum.set(index2, updatedMomentum2);
            return updatedMomentum;
        }
    }

    public static class FullHessianPreconditioning
    extends HessianBased {
        FullHessianPreconditioning(HessianWrtParameterProvider hessian, Transform transform) {
            super(hessian, transform);
        }

        FullHessianPreconditioning(HessianWrtParameterProvider hessian, Transform transform, int dim) {
            super(hessian, transform, dim);
            this.inverseMass = new Parameter.Default("InverseMass", dim * dim);
            this.addVariable((Variable)this.inverseMass);
        }

        @Override
        protected void initializeMass() {
            double[] result = new double[this.dim * this.dim];
            int i = 0;
            while (i < this.dim) {
                result[i * this.dim + i] = 1.0;
                ++i;
            }
        }

        private double[] computeInverseMass(WrappedMatrix.ArrayOfArray hessianMatrix, GradientWrtParameterProvider gradientProvider, PDTransformMatrix pdTransformMatrix) {
            double[][] transformedHessian = hessianMatrix.getArrays();
            if (this.transform != null) {
                transformedHessian = this.transform.updateHessianLogDensity(transformedHessian, new double[this.dim][this.dim], gradientProvider.getGradientLogDensity(), gradientProvider.getParameter().getParameterValues(), 0, this.dim);
            }
            return pdTransformMatrix.transformMatrix(transformedHessian, this.dim);
        }

        @Override
        protected void computeInverseMass() {
            WrappedMatrix.ArrayOfArray hessianMatrix = new WrappedMatrix.ArrayOfArray(this.hessian.getHessianLogDensity());
            this.setInverseMassFromArray(this.computeInverseMass(hessianMatrix, (GradientWrtParameterProvider)this.hessian, PDTransformMatrix.Invert));
        }

        @Override
        public void storeSecant(ReadableVector gradient, ReadableVector position) {
        }

        @Override
        public void updateVariance(WrappedVector position) {
        }

        @Override
        public WrappedVector getMass() {
            throw new RuntimeException("Not yet implemented!");
        }

        @Override
        public WrappedVector drawInitialMomentum() {
            MultivariateNormalDistribution mvn = new MultivariateNormalDistribution(new double[this.dim], FullHessianPreconditioning.toArray(this.inverseMass.getParameterValues(), this.dim, this.dim));
            return new WrappedVector.Raw(mvn.nextMultivariateNormal());
        }

        @Override
        public double getVelocity(int i, ReadableVector momentum) {
            double velocity = 0.0;
            int j = 0;
            while (j < this.dim) {
                velocity += this.inverseMass.getParameterValue(i * this.dim + j) * momentum.get(j);
                ++j;
            }
            return velocity;
        }

        private static double[][] toArray(double[] vector, int rowDim, int colDim) {
            double[][] array = new double[rowDim][];
            int row = 0;
            while (row < rowDim) {
                array[row] = new double[colDim];
                System.arraycopy(vector, colDim * row, array[row], 0, colDim);
                ++row;
            }
            return array;
        }

        static enum PDTransformMatrix {
            Invert("Transform inverse matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D eigenvalues) {
                    this.inverseNegateEigenvalues(eigenvalues);
                }
            }
            ,
            Default("Transform matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D eigenvalues) {
                    this.negateEigenvalues(eigenvalues);
                }
            }
            ,
            Negate("Transform negative matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D eigenvalues) {
                    this.negateEigenvalues(eigenvalues);
                }

                @Override
                protected void normalizeEigenvalues(DoubleMatrix1D eigenvalues) {
                    this.negateEigenvalues(eigenvalues);
                    this.boundEigenvalues(eigenvalues);
                    this.scaleEigenvalues(eigenvalues);
                }
            }
            ,
            NegateInvert("Transform negative inverse matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D eigenvalues) {
                    this.inverseNegateEigenvalues(eigenvalues);
                }

                @Override
                protected void normalizeEigenvalues(DoubleMatrix1D eigenvalues) {
                    this.negateEigenvalues(eigenvalues);
                    this.boundEigenvalues(eigenvalues);
                    this.scaleEigenvalues(eigenvalues);
                }
            };

            String desc;
            private static final double MIN_EIGENVALUE = -20.0;
            private static final double MAX_EIGENVALUE = -0.5;

            private PDTransformMatrix(String s) {
                this.desc = s;
            }

            public String toString() {
                return this.desc;
            }

            protected void boundEigenvalues(DoubleMatrix1D eigenvalues) {
                int i = 0;
                while (i < eigenvalues.cardinality()) {
                    if (eigenvalues.get(i) > -0.5) {
                        eigenvalues.set(i, -0.5);
                    } else if (eigenvalues.get(i) < -20.0) {
                        eigenvalues.set(i, -20.0);
                    }
                    ++i;
                }
            }

            protected void scaleEigenvalues(DoubleMatrix1D eigenvalues) {
                double sum = 0.0;
                int i = 0;
                while (i < eigenvalues.cardinality()) {
                    sum += eigenvalues.get(i);
                    ++i;
                }
                double mean = -sum / (double)eigenvalues.cardinality();
                int i2 = 0;
                while (i2 < eigenvalues.cardinality()) {
                    eigenvalues.set(i2, eigenvalues.get(i2) / mean);
                    ++i2;
                }
            }

            protected void normalizeEigenvalues(DoubleMatrix1D eigenvalues) {
                this.boundEigenvalues(eigenvalues);
                this.scaleEigenvalues(eigenvalues);
            }

            protected void inverseNegateEigenvalues(DoubleMatrix1D eigenvalues) {
                int i = 0;
                while (i < eigenvalues.cardinality()) {
                    eigenvalues.set(i, -1.0 / eigenvalues.get(i));
                    ++i;
                }
            }

            protected void negateEigenvalues(DoubleMatrix1D eigenvalues) {
                int i = 0;
                while (i < eigenvalues.cardinality()) {
                    eigenvalues.set(i, -eigenvalues.get(i));
                    ++i;
                }
            }

            public double[] transformMatrix(double[][] inputMatrix, int dim) {
                Algebra algebra = new Algebra();
                DenseDoubleMatrix2D H = new DenseDoubleMatrix2D(inputMatrix);
                RobustEigenDecomposition decomposition = new RobustEigenDecomposition((DoubleMatrix2D)H);
                DoubleMatrix1D eigenvalues = decomposition.getRealEigenvalues();
                this.normalizeEigenvalues(eigenvalues);
                DoubleMatrix2D V = decomposition.getV();
                this.transformEigenvalues(eigenvalues);
                double[][] negativeHessianInverse = algebra.mult(algebra.mult(V, DoubleFactory2D.dense.diagonal(eigenvalues)), algebra.inverse(V)).toArray();
                double[] massArray = new double[dim * dim];
                int i = 0;
                while (i < dim) {
                    System.arraycopy(negativeHessianInverse[i], 0, massArray, i * dim, dim);
                    ++i;
                }
                return massArray;
            }

            protected abstract void transformEigenvalues(DoubleMatrix1D var1);
        }
    }

    public static abstract class HessianBased
    extends AbstractMassPreconditioning {
        protected final HessianWrtParameterProvider hessian;

        HessianBased(HessianWrtParameterProvider hessian, Transform transform) {
            this(hessian, transform, hessian.getDimension());
        }

        HessianBased(HessianWrtParameterProvider hessian, Transform transform, int dim) {
            super(dim, transform);
            this.hessian = hessian;
            this.initializeMass();
        }

        @Override
        public ReadableVector doCollision(int[] indices, ReadableVector momentum) {
            throw new RuntimeException("Not yet implemented.");
        }
    }

    public static class NoPreconditioning
    implements MassPreconditioner {
        final int dim;

        NoPreconditioning(int dim) {
            this.dim = dim;
        }

        @Override
        public WrappedVector drawInitialMomentum() {
            double[] momentum = new double[this.dim];
            int i = 0;
            while (i < this.dim) {
                momentum[i] = MathUtils.nextGaussian();
                ++i;
            }
            return new WrappedVector.Raw(momentum);
        }

        @Override
        public double getVelocity(int i, ReadableVector momentum) {
            return momentum.get(i);
        }

        @Override
        public void storeSecant(ReadableVector gradient, ReadableVector position) {
        }

        @Override
        public void updateMass() {
        }

        @Override
        public ReadableVector doCollision(int[] indices, ReadableVector momentum) {
            if (indices.length != 2) {
                throw new RuntimeException("Not implemented for more than two dimensions yet.");
            }
            WrappedVector.Raw updatedMomentum = new WrappedVector.Raw(new double[momentum.getDim()]);
            int i = 0;
            while (i < momentum.getDim()) {
                updatedMomentum.set(i, momentum.get(i));
                ++i;
            }
            updatedMomentum.set(indices[0], momentum.get(indices[1]));
            updatedMomentum.set(indices[1], momentum.get(indices[0]));
            return updatedMomentum;
        }

        @Override
        public WrappedVector getMass() {
            double[] mass = new double[this.dim];
            Arrays.fill(mass, 1.0);
            return new WrappedVector.Raw(mass);
        }

        @Override
        public void updateVariance(WrappedVector position) {
        }

        @Override
        public int getDimension() {
            return this.dim;
        }
    }

    public static class PriorPreconditioner
    extends DiagonalPreconditioning {
        PriorPreconditioningProvider priorDistribution;

        public PriorPreconditioner(PriorPreconditioningProvider priorDistribution, Transform transform) {
            super(priorDistribution.getDimension(), transform);
            this.priorDistribution = priorDistribution;
            this.computeInverseMass();
        }

        public MassPreconditioner factory(PriorPreconditioningProvider priorDistribution, Transform transform) {
            return new PriorPreconditioner(priorDistribution, transform);
        }

        @Override
        protected void computeInverseMass() {
            int i = 0;
            while (i < this.priorDistribution.getDimension()) {
                double stDev = this.priorDistribution.getStandardDeviation(i);
                this.inverseMass.setParameterValue(i, stDev * stDev);
                ++i;
            }
        }

        @Override
        public void storeSecant(ReadableVector gradient, ReadableVector position) {
        }

        @Override
        public void updateVariance(WrappedVector position) {
        }

        @Override
        public WrappedVector getMass() {
            throw new RuntimeException("Not yet implemented!");
        }
    }

    public static class Secant
    extends FullHessianPreconditioning {
        private final SecantHessian secantHessian;

        Secant(SecantHessian secantHessian, Transform transform) {
            super((HessianWrtParameterProvider)secantHessian, transform);
            this.secantHessian = secantHessian;
        }

        @Override
        public void storeSecant(ReadableVector gradient, ReadableVector position) {
            this.secantHessian.storeSecant(gradient, position);
        }
    }

    public static enum Type {
        NONE("none"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradient, Transform transform, MassPreconditioningOptions options) {
                Parameter parameter = gradient.getParameter();
                int dim = parameter.getDimension();
                if (transform != null && transform instanceof Transform.MultivariableTransform) {
                    dim = ((Transform.MultivariableTransform)transform).getDimension();
                }
                return new NoPreconditioning(dim);
            }
        }
        ,
        DIAGONAL("diagonal"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradient, Transform transform, MassPreconditioningOptions options) {
                return new DiagonalHessianPreconditioning((HessianWrtParameterProvider)gradient, transform, options.preconditioningMemory(), options.preconditioningEigenLowerBound(), options.preconditioningEigenUpperBound());
            }
        }
        ,
        ADAPTIVE_DIAGONAL("adaptiveDiagonal"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradient, Transform transform, MassPreconditioningOptions options) {
                int dimension = transform instanceof Transform.MultivariableTransform ? ((Transform.MultivariableTransform)transform).getDimension() : gradient.getDimension();
                return new AdaptiveDiagonalPreconditioning(dimension, gradient, transform, options);
            }
        }
        ,
        PRIOR_DIAGONAL("priorDiagonal"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradient, Transform transform, MassPreconditioningOptions options) {
                if (!(gradient instanceof PriorPreconditioningProvider)) {
                    throw new RuntimeException("Gradient must be a PriorPreconditioningProvider for prior preconditioning!");
                }
                return new PriorPreconditioner((PriorPreconditioningProvider)gradient, transform);
            }
        }
        ,
        FULL("full"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradient, Transform transform, MassPreconditioningOptions options) {
                return new FullHessianPreconditioning((HessianWrtParameterProvider)gradient, transform);
            }
        }
        ,
        SECANT("secant"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradient, Transform transform, MassPreconditioningOptions options) {
                SecantHessian secantHessian = new SecantHessian(gradient, options.preconditioningMemory());
                return new Secant(secantHessian, transform);
            }
        }
        ,
        ADAPTIVE("adaptive"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradient, Transform transform, MassPreconditioningOptions options) {
                AdaptableCovariance adaptableCovariance = new AdaptableCovariance(gradient.getDimension());
                return new AdaptiveFullHessianPreconditioning(gradient, adaptableCovariance, transform, gradient.getDimension(), options.preconditioningDelay());
            }
        };

        private final String name;

        private Type(String name) {
            this.name = name;
        }

        public abstract MassPreconditioner factory(GradientWrtParameterProvider var1, Transform var2, MassPreconditioningOptions var3);

        public String getName() {
            return this.name;
        }

        public static Type parseFromString(String text) {
            Type[] typeArray = Type.values();
            int n = typeArray.length;
            int n2 = 0;
            while (n2 < n) {
                Type type = typeArray[n2];
                if (type.name.toLowerCase().compareToIgnoreCase(text) == 0) {
                    return type;
                }
                ++n2;
            }
            return NONE;
        }
    }
}

