package tools.descartes.librede.bayesplusplus;

import com.sun.jna.Pointer;
import org.apache.log4j.Logger;
import tools.descartes.librede.algorithm.AbstractEstimationAlgorithm;
import tools.descartes.librede.bayesplusplus.backend.BayesPlusPlusLibrary;
import tools.descartes.librede.bayesplusplus.backend.FCallback;
import tools.descartes.librede.bayesplusplus.backend.HCallback;
import tools.descartes.librede.configuration.ResourceDemand;
import tools.descartes.librede.exceptions.EstimationException;
import tools.descartes.librede.exceptions.InitializationException;
import tools.descartes.librede.linalg.LinAlg;
import tools.descartes.librede.linalg.Matrix;
import tools.descartes.librede.linalg.MatrixFunction;
import tools.descartes.librede.linalg.Vector;
import tools.descartes.librede.linalg.VectorFunction;
import tools.descartes.librede.models.EstimationProblem;
import tools.descartes.librede.models.State;
import tools.descartes.librede.models.observation.OutputFunction;
import tools.descartes.librede.models.state.constraints.IStateBoundsConstraint;
import tools.descartes.librede.models.variables.OutputVariable;
import tools.descartes.librede.nativehelper.NativeHelper;
import tools.descartes.librede.registry.Component;
import tools.descartes.librede.registry.ParameterDefinition;
import tools.descartes.librede.repository.IRepositoryCursor;

@Component(displayName = "Extended Kalman Filter")
/* loaded from: input_file:tools/descartes/librede/bayesplusplus/ExtendedKalmanFilter.class */
public class ExtendedKalmanFilter extends AbstractEstimationAlgorithm {
    private static final Logger log = Logger.getLogger(ExtendedKalmanFilter.class);
    private int stateSize;
    private int outputSize;
    private Vector stateNoiseCovariance;
    private Matrix stateNoiseCoupling;
    private Vector observeNoise;
    private Vector lastEstimate;
    private Matrix estimates;
    private Vector lowerStateBounds;
    private Vector upperStateBounds;
    private FFunction fcallback;
    private HFunction hcallback;
    private Pointer stateBuffer;
    private Pointer stateCovarianceBuffer;

    @ParameterDefinition(name = "StateNoiseCovariance", label = "State Noise Covariance", defaultValue = "1.0")
    private double stateNoiseCovarianceConstant = 1.0d;

    @ParameterDefinition(name = "StateNoiseCoupling", label = "State Noise Coupling", defaultValue = "1.0")
    private double stateNoiseCouplingConstant = 1.0d;

    @ParameterDefinition(name = "ObserveNoiseCovariance", label = "Observe Noise Covariance", defaultValue = "0.0001")
    private double observeNoiseCovarianceConstant = 1.0E-4d;

    @ParameterDefinition(name = "BoundsFactor", label = "Bounds factor", defaultValue = "0.9")
    private double boundsFactor = 0.9d;

    @ParameterDefinition(name = "InitialBoundsDistance", label = "Initial bounds distance", defaultValue = "1e-4")
    private double initialBoundsDistance = 1.0E-4d;
    private boolean initialized = false;
    private Pointer nativeObservationModel = null;
    private Pointer nativeStateModel = null;
    private Pointer nativeScheme = null;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:tools/descartes/librede/bayesplusplus/ExtendedKalmanFilter$FFunction.class */
    public class FFunction implements FCallback {
        private Pointer stateBuffer;
        private Pointer jacobiBuffer;

        private FFunction() {
            this.stateBuffer = NativeHelper.allocateDoubleArray(ExtendedKalmanFilter.this.stateSize);
            this.jacobiBuffer = NativeHelper.allocateDoubleArray(ExtendedKalmanFilter.this.stateSize * ExtendedKalmanFilter.this.stateSize);
        }

        @Override // tools.descartes.librede.bayesplusplus.backend.FCallback
        public Pointer execute(Pointer pointer) {
            State step = ExtendedKalmanFilter.this.getStateModel().step(new State(ExtendedKalmanFilter.this.getStateModel(), NativeHelper.nativeVector(ExtendedKalmanFilter.this.stateSize, pointer), 1));
            NativeHelper.toNative(this.jacobiBuffer, step.getStateJacobiMatrix());
            BayesPlusPlusLibrary.set_Fx(ExtendedKalmanFilter.this.nativeStateModel, this.jacobiBuffer, ExtendedKalmanFilter.this.stateSize);
            NativeHelper.toNative(this.stateBuffer, step.getVector());
            return this.stateBuffer;
        }

