/*
 * Decompiled with CFR 0.152.
 */
package org.palladiosimulator.dependability.ml.model.nn;

import com.google.common.collect.Lists;
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.palladiosimulator.dependability.ml.exception.DependableMLException;
import org.palladiosimulator.dependability.ml.iterator.DirectoryBasedTrainingDataIterator;
import org.palladiosimulator.dependability.ml.model.InputData;
import org.palladiosimulator.dependability.ml.model.InputDataLabel;
import org.palladiosimulator.dependability.ml.model.MLPredictionResult;
import org.palladiosimulator.dependability.ml.model.OutputData;
import org.palladiosimulator.dependability.ml.model.TrainedModel;
import org.palladiosimulator.dependability.ml.model.access.HttpModelAccessor;
import org.palladiosimulator.dependability.ml.model.nn.ImageInputData;
import org.palladiosimulator.dependability.ml.model.nn.ImageSegmentationLabel;
import org.palladiosimulator.dependability.ml.util.Tuple;

public class MaskRCNN
implements TrainedModel {
    private static final String MODEL_NAME = "Mask R-CNN";
    private static final String CLASS_1_PREDICTION = "class1";
    private static final String CLASS_2_PREDICTION = "class2";
    private final HttpModelAccessor remoteTrainedModel = new HttpModelAccessor();

    @Override
    public MaskRCNNTrainingDataIterator getTrainingDataIteratorBy(File dataLocation) {
        return new MaskRCNNTrainingDataIterator(dataLocation);
    }

    @Override
    public void loadModel(URI modelURI) {
        if (this.remoteTrainedModel.canNotAccess(modelURI)) {
            DependableMLException.throwWithMessage(String.format("The model %s could not be loaded.", modelURI));
        }
        this.remoteTrainedModel.load(modelURI);
    }

    @Override
    public MLPredictionResult makePrediction(Tuple<InputData, InputDataLabel> dataTuple) {
        List<OutputData> prediction = this.remoteTrainedModel.query(dataTuple.getFirst());
        return this.parsePredictionResult(prediction);
    }

    private MLPredictionResult parsePredictionResult(List<OutputData> predictions) {
        boolean randomizedExpectation = new Random(System.currentTimeMillis()).nextInt(2) == 1;
        MLPredictionResult result = new MLPredictionResult(randomizedExpectation);
        result.getPredictions().addAll(predictions);
        return result;
    }

    @Override
    public String getName() {
        return MODEL_NAME;
    }

    public static class MaskRCNNTrainingDataIterator
    extends DirectoryBasedTrainingDataIterator {
        private static final String LABEL_POSTFIX = "label.txt";

        public MaskRCNNTrainingDataIterator(File trainingDataLocation) {
            super(trainingDataLocation);
        }

        @Override
        protected Iterator<Tuple<InputData, InputDataLabel>> arrangeTrainingData(List<File> trainData) {
            Map<Boolean, List<File>> partitionedData = trainData.stream().collect(Collectors.partitioningBy(this.isTrainingData()));
            List<File> trainDataSplit = partitionedData.get(true);
            List<File> labelDataSplit = partitionedData.get(false);
            if (trainDataSplit.size() != labelDataSplit.size()) {
                DependableMLException.throwWithMessage("The number of training and label data is unequal");
            }
            ArrayList arrangedData = Lists.newArrayList();
            int i = 0;
            while (i < trainDataSplit.size()) {
                ImageInputData inputData = new ImageInputData(trainDataSplit.get(i));
                ImageSegmentationLabel label = new ImageSegmentationLabel(labelDataSplit.get(i));
                arrangedData.add(Tuple.of(inputData, label));
                ++i;
            }
            return arrangedData.iterator();
        }

        private Predicate<File> isTrainingData() {
            return file -> !file.getName().endsWith(LABEL_POSTFIX);
        }
    }
}

