/*
 * Decompiled with CFR 0.152.
 */
package org.palladiosimulator.dependability.ml.sensitivity.analysis;

import com.google.common.collect.Lists;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.palladiosimulator.dependability.ml.iterator.TrainingDataIterator;
import org.palladiosimulator.dependability.ml.model.InputData;
import org.palladiosimulator.dependability.ml.model.MLPredictionResult;
import org.palladiosimulator.dependability.ml.model.OutputData;
import org.palladiosimulator.dependability.ml.model.TrainedModel;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.MLAnalysisContext;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.MLSensitivityAnalysis;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.MLSensitivityAnalysisStrategy;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.SensitivityAggregations;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.SensitivityModel;
import org.palladiosimulator.dependability.ml.sensitivity.exception.MLSensitivityAnalysisException;
import org.palladiosimulator.dependability.ml.sensitivity.transformation.AnalysisTransformation;
import org.palladiosimulator.dependability.ml.sensitivity.transformation.PropertyMeasure;
import org.palladiosimulator.dependability.ml.sensitivity.transformation.SensitivityProperty;
import org.palladiosimulator.dependability.ml.util.Tuple;

public class TrainingDataBasedAnalysisStrategy
implements MLSensitivityAnalysisStrategy {
    private static final String CONFIDENCE_STRATEGY_NAME = "Confidence based training data analysis strategy";
    private static final String ACCURACY_STRATEGY_NAME = "Accuracy based training data analysis strategy";
    private final String strategyName;
    private final Function<MLPredictionResult, Double> incrementalUpdate;

    private TrainingDataBasedAnalysisStrategy(Function<MLPredictionResult, Double> incrementalUpdate, String strategyName) {
        this.incrementalUpdate = incrementalUpdate;
        this.strategyName = strategyName;
    }

    public static TrainingDataBasedAnalysisStrategy accuracyBasedStrategy() {
        return new TrainingDataBasedAnalysisStrategy(r -> r.isExpectedResult() ? 1.0 : 0.0, ACCURACY_STRATEGY_NAME);
    }

    public static TrainingDataBasedAnalysisStrategy confidenceBasedStrategy() {
        return new TrainingDataBasedAnalysisStrategy(r -> {
            int numberOfPredictions = r.getPredictions().size();
            Double sumOfPredictions = r.getPredictions().stream().map(OutputData::getPredictionConfidence).reduce(Double::sum).get();
            return sumOfPredictions / (double)numberOfPredictions;
        }, CONFIDENCE_STRATEGY_NAME);
    }

    @Override
    public String getName() {
        return this.strategyName;
    }

    @Override
    public SensitivityModel analyseSensitivity(MLAnalysisContext context) {
        SensitivityModel sensitivityModel = context.getSensitivityModel();
        SensitivityAggregations sensitivityValues = this.computeAggregatedSensitivityValues(context);
        return this.complementSensitivityModel(sensitivityModel, sensitivityValues);
    }

    private SensitivityAggregations computeAggregatedSensitivityValues(MLAnalysisContext context) {
        TrainedModel mlModel = context.getMLModel();
        TrainingDataIterator dataIterator = mlModel.getTrainingDataIteratorBy(context.getTrainingData());
        SensitivityAggregations sensitivityAggregations = new SensitivityAggregations();
        while (dataIterator.hasNext()) {
            Tuple dataTuple = dataIterator.next();
            Set<PropertyMeasure.MeasurableSensitivityProperty> properties = MLSensitivityAnalysis.getAnalysisTransformation().computeMeasurableProperties((InputData)dataTuple.getFirst());
            Double predictionAccuracy = this.incrementalUpdate.apply(mlModel.makePrediction(dataTuple));
            sensitivityAggregations.record(properties, predictionAccuracy);
        }
        return sensitivityAggregations;
    }

    private SensitivityModel complementSensitivityModel(SensitivityModel sensitivitiyModel, SensitivityAggregations sensitivityAggregations) {
        for (String each : sensitivityAggregations.getMeasurablePropertyIds()) {
            Map<PropertyMeasure.MeasurableSensitivityProperty, Double> sensitivityValues = sensitivityAggregations.getPropertySensitivityValues(each);
            this.complementPropertyValuesIfNecessary(sensitivityValues);
            Map<SensitivityProperty, Double> result = sensitivityValues.entrySet().stream().collect(Collectors.toMap(e -> (SensitivityProperty)e.getKey(), Map.Entry::getValue));
            sensitivitiyModel.setSensitivityValues(result);
        }
        Map<SensitivityAggregations.MLSensitivityEntry, Double> mlSensitivityValues = sensitivityAggregations.getMLSensitivityValues();
        this.complementSensitivityValuesIfNecessary(mlSensitivityValues);
        sensitivitiyModel.setMLSensitivityValues(mlSensitivityValues);
        return sensitivitiyModel;
    }

    private void complementPropertyValuesIfNecessary(Map<PropertyMeasure.MeasurableSensitivityProperty, Double> sensitivityValues) {
        PropertyMeasure measure = this.retrievePropertyMeasureBy(this.reduceToSingleProperty(sensitivityValues));
        for (PropertyMeasure.MeasurableSensitivityProperty each : measure.getMeasurablePropertySpace()) {
            if (!this.containsNoPropertyWith(each, sensitivityValues.keySet())) continue;
            this.enrichWithZeroProbability(sensitivityValues, each);
        }
    }

    private void complementSensitivityValuesIfNecessary(Map<SensitivityAggregations.MLSensitivityEntry, Double> mlSensitivityValues) {
        AnalysisTransformation transformation = MLSensitivityAnalysis.getAnalysisTransformation();
        for (List<PropertyMeasure.MeasurableSensitivityProperty> each : transformation.computeMeasurableSpace()) {
            SensitivityAggregations.MLSensitivityEntry entry = SensitivityAggregations.MLSensitivityEntry.from(Lists.newArrayList(each));
            if (!this.containsNoPropertyWith(entry, mlSensitivityValues.keySet())) continue;
            this.enrichWithMaxEntropy(mlSensitivityValues, entry);
        }
    }

    private boolean containsNoPropertyWith(PropertyMeasure.MeasurableSensitivityProperty property, Set<PropertyMeasure.MeasurableSensitivityProperty> recordedProperties) {
        return recordedProperties.stream().noneMatch(prop -> prop.equals(property));
    }

    private boolean containsNoPropertyWith(SensitivityAggregations.MLSensitivityEntry entry, Set<SensitivityAggregations.MLSensitivityEntry> mlSensitivityValues) {
        return !mlSensitivityValues.contains(entry);
    }

    private void enrichWithZeroProbability(Map<PropertyMeasure.MeasurableSensitivityProperty, Double> recordedProperties, PropertyMeasure.MeasurableSensitivityProperty zeroSensitivityProperty) {
        recordedProperties.put(zeroSensitivityProperty, 0.0);
    }

    private void enrichWithMaxEntropy(Map<SensitivityAggregations.MLSensitivityEntry, Double> mlSensitivityValues, SensitivityAggregations.MLSensitivityEntry zeroSensitivityProperty) {
        mlSensitivityValues.put(zeroSensitivityProperty, 0.5);
    }

    private PropertyMeasure.MeasurableSensitivityProperty reduceToSingleProperty(Map<PropertyMeasure.MeasurableSensitivityProperty, Double> sensitivityValues) {
        return sensitivityValues.keySet().iterator().next();
    }

    private PropertyMeasure retrievePropertyMeasureBy(PropertyMeasure.MeasurableSensitivityProperty property) {
        return MLSensitivityAnalysis.getAnalysisTransformation().findPropertyMeasureWith(property.getId()).orElseThrow(MLSensitivityAnalysisException.supplierWithMessage(String.format("There is no property measure for property %s", property.getId())));
    }
}

