package org.palladiosimulator.dependability.ml.sensitivity.analysis;

import com.google.common.collect.Lists;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.palladiosimulator.dependability.ml.iterator.TrainingDataIterator;
import org.palladiosimulator.dependability.ml.model.InputData;
import org.palladiosimulator.dependability.ml.model.MLPredictionResult;
import org.palladiosimulator.dependability.ml.model.TrainedModel;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.SensitivityAggregations;
import org.palladiosimulator.dependability.ml.sensitivity.exception.MLSensitivityAnalysisException;
import org.palladiosimulator.dependability.ml.sensitivity.transformation.PropertyMeasure;
import org.palladiosimulator.dependability.ml.sensitivity.transformation.SensitivityProperty;
import org.palladiosimulator.dependability.ml.util.Tuple;

/* loaded from: input_file:org/palladiosimulator/dependability/ml/sensitivity/analysis/TrainingDataBasedAnalysisStrategy.class */
public class TrainingDataBasedAnalysisStrategy implements MLSensitivityAnalysisStrategy {
    private static final String CONFIDENCE_STRATEGY_NAME = "Confidence based training data analysis strategy";
    private static final String ACCURACY_STRATEGY_NAME = "Accuracy based training data analysis strategy";
    private final String strategyName;
    private final Function<MLPredictionResult, Double> incrementalUpdate;

    private TrainingDataBasedAnalysisStrategy(Function<MLPredictionResult, Double> function, String str) {
        this.incrementalUpdate = function;
        this.strategyName = str;
    }

    public static TrainingDataBasedAnalysisStrategy accuracyBasedStrategy() {
        return new TrainingDataBasedAnalysisStrategy(mLPredictionResult -> {
            return Double.valueOf(mLPredictionResult.isExpectedResult() ? 1.0d : 0.0d);
        }, ACCURACY_STRATEGY_NAME);
    }

    public static TrainingDataBasedAnalysisStrategy confidenceBasedStrategy() {
        return new TrainingDataBasedAnalysisStrategy(mLPredictionResult -> {
            return Double.valueOf(((Double) mLPredictionResult.getPredictions().stream().map((v0) -> {
                return v0.getPredictionConfidence();
            }).reduce((v0, v1) -> {
                return Double.sum(v0, v1);
            }).get()).doubleValue() / mLPredictionResult.getPredictions().size());
        }, CONFIDENCE_STRATEGY_NAME);
    }

    @Override // org.palladiosimulator.dependability.ml.sensitivity.analysis.MLSensitivityAnalysisStrategy
    public String getName() {
        return this.strategyName;
    }

    @Override // org.palladiosimulator.dependability.ml.sensitivity.analysis.MLSensitivityAnalysisStrategy
    public SensitivityModel analyseSensitivity(MLAnalysisContext mLAnalysisContext) {
        return complementSensitivityModel(mLAnalysisContext.getSensitivityModel(), computeAggregatedSensitivityValues(mLAnalysisContext));
    }

    private SensitivityAggregations computeAggregatedSensitivityValues(MLAnalysisContext mLAnalysisContext) {
        TrainedModel mLModel = mLAnalysisContext.getMLModel();
        TrainingDataIterator trainingDataIteratorBy = mLModel.getTrainingDataIteratorBy(mLAnalysisContext.getTrainingData());
        SensitivityAggregations sensitivityAggregations = new SensitivityAggregations();
        while (trainingDataIteratorBy.hasNext()) {
            Tuple next = trainingDataIteratorBy.next();
            sensitivityAggregations.record(MLSensitivityAnalysis.getAnalysisTransformation().computeMeasurableProperties((InputData) next.getFirst()), this.incrementalUpdate.apply(mLModel.makePrediction(next)).doubleValue());
        }
        return sensitivityAggregations;
    }

