/*
 * Decompiled with CFR 0.152.
 */
package org.palladiosimulator.envdyn.api.entity.bn;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.palladiosimulator.envdyn.api.entity.ProbabilisticModel;
import org.palladiosimulator.envdyn.api.entity.bn.BayesianNetwork;
import org.palladiosimulator.envdyn.api.entity.bn.ConditionalInputValueUtil;
import org.palladiosimulator.envdyn.api.entity.bn.InputValue;
import org.palladiosimulator.envdyn.api.entity.bn.ProbabilityDistributionHandler;
import org.palladiosimulator.envdyn.api.exception.EnvironmentalDynamicsException;
import org.palladiosimulator.envdyn.api.util.InductiveDynamicBehaviourQuerying;
import org.palladiosimulator.envdyn.api.util.TemplateDefinitionsQuerying;
import org.palladiosimulator.envdyn.environment.dynamicmodel.DynamicBehaviourExtension;
import org.palladiosimulator.envdyn.environment.dynamicmodel.InterTimeSliceInduction;
import org.palladiosimulator.envdyn.environment.dynamicmodel.IntraTimeSliceInduction;
import org.palladiosimulator.envdyn.environment.staticmodel.GroundRandomVariable;
import org.palladiosimulator.envdyn.environment.staticmodel.LocalProbabilisticNetwork;
import org.palladiosimulator.envdyn.environment.templatevariable.TemplateVariable;
import tools.mdsd.probdist.api.builder.ProbabilityDistributionBuilder;
import tools.mdsd.probdist.api.entity.Conditionable;
import tools.mdsd.probdist.api.entity.ConditionableProbabilityDistribution;
import tools.mdsd.probdist.api.entity.ProbabilityDistributionFunction;
import tools.mdsd.probdist.api.entity.Value;
import tools.mdsd.probdist.api.factory.IProbabilityDistributionFactory;
import tools.mdsd.probdist.api.random.ISeedProvider;
import tools.mdsd.probdist.distributionfunction.Domain;
import tools.mdsd.probdist.distributionfunction.ProbabilityDistribution;
import tools.mdsd.probdist.distributiontype.ProbabilityDistributionSkeleton;

