package ru.yandex.jns.client;

import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executors;
import java.util.function.Function;

import javax.annotation.ParametersAreNonnullByDefault;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Throwables;
import io.netty.handler.codec.http.HttpStatusClass;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.ResponseStatus;

import ru.yandex.jns.config.JnsClientOptions;
import ru.yandex.jns.dto.GetEscalationRequest;
import ru.yandex.jns.dto.JnsEscalationPolicy;
import ru.yandex.jns.dto.JnsListEscalationPolicy;
import ru.yandex.jns.dto.JnsSendMessageRequest;
import ru.yandex.jns.dto.JnsSendRecipient;
import ru.yandex.jns.dto.ListEscalationRequest;
import ru.yandex.jns.dto.StartEscalationRequest;
import ru.yandex.jns.dto.StopEscalationRequest;
import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.primitives.Rate;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.selfmon.counters.AsyncMetrics;
import ru.yandex.solomon.util.future.RetryConfig;

import static ru.yandex.misc.concurrent.CompletableFutures.safeCall;
import static ru.yandex.solomon.util.future.RetryCompletableFuture.runWithRetries;

/**
 * @author Alexey Trushkin
 */
@ParametersAreNonnullByDefault
public class HttpJnsClient implements JnsClient {

    private static final Logger logger = LoggerFactory.getLogger(HttpJnsClient.class);
    private static final Duration DEFAULT_REQUEST_TIMEOUT_MILLIS = Duration.ofSeconds(30);
    private static final RetryConfig RETRY_CONFIG = RetryConfig.DEFAULT
            .withExceptionFilter(HttpJnsClient::needToRetry)
            .withNumRetries(2)
            .withDelay(1_000)
            .withMaxDelay(60_000);

    private final HttpClient httpClient;
    private final JnsClientOptions opts;

    private final ConcurrentMap<String, InnerMetrics> metricsMap = new ConcurrentHashMap<>();
    private final ObjectMapper mapper = new ObjectMapper();

    public HttpJnsClient(JnsClientOptions opts) {
        this.httpClient = HttpClient.newBuilder()
                .version(HttpClient.Version.HTTP_1_1)
                .followRedirects(HttpClient.Redirect.NEVER)
                .connectTimeout(opts.getConnectionTimeout())
                .executor(opts.getExecutor())
                .build();
        this.opts = opts;
    }

    @Override
    public CompletableFuture<Void> sendMessage(JnsSendMessageRequest request) {
        return executeWithRetries(Void.class, request, "/api/messages/send");
    }

    @Override
    public CompletableFuture<JnsListEscalationPolicy> listEscalations(ListEscalationRequest request) {
        final String endpoint = "/api/escalations/list";
        return executeGetWithRetries(JnsListEscalationPolicy.class, "?project=" + request.project() + "&name_filter=" + request.nameFilter(), endpoint)
                .thenApply(jnsListEscalationPolicy -> {
                    if (!StringUtils.isEmpty(jnsListEscalationPolicy.error())) {
                        opts.getMetricRegistry().rate("jns.control_plane.request.error", Labels.of("endpoint", endpoint)).inc();
                    }
                    return jnsListEscalationPolicy;
                });
    }

    @Override
    public CompletableFuture<Optional<JnsEscalationPolicy>> getEscalation(GetEscalationRequest request) {
        return executeGetWithRetries(JnsEscalationPolicy.class, "?project=" + request.project() + "&name=" + request.name(), "/api/escalations/get")
                .handle((jnsEscalationPolicy, throwable) -> {
                   if (throwable != null && throwable.getCause() instanceof ClientError ce && ce.code == 404) {
                       return Optional.empty();
                   } else if (throwable != null) {
                       Throwables.propagate(throwable);
                   }
                   return Optional.ofNullable(jnsEscalationPolicy);
                });
    }

    @Override
    public CompletableFuture<Void> startEscalation(StartEscalationRequest request) {
        return executeWithRetries(Void.class, request, "/api/escalations/start");
    }

