package org.palladiosimulator.dependability.ml.sensitivity.api;

import com.google.common.collect.Sets;
import java.util.Iterator;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.palladiosimulator.dependability.ml.model.TrainedModel;
import org.palladiosimulator.dependability.ml.model.nn.MaskRCNN;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.MLAnalysisContext;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.MLSensitivityAnalysis;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.MLSensitivityAnalysisStrategy;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.ProbabilisticSensitivityModel;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.SensitivityModel;
import org.palladiosimulator.dependability.ml.sensitivity.analysis.TrainingDataBasedAnalysisStrategy;
import org.palladiosimulator.dependability.ml.sensitivity.transformation.PropertyMeasure;
import org.palladiosimulator.dependability.ml.sensitivity.transformation.property.ImageBrightness;
import tools.mdsd.probdist.api.entity.CategoricalValue;
import tools.mdsd.probdist.api.factory.IProbabilityDistributionFactory;

/* loaded from: input_file:org/palladiosimulator/dependability/ml/sensitivity/api/MLSensitivityAnalyser.class */
public class MLSensitivityAnalyser {
    private static final Set<MLSensitivityAnalysisStrategy> ANALYSIS_STRATEGY_REGISTRY = Sets.newHashSet();
    private static final Set<PropertyMeasure> PROPERTY_MEASURE_REGISTRY;
    private static final Set<TrainedModel> ANALYSABLE_MODEL_REGISTRY;

    static {
        ANALYSIS_STRATEGY_REGISTRY.add(TrainingDataBasedAnalysisStrategy.accuracyBasedStrategy());
        ANALYSIS_STRATEGY_REGISTRY.add(TrainingDataBasedAnalysisStrategy.confidenceBasedStrategy());
        PROPERTY_MEASURE_REGISTRY = Sets.newHashSet();
        PROPERTY_MEASURE_REGISTRY.add(ImageBrightness.get());
        ANALYSABLE_MODEL_REGISTRY = Sets.newHashSet();
        ANALYSABLE_MODEL_REGISTRY.add(new MaskRCNN());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Optional<TrainedModel> findAnalysableModelWith(String str) {
        return ANALYSABLE_MODEL_REGISTRY.stream().filter(modelWith(str)).findFirst();
    }

    public static Optional<PropertyMeasure> findAnalysablePropertyMeasureWith(String str) {
        return PROPERTY_MEASURE_REGISTRY.stream().filter(propertyWith(str)).findFirst();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Optional<MLSensitivityAnalysisStrategy> findAnalysisStrategyWith(String str) {
        return ANALYSIS_STRATEGY_REGISTRY.stream().filter(strategyWith(str)).findFirst();
    }

    public static void registerAnalysableProperties(PropertyMeasure propertyMeasure) {
        PROPERTY_MEASURE_REGISTRY.add(propertyMeasure);
    }

    public static void registerAnalysableModel(TrainedModel trainedModel) {
        ANALYSABLE_MODEL_REGISTRY.add(trainedModel);
    }

    public static void registerAnalysisStrategy(MLSensitivityAnalysisStrategy mLSensitivityAnalysisStrategy) {
        ANALYSIS_STRATEGY_REGISTRY.add(mLSensitivityAnalysisStrategy);
    }

    public static Set<PropertyMeasure> getAnalysablePropertyMeasures() {
        return PROPERTY_MEASURE_REGISTRY;
    }

    public static Set<String> getAnalysableModelNames() {
        return (Set) ANALYSABLE_MODEL_REGISTRY.stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.toSet());
    }

    public static Set<String> getAnalysisStrategyNames() {
        return (Set) ANALYSIS_STRATEGY_REGISTRY.stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.toSet());
    }

    public static SensitivityModel analyse(SensitivityAnalysisConfig sensitivityAnalysisConfig, IProbabilityDistributionFactory<CategoricalValue> iProbabilityDistributionFactory) {
        MLSensitivityAnalysis.MLSensitivityAnalysisBuilder withSensitivityAnalysisStrategy = MLSensitivityAnalysis.newBuilder().withSensitivityAnalysisStrategy(sensitivityAnalysisConfig.getAnalysisStrategy());
        Iterator<PropertyMeasure> it = sensitivityAnalysisConfig.getPropertyMeasures().iterator();
        while (it.hasNext()) {
            withSensitivityAnalysisStrategy.addPropertyMeasure(it.next());
        }
        MLSensitivityAnalysis build = withSensitivityAnalysisStrategy.build();
        MLAnalysisContext build2 = MLAnalysisContext.newBuilder().analyseSensitivityOf(sensitivityAnalysisConfig.getTrainedModel()).trainedWith(sensitivityAnalysisConfig.getTrainDataLocation()).andCapturedBy(ProbabilisticSensitivityModel.createFrom(Sets.newHashSet(sensitivityAnalysisConfig.getPropertyMeasures()), iProbabilityDistributionFactory)).build();
        build2.getMLModel().loadModel(sensitivityAnalysisConfig.getTrainedModelURI());
        return build.analyseSensitivity(build2);
    }

    public static void analyseAndSave(SensitivityAnalysisConfig sensitivityAnalysisConfig, IProbabilityDistributionFactory<CategoricalValue> iProbabilityDistributionFactory) {
        analyse(sensitivityAnalysisConfig, iProbabilityDistributionFactory).saveAt(sensitivityAnalysisConfig.getSensitivityModelStoringLocation());
    }

    private static Predicate<PropertyMeasure> propertyWith(String str) {
        return propertyMeasure -> {
            return propertyMeasure.getId().equals(str);
        };
    }

    private static Predicate<TrainedModel> modelWith(String str) {
        return trainedModel -> {
            return trainedModel.getName().equals(str);
        };
    }

    private static Predicate<MLSensitivityAnalysisStrategy> strategyWith(String str) {
        return mLSensitivityAnalysisStrategy -> {
            return mLSensitivityAnalysisStrategy.getName().equals(str);
        };
    }
}
