package ru.yandex.solomon.gateway.data;

import java.util.Collection;
import java.util.Collections;
import java.util.EnumMap;
import java.util.EnumSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import com.google.common.base.Joiner;
import com.google.common.base.Throwables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bolts.collection.Try;
import ru.yandex.monlib.metrics.primitives.Rate;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.expression.analytics.GraphDataLoadRequest;
import ru.yandex.solomon.expression.analytics.GraphDataLoader;
import ru.yandex.solomon.expression.analytics.PreparedProgram;
import ru.yandex.solomon.expression.analytics.Program;
import ru.yandex.solomon.expression.compile.DeprOpts;
import ru.yandex.solomon.expression.expr.ProgramType;
import ru.yandex.solomon.expression.value.SelValue;
import ru.yandex.solomon.expression.version.SelVersion;
import ru.yandex.solomon.flags.FeatureFlag;
import ru.yandex.solomon.flags.FeatureFlagsHolder;
import ru.yandex.solomon.selfmon.counters.EnumMetrics;

/**
 * @author Vladimir Gordiychuk
 */
public class ExpressionCompilerImpl implements ExpressionCompiler {
    private static final Logger expressionLogger = LoggerFactory.getLogger("expression");

    private final FeatureFlagsHolder featureFlagsHolder;
    private final EnumMap<SelVersion, Rate> failedForVersion;
    private final EnumMap<SelVersion, Rate> completedForVersion;
    private final Rate versionsInconsistent;
    private final Rate versionsConsistent;

    public ExpressionCompilerImpl(FeatureFlagsHolder featureFlagsHolder, MetricRegistry registry) {
        this.featureFlagsHolder = featureFlagsHolder;
        this.failedForVersion = EnumMetrics.rates(SelVersion.class, registry, "dataClient.versionedRequests.failed", "version");
        this.completedForVersion = EnumMetrics.rates(SelVersion.class, registry, "dataClient.versionedRequests.completed", "version");
        this.versionsInconsistent = registry.rate("dataClient.versionedRequests.inconsistent");
        this.versionsConsistent = registry.rate("dataClient.versionedRequests.consistent");
    }

    @Override
    public ExpressionProgram compile(DataRequest request) {
        List<SelVersion> versions = getExpressionLanguageVersions(request.getVersion(), request.getProjectId());
        if (versions.size() == 1) {
            return prepare(versions.get(0), request);
        }

        return prepare(versions, request);
    }

    private Program compile(SelVersion version, String source, DeprOpts deprOpts, boolean useNewSelectors) {
        return Program.fromSourceWithReturn(version, source, deprOpts, useNewSelectors).compile();
    }

    private VersionedProgram prepare(SelVersion version, DataRequest request) {
        try {
            var compiled = compile(version, request.getProgram(), request.getDeprOpts(), request.isUseNewFormat());
            var prepared = compiled.prepare(request.getInterval());
            return new VersionedProgram(prepared);
        } catch (Throwable e) {
            failedForVersion.get(version).inc();
            Throwables.throwIfUnchecked(e);
            throw new RuntimeException(e);
        }
    }

    private MultiVersionedProgram prepare(List<SelVersion> versions, DataRequest request) {
        LinkedHashMap<SelVersion, Try<VersionedProgram>> programsByVersion = versions.stream()
                .collect(Collectors.toMap(Function.identity(), version -> Try.tryCatchException(() -> prepare(version, request)),
                        (l, r) -> r, () -> new LinkedHashMap<>(versions.size())));
        if (programsByVersion.values().stream().allMatch(Try::isFailure)) {
            versionsConsistent.inc();
            Throwable throwable = programsByVersion.values().iterator().next().getThrowable();
            Throwables.throwIfUnchecked(throwable);
            throw new RuntimeException(throwable);
        }

        return new MultiVersionedProgram(request, programsByVersion);
    }

    private List<SelVersion> getExpressionLanguageVersions(@Nullable SelVersion version, String projectId) {
        // 1. If set explicitly, use it
        if (version != null) {
            return List.of(version);
        }
        // 2. Look if set by flags
        if (featureFlagsHolder.hasFlag(FeatureFlag.EXPRESSION_VECTORED_TYPES, projectId)) {
            return List.of(SelVersion.GROUP_LINES_RETURN_VECTOR_2);
        }
        if (featureFlagsHolder.hasFlag(FeatureFlag.EXPRESSION_LAST_VERSION, projectId)) {
            return List.of(SelVersion.MAX);
        }
        // 3. Assume lowest supported version, also compare with latest
        return (SelVersion.MIN == SelVersion.MAX) ? List.of(SelVersion.MAX) : List.of(SelVersion.MIN, SelVersion.MAX);
    }

    private class VersionedProgram implements ExpressionProgram {
        private final PreparedProgram prepared;