    @Override
    public CompletableFuture<Void> stopEscalation(StopEscalationRequest request) {
        return executeWithRetries(Void.class, request, "/api/escalations/stop");
    }

    private <T> CompletableFuture<T> request(Class<T> clz, Object requestModel, CompletableFuture<T> resultFuture, String endpoint) throws JsonProcessingException {
        UUID uuid = UUID.randomUUID();
        String uri = opts.getUrl() + endpoint;
        var json = mapper.writeValueAsString(requestModel);
        var request = HttpRequest.newBuilder(URI.create(uri))
                .POST(HttpRequest.BodyPublishers.ofString(json))
                .header(opts.getTokenHeaderProvider().get(), opts.getTokenProvider().get())
                .header("Content-Type", "application/json")
                .timeout(DEFAULT_REQUEST_TIMEOUT_MILLIS)
                .build();
        logger.info("Jns request:\n{}\n{}\n{}", uuid, request, json);
        return execute(clz, resultFuture, endpoint, uuid, request);
    }

    private <T> CompletableFuture<T> requestGet(Class<T> clz, String params, CompletableFuture<T> resultFuture, String endpoint) throws JsonProcessingException {
        UUID uuid = UUID.randomUUID();
        String uri = opts.getUrl() + endpoint + params;
        var request = HttpRequest.newBuilder(URI.create(uri))
                .GET()
                .header(opts.getTokenHeaderProvider().get(), opts.getTokenProvider().get())
                .header("Content-Type", "application/json")
                .timeout(DEFAULT_REQUEST_TIMEOUT_MILLIS)
                .build();
        logger.info("Jns request:\n{}\n{}\n{}", uuid, request, params);
        return execute(clz, resultFuture, endpoint, uuid, request);
    }

    private <T> CompletableFuture<T> execute(Class<T> clz, CompletableFuture<T> resultFuture, String endpoint, UUID uuid, HttpRequest request) {
        var future = wrapMetrics(endpoint, metrics -> httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())
                .thenApply(response -> {
                    handleException(metrics, response, uuid);
                    if (clz == Void.class) {
                        return null;
                    }
                    try {
                        return mapper.readValue(response.body(), clz);
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                }));
        CompletableFutures.whenComplete(future, resultFuture);
        return future;
    }

    private <R> CompletableFuture<R> wrapMetrics(String endpoint, Function<InnerMetrics, CompletableFuture<R>> func) {
        InnerMetrics metrics = metricsMap.computeIfAbsent(endpoint, s -> new InnerMetrics(endpoint));
        var future = func.apply(metrics);
        metrics.asyncMetrics.forFuture(future);
        return future;
    }

    private void handleException(InnerMetrics metrics, HttpResponse<String> response, UUID headerUuid) {
        metrics.status(response.statusCode());
        if (!HttpStatusClass.SUCCESS.contains(response.statusCode())) {
            logger.error("Jns response (status {}) ({}):\n{}, body {}",
                    response.statusCode(), headerUuid, response, response.body());
            if (response.statusCode() == HttpStatus.CONFLICT.value()) {
                // already sent
                return;
            }
            if (HttpStatusClass.CLIENT_ERROR.contains(response.statusCode()) && response.statusCode() != HttpStatus.TOO_MANY_REQUESTS.value()) {
                // wouldnt retry
                throw new ClientError(response.body(), response.statusCode());
            }
            //retry errors
            throw new Error("Jns response status " + response.statusCode(), response);
        }
    }

    private <T> CompletableFuture<T> executeWithRetries(Class<T> clzResult, Object requestModel, String endpoint) {
        CompletableFuture<T> resultFuture = new CompletableFuture<>();
        try {
            runWithRetries(() -> safeCall(() -> request(clzResult, requestModel, resultFuture, endpoint)), RETRY_CONFIG);
        } catch (Exception t) {
            resultFuture.completeExceptionally(t);
        }
        return resultFuture;
    }