        /* synthetic */ FFunction(ExtendedKalmanFilter extendedKalmanFilter, FFunction fFunction) {
            this();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:tools/descartes/librede/bayesplusplus/ExtendedKalmanFilter$HFunction.class */
    public class HFunction implements HCallback {
        private Pointer outputBuffer;
        private Pointer jacobiBuffer;

        private HFunction() {
            this.outputBuffer = NativeHelper.allocateDoubleArray(ExtendedKalmanFilter.this.outputSize);
            this.jacobiBuffer = NativeHelper.allocateDoubleArray(ExtendedKalmanFilter.this.outputSize * ExtendedKalmanFilter.this.stateSize);
        }

        @Override // tools.descartes.librede.bayesplusplus.backend.HCallback
        public Pointer execute(Pointer pointer) {
            State state = new State(ExtendedKalmanFilter.this.getStateModel(), NativeHelper.nativeVector(ExtendedKalmanFilter.this.stateSize, pointer), 1);
            double[] dArr = new double[ExtendedKalmanFilter.this.outputSize];
            double[][] dArr2 = new double[ExtendedKalmanFilter.this.outputSize][ExtendedKalmanFilter.this.stateSize];
            try {
                int[] iArr = new int[ExtendedKalmanFilter.this.stateSize];
                for (int i = 0; i < ExtendedKalmanFilter.this.outputSize; i++) {
                    OutputFunction outputFunction = ExtendedKalmanFilter.this.getObservationModel().getOutputFunction(i);
                    if (outputFunction.hasData()) {
                        OutputVariable calculatedOutput = outputFunction.getCalculatedOutput(state);
                        dArr[i] = calculatedOutput.getDerivativeStructure().getValue();
                        for (int i2 = 0; i2 < ExtendedKalmanFilter.this.stateSize; i2++) {
                            iArr[i2] = 1;
                            dArr2[i][i2] = calculatedOutput.getDerivativeStructure().getPartialDerivative(iArr);
                            iArr[i2] = 0;
                        }
                    }
                }
                NativeHelper.toNative(this.jacobiBuffer, LinAlg.matrix(dArr2));
                BayesPlusPlusLibrary.set_Hx(ExtendedKalmanFilter.this.nativeObservationModel, this.jacobiBuffer, ExtendedKalmanFilter.this.stateSize, ExtendedKalmanFilter.this.outputSize);
            } catch (Throwable th) {
                ExtendedKalmanFilter.log.error("Error evaluating H(x).", th);
            }
            NativeHelper.toNative(this.outputBuffer, LinAlg.vector(dArr));
            return this.outputBuffer;
        }

        /* synthetic */ HFunction(ExtendedKalmanFilter extendedKalmanFilter, HFunction hFunction) {
            this();
        }
    }

    private void initNativeKalmanFilter(Vector vector) throws EstimationException {
        this.nativeScheme = BayesPlusPlusLibrary.create_covariance_scheme(this.stateSize);
        if (this.nativeScheme == null) {
            throw new EstimationException("Could not create kalman filter: " + BayesPlusPlusLibrary.get_last_error());
        }
        NativeHelper.toNative(this.stateBuffer, vector);
        NativeHelper.toNative(this.stateCovarianceBuffer, getInitialStateCovariance(vector));
        if (BayesPlusPlusLibrary.init_kalman(this.nativeScheme, this.stateBuffer, this.stateCovarianceBuffer, this.stateSize) == -1) {
            throw new EstimationException("Could not initialize kalman filter: " + BayesPlusPlusLibrary.get_last_error());
        }
    }

    private void initNativeStateModel() throws EstimationException {
        this.fcallback = new FFunction(this, null);
        this.nativeStateModel = BayesPlusPlusLibrary.create_linrz_predict_model(this.stateSize, this.stateSize, this.fcallback);
        if (this.nativeStateModel == Pointer.NULL) {
            throw new EstimationException("Error creating state model: " + BayesPlusPlusLibrary.get_last_error());
        }
        this.stateNoiseCovariance = LinAlg.vector(this.stateSize, new VectorFunction() { // from class: tools.descartes.librede.bayesplusplus.ExtendedKalmanFilter.1
            public double cell(int i) {
                return ExtendedKalmanFilter.this.stateNoiseCovarianceConstant;
            }
        });
        NativeHelper.toNative(this.stateBuffer, this.stateNoiseCovariance);
        BayesPlusPlusLibrary.set_q(this.nativeStateModel, this.stateBuffer, this.stateSize);
        Pointer allocateDoubleArray = NativeHelper.allocateDoubleArray(this.stateSize * this.stateSize);
        this.stateNoiseCoupling = LinAlg.matrix(this.stateSize, this.stateSize, new MatrixFunction() { // from class: tools.descartes.librede.bayesplusplus.ExtendedKalmanFilter.2
            public double cell(int i, int i2) {
                if (i == i2) {
                    return ExtendedKalmanFilter.this.stateNoiseCouplingConstant;
                }
                return 0.0d;
            }
        });
        NativeHelper.toNative(allocateDoubleArray, this.stateNoiseCoupling);
        BayesPlusPlusLibrary.set_G(this.nativeStateModel, allocateDoubleArray, this.stateSize);
    }

    private void initNativeObservationModel() throws EstimationException {
        this.hcallback = new HFunction(this, null);
        this.nativeObservationModel = BayesPlusPlusLibrary.create_linrz_uncorrelated_observe_model(this.stateSize, this.outputSize, this.hcallback);
        if (this.nativeObservationModel == null) {
            throw new EstimationException("Error creating observation model: " + BayesPlusPlusLibrary.get_last_error());
        }
        Pointer allocateDoubleArray = NativeHelper.allocateDoubleArray(this.outputSize);
        this.observeNoise = LinAlg.vector(this.outputSize, new VectorFunction() { // from class: tools.descartes.librede.bayesplusplus.ExtendedKalmanFilter.3
            public double cell(int i) {
                return ExtendedKalmanFilter.this.observeNoiseCovarianceConstant;
            }
        });
        NativeHelper.toNative(allocateDoubleArray, this.observeNoise);
        BayesPlusPlusLibrary.set_Zv(this.nativeObservationModel, allocateDoubleArray, this.outputSize);
    }

    private void predict() throws EstimationException {
        if (BayesPlusPlusLibrary.predict(this.nativeScheme, this.nativeStateModel) == -1) {
            throw new EstimationException("Error in prediction phase: " + BayesPlusPlusLibrary.get_last_error());
        }
    }

    private void observe(Vector vector) throws EstimationException {
        Pointer allocateDoubleArray = NativeHelper.allocateDoubleArray(vector.rows());
        NativeHelper.toNative(allocateDoubleArray, vector);
        if (BayesPlusPlusLibrary.observe(this.nativeScheme, this.nativeObservationModel, allocateDoubleArray, vector.rows()) == -1) {
            throw new EstimationException("Error in observation phase: " + BayesPlusPlusLibrary.get_last_error());
        }
    }

    private void updateState() throws EstimationException {
        if (BayesPlusPlusLibrary.update(this.nativeScheme) == -1) {
            throw new EstimationException("Error in update phase: " + BayesPlusPlusLibrary.get_last_error());
        }
        BayesPlusPlusLibrary.get_x(this.nativeScheme, this.stateBuffer);
        NativeHelper.toNative(this.stateBuffer, truncateState(NativeHelper.nativeVector(this.stateSize, this.stateBuffer)));
        BayesPlusPlusLibrary.set_x(this.nativeScheme, this.stateBuffer, this.stateSize);
    }

    private Vector truncateState(final Vector vector) {
        final Vector plus = this.lowerStateBounds.times(this.boundsFactor).plus(this.lastEstimate.times(1.0d - this.boundsFactor));
        final Vector plus2 = this.upperStateBounds.times(this.boundsFactor).plus(this.lastEstimate.times(1.0d - this.boundsFactor));
        return LinAlg.vector(this.stateSize, new VectorFunction() { // from class: tools.descartes.librede.bayesplusplus.ExtendedKalmanFilter.4
            public double cell(int i) {
                return Math.min(plus2.get(i), Math.max(plus.get(i), vector.get(i)));
            }
        });
    }

    private Vector getCurrentEstimate() {
        BayesPlusPlusLibrary.get_x(this.nativeScheme, this.stateBuffer);
        return NativeHelper.nativeVector(this.stateSize, this.stateBuffer);
    }

    private Matrix getInitialStateCovariance(final Vector vector) {
        return LinAlg.matrix(this.stateSize, this.stateSize, new MatrixFunction() { // from class: tools.descartes.librede.bayesplusplus.ExtendedKalmanFilter.5
            public double cell(int i, int i2) {
                if (i != i2) {
                    return 0.0d;
                }
                double d = vector.get(i);
                return d * d;
            }
        });
    }

    protected void finalize() throws Throwable {
        destroy();
        super/*java.lang.Object*/.finalize();
    }

    public void initialize(EstimationProblem estimationProblem, IRepositoryCursor iRepositoryCursor, int i) throws InitializationException {
        super.initialize(estimationProblem, iRepositoryCursor, i);
        this.stateSize = estimationProblem.getStateModel().getStateSize();
        this.outputSize = estimationProblem.getObservationModel().getOutputSize();
        this.stateBuffer = NativeHelper.allocateDoubleArray(this.stateSize);
        this.stateCovarianceBuffer = NativeHelper.allocateDoubleArray(this.stateSize * this.stateSize);
        this.estimates = LinAlg.matrix(i, this.stateSize, Double.NaN);
    }

    public void update() throws EstimationException {
        updateStateBounds();
        if (!this.initialized) {
            getStateModel().step((State) null);
            Vector truncateIntialState = truncateIntialState(getStateModel().getInitialState());
            if (!truncateIntialState.isEmpty()) {
                initNativeStateModel();
                initNativeObservationModel();
                initNativeKalmanFilter(truncateIntialState);
                this.lastEstimate = truncateIntialState;
                this.initialized = true;
            }
        }
        predict();
        observe(getObservationModel().getObservedOutput());
        updateState();
        this.estimates = this.estimates.circshift(1).setRow(0, getCurrentEstimate());
    }

    private Vector truncateIntialState(final Vector vector) {
        return LinAlg.vector(vector.rows(), new VectorFunction() { // from class: tools.descartes.librede.bayesplusplus.ExtendedKalmanFilter.6
            public double cell(int i) {
                double d = vector.get(i);
                return ExtendedKalmanFilter.this.upperStateBounds.get(i) - ExtendedKalmanFilter.this.initialBoundsDistance < d ? ExtendedKalmanFilter.this.upperStateBounds.get(i) - ExtendedKalmanFilter.this.initialBoundsDistance : ExtendedKalmanFilter.this.lowerStateBounds.get(i) + ExtendedKalmanFilter.this.initialBoundsDistance > d ? ExtendedKalmanFilter.this.lowerStateBounds.get(i) + ExtendedKalmanFilter.this.initialBoundsDistance : d;
            }
        });
    }

    private void updateStateBounds() {
        this.upperStateBounds = LinAlg.matrix(this.stateSize, 1, Double.POSITIVE_INFINITY);
        this.lowerStateBounds = LinAlg.matrix(this.stateSize, 1, 0.0d);
        for (IStateBoundsConstraint iStateBoundsConstraint : getStateModel().getConstraints()) {
            if (iStateBoundsConstraint instanceof IStateBoundsConstraint) {
                ResourceDemand stateVariable = iStateBoundsConstraint.getStateVariable();
                int stateVariableIndex = getStateModel().getStateVariableIndex(stateVariable.getResource(), stateVariable.getService());
                this.upperStateBounds = this.upperStateBounds.set(stateVariableIndex, Math.min(this.upperStateBounds.get(stateVariableIndex), iStateBoundsConstraint.getUpperBound()));
                this.lowerStateBounds = this.lowerStateBounds.set(stateVariableIndex, Math.max(this.lowerStateBounds.get(stateVariableIndex), iStateBoundsConstraint.getLowerBound()));
            }
        }
    }

    public Vector estimate() throws EstimationException {
        return LinAlg.nanmean(this.estimates);
    }

    public void destroy() {
        if (this.nativeScheme != null) {
            BayesPlusPlusLibrary.dispose_covariance_scheme(this.nativeScheme);
            this.nativeScheme = null;
            BayesPlusPlusLibrary.dispose_linrz_uncorrelated_observe_model(this.nativeObservationModel);
            this.nativeObservationModel = null;
            BayesPlusPlusLibrary.dispose_linrz_predict_model(this.nativeStateModel);
            this.nativeStateModel = null;
        }
    }
}
