package ru.yandex.metabase.client.impl;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import org.slf4j.Logger;

import ru.yandex.solomon.metabase.api.protobuf.CreateManyResponse;
import ru.yandex.solomon.metabase.api.protobuf.CreateOneResponse;
import ru.yandex.solomon.metabase.api.protobuf.DeleteManyResponse;
import ru.yandex.solomon.metabase.api.protobuf.EMetabaseStatusCode;
import ru.yandex.solomon.metabase.api.protobuf.FindRequest;
import ru.yandex.solomon.metabase.api.protobuf.FindResponse;
import ru.yandex.solomon.metabase.api.protobuf.Metric;
import ru.yandex.solomon.metabase.api.protobuf.MetricNamesRequest;
import ru.yandex.solomon.metabase.api.protobuf.MetricNamesResponse;
import ru.yandex.solomon.metabase.api.protobuf.ResolveManyResponse;
import ru.yandex.solomon.metabase.api.protobuf.ResolveOneResponse;
import ru.yandex.solomon.metabase.api.protobuf.TLabelNamesResponse;
import ru.yandex.solomon.metabase.api.protobuf.TLabelValuesRequest;
import ru.yandex.solomon.metabase.api.protobuf.TLabelValuesResponse;
import ru.yandex.solomon.metabase.api.protobuf.TResolveLogsResponse;
import ru.yandex.solomon.metabase.api.protobuf.TUniqueLabelsResponse;
import ru.yandex.solomon.util.collection.Nullables;
import ru.yandex.solomon.util.labelStats.LabelStatsConverter;

/**
 * @author Egor Litvinenko
 */
class MetabaseResponses {

    private static final Map<Class, BiFunction<EMetabaseStatusCode, String, ?>> STATUS_AND_MESSAGE_RESPONSE = Map.ofEntries(
            Map.entry(ResolveManyResponse.class,
                    (code, message) -> ResolveManyResponse.newBuilder().setStatus(code).setStatusMessage(message).build()),
            Map.entry(ResolveOneResponse.class,
                    (code, message) -> ResolveOneResponse.newBuilder().setStatus(code).setStatusMessage(message).build()),
            Map.entry(CreateManyResponse.class,
                    (code, message) -> CreateManyResponse.newBuilder().setStatus(code).setStatusMessage(message).build()),
            Map.entry(CreateOneResponse.class,
                    (code, message) -> CreateOneResponse.newBuilder().setStatus(code).setStatusMessage(message).build()),
            Map.entry(FindResponse.class,
                    (code, message) -> FindResponse.newBuilder().setStatus(code).setStatusMessage(message).build()),
            Map.entry(DeleteManyResponse.class,
                    (code, message) -> DeleteManyResponse.newBuilder().setStatus(code).setStatusMessage(message).build()),
            Map.entry(TResolveLogsResponse.class,
                    (code, message) -> TResolveLogsResponse.newBuilder().setStatus(code).setStatusMessage(message).build()),
            Map.entry(MetricNamesResponse.class,
                    (code, message) -> MetricNamesResponse.newBuilder().setStatus(code).setStatusMessage(message).build()),
            Map.entry(TLabelValuesResponse.class,
                    (code, message) -> TLabelValuesResponse.newBuilder().setStatus(code).setStatusMessage(message).build()),
            Map.entry(TLabelNamesResponse.class,
                    (code, message) -> TLabelNamesResponse.newBuilder().setStatus(code).setStatusMessage(message).build()),
            Map.entry(TUniqueLabelsResponse.class,
                    (code, message) -> TUniqueLabelsResponse.newBuilder().setStatus(code).setStatusMessage(message).build())
    );

    private static final Map<Class, Function<EMetabaseStatusCode, ?>> STATUS_RESPONSE = Map.ofEntries(
            Map.entry(ResolveManyResponse.class,
                    code -> ResolveManyResponse.newBuilder().setStatus(code).build()),
            Map.entry(ResolveOneResponse.class,
                    code -> ResolveOneResponse.newBuilder().setStatus(code).build()),
            Map.entry(CreateManyResponse.class,
                    code -> CreateManyResponse.newBuilder().setStatus(code).build()),
            Map.entry(CreateOneResponse.class,
                    code -> CreateOneResponse.newBuilder().setStatus(code).build()),
            Map.entry(FindResponse.class,
                    code -> FindResponse.newBuilder().setStatus(code).build()),
            Map.entry(DeleteManyResponse.class,
                    code -> DeleteManyResponse.newBuilder().setStatus(code).build()),
            Map.entry(TResolveLogsResponse.class,
                    code -> TResolveLogsResponse.newBuilder().setStatus(code).build()),
            Map.entry(MetricNamesResponse.class,
                    code -> MetricNamesResponse.newBuilder().setStatus(code).build()),
            Map.entry(TLabelValuesResponse.class,
                    code -> TLabelValuesResponse.newBuilder().setStatus(code).build()),
            Map.entry(TLabelNamesResponse.class,
                    code -> TLabelNamesResponse.newBuilder().setStatus(code).build()),
            Map.entry(TUniqueLabelsResponse.class,
                    code -> TUniqueLabelsResponse.newBuilder().setStatus(code).build())
    );

