package ru.yandex.travel.commons.logging;

import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import javax.annotation.Nullable;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.Metrics;
import io.opentracing.Scope;
import io.opentracing.Span;
import io.opentracing.Tracer;
import io.opentracing.propagation.Format;
import io.opentracing.tag.Tags;
import lombok.extern.slf4j.Slf4j;
import org.asynchttpclient.AsyncHttpClient;
import org.asynchttpclient.AsyncHttpClientConfig;
import org.asynchttpclient.ListenableFuture;
import org.asynchttpclient.Param;
import org.asynchttpclient.Request;
import org.asynchttpclient.RequestBuilder;
import org.asynchttpclient.Response;
import org.slf4j.Logger;
import org.slf4j.MDC;

import ru.yandex.travel.commons.logging.http.HttpStatus;
import ru.yandex.travel.commons.logging.http.Meters;
import ru.yandex.travel.commons.logging.masking.LogAwareBodyGenerator;
import ru.yandex.travel.commons.logging.masking.LogAwareRequestBuilder;
import ru.yandex.travel.commons.logging.masking.LogMaskingConverter;

/**
 * Wrapper around {@link AsyncHttpClient} to add logging and metrics reporting.
 */
@Slf4j
public class AsyncHttpClientWrapper implements IAsyncHttpClientWrapper {
    public static final String YA_REQUEST_ID = "x-ya-call-id";
    public static final String MDC_OFFER_TOKEN = "OfferToken";
    private final String localHostName;
    protected final AsyncHttpClient client;
    private final Logger logger;
    private final ObjectMapper logObjectMapper;
    private final String destinationName;
    private final Tracer tracer;
    //metrics
    private final ConcurrentMap<String, Meters> metersForMethod;
    private final AtomicInteger inFlightRequests = new AtomicInteger();

    public AsyncHttpClientWrapper(AsyncHttpClient client, Logger logger, String destinationName, Tracer tracer) {
        this(client, logger, destinationName, tracer, null);
    }

    public AsyncHttpClientWrapper(AsyncHttpClient client, Logger logger, String destinationName,
                                  Tracer tracer, Set<String> wellKnownMethods) {
        this.destinationName = destinationName;
        this.localHostName = getLocalHostName();
        this.client = client;
        this.logger = logger;
        this.tracer = tracer;
        this.logObjectMapper = LogMaskingConverter.getObjectMapperForLogEvents();

        // metrics
        metersForMethod = new ConcurrentHashMap<>();
        if (wellKnownMethods != null) {
            for (String method : wellKnownMethods) {
                metersForMethod.put(method, new Meters(destinationName, method));
            }
        }
        metersForMethod.putIfAbsent(DEFAULT_METHOD, new Meters(destinationName, DEFAULT_METHOD));
        Gauge.builder("http.client.inFlightRequests", () -> this.inFlightRequests)
                .register(Metrics.globalRegistry);
    }

    private static Counter createCounterForErrorWithException(String exceptionClassName, String method,
                                                              String destination) {
        return Counter.builder("http.client.errors.exceptions")
                .tag("destination", destination)
                .tag("method", method)
                .tag("exception", exceptionClassName)
                .register(Metrics.globalRegistry);
    }

    @Deprecated
    public AsyncHttpClient getClient() {
        return this.client;
    }

    @Deprecated
    public Logger getLogger() {
        return this.logger;
    }

    @Override
    public CompletableFuture<Response> executeRequest(RequestBuilder requestBuilder, String method, String requestId) {
        Preconditions.checkArgument(!Strings.isNullOrEmpty(method),
                "Provided method must be non-null and non-empty string");
        Meters meters = metersForMethod.computeIfAbsent(method, k -> new Meters(destinationName, method));
        Map<String, String> mdc = MDC.getCopyOfContextMap();

        Span httpSpan = tracer.buildSpan(this.destinationName + " " + method).start();
        inFlightRequests.incrementAndGet();
        try (Scope scope = tracer.scopeManager().activate(httpSpan)) {
            Request request = requestBuilder.build();
            Tags.HTTP_URL.set(httpSpan, request.getUrl());
            Tags.HTTP_METHOD.set(httpSpan, request.getMethod());

            String yaCallId = setRequestId(requestBuilder, requestId, request);

            logRequest(method, mdc, request, yaCallId);

            long requestSentTime = System.currentTimeMillis();

            tracer.inject(tracer.activeSpan().context(), Format.Builtin.HTTP_HEADERS,
                    new RequestBuilderInjectAdapter(requestBuilder));
            ListenableFuture<Response> future = client.executeRequest(requestBuilder);

            return future.toCompletableFuture().whenComplete((r, t) -> {
                try (Scope asyncScope = tracer.scopeManager().activate(httpSpan)) {
                    long responseTime = (System.currentTimeMillis() - requestSentTime);
                    logResponse(method, meters, mdc, httpSpan, yaCallId, responseTime, r, t);
                } finally {
                    inFlightRequests.decrementAndGet();
                    httpSpan.finish();
                }
            });
        }
    }

