package org.palladiosimulator.dependability.ml.model.nn;

import com.google.common.collect.Lists;
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.palladiosimulator.dependability.ml.exception.DependableMLException;
import org.palladiosimulator.dependability.ml.iterator.DirectoryBasedTrainingDataIterator;
import org.palladiosimulator.dependability.ml.model.InputData;
import org.palladiosimulator.dependability.ml.model.InputDataLabel;
import org.palladiosimulator.dependability.ml.model.MLPredictionResult;
import org.palladiosimulator.dependability.ml.model.OutputData;
import org.palladiosimulator.dependability.ml.model.TrainedModel;
import org.palladiosimulator.dependability.ml.model.access.HttpModelAccessor;
import org.palladiosimulator.dependability.ml.util.Tuple;

/* loaded from: input_file:org/palladiosimulator/dependability/ml/model/nn/MaskRCNN.class */
public class MaskRCNN implements TrainedModel {
    private static final String MODEL_NAME = "Mask R-CNN";
    private static final String CLASS_1_PREDICTION = "class1";
    private static final String CLASS_2_PREDICTION = "class2";
    private final HttpModelAccessor remoteTrainedModel = new HttpModelAccessor();

    /* loaded from: input_file:org/palladiosimulator/dependability/ml/model/nn/MaskRCNN$MaskRCNNTrainingDataIterator.class */
    public static class MaskRCNNTrainingDataIterator extends DirectoryBasedTrainingDataIterator {
        private static final String LABEL_POSTFIX = "label.txt";

        public MaskRCNNTrainingDataIterator(File file) {
            super(file);
        }

        @Override // org.palladiosimulator.dependability.ml.iterator.DirectoryBasedTrainingDataIterator
        protected Iterator<Tuple<InputData, InputDataLabel>> arrangeTrainingData(List<File> list) {
            Map map = (Map) list.stream().collect(Collectors.partitioningBy(isTrainingData()));
            List list2 = (List) map.get(true);
            List list3 = (List) map.get(false);
            if (list2.size() != list3.size()) {
                DependableMLException.throwWithMessage("The number of training and label data is unequal");
            }
            ArrayList newArrayList = Lists.newArrayList();
            for (int i = 0; i < list2.size(); i++) {
                newArrayList.add(Tuple.of(new ImageInputData((File) list2.get(i)), new ImageSegmentationLabel((File) list3.get(i))));
            }
            return newArrayList.iterator();
        }

        private Predicate<File> isTrainingData() {
            return file -> {
                return !file.getName().endsWith(LABEL_POSTFIX);
            };
        }
    }

    @Override // org.palladiosimulator.dependability.ml.model.TrainedModel
    public MaskRCNNTrainingDataIterator getTrainingDataIteratorBy(File file) {
        return new MaskRCNNTrainingDataIterator(file);
    }

    @Override // org.palladiosimulator.dependability.ml.model.TrainedModel
    public void loadModel(URI uri) {
        if (this.remoteTrainedModel.canNotAccess(uri)) {
            DependableMLException.throwWithMessage(String.format("The model %s could not be loaded.", uri));
        }
        this.remoteTrainedModel.load(uri);
    }

    @Override // org.palladiosimulator.dependability.ml.model.TrainedModel
    public MLPredictionResult makePrediction(Tuple<InputData, InputDataLabel> tuple) {
        return parsePredictionResult(this.remoteTrainedModel.query(tuple.getFirst()));
    }

    private MLPredictionResult parsePredictionResult(List<OutputData> list) {
        MLPredictionResult mLPredictionResult = new MLPredictionResult(new Random(System.currentTimeMillis()).nextInt(2) == 1);
        mLPredictionResult.getPredictions().addAll(list);
        return mLPredictionResult;
    }

    @Override // org.palladiosimulator.dependability.ml.model.TrainedModel
    public String getName() {
        return MODEL_NAME;
    }
}
