package ru.yandex.travel.suburban.partners.aeroexpress;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeoutException;

import javax.validation.constraints.NotNull;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBElement;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import javax.xml.namespace.QName;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.soap.MessageFactory;
import javax.xml.soap.MimeHeaders;
import javax.xml.soap.SOAPBody;
import javax.xml.soap.SOAPConstants;
import javax.xml.soap.SOAPEnvelope;
import javax.xml.soap.SOAPException;
import javax.xml.soap.SOAPMessage;
import javax.xml.soap.SOAPPart;

import com.google.common.base.Strings;
import lombok.Data;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.asynchttpclient.RequestBuilder;
import org.asynchttpclient.Response;
import org.asynchttpclient.util.HttpConstants;
import org.w3c.dom.Document;

import ru.yandex.travel.commons.logging.AsyncHttpClientWrapper;
import ru.yandex.travel.suburban.exceptions.SuburbanException;
import ru.yandex.travel.suburban.exceptions.SuburbanRetryableException;
import ru.yandex.travel.suburban.partners.aeroexpress.exceptions.AeroexpressParseResponseException;
import ru.yandex.travel.suburban.partners.aeroexpress.exceptions.AeroexpressUnknownException;
import ru.yandex.travel.suburban.partners.aeroexpress.model.GetOrderTickets;
import ru.yandex.travel.suburban.partners.aeroexpress.model.GetOrderTicketsResponse;
import ru.yandex.travel.suburban.partners.aeroexpress.model.GetVersionInfo;
import ru.yandex.travel.suburban.partners.aeroexpress.model.GetVersionInfoResponse;
import ru.yandex.travel.suburban.partners.aeroexpress.model.PayOrder;
import ru.yandex.travel.suburban.partners.aeroexpress.model.PayOrderResponse;
import ru.yandex.travel.suburban.partners.aeroexpress.model.RequestTickets3;
import ru.yandex.travel.suburban.partners.aeroexpress.model.RequestTickets3Response;
import ru.yandex.travel.suburban.partners.aeroexpress.model.RequestTicketsResponse;
import ru.yandex.travel.suburban.partners.aeroexpress.model.TicketHistory2;
import ru.yandex.travel.suburban.partners.aeroexpress.model.VersionInfoResponse;


@Slf4j
public class DefaultAeroexpressClient implements AeroexpressClient {
    @NotNull
    private final Config conf;
    private final AsyncHttpClientWrapper asyncHttpClient;
    private final JAXBContext modelJC;
    private final MessageFactory messageFactory;

    private final String XML_NAMESPACE = "http://service.ticket.lynx.ru";

    @Data
    public static class Config {
        @NotNull
        private String baseUrl;

        @NotNull
        private Duration timeout;

        @NotNull
        private Duration ticketDownloadTimeout;

        @NotNull
        private SslConfig ssl;
    }

    @Data
    public static class SslConfig {
        @NotNull
        private boolean enabled;

        @NotNull
        private String clientCertificate;

        @NotNull
        private String privateKey;

        @NotNull
        private boolean getFromFiles;

        private String clientCertificateFile;

        private String privateKeyFile;
    }

    @SneakyThrows
    public DefaultAeroexpressClient(AsyncHttpClientWrapper asyncHttpClient, Config conf) {
        this.conf = conf;
        this.asyncHttpClient = asyncHttpClient;
        this.modelJC = JAXBContext.newInstance("ru.yandex.travel.suburban.partners.aeroexpress.model");
        this.messageFactory = MessageFactory.newInstance(SOAPConstants.SOAP_1_1_PROTOCOL);
    }

    @Override
    public VersionInfoResponse versionInfo() {
        var request = new GetVersionInfo();
        GetVersionInfoResponse response = call(
                Endpoint.VERSION_INFO, request, GetVersionInfo.class, GetVersionInfoResponse.class, conf.timeout);

        return response.getGetVersionInfoResponse();
    }

    @Override
    public RequestTicketsResponse requestTickets(RequestTickets3 request) {
        RequestTickets3Response response = call(
                Endpoint.REQUEST_TICKETS, request, RequestTickets3.class, RequestTickets3Response.class, conf.timeout);

        return response.getRequestTicketsResponse();
    }

    @Override
    public void payOrder(Long providerOrderId) {
        var payOrderRequest = new PayOrder();
        payOrderRequest.setOrderId(providerOrderId);

        call(Endpoint.PAY_ORDER, payOrderRequest, PayOrder.class, PayOrderResponse.class, conf.timeout);
    }

    @Override
    public TicketHistory2 getOrderTickets(Long providerOrderId) {
        var ticketRequest = new GetOrderTickets();
        ticketRequest.setOrderId(providerOrderId);

        GetOrderTicketsResponse response = call(
                Endpoint.GET_ORDER_TICKETS, ticketRequest, GetOrderTickets.class, GetOrderTicketsResponse.class, conf.timeout);

        return response.getOrderTicketsResponse().getTickets().getTicket().get(0);
    }