    @Override
    public <T> CompletableFuture<T> executeRequest(LogAwareRequestBuilder requestBuilder,
                                                   String method,
                                                   @Nullable String requestId,
                                                   ResponseParser<T> responseParser) {
        Preconditions.checkArgument(!Strings.isNullOrEmpty(method),
                "Provided method must be non-null and non-empty string");
        Meters meters = metersForMethod.computeIfAbsent(method, k -> new Meters(destinationName, method));
        Map<String, String> mdc = MDC.getCopyOfContextMap();

        Span httpSpan = tracer.buildSpan(this.destinationName + " " + method).start();
        inFlightRequests.incrementAndGet();
        try (Scope scope = tracer.scopeManager().activate(httpSpan)) {
            Request request = requestBuilder.build();
            Tags.HTTP_URL.set(httpSpan, request.getUrl());
            Tags.HTTP_METHOD.set(httpSpan, request.getMethod());

            String yaCallId = setRequestId(requestBuilder, requestId, request);

            logRequest(method, mdc, request, yaCallId);

            long requestSentTime = System.currentTimeMillis();

            tracer.inject(tracer.activeSpan().context(), Format.Builtin.HTTP_HEADERS,
                    new RequestBuilderInjectAdapter(requestBuilder));
            ListenableFuture<Response> future = client.executeRequest(requestBuilder);

            return future.toCompletableFuture().handle((r, t) -> {
                try (Scope asyncScope = tracer.scopeManager().activate(httpSpan)) {
                    long responseTime = (System.currentTimeMillis() - requestSentTime);
                    logResponse(method, meters, mdc, httpSpan, yaCallId, responseTime, r, t);
                    Optional<T> responseObj;
                    if (r != null) {
                        responseObj = Optional.ofNullable(responseParser.parse(r));
                    } else {
                        responseObj = Optional.empty();
                    }

                    rethrowExceptionIfPresent(t);
                    return responseObj.orElse(null);
                } finally {
                    inFlightRequests.decrementAndGet();
                    httpSpan.finish();
                }
            });
        }
    }

    private void rethrowExceptionIfPresent(Throwable t) {
        if (t != null) {
            if (t instanceof RuntimeException) {
                throw (RuntimeException) t;
            }
            if (t instanceof Error) {
                throw (Error) t;
            }
            throw new CompletionException(t);
        }
    }

    private void logResponse(String method, Meters meters, Map<String, String> mdc, Span httpSpan, String yaCallId,
                             long responseTime, Response r, Throwable t) {
        LogEventResponse logEventResponse = new LogEventResponse();
        logEventResponse.setTimestamp(System.currentTimeMillis());
        logEventResponse.setDateTime(LocalDateTime.now());
        logEventResponse.setDestinationName(this.destinationName);
        logEventResponse.setDestinationMethod(method);
        logEventResponse.setCallId(yaCallId);
        logEventResponse.setFqdn(localHostName);
        logEventResponse.setResponseTime(responseTime);
        logEventResponse.setMdc(mdc);

        HttpStatus status = HttpStatus.ERROR;
        if (t != null) {
            Tags.ERROR.set(httpSpan, true);
            if (t.getClass().equals(TimeoutException.class)) {
                status = HttpStatus.TIMEOUT;
            }
            logEventResponse.setEventKind(LogEventType.EXCEPTION.getValue());
            logEventResponse.setException(t);
            String exceptionClassName = t.getClass().getCanonicalName();
            logEventResponse.setExceptionClass(exceptionClassName);
            meters.getErrorWithExceptionCounters().computeIfAbsent(exceptionClassName,
                    k -> createCounterForErrorWithException(exceptionClassName, method, destinationName)).increment();
            meters.getErrorCounter().increment();
        }
        if (r != null) {
            int statusCode = r.getStatusCode();
            Tags.HTTP_STATUS.set(httpSpan, statusCode);
            if (100 <= statusCode && statusCode < 200) {
                status = HttpStatus.CODE_1XX;
            } else if (200 <= statusCode && statusCode < 300) {
                status = HttpStatus.CODE_2XX;
            } else if (300 <= statusCode && statusCode < 400) {
                status = HttpStatus.CODE_3XX;
            } else if (400 <= statusCode && statusCode < 500) {
                status = HttpStatus.CODE_4XX;
            } else if (500 <= statusCode) {
                status = HttpStatus.CODE_5XX;
            } else {
                status = HttpStatus.ERROR;
            }
            logEventResponse.setResponseCode(r.getStatusCode());
            if (r.getResponseBodyAsBytes() != null) {
                logEventResponse.setResponseSize(r.getResponseBodyAsBytes().length);
                meters.getResponseByteCounters().get(status).increment(r.getResponseBodyAsBytes().length);
            }
            Map<String, String> respHeaders = new HashMap<>();
            r.getHeaders().iteratorAsString().forEachRemaining(h -> respHeaders.put(h.getKey(),
                    h.getValue()));
            logEventResponse.setResponseHeaders(respHeaders);
            Optional.ofNullable(r.getResponseBody())
                    .map(this::toJsonOrStringForLogs)
                    .ifPresent(logEventResponse::setResponseBody);
        }

        try (var ignored = NestedMdc.empty()) {
            if (mdc != null) {
                MDC.setContextMap(mdc);
            }
            logger.info(LoggingMarkers.HTTP_REQUEST_RESPONSE_MARKER,
                    logObjectMapper.writeValueAsString(logEventResponse));
        } catch (JsonProcessingException e) {
            log.error("Unable to log response", e);
        }

        meters.getCallCounters().get(status).increment();
        if (status == HttpStatus.CODE_2XX) {
            // Other codes are also interesting, but solomon sensors quota is very limited
            // So we do not track timings of non-2xx requests
            meters.getCallTimer2xx().record(responseTime, TimeUnit.MILLISECONDS);
        }
    }