    private SensitivityModel complementSensitivityModel(SensitivityModel sensitivityModel, SensitivityAggregations sensitivityAggregations) {
        Iterator<String> it = sensitivityAggregations.getMeasurablePropertyIds().iterator();
        while (it.hasNext()) {
            Map<PropertyMeasure.MeasurableSensitivityProperty, Double> propertySensitivityValues = sensitivityAggregations.getPropertySensitivityValues(it.next());
            complementPropertyValuesIfNecessary(propertySensitivityValues);
            sensitivityModel.setSensitivityValues((Map) propertySensitivityValues.entrySet().stream().collect(Collectors.toMap(entry -> {
                return (SensitivityProperty) entry.getKey();
            }, (v0) -> {
                return v0.getValue();
            })));
        }
        Map<SensitivityAggregations.MLSensitivityEntry, Double> mLSensitivityValues = sensitivityAggregations.getMLSensitivityValues();
        complementSensitivityValuesIfNecessary(mLSensitivityValues);
        sensitivityModel.setMLSensitivityValues(mLSensitivityValues);
        return sensitivityModel;
    }

    private void complementPropertyValuesIfNecessary(Map<PropertyMeasure.MeasurableSensitivityProperty, Double> map) {
        for (PropertyMeasure.MeasurableSensitivityProperty measurableSensitivityProperty : retrievePropertyMeasureBy(reduceToSingleProperty(map)).getMeasurablePropertySpace()) {
            if (containsNoPropertyWith(measurableSensitivityProperty, map.keySet())) {
                enrichWithZeroProbability(map, measurableSensitivityProperty);
            }
        }
    }

    private void complementSensitivityValuesIfNecessary(Map<SensitivityAggregations.MLSensitivityEntry, Double> map) {
        Iterator<List<PropertyMeasure.MeasurableSensitivityProperty>> it = MLSensitivityAnalysis.getAnalysisTransformation().computeMeasurableSpace().iterator();
        while (it.hasNext()) {
            SensitivityAggregations.MLSensitivityEntry from = SensitivityAggregations.MLSensitivityEntry.from(Lists.newArrayList(it.next()));
            if (containsNoPropertyWith(from, map.keySet())) {
                enrichWithMaxEntropy(map, from);
            }
        }
    }

    private boolean containsNoPropertyWith(PropertyMeasure.MeasurableSensitivityProperty measurableSensitivityProperty, Set<PropertyMeasure.MeasurableSensitivityProperty> set) {
        return set.stream().noneMatch(measurableSensitivityProperty2 -> {
            return measurableSensitivityProperty2.equals(measurableSensitivityProperty);
        });
    }

    private boolean containsNoPropertyWith(SensitivityAggregations.MLSensitivityEntry mLSensitivityEntry, Set<SensitivityAggregations.MLSensitivityEntry> set) {
        return !set.contains(mLSensitivityEntry);
    }

    private void enrichWithZeroProbability(Map<PropertyMeasure.MeasurableSensitivityProperty, Double> map, PropertyMeasure.MeasurableSensitivityProperty measurableSensitivityProperty) {
        map.put(measurableSensitivityProperty, Double.valueOf(0.0d));
    }

    private void enrichWithMaxEntropy(Map<SensitivityAggregations.MLSensitivityEntry, Double> map, SensitivityAggregations.MLSensitivityEntry mLSensitivityEntry) {
        map.put(mLSensitivityEntry, Double.valueOf(0.5d));
    }

    private PropertyMeasure.MeasurableSensitivityProperty reduceToSingleProperty(Map<PropertyMeasure.MeasurableSensitivityProperty, Double> map) {
        return map.keySet().iterator().next();
    }

    private PropertyMeasure retrievePropertyMeasureBy(PropertyMeasure.MeasurableSensitivityProperty measurableSensitivityProperty) {
        return MLSensitivityAnalysis.getAnalysisTransformation().findPropertyMeasureWith(measurableSensitivityProperty.getId()).orElseThrow(MLSensitivityAnalysisException.supplierWithMessage(String.format("There is no property measure for property %s", measurableSensitivityProperty.getId())));
    }
}