    private CompletableFuture<byte[]> downloadTicketAsync(String ticketPdfUrl) {
        int timeout = Math.toIntExact(conf.ticketDownloadTimeout.toMillis());
        RequestBuilder requestBuilder = new RequestBuilder()
                .setMethod(HttpConstants.Methods.GET)
                .setUrl(ticketPdfUrl)
                .setHeader("Content-Type", "application/pdf")
                .setReadTimeout(timeout)
                .setRequestTimeout(timeout);

        return asyncHttpClient.executeRequest(requestBuilder)
                .thenApply(r -> r.getResponseBodyAsBytes());
    }

    public byte[] downloadTicket(String ticketPdfUrl) {
        return sync(downloadTicketAsync(ticketPdfUrl));
    }

    private <Req, Rsp> Rsp call(
            Endpoint endpoint, Req request, Class<Req> requestClass, Class<Rsp> responseClass, Duration timeout) {

        var rootElement = new JAXBElement<>(new QName(XML_NAMESPACE, endpoint.getPath()), requestClass, request);
        return sync(sendRequest(
                endpoint, createSOAPRequestBody(rootElement), responseClass, timeout));
    }

    private <Req> String createSOAPRequestBody(Req request) {
        try {
            SOAPMessage soapMessage = messageFactory.createMessage();
            SOAPPart soapPart = soapMessage.getSOAPPart();

            SOAPEnvelope envelope = soapPart.getEnvelope();
            envelope.addNamespaceDeclaration("soapenv", SOAPConstants.URI_NS_SOAP_1_1_ENVELOPE);
            envelope.addNamespaceDeclaration("ns2", XML_NAMESPACE);

            SOAPBody soapBody = envelope.getBody();
            soapBody.addDocument(marshalToDocument(request));

            soapMessage.saveChanges();

            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            soapMessage.writeTo(baos);
            return baos.toString();
        }
        catch (SOAPException | IOException ex) {
            log.error("Error serializing request", ex);
            throw new AeroexpressUnknownException(ex);
        }
    }

    @SneakyThrows
    private <Req> Document marshalToDocument(Req request) {
        DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
        DocumentBuilder db = dbf.newDocumentBuilder();
        Document document = db.newDocument();
        Marshaller marshaller = modelJC.createMarshaller();
        marshaller.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, Boolean.TRUE);
        marshaller.marshal(request, document);
        return document;
    }

    private <Rsp> CompletableFuture<Rsp> sendRequest(Endpoint endpoint, String requestBody,
                                                     Class<Rsp> responseClass, Duration timeout) {
        RequestBuilder requestBuilder = new RequestBuilder()
                .setMethod("POST")
                .setUrl(conf.baseUrl + endpoint.getPath())
                .setBody(requestBody)
                .setHeader("Content-Type", SOAPConstants.SOAP_1_1_CONTENT_TYPE)
                .setHeader("SOAPAction", endpoint.getPath())
                .setReadTimeout(Math.toIntExact(timeout.toMillis()))
                .setRequestTimeout(Math.toIntExact(timeout.toMillis()));

        return asyncHttpClient.executeRequest(requestBuilder, endpoint.toString())
                .thenApply(response -> handleResponse(response, responseClass));
    }

    private <Rsp> Rsp handleResponse(Response response, Class<Rsp> responseClass) {
        int status_code = response.getStatusCode();
        if (status_code >= 200 && status_code < 300) {
            return parseResponse(response, responseClass);
        } else if (response.getStatusCode() >= 400) {
            throw new SuburbanRetryableException(String.format("Failed to get response: %s", status_code));
        } else {
            throw new AeroexpressUnknownException(
                    String.format("Do not know how to handle status %s from aeroexpress", status_code));
        }
    }

    protected <Rsp> Rsp parseResponse(Response response, Class<Rsp> responseClass) {
        try {
            if (Strings.isNullOrEmpty(response.getResponseBody()) || responseClass == null) {
                return null;
            }
            SOAPMessage message = messageFactory.createMessage(
                    new MimeHeaders(),
                    new ByteArrayInputStream(response.getResponseBody().getBytes(StandardCharsets.UTF_8))
            );
            Unmarshaller unmarshaller = modelJC.createUnmarshaller();
            return unmarshaller.unmarshal(message.getSOAPBody().extractContentAsDocument(), responseClass).getValue();
        }
        catch (IOException | SOAPException | JAXBException ex) {
            throw new AeroexpressParseResponseException(ex);
        }
    }

    public static <Rsp> Rsp sync(CompletionStage<Rsp> future) {
        try {
            return future.toCompletableFuture().join();
        } catch (CompletionException ex) {
            Throwable cause = ex.getCause();
            if (cause instanceof SuburbanException) {
                throw (SuburbanException) cause;
            } else if (cause instanceof TimeoutException) {
                throw new SuburbanRetryableException("call timeout", cause);
            } else if (cause instanceof IOException) {
                throw new SuburbanRetryableException("call io error", cause);
            } else {
                throw new AeroexpressUnknownException(cause);
            }
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }
}