    private <T> CompletableFuture<T> executeGetWithRetries(Class<T> clzResult, String params, String endpoint) {
        CompletableFuture<T> resultFuture = new CompletableFuture<>();
        try {
            runWithRetries(() -> safeCall(() -> requestGet(clzResult, params, resultFuture, endpoint)), RETRY_CONFIG);
        } catch (Exception t) {
            resultFuture.completeExceptionally(t);
        }
        return resultFuture;
    }

    private static boolean needToRetry(Throwable throwable) {
        return !(throwable.getCause() instanceof ClientError);
    }

    @Override
    public void close() {
    }

    private class InnerMetrics {
        private final ConcurrentMap<Integer, Rate> statusCodes = new ConcurrentHashMap<>();
        private final AsyncMetrics asyncMetrics;
        private final MetricRegistry registry;
        private final Labels commonLabels;

        private InnerMetrics(String endpoint) {
            registry = opts.getMetricRegistry();
            commonLabels = Labels.of("endpoint", endpoint);
            this.asyncMetrics = new AsyncMetrics(opts.getMetricRegistry(),
                    "jns.control_plane.request", commonLabels);
        }

        public void status(int statusCode) {
            Labels labels = commonLabels.toBuilder()
                    .add("code", Integer.toString(statusCode))
                    .build();
            statusCodes.computeIfAbsent(statusCode,
                    code -> registry.rate("jns.control_plane.request.status", labels)).inc();
        }
    }

    @ResponseStatus(HttpStatus.BAD_REQUEST)
    public static class ClientError extends RuntimeException {
        private final int code;

        public ClientError(String message, int code) {
            super(message);
            this.code = code;
        }
    }

    @ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR)
    public static class Error extends RuntimeException {
        private final HttpResponse<String> response;

        public Error(String message, HttpResponse<String> response) {
            super(message);
            this.response = response;
        }

        public HttpResponse<String> getResponse() {
            return response;
        }
    }

    public static void main(String[] args) {
        var alertName = "alerting-cluster-membership";
        var subAlertLabels = List.of(Map.of("string_value", "alerting"), Map.of("string_value", "sas-00"));
        var opts = JnsClientOptions.newBuilder()
                .setUrl("http://jns-prestable.yandex-team.ru")
                .setExecutor(Executors.newSingleThreadExecutor())
                .setTokenHeaderProvider(() -> HttpHeaders.AUTHORIZATION)
                .setTokenProvider(() -> "OAuth ")
                .build();
        var client = new HttpJnsClient(opts);

        var user = JnsSendRecipient.duty("solomon", "alextrushkin-test");
        // var user = JnsSendRecipient.login("alextrushkin");

        var recepient = new JnsSendRecipient();
        recepient.phone = new JnsSendRecipient.PhoneChannelOptions();
        recepient.phone.internal = List.of(user);

        var message = new JnsSendMessageRequest();
        message.projectWithTemplate = "alextrushkin_test2";
        message.template = "phone1";
        message.projectAbcService = "solomon";
        message.targetJnsProject = "alextrushkin_test2";
        message.params = Map.of(
                "alertName", Map.of("string_value", alertName),
                //  "subAlertLabels", Map.of("list_value", subAlertLabels),
                "monitoringProject", Map.of("string_value", message.targetJnsProject)
        );
        message.recipient = recepient;
       // client.sendMessage(message).join();

        var list = client.listEscalations(new ListEscalationRequest("alextrushkin_test", "")).join();
        list.policies();

        var res = client.getEscalation(new GetEscalationRequest("alextrushkin_test", "12")).join();
        res.toString();

        StartEscalationRequest r = new StartEscalationRequest();
        r.idempotencyKey = "1";
        r.template = "default";
        r.escalation = "test1";
        r.params = Map.of(
                "title", Map.of("string_value", "this is title"),
                "message", Map.of("string_value", "this is message")
        );
        client.startEscalation(r).join();

        StopEscalationRequest r2 = new StopEscalationRequest(new StopEscalationRequest.Key("alextrushkin_test", "1"),
                new StopEscalationRequest.Terminator("Source", "terminated", ""));
        client.stopEscalation(r2).join();
    }
}