    private void logRequest(String method, Map<String, String> mdc, Request request, String yaCallId) {
        LogEventRequest logEventRequest = new LogEventRequest();
        logEventRequest.setCallId(yaCallId);
        logEventRequest.setTimestamp(System.currentTimeMillis());
        logEventRequest.setDateTime(LocalDateTime.now());
        logEventRequest.setDestinationName(this.destinationName);
        logEventRequest.setDestinationMethod(method);
        logEventRequest.setFqdn(localHostName);
        Map<String, String> headers = new HashMap<>();
        request.getHeaders().iteratorAsString().forEachRemaining(h -> headers.put(h.getKey(), h.getValue()));
        logEventRequest.setRequestHeaders(headers);
        if (request.getStringData() != null) {
            logEventRequest.setRequestBody(toJsonOrStringForLogs(request.getStringData()));
        } else if (request.getFormParams() != null && !request.getFormParams().isEmpty()) {
            logEventRequest.setRequestBody(request.getFormParams().stream().map(param -> {
                if (param instanceof LogAwareRequestBuilder.FormParam) {
                    return new Param(param.getName(), ((LogAwareRequestBuilder.FormParam) param).getValueForLogs());
                }
                return param;
            }).collect(Collectors.toList()));
        } else if (request.getBodyGenerator() != null && request.getBodyGenerator() instanceof LogAwareBodyGenerator) {
            logEventRequest.setRequestBody(((LogAwareBodyGenerator) request.getBodyGenerator()).getContentForLogs());
        }
        logEventRequest.setRequestMethod(request.getMethod());
        if (request.getByteData() != null) {
            logEventRequest.setRequestSize(request.getByteData().length);
        }
        logEventRequest.setRequestUrl(request.getUrl());
        logEventRequest.setMdc(mdc);

        try {
            logger.info(LoggingMarkers.HTTP_REQUEST_RESPONSE_MARKER,
                    logObjectMapper.writeValueAsString(logEventRequest));
        } catch (JsonProcessingException e) {
            log.error("Unable to log request", e);
        }
    }

    private String setRequestId(RequestBuilder requestBuilder, String requestId, Request request) {
        String yaCallId;
        if (requestId != null) {
            yaCallId = requestId;
            requestBuilder.addHeader(YA_REQUEST_ID, requestId);
        } else {
            if (request.getHeaders().contains(YA_REQUEST_ID)) {
                yaCallId = request.getHeaders().get(YA_REQUEST_ID);
            } else {
                yaCallId = UUID.randomUUID().toString();
                requestBuilder.addHeader(YA_REQUEST_ID, yaCallId);
            }
        }
        return yaCallId;
    }

    public AsyncHttpClientConfig getConfig() {
        return client.getConfig();
    }

    private String getLocalHostName() {
        try {
            return InetAddress.getLocalHost().getCanonicalHostName();
        } catch (UnknownHostException e) {
            throw new RuntimeException("Unknown host", e);
        }
    }

    private Object toJsonOrStringForLogs(String data) {
        try {
            return logObjectMapper.readTree(data);
        } catch (IOException e) {
            return data;
        }
    }
}
