package org.palladiosimulator.simexp.markovian.evaluation;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.palladiosimulator.simexp.distribution.function.ProbabilityMassFunction;
import org.palladiosimulator.simexp.markovian.access.MarkovModelAccessor;
import org.palladiosimulator.simexp.markovian.model.factory.MarkovModelFactory;
import org.palladiosimulator.simexp.markovian.model.markovmodel.markoventity.MarkovModel;
import org.palladiosimulator.simexp.markovian.model.markovmodel.markoventity.State;
import org.palladiosimulator.simexp.markovian.model.markovmodel.markoventity.Transition;
import org.palladiosimulator.simexp.markovian.model.markovmodel.samplemodel.Sample;
import org.palladiosimulator.simexp.markovian.model.markovmodel.samplemodel.SampleModel;
import org.palladiosimulator.simexp.markovian.model.markovmodel.samplemodel.Trajectory;
import org.palladiosimulator.simexp.markovian.type.MarkovianResult;

/* loaded from: input_file:org/palladiosimulator/simexp/markovian/evaluation/SampleModelEvaluator.class */
public class SampleModelEvaluator<A> {
    private final TransitionCache<A> transitionCache = new TransitionCache<>();
    private final MarkovModelAccessor<A, Double> modelAccessor;
    private final ProbabilityMassFunction<State> initialStateDist;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/palladiosimulator/simexp/markovian/evaluation/SampleModelEvaluator$TransitionCache.class */
    public static class TransitionCache<A> {
        private static final int CACHE_SIZE = 100;
        private final Map<Integer, Transition<A>> cache = new HashMap();

        public Optional<Transition<A>> findTransition(Sample<A, Double> sample) {
            return Optional.ofNullable(this.cache.get(Integer.valueOf(sample.hashCode())));
        }

        public void put(Sample<A, Double> sample, Transition<A> transition) {
            if (isCacheFull()) {
                return;
            }
            this.cache.put(Integer.valueOf(sample.hashCode()), transition);
        }

        private boolean isCacheFull() {
            return this.cache.size() == CACHE_SIZE;
        }
    }

    private SampleModelEvaluator(MarkovModel<A, Double> markovModel, ProbabilityMassFunction<State> probabilityMassFunction) {
        this.modelAccessor = MarkovModelAccessor.of(markovModel);
        this.initialStateDist = probabilityMassFunction;
    }

    public static <A> SampleModelEvaluator<A> of(MarkovModel<A, Double> markovModel, ProbabilityMassFunction<State> probabilityMassFunction) {
        return new SampleModelEvaluator<>(markovModel, probabilityMassFunction);
    }

    public List<MarkovianResult<A, Double>> evaluate(SampleModel<A, Double> sampleModel) {
        return (List) sampleModel.getTrajectories().stream().map(this::evaluate).collect(Collectors.toList());
    }

    public MarkovianResult<A, Double> evaluate(Trajectory<A, Double> trajectory) {
        double d = 0.0d;
        double computeInitial = computeInitial(trajectory);
        for (Sample<A, Double> sample : trajectory.getSamplePath()) {
            d += ((Double) sample.getReward().getValue()).doubleValue();
            computeInitial *= getProbability(sample);
        }
        return MarkovianResult.of(trajectory).withProbability(computeInitial).andReward(new MarkovModelFactory().createRewardSignal(Double.valueOf(d))).build();
    }

    private double computeInitial(Trajectory<A, Double> trajectory) {
        return this.initialStateDist.probability(ProbabilityMassFunction.Sample.of(((Sample) trajectory.getSamplePath().get(0)).getCurrent()));
    }

    private double getProbability(Sample<A, Double> sample) {
        return this.transitionCache.findTransition(sample).orElse(queryMarkovModelAndCacheResult(sample)).getProbability();
    }

    private Transition<A> queryMarkovModelAndCacheResult(Sample<A, Double> sample) {
        Transition<A> orElseThrow = this.modelAccessor.findTransition(sample.getCurrent(), sample.getNext()).orElseThrow(() -> {
            return new RuntimeException("");
        });
        this.transitionCache.put(sample, orElseThrow);
        return orElseThrow;
    }
}
