/*
 * Decompiled with CFR 0.152.
 */
package org.palladiosimulator.simexp.markovian.exploration;

import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.palladiosimulator.simexp.distribution.factory.ProbabilityDistributionFactory;
import org.palladiosimulator.simexp.distribution.function.ProbabilityMassFunction;
import org.palladiosimulator.simexp.markovian.activity.BasePolicy;
import org.palladiosimulator.simexp.markovian.exploration.RandomizedStrategy;
import org.palladiosimulator.simexp.markovian.model.markovmodel.markoventity.State;
import org.palladiosimulator.simexp.markovian.model.markovmodel.markoventity.Transition;
import org.palladiosimulator.simexp.markovian.util.MarkovianUtil;

public class EpsilonGreedyStrategy<A>
implements BasePolicy<Transition<A>> {
    private static final String STRATEGY_ID = "EpsilonGreedy";
    private static final double DEFAULT_INIT_EPSILON = 0.01;
    private static final int EXPLORATION_FACTOR = 100;
    private double epsilon = 0.01;
    private Function<Integer, Double> epsilonAdjustementLaw = this.getDefaultEpsilonAdjustementLaw();

    public void setEpsilon(double epsilon) {
        this.epsilon = epsilon;
    }

    public void setEpsilonAdjustementLaw(Function<Integer, Double> epsilonAdjustementLaw) {
        this.epsilonAdjustementLaw = epsilonAdjustementLaw;
    }

    @Override
    public Transition<A> select(State source, Set<Transition<A>> options) {
        TransitionHelper<A> transHelper = new TransitionHelper<A>(options);
        Transition<A> max = transHelper.getMostProbableTransition();
        ProbabilityMassFunction.Sample maxSample = ProbabilityMassFunction.Sample.of(max, (double)this.epsilon);
        ProbabilityMassFunction.Sample otherSamples = ProbabilityMassFunction.Sample.of(null, (double)(1.0 - this.epsilon));
        LinkedHashSet<ProbabilityMassFunction.Sample> samples = new LinkedHashSet<ProbabilityMassFunction.Sample>(Arrays.asList(maxSample, otherSamples));
        ProbabilityMassFunction pmfOver = ProbabilityDistributionFactory.INSTANCE.pmfOver(samples);
        ProbabilityMassFunction.Sample result = (ProbabilityMassFunction.Sample)pmfOver.drawSample();
        if (result.getValue() == max) {
            return max;
        }
        Set<Transition<A>> all = transHelper.filterAllExcept(max);
        return this.selectRandomly(all);
    }

    public void adjust(int numberOfIteration) {
        this.epsilon = this.epsilonAdjustementLaw.apply(numberOfIteration);
    }

    private Transition<A> selectRandomly(Set<Transition<A>> transitions) {
        RandomizedStrategy<Transition<A>> randomizedStrategy = new RandomizedStrategy<Transition<A>>();
        Transition<A> transition = randomizedStrategy.select((State)null, transitions);
        return transition;
    }

    private Function<Integer, Double> getDefaultEpsilonAdjustementLaw() {
        return value -> Math.exp(value - 100);
    }

    @Override
    public String getId() {
        return STRATEGY_ID;
    }

    private static class TransitionHelper<A> {
        private Set<Transition<A>> transitions;

        public TransitionHelper(Set<Transition<A>> transitions) {
            this.transitions = transitions;
        }

        public Transition<A> getMostProbableTransition() {
            return MarkovianUtil.maxTransition(this.transitions);
        }

        public Set<Transition<A>> filterAllExcept(Transition<A> transition) {
            return this.transitions.stream().filter(t -> !t.equals(transition)).collect(Collectors.toCollection(LinkedHashSet::new));
        }
    }
}