public class DynamicBayesianNetwork<I extends Value<?>>
extends ProbabilityDistributionFunction<Trajectory<I>>
implements ProbabilisticModel<Trajectory<I>>,
Conditionable<I> {
    private static final int SINGLE_TIME_SLICE = 0;
    private final InductiveDynamicBehaviourQuerying dynBehaviourQuery;
    private final DynamicBehaviourExtension dynamics;
    private final BayesianNetwork<I> initialDistribution;
    private final TemporalProbabilityHandler probHandler;
    private final List<ConditionalInputValue<I>> conditionals;
    private final ConditionalInputValueUtil<I> conditionalInputValueUtil = new ConditionalInputValueUtil();

    public DynamicBayesianNetwork(ProbabilityDistributionSkeleton distSkeleton, BayesianNetwork<I> initialDistribution, DynamicBehaviourExtension dynamics, IProbabilityDistributionFactory<I> probabilityDistributionFactory) {
        super(distSkeleton);
        this.dynamics = dynamics;
        this.dynBehaviourQuery = InductiveDynamicBehaviourQuerying.create(dynamics);
        this.initialDistribution = initialDistribution;
        this.probHandler = new TemporalProbabilityHandler(probabilityDistributionFactory);
        this.conditionals = Lists.newArrayList();
    }

    public Double probability(Trajectory<I> value) {
        return this.unrollForProbability(value);
    }

    public void init(Optional<ISeedProvider> seedProvider) {
        ConditionableProbabilityDistribution localCPD;
        if (this.initialized) {
            throw new RuntimeException("already initialized");
        }
        this.initialized = true;
        this.initialDistribution.init(seedProvider);
        for (InterTimeSliceInduction interTimeSliceInduction : this.dynBehaviourQuery.getInterTimeSliceInductions()) {
            localCPD = this.probHandler.getCPD(interTimeSliceInduction.getAppliedGroundVariable());
            localCPD.init(seedProvider);
        }
        for (IntraTimeSliceInduction intraTimeSliceInduction : this.dynBehaviourQuery.getIntraTimeSliceInductions()) {
            localCPD = this.getCPDFromInitial(intraTimeSliceInduction, this.conditionals);
            localCPD.init(seedProvider);
        }
    }

    public Trajectory<I> sample() {
        if (!this.initialized) {
            throw new RuntimeException("not initialized");
        }
        return this.unrollForSampling(0);
    }

    @Override
    public Double infer(List<Trajectory<I>> inputs) {
        throw new UnsupportedOperationException("The method is not implemented yet.");
    }

    @Override
    public void learn(List<Trajectory<I>> trainingData) {
        throw new UnsupportedOperationException("The method is not implemented yet.");
    }

    public DynamicBayesianNetwork<I> given(List<Conditionable.Conditional<I>> conditionals) {
        this.checkValidity(conditionals);
        this.setConditionals(this.asConditionalInputValues(conditionals));
        return this;
    }

    public BayesianNetwork<I> getBayesianNetwork() {
        return this.initialDistribution;
    }

    public DynamicBehaviourExtension getDynamics() {
        return this.dynamics;
    }

    private void setConditionals(List<ConditionalInputValue<I>> conditionals) {
        this.conditionals.clear();
        this.conditionals.addAll(conditionals);
    }

    public Double unrollForProbability(Trajectory<I> traj) {
        double probability = 1.0;
        int timeSlice = 0;
        while (traj.inTimeRange(timeSlice)) {
            List<InputValue<I>> current = traj.valueAtTime(timeSlice);
            if (timeSlice == 0) {
                probability *= this.calculateInitialProbability(current);
            } else {
                List<InputValue<I>> last = traj.valueAtTime(timeSlice - 1);
                probability *= this.calculateProbability(current, last);
            }
            ++timeSlice;
        }
        return probability;
    }

    public Trajectory<I> unrollForSampling(int timeSlices) {
        Trajectory<I> samplePath = Trajectory.create(timeSlices);
        int i = 0;
        while (samplePath.inTimeRange(i)) {
            List<InputValue<I>> sample = this.sampleNext();
            this.setConditionals(this.conditionalInputValueUtil.toConditionalInputs(sample));
            samplePath.append(sample);
            ++i;
        }
        return samplePath;
    }

    private double calculateInitialProbability(List<InputValue<I>> inputs) {
        return this.initialDistribution.probability(inputs);
    }

    private double calculateProbability(List<InputValue<I>> current, List<InputValue<I>> last) {
        InputValue<I> resolvedValue;
        ConditionableProbabilityDistribution localCPD;
        double probability = 1.0;
        for (InterTimeSliceInduction interTimeSliceInduction : this.dynBehaviourQuery.getInterTimeSliceInductions()) {
            localCPD = this.probHandler.getCPD(interTimeSliceInduction.getAppliedGroundVariable());
            List<Conditionable.Conditional<I>> resolvedConditionals = this.resolveConditionals(interTimeSliceInduction, this.conditionalInputValueUtil.toConditionalInputs(last));
            resolvedValue = this.conditionalInputValueUtil.getInputValue(interTimeSliceInduction.getAppliedGroundVariable(), current);
            ConditionableProbabilityDistribution givenCPD = (ConditionableProbabilityDistribution)localCPD.given(resolvedConditionals);
            I value = resolvedValue.getValue();
            probability *= givenCPD.probability(value).doubleValue();
        }
        for (IntraTimeSliceInduction intraTimeSliceInduction : this.dynBehaviourQuery.getIntraTimeSliceInductions()) {
            localCPD = this.getCPDFromInitial(intraTimeSliceInduction, this.conditionalInputValueUtil.toConditionalInputs(current));
            InputValue<I> resolvedInputValue = this.conditionalInputValueUtil.getInputValue(intraTimeSliceInduction.getAppliedGroundVariable(), current);
            resolvedValue = resolvedInputValue.getValue();
            probability *= localCPD.probability(resolvedValue).doubleValue();
        }
        return probability;
    }

    private List<InputValue<I>> sampleNext() {
        if (this.conditionals.isEmpty()) {
            return this.initialDistribution.sample();
        }
        return this.sampleNextGiven(this.conditionals);
    }

    private List<InputValue<I>> sampleNextGiven(List<ConditionalInputValue<I>> conditionals) {
        GroundRandomVariable variable;
        ArrayList sample = Lists.newArrayList();
        for (InterTimeSliceInduction interTimeSliceInduction : this.dynBehaviourQuery.getInterTimeSliceInductions()) {
            List<Conditionable.Conditional<I>> resolved = this.resolveConditionals(interTimeSliceInduction, conditionals);
            variable = interTimeSliceInduction.getAppliedGroundVariable();
            ConditionableProbabilityDistribution localCPD = this.probHandler.getCPD(variable);
            ConditionableProbabilityDistribution given = (ConditionableProbabilityDistribution)localCPD.given(resolved);
            Value value = (Value)given.sample();
            sample.add(InputValue.create(value, variable));
        }
        for (IntraTimeSliceInduction intraTimeSliceInduction : this.dynBehaviourQuery.getIntraTimeSliceInductions()) {
            ConditionableProbabilityDistribution<I> localCPD = this.getCPDFromInitial(intraTimeSliceInduction, conditionals);
            variable = intraTimeSliceInduction.getAppliedGroundVariable();
            Value value = (Value)localCPD.sample();
            InputValue<Value> inputValue = InputValue.create(value, variable);
            sample.add(inputValue);
        }
        return sample;
    }

    private ConditionableProbabilityDistribution<I> getCPDFromInitial(IntraTimeSliceInduction induction, List<ConditionalInputValue<I>> conditionals) {
        List<InputValue<I>> history = this.conditionalInputValueUtil.toInputValues(conditionals);
        ProbabilityDistributionFunction<I> pdf = this.initialDistribution.getPDF(induction.getAppliedGroundVariable(), history);
        return (ConditionableProbabilityDistribution)pdf;
    }

    private List<Conditionable.Conditional<I>> resolveConditionals(InterTimeSliceInduction induction, List<ConditionalInputValue<I>> conditionals) {
        Set<TemplateVariable> interfaceVarParents = this.getInterfaceVariableParents(induction);
        ArrayList resolved = Lists.newArrayList();
        for (ConditionalInputValue<I> each : conditionals) {
            TemplateVariable instantiated = each.getGroundVariable().getInstantiatedTemplate();
            if (!TemplateDefinitionsQuerying.contains(instantiated, interfaceVarParents) || !this.shareSameContext(induction.getAppliedGroundVariable(), each.getGroundVariable())) continue;
            resolved.add(each);
        }
        return this.conditionalInputValueUtil.asConditionals(resolved);
    }

    private boolean shareSameContext(GroundRandomVariable parent, GroundRandomVariable child) {
        LinkedHashSet childContext;
        LinkedHashSet parentContext = new LinkedHashSet(parent.getAppliedObjects());
        return Sets.intersection(parentContext, childContext = new LinkedHashSet(child.getAppliedObjects())).size() > 0;
    }

    private Set<TemplateVariable> getInterfaceVariableParents(InterTimeSliceInduction induction) {
        return induction.getTemporalStructure().stream().map(InductiveDynamicBehaviourQuerying::getSource).collect(Collectors.toCollection(LinkedHashSet::new));
    }

    private void checkValidity(List<Conditionable.Conditional<I>> conditionals) {
        if (conditionals.isEmpty()) {
            return;
        }
        if (!conditionals.stream().allMatch(ConditionalInputValue.class::isInstance)) {
            throw new IllegalArgumentException("The conditionals cannot applied in the dynamic bayesian network context.");
        }
    }

    private List<ConditionalInputValue<I>> asConditionalInputValues(List<Conditionable.Conditional<I>> conditionals) {
        return conditionals.stream().map(ConditionalInputValue.class::cast).collect(Collectors.toList());
    }

    public static class ConditionalInputValue<I extends Value<?>>
    extends Conditionable.Conditional<I> {
        private final GroundRandomVariable variable;

        private ConditionalInputValue(Domain valueSpace, I value, GroundRandomVariable variable) {
            super(valueSpace, value);
            this.variable = variable;
        }

        public static <I extends Value<?>> ConditionalInputValue<I> create(Conditionable.Conditional<I> conditional, GroundRandomVariable variable) {
            return new ConditionalInputValue<Value>(conditional.getValueSpace(), conditional.getValue(), variable);
        }

        public GroundRandomVariable getGroundVariable() {
            return this.variable;
        }
    }

    private class TemporalProbabilityHandler
    extends ProbabilityDistributionHandler<I> {
        private final IProbabilityDistributionFactory<I> probabilityDistributionFactory;

        public TemporalProbabilityHandler(IProbabilityDistributionFactory<I> probabilityDistributionFactory) {
            this.probabilityDistributionFactory = probabilityDistributionFactory;
        }

        @Override
        protected void initialize() {
            List<LocalProbabilisticNetwork> localProbabilisticNetworks = DynamicBayesianNetwork.this.initialDistribution.getLocalProbabilisticNetworks();
            localProbabilisticNetworks.forEach(this::createAndCache);
        }

        private void createAndCache(LocalProbabilisticNetwork localNetwork) {
            for (GroundRandomVariable each : localNetwork.getGroundRandomVariables()) {
                DynamicBayesianNetwork.this.dynBehaviourQuery.findInductionExtending(each).ifPresent(i -> this.createAndCacheCPD(each, i.getDescriptiveModel().getDistributionFunction()));
            }
        }

        private void createAndCacheCPD(GroundRandomVariable variable, ProbabilityDistribution distribution) {
            ProbabilityDistributionBuilder probabilityDistributionBuilder = this.probabilityDistributionFactory.getProbabilityDistributionBuilder();
            ProbabilityDistributionFunction pdf = probabilityDistributionBuilder.withStructure(distribution).asConditionalProbabilityDistribution().build();
            this.cache(variable, pdf);
        }

        public ConditionableProbabilityDistribution<I> getCPD(GroundRandomVariable variable) {
            return (ConditionableProbabilityDistribution)this.getPDF(variable);
        }
    }

    public static class Trajectory<I extends Value<?>> {
        private final int trajLength;
        private final Map<Integer, List<InputValue<I>>> samplePath;

        private Trajectory(int trajLength, Map<Integer, List<InputValue<I>>> samplePath) {
            this.trajLength = trajLength;
            this.samplePath = samplePath;
        }

        public static <I extends Value<?>> Trajectory<I> create(int timeSlices, List<List<InputValue<I>>> samples) {
            if (timeSlices != samples.size() - 1) {
                throw new IllegalArgumentException("The number of time slices must match the input sequence size.");
            }
            LinkedHashMap samplePath = Maps.newLinkedHashMap();
            int i = 0;
            while (i < timeSlices) {
                samplePath.put(i, samples.get(i));
                ++i;
            }
            return new Trajectory<I>(timeSlices, samplePath);
        }

        public static <I extends Value<?>> Trajectory<I> create(int timeSlices) {
            LinkedHashMap samplePath = Maps.newLinkedHashMap();
            int i = 0;
            while (i <= timeSlices) {
                samplePath.put(i, Lists.newArrayList());
                ++i;
            }
            return new Trajectory<I>(timeSlices, samplePath);
        }

        public List<InputValue<I>> valueAtTime(int timeSlice) {
            if (Boolean.logicalOr(timeSlice > this.samplePath.size(), timeSlice < 0)) {
                throw new IllegalArgumentException("The time slice is not in the range of the trajectory.");
            }
            return this.samplePath.get(timeSlice);
        }

        public boolean inTimeRange(int timeSlice) {
            return timeSlice <= this.trajLength;
        }

        public void append(List<InputValue<I>> samples) {
            if (this.maxTrajSizeNotReached()) {
                this.samplePath.put(this.calculateTimeSlice(), samples);
            }
        }

        private Integer calculateTimeSlice() {
            for (Integer each : this.samplePath.keySet()) {
                if (!this.samplePath.get(each).isEmpty()) continue;
                return each;
            }
            throw new EnvironmentalDynamicsException("The max trajectory size is reached");
        }

        private boolean maxTrajSizeNotReached() {
            return this.samplePath.size() >= this.trajLength;
        }

        public String toString() {
            StringBuilder builder = new StringBuilder();
            for (int each : this.samplePath.keySet()) {
                String timeSlice = this.stringifyTimeSlice(each, this.samplePath.get(each));
                builder.append(timeSlice);
                builder.append("\n");
            }
            return builder.toString();
        }

        private String stringifyTimeSlice(int timeSlice, List<InputValue<I>> values) {
            StringBuilder builder = new StringBuilder();
            for (InputValue<I> each : values) {
                builder.append(String.format("(Variable: %1s, Value: %2s),", each.getVariable().getEntityName(), each.getValue().toString()));
            }
            String stringValues = builder.toString();
            return String.format("Time slice: %d, samples: [%s])", timeSlice, stringValues.substring(0, stringValues.length() - 1));
        }
    }
}