    private static final Map<Class, Function<String, ?>> SHARD_NOT_FOUND_RESPONSE =
            STATUS_AND_MESSAGE_RESPONSE.entrySet().stream()
            .collect(Collectors.toUnmodifiableMap(
                    Map.Entry::getKey,
                    entry -> (message -> entry.getValue().apply(EMetabaseStatusCode.SHARD_NOT_FOUND, message))
            ));

    private static final Map<Class, Supplier<?>> SHARD_NOT_FOUND_DEFAULT_RESPONSE =
            STATUS_AND_MESSAGE_RESPONSE.entrySet().stream()
            .collect(Collectors.toUnmodifiableMap(
                    Map.Entry::getKey,
                    entry -> (() -> entry.getValue().apply(EMetabaseStatusCode.SHARD_NOT_FOUND, "Shard not found by " +
                            "common labels"))
            ));

    private static final Map<Class, Supplier<?>> EMPTY_RESPONSE =
            STATUS_RESPONSE.entrySet().stream()
                    .collect(Collectors.toUnmodifiableMap(
                            Map.Entry::getKey,
                            entry -> (() -> entry.getValue().apply(EMetabaseStatusCode.OK))
                    ));

    private static final Map<Class<?>, Predicate<?>> STATUS_IS_OK_TEST = Map.ofEntries(
            Map.entry(ResolveManyResponse.class,
                    value -> ((ResolveManyResponse) value).getStatus() == EMetabaseStatusCode.OK),
            Map.entry(ResolveOneResponse.class,
                    value -> ((ResolveOneResponse) value).getStatus() == EMetabaseStatusCode.OK),
            Map.entry(CreateManyResponse.class,
                    value -> ((CreateManyResponse) value).getStatus() == EMetabaseStatusCode.OK),
            Map.entry(CreateOneResponse.class,
                    value -> ((CreateOneResponse) value).getStatus() == EMetabaseStatusCode.OK),
            Map.entry(FindResponse.class,
                    value -> ((FindResponse) value).getStatus() == EMetabaseStatusCode.OK),
            Map.entry(DeleteManyResponse.class,
                    value -> ((DeleteManyResponse) value).getStatus() == EMetabaseStatusCode.OK),
            Map.entry(TResolveLogsResponse.class,
                    value -> ((TResolveLogsResponse) value).getStatus() == EMetabaseStatusCode.OK),
            Map.entry(MetricNamesResponse.class,
                    value -> ((MetricNamesResponse) value).getStatus() == EMetabaseStatusCode.OK),
            Map.entry(TLabelValuesResponse.class,
                    value -> ((TLabelValuesResponse) value).getStatus() == EMetabaseStatusCode.OK),
            Map.entry(TLabelNamesResponse.class,
                    value -> ((TLabelNamesResponse) value).getStatus() == EMetabaseStatusCode.OK),
            Map.entry(TUniqueLabelsResponse.class,
                    value -> ((TUniqueLabelsResponse) value).getStatus() == EMetabaseStatusCode.OK)

    );

    private static <T> BiFunction<EMetabaseStatusCode, String, T> responseWithStatusAndMessageFactory(Class<T> clazz) {
        return (BiFunction<EMetabaseStatusCode, String, T>) STATUS_AND_MESSAGE_RESPONSE.get(clazz);
    }

    static <T> T exceptionResponse(Class<T> clazz, Throwable throwable, Logger logger) {
        BiFunction<EMetabaseStatusCode, String, T> exceptionResponseFactory = responseWithStatusAndMessageFactory(clazz);
        if (throwable instanceof StatusRuntimeException clientException) {
            // exception thrown by client itself
            Status status = clientException.getStatus();
            return exceptionResponseFactory.apply(GrpcStatusMapping.toMetabaseStatusCode(status.getCode()), Nullables.orEmpty(status.getDescription()));
        } else if (throwable instanceof MetabaseResponseException clientException) {
            return exceptionResponseFactory.apply(clientException.getCode(), clientException.getMessage());
        } else {
            logger.error("Unhandled exception", throwable);
            return exceptionResponseFactory.apply(EMetabaseStatusCode.INTERNAL_ERROR, Nullables.orEmpty(throwable.getMessage()));
        }
    }