        public VersionedProgram(PreparedProgram prepared) {
            this.prepared = prepared;
        }

        @Override
        public ProgramType type() {
            return prepared.getProgramType();
        }

        @Override
        public Collection<GraphDataLoadRequest> loadRequests() {
            return prepared.getLoadRequests();
        }

        @Nonnull
        @Override
        public SelValue evaluate(GraphDataLoader loader) {
            try {
                Map<String, SelValue> evalResultByVar = prepared.evaluate(loader, Collections.emptyMap());
                String resultVar = prepared.expressionToVarName();
                SelValue evalResult = evalResultByVar.get(resultVar);

                if (evalResult == null) {
                    throw new RuntimeException("failed to return value for expression: " + prepared.getSource());
                }
                completedForVersion.get(prepared.getVersion()).inc();
                return evalResult;
            } catch (Throwable e) {
                failedForVersion.get(prepared.getVersion()).inc();
                Throwables.throwIfUnchecked(e);
                throw new RuntimeException(e);
            }
        }
    }

    private class MultiVersionedProgram implements ExpressionProgram {
        private final DataRequest request;
        private final LinkedHashMap<SelVersion, Try<VersionedProgram>> programByVersion;
        private final Set<GraphDataLoadRequest> loadRequests;
        private final ProgramType type;

        public MultiVersionedProgram(DataRequest request, LinkedHashMap<SelVersion, Try<VersionedProgram>> programByVersion) {
            this.request = request;
            this.programByVersion = programByVersion;
            this.loadRequests = programByVersion.values()
                    .stream()
                    .filter(Try::isSuccess)
                    .flatMap(v -> v.get().loadRequests().stream())
                    .collect(Collectors.toSet());
            this.type = programByVersion.values()
                    .stream()
                    .filter(Try::isSuccess)
                    .map(v -> v.get().type())
                    .findFirst()
                    .orElseThrow(() -> new RuntimeException("Unknown program type"));
        }

        @Override
        public ProgramType type() {
            return type;
        }

        @Override
        public Collection<GraphDataLoadRequest> loadRequests() {
            return loadRequests;
        }

        @Nonnull
        @Override
        public SelValue evaluate(GraphDataLoader loader) {
            LinkedHashMap<SelVersion, Try<SelValue>> results = programByVersion.entrySet().stream()
                    .map(versionAndMaybeProgram -> {
                        SelVersion version = versionAndMaybeProgram.getKey();
                        return Map.entry(version, versionAndMaybeProgram.getValue()
                                .mapCatchException(program -> {
                                    return program.evaluate(loader);
                                }));
                    })
                    .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,
                            (l, r) -> r, () -> new LinkedHashMap<>(programByVersion.size())));

            return compareResults(results);
        }

        private SelValue compareResults(LinkedHashMap<SelVersion, Try<SelValue>> responses) {
            if (responses.values().stream().allMatch(Try::isFailure)) {
                if (responses.size() > 1) {
                    versionsConsistent.inc();
                }
                Throwable throwable = responses.values().iterator().next().getThrowable();
                Throwables.throwIfUnchecked(throwable);
                throw new RuntimeException(throwable);
            }

            // Only partial fail, some versions succeeded - log all that have failed
            responses.entrySet().stream()
                    .filter(response -> response.getValue().isFailure())
                    .forEach(versionAndFailure -> {
                        SelVersion version = versionAndFailure.getKey();
                        Throwable t = versionAndFailure.getValue().getThrowable();
                        expressionLogger.error("Processing data request " + request + " failed for version " + version, t);
                    });

            Map<Optional<SelValue>, EnumSet<SelVersion>> versionsByBody = responses.entrySet().stream()
                    .collect(Collectors.groupingBy(
                            verAndTry -> Optional.ofNullable(verAndTry.getValue().getOrElse(null)),
                            Collector.of(
                                    () -> EnumSet.noneOf(SelVersion.class),
                                    (set, e) -> set.add(e.getKey()),
                                    (left, right) -> { left.addAll(right); return left; })
                    ));

            if (versionsByBody.keySet().size() != 1) {
                versionsInconsistent.inc();
                String message = "Different expression language versions gave different results for request\n" +
                        request + ". The results were: \n" +
                        Joiner.on('\n').withKeyValueSeparator(" for versions ").join(versionsByBody);
                expressionLogger.error(message);
            } else if (responses.size() > 1) {
                versionsConsistent.inc();
            }

            Try<SelValue> firstResult = responses.values().iterator().next();
            if (firstResult.isSuccess()) {
                return firstResult.get();
            }
            Throwable t = firstResult.getThrowable();
            Throwables.throwIfUnchecked(t);
            throw new RuntimeException(t);
        }
    }
}