    static RuntimeException createException(String message, EMetabaseStatusCode code) {
        return new MetabaseResponseException(message, code);
    }

    private static class MetabaseResponseException extends RuntimeException {
        private final EMetabaseStatusCode code;
        public MetabaseResponseException(String message, EMetabaseStatusCode code) {
            super(message);
            this.code = code;
        }

        public EMetabaseStatusCode getCode() {
            return code;
        }
    }

    static <T> CompletableFuture<T> completedException(
            Class<T> clazz,
            EMetabaseStatusCode code,
            String message)
    {
        return CompletableFuture.completedFuture((T) STATUS_AND_MESSAGE_RESPONSE.get(clazz).apply(code, message));
    }

    static <T> Function<String, T> shardNotFoundForOne(Class<T> clazz) {
        return (Function<String, T>) SHARD_NOT_FOUND_RESPONSE.get(clazz);
    }

    static <T> Supplier<T> shardNotFoundForMany(Class<T> clazz) {
        return (Supplier<T>) SHARD_NOT_FOUND_DEFAULT_RESPONSE.get(clazz);
    }

    static <T> Supplier<T> emptyOkResponse(Class<T> clazz) {
        return (Supplier<T>) EMPTY_RESPONSE.get(clazz);
    }

    static <T> Predicate<T> statusIsOk(Class<T> clazz) {
        return (Predicate<T>) STATUS_IS_OK_TEST.get(clazz);
    }

    static ResolveManyResponse reduce(ResolveManyResponse left, ResolveManyResponse right) {
        if (left == null) {
            return right;
        }
        if (right == null) {
            return left;
        }
        if (left.getStatus() != EMetabaseStatusCode.OK) {
            return left;
        }
        if (right.getStatus() != EMetabaseStatusCode.OK) {
            return right;
        }
        // different nodes of metabase partitions can not return the same metrics
        return ResolveManyResponse.newBuilder()
                .setStatus(EMetabaseStatusCode.OK)
                .addAllMetrics(left.getMetricsList())
                .addAllMetrics(right.getMetricsList())
                .build();
    }

    static CreateOneResponse reduce(CreateOneResponse left, CreateOneResponse right) {
        if (left == null) {
            return right;
        }
        return left;
    }

    static ResolveOneResponse reduce(ResolveOneResponse left, ResolveOneResponse right) {
        if (left == null) {
            return right;
        }
        return left;
    }

    static CreateManyResponse reduce(CreateManyResponse left, CreateManyResponse right) {
        if (left == null) {
            return right;
        }
        if (right == null) {
            return left;
        }
        if (left.getStatus() != EMetabaseStatusCode.OK) {
            return left;
        }
        if (right.getStatus() != EMetabaseStatusCode.OK) {
            return right;
        }
        return CreateManyResponse.newBuilder()
                .setStatus(EMetabaseStatusCode.OK)
                .addAllMetrics(left.getMetricsList())
                .addAllMetrics(right.getMetricsList())
                .build();
    }

    static DeleteManyResponse reduce(DeleteManyResponse left, DeleteManyResponse right) {
        if (left == null) {
            return right;
        }
        if (right == null) {
            return left;
        }
        if (left.getStatus() != EMetabaseStatusCode.OK) {
            return left;
        }
        if (right.getStatus() != EMetabaseStatusCode.OK) {
            return right;
        }
        return DeleteManyResponse.newBuilder()
                .setStatus(EMetabaseStatusCode.OK)
                .addAllMetrics(left.getMetricsList())
                .addAllMetrics(right.getMetricsList())
                .build();
    }

    static BinaryOperator<FindResponse> reduce(FindRequest request) {
        final int metricLimit = request.getSliceOptions().getLimit() > 0 ? request.getSliceOptions().getLimit() : Integer.MAX_VALUE;
        return (left, right) -> {
            if (left == null) {
                return right;
            }
            if (right == null) {
                return left;
            }
            if (left.getStatus() != EMetabaseStatusCode.OK) {
                return left;
            }
            if (right.getStatus() != EMetabaseStatusCode.OK) {
                return right;
            }
            List<Metric> metrics = new ArrayList<>(Math.min(metricLimit,
                    left.getMetricsList().size() + right.getMetricsList().size()));
            if (left.getMetricsCount() <= metricLimit) {
                metrics.addAll(left.getMetricsList());
            } else {
                metrics.addAll(left.getMetricsList().subList(0, metricLimit));
            }
            if (metrics.size() + right.getMetricsCount() <= metricLimit) {
                metrics.addAll(right.getMetricsList());
            } else if (metrics.size() < metricLimit) {
                metrics.addAll(right.getMetricsList().subList(0, metricLimit - metrics.size()));
            }
            int totalCount = left.getTotalCount() + right.getTotalCount();
            return FindResponse.newBuilder()
                    .setStatus(EMetabaseStatusCode.OK)
                    .addAllMetrics(metrics)
                    .setTotalCount(totalCount)
                    .build();
        };
    }

    static BinaryOperator<MetricNamesResponse> reduce(MetricNamesRequest request) {
        final int metricLimit = request.getLimit() > 0 ? request.getLimit() : Integer.MAX_VALUE;
        return (left, right) -> {
            if (left == null) {
                return right;
            }
            if (right == null) {
                return left;
            }
            if (left.getStatus() != EMetabaseStatusCode.OK) {
                return left;
            }
            if (right.getStatus() != EMetabaseStatusCode.OK) {
                return right;
            }
            if (left.getNamesCount() < metricLimit) {
                var lStat = LabelStatsConverter.fromProto(left);
                var rStat = LabelStatsConverter.fromProto(right);
                lStat.combine(rStat);
                lStat.limit(metricLimit);
                return LabelStatsConverter.toProto(lStat)
                        .toBuilder()
                        .setStatus(EMetabaseStatusCode.OK).build();
            } else {
                return left;
            }
        };
    }

    static BinaryOperator<TLabelValuesResponse> reduce(TLabelValuesRequest request) {
        final int metricLimit = request.getLimit() > 0 ? request.getLimit() : Integer.MAX_VALUE;
        return (left, right) -> {
            if (left == null) {
                return right;
            }
            if (right == null) {
                return left;
            }
            if (left.getStatus() != EMetabaseStatusCode.OK) {
                return left;
            }
            if (right.getStatus() != EMetabaseStatusCode.OK) {
                return right;
            }
            var lStat = LabelStatsConverter.fromProto(left);
            var rStat = LabelStatsConverter.fromProto(right);
            lStat.combine(rStat);
            lStat.limit(metricLimit);
            return LabelStatsConverter.toProto(lStat)
                    .toBuilder()
                    .setStatus(EMetabaseStatusCode.OK).build();
        };
    }

    static TLabelNamesResponse reduce(TLabelNamesResponse left, TLabelNamesResponse right) {
        if (left == null) {
            return right;
        }
        if (right == null) {
            return left;
        }
        if (left.getStatus() != EMetabaseStatusCode.OK) {
            return left;
        }
        if (right.getStatus() != EMetabaseStatusCode.OK) {
            return right;
        }
        var names = Stream.concat(left.getNamesList().stream(), right.getNamesList().stream())
                .distinct()
                .collect(Collectors.toList());
        return TLabelNamesResponse.newBuilder()
                .setStatus(EMetabaseStatusCode.OK)
                .addAllNames(names)
                .build();
    }

    static TUniqueLabelsResponse reduce(TUniqueLabelsResponse left, TUniqueLabelsResponse right) {
        if (left == null) {
            return right;
        }
        if (right == null) {
            return left;
        }
        if (left.getStatus() != EMetabaseStatusCode.OK) {
            return left;
        }
        if (right.getStatus() != EMetabaseStatusCode.OK) {
            return right;
        }
        var names = Stream.concat(left.getLabelListsList().stream(), right.getLabelListsList().stream())
                .distinct()
                .collect(Collectors.toList());
        return TUniqueLabelsResponse.newBuilder()
                .setStatus(EMetabaseStatusCode.OK)
                .addAllLabelLists(names)
                .build();
    }

    static TResolveLogsResponse reduce(TResolveLogsResponse left, TResolveLogsResponse right) {
        if (left == null) {
            return right;
        }
        if (right == null) {
            return left;
        }
        if (left.getStatus() != EMetabaseStatusCode.OK) {
            return left;
        }
        if (right.getStatus() != EMetabaseStatusCode.OK) {
            return right;
        }
        var names = Stream.concat(left.getResolvedLogMetricsList().stream(), right.getResolvedLogMetricsList().stream())
                .distinct()
                .collect(Collectors.toList());
        return TResolveLogsResponse.newBuilder()
                .setStatus(EMetabaseStatusCode.OK)
                .addAllResolvedLogMetrics(names)
                .build();
    }


}
