package ru.yandex.ljinx;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.logging.Level;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;

import org.apache.http.HttpEntity;
import org.apache.http.HttpException;
import org.apache.http.HttpHost;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.entity.ContentType;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.xml.sax.SAXException;

import ru.yandex.charset.StreamEncoder;
import ru.yandex.client.tvm2.Tvm2ServiceContextRenewalTask;
import ru.yandex.client.tvm2.Tvm2TicketRenewalTask;
import ru.yandex.http.proxy.AbstractProxySessionCallback;
import ru.yandex.http.util.HeaderUtils;
import ru.yandex.http.util.YandexHeaders;
import ru.yandex.http.util.nio.BasicAsyncRequestProducerGenerator;
import ru.yandex.http.util.nio.BasicAsyncResponseConsumerFactory;
import ru.yandex.http.util.nio.BasicAsyncResponseProducerGenerator;
import ru.yandex.http.util.nio.NByteArrayEntityGenerator;
import ru.yandex.http.util.nio.client.AsyncClient;
import ru.yandex.http.util.request.function.RequestFunctionValue;
import ru.yandex.io.DecodableByteArrayOutputStream;
import ru.yandex.json.dom.JsonMap;
import ru.yandex.json.dom.TypesafeValueContentHandler;
import ru.yandex.json.parser.JsonException;
import ru.yandex.json.writer.DollarJsonWriter;
import ru.yandex.json.writer.JsonWriter;
import ru.yandex.util.timesource.TimeSource;

public class MdsCacheStorage extends LuceneCacheStorage {
    private static final String TVM2 = "tvm2";
    private static final String HTTP_BODY = "http_body";
    private static final String HTTP_HEADERS = "http_headers";
    private static final String HTTP_STATUS = "http_status";
    private static final String HTTP_EXPIRE_TIMESTAMP = "http_expire_timestamp";
    private static final long MILLIS = 1000L;

    private final ConcurrentHashMap<
        RequestFunctionValue,
        BasicAsyncResponseProducerGenerator> tempStoreMap =
            new ConcurrentHashMap<>();
    private final AsyncClient mdsWriterClient;
    private final AsyncClient mdsReaderClient;
    private final HttpHost backendWriterHost;
    private final HttpHost backendReaderHost;
    private final Tvm2ServiceContextRenewalTask serviceContextRenewalTask;
    private final Tvm2TicketRenewalTask tvm2RenewalTask;
    private final String mdsTvmClientId;
    private final String mdsNamespace;
    private final long mdsTTL;
    private final boolean loadHitsToMemory;

    public MdsCacheStorage(
        final String storageName,
        final Ljinx ljinx,
        final MdsCacheStorageConfig config)
            throws CacheStorageException
    {
        super(storageName, ljinx, config);
        mdsWriterClient =
            ljinx.client(
                "MDS-to-Write-" + storageName,
                config.backendWriterConfig());
        mdsReaderClient =
            ljinx.client(
                "MDS-to-Read-" + storageName,
                config.backendReaderConfig());
        backendWriterHost = config.backendWriterConfig().host();
        backendReaderHost = config.backendReaderConfig().host();
        mdsTvmClientId = config.mdsTvmClientId();
        mdsNamespace = config.mdsNamespace();
        mdsTTL = config.mdsDefaultTTL();
        ljinx.logger().info("MdsCacheStorage: mdsTTL=" + mdsTTL);
        loadHitsToMemory = config.loadHitsToMemory();
        try {
            serviceContextRenewalTask = new Tvm2ServiceContextRenewalTask(
                ljinx.logger().addPrefix(TVM2),
                config.tvm2ServiceConfig(),
                config.dnsConfig());
            tvm2RenewalTask = new Tvm2TicketRenewalTask(
                ljinx.logger().addPrefix(TVM2),
                serviceContextRenewalTask,
                config.tvm2ClientConfig());
            ljinx.logger().log(Level.SEVERE,"TVM2 ticket: " + tvm2RenewalTask.ticket());
        } catch (HttpException | IOException | JsonException | URISyntaxException e) {
            //e.printStackTrace();
            ljinx.logger().log(
                Level.SEVERE,
                "Error occured while creating TVM2 renewal task for MDS cache storage plugin",
                e);
            throw new CacheStorageException(
                config.type(),
                storageName,
                "Error occured while creating of MDS cache storage plugin: " + e);
        }
    }

    @Override
    protected boolean loadHitsToMemory() {
        //disable lucene's cache loadHitsToMemory
        //cause' it will load stid into memory
        //where full blob is expected
        return false;
    }

    @Override
    public void start() throws IOException {
        tvm2RenewalTask.start();
        super.start();
    }

    @Override
    public void close() throws IOException {
        tvm2RenewalTask.cancel();
        super.close();
    }

    public String mdsTvm2Ticket() {
        return tvm2RenewalTask.ticket(mdsTvmClientId);
    }

    public final Tvm2ServiceContextRenewalTask serviceContextRenewalTask() {
        return serviceContextRenewalTask;
    }

    @Override
    protected long ttl(final ProxyPassSession session) {
        Long ttl = session.ttl();
        if (ttl == null) {
            ttl = mdsTTL;
        }
        if (ttl == 0) {
            ttl = defaultTTL;
        }
        return ttl;
    }

    @Override
    public void get(
        final ProxyPassSession session,
        final FutureCallback<CacheResponse> callback)
    {
        super.get(session, new GetResultCallback(session, callback));
    }

    @Override
    @SuppressWarnings("FutureReturnValueIgnored")
    public void put(
        final ProxyPassSession session,
        final BasicAsyncResponseProducerGenerator value,
        final FutureCallback<Void> callback)
    {
        final boolean mdsStore;
        long ttl = ttl(session);
        if (ttl <= minimalStoreTTL) {
            session.logger().info(
                "TTL too low for MDS store: " + ttl
                + " data will be stored only in memory");
            mdsStore = false;
        } else {
            mdsStore = true;
        }
        if (mdsStore) {
            tempStoreMap.put(session.cacheKey(), value);
        } else {
            super.put(session, value, null, callback);
            return;
        }
        final AsyncClient client =
            mdsWriterClient.adjust(session.session().context());
        Supplier<? extends HttpClientContext> contextGenerator =
            session.session().listener().createContextGeneratorFor(client);

        DecodableByteArrayOutputStream out =
            new DecodableByteArrayOutputStream();
        try (JsonWriter writer =
                new DollarJsonWriter(
                    new StreamEncoder(out, client.requestCharset())))
        {
            HttpResponse httpMessage = value.get().generateResponse();
            HttpEntity entity = httpMessage.getEntity();
            writer.startObject();
            writer.key(HTTP_STATUS);
            writer.value(httpMessage.getStatusLine().getStatusCode());
            writer.key(HTTP_BODY);
            writeEntity(entity, writer);
            writer.key(HTTP_HEADERS);
            writeHeaders(httpMessage.headerIterator(), entity, writer);
            writer.key(HTTP_EXPIRE_TIMESTAMP);
            writer.value(
                TimeUnit.MILLISECONDS.toSeconds(
                    TimeSource.INSTANCE.currentTimeMillis() + ttl));
            writer.endObject();
        } catch (IOException e) {
            session.logger().log(
                Level.SEVERE,
                "Error occured while saving object to MDS",
                e);
            tempStoreMap.remove(session.cacheKey(), value);
            callback.failed(e);
            return;
        }
        String uri = "/upload-" + mdsNamespace;
        if (ttl > MILLIS) {
            uri += "/?expire=" + (ttl / MILLIS) + "s";
        }
        BasicAsyncRequestProducerGenerator producerGenerator =
            new BasicAsyncRequestProducerGenerator(
                uri,
                new NByteArrayEntityGenerator(
                    out,
                    ContentType.APPLICATION_JSON.withCharset(client.requestCharset())));
        producerGenerator.addHeader(
            HeaderUtils.createHeader(
                YandexHeaders.X_YA_SERVICE_TICKET,
                mdsTvm2Ticket()));
        client.execute(
            backendWriterHost,
            producerGenerator,
            BasicAsyncResponseConsumerFactory.ANY_GOOD,
            contextGenerator,
            new MdsPutResultCallback(session, value, callback));
    }

    private void cachePut(
        final ProxyPassSession session,
        final BasicAsyncResponseProducerGenerator fullValue,
        final BasicAsyncResponseProducerGenerator stidValue,
        final FutureCallback<Void> callback)
    {
        super.put(session, fullValue, stidValue, callback);
    }

    private void cachePut(
        final RequestFunctionValue key,
        final BasicAsyncResponseProducerGenerator value,
        final long expireTimestamp)
    {
        super.put(key, value, expireTimestamp);
    }

    @Override
    public Map<String, Object> status(final boolean verbose) {
        Map<String, Object> status = super.status(verbose);
        status.put("mdsclient-to-read", mdsReaderClient.status(verbose));
        status.put("mdsclient-to-write", mdsWriterClient.status(verbose));
        return status;
    }

    private class GetResultCallback
        extends AbstractProxySessionCallback<CacheResponse>
    {
        private final ProxyPassSession ppSession;
        private final FutureCallback<CacheResponse> callback;

        GetResultCallback(
            final ProxyPassSession ppSession,
            final FutureCallback<CacheResponse> callback)
        {
            super(ppSession.session());
            this.ppSession = ppSession;
            this.callback = callback;
        }

        @Override
        @SuppressWarnings("FutureReturnValueIgnored")
        public void completed(final CacheResponse response) {
            BasicAsyncResponseProducerGenerator responseProducer = null;
            if (response != null) {
                responseProducer = response.response();
                if (response.cacheType() == CacheResponse.CacheType.MEMORY) {
                    ppSession.logger().severe(
                        "GetResultCallback: got full blob entity from memory");
                    callback.completed(response);
                    return;
                }
            }
            if (responseProducer == null) {
                responseProducer = tempStoreMap.get(ppSession.cacheKey());
                if (responseProducer != null) {
                    callback.completed(
                        new CacheResponse(
                            responseProducer,
                            CacheResponse.CacheType.MEMORY,
                            TimeSource.INSTANCE.currentTimeMillis()
                                + ttl(ppSession)));
                    ppSession.logger().info(
                        "GetResultCallback: got full blob entity from "
                            + "outstanding cache");
                    return;
                }
            }
            try {
                String stid = "";
                if (responseProducer == null) {
                    if (ppSession.logger() != null) {
                        ppSession.logger().severe(
                            "GetResultCallback: responseProducer is null!");
                    }
                } else {
                    HttpEntity entity =
                        responseProducer.entityGenerator().get();
                    ByteArrayOutputStream out = new ByteArrayOutputStream();
                    entity.writeTo(out);
                    stid = out.toString();
                    ppSession.logger().severe("Found STID: " + stid);
                }
                if (stid.isEmpty()) {
                    ppSession.logger().severe(
                        "GetResultCallback: STID is empty!");
                } else {
                    final StringBuilder request = new StringBuilder();
                    request.append("/get-");
                    request.append(mdsNamespace);
                    request.append("/");
                    request.append(stid);

                    final AsyncClient client =
                        mdsReaderClient.adjust(ppSession.session().context());
                    Supplier<? extends HttpClientContext> contextGenerator =
                        ppSession.session().listener()
                        .createContextGeneratorFor(client);
                    BasicAsyncRequestProducerGenerator producerGenerator =
                        new BasicAsyncRequestProducerGenerator(
                            new String(request));
                    producerGenerator.addHeader(
                        HeaderUtils.createHeader(
                            YandexHeaders.X_YA_SERVICE_TICKET,
                            mdsTvm2Ticket()));
                    client.execute(
                        backendReaderHost,
                        producerGenerator,
                        BasicAsyncResponseConsumerFactory.ANY_GOOD,
                        contextGenerator,
                        new MdsGetResultCallback(
                            ppSession,
                            callback,
                            response.expireTimestamp()));
                    return;
                }
            } catch (IOException e) {
                ppSession.logger().log(
                    Level.SEVERE,
                    "Error occured while receiving of stid for MDS object",
                    e);
            } catch (RuntimeException e) {
                ppSession.logger().log(
                    Level.SEVERE,
                    "Error occured while trying to receive MDS object",
                    e);
            }
            ppSession.logger().info(
                "MDS cache miss because of failed request for stid");
            callback.completed(null);
        }

        @Override
        public void failed(final Exception e) {
            ppSession.logger().log(
                Level.SEVERE,
                "Lucene cache request for stid failed",
                e);
            callback.completed(null);
        }
    }

    private class MdsGetResultCallback
        extends AbstractProxySessionCallback<HttpResponse>
    {
        private final ProxyPassSession ppSession;
        private final FutureCallback<CacheResponse> callback;
        private final long expireTimestamp;

        MdsGetResultCallback(
            final ProxyPassSession ppSession,
            final FutureCallback<CacheResponse> callback,
            final long expireTimestamp)
        {
            super(ppSession.session());
            this.ppSession = ppSession;
            this.callback = callback;
            this.expireTimestamp = expireTimestamp;
        }

        @Override
        public void completed(final HttpResponse response) {
            try {
                HttpEntity entity = response.getEntity();
                ByteArrayOutputStream out = new ByteArrayOutputStream();
                entity.writeTo(out);
                HashMap<String, String> headers = new HashMap<>();
                JsonMap json =
                    TypesafeValueContentHandler.parse(out.toString()).asMap();
                for (String key : json.keySet()) {
                    headers.put(key, json.get(key).asStringOrNull());
                }
                CacheResponse cacheResponse =
                    new CacheResponse(
                        prepareResponse(headers),
                        CacheResponse.CacheType.MDS,
                        expireTimestamp);
                if (loadHitsToMemory) {
                    cachePut(
                        ppSession.cacheKey(),
                        cacheResponse.response(),
                        expireTimestamp);
                }
                callback.completed(cacheResponse);
                return;
            } catch (IOException | JsonException e) {
                ppSession.logger().log(
                    Level.SEVERE,
                    "Error occured while receiving of MDS object",
                    e);
            }
            callback.completed(null);
        }

        @Override
        public void failed(final Exception e) {
            ppSession.logger().log(Level.SEVERE, "MDS request failed", e);
            callback.completed(null);
        }
    }

    private class MdsPutResultCallback
        extends AbstractProxySessionCallback<HttpResponse>
    {
        private static final String XML_ERROR =
            "Error occured while parsing of received MDS answer's XML";

        private final ProxyPassSession ppSession;
        private final BasicAsyncResponseProducerGenerator value;
        private final FutureCallback<Void> callback;

        MdsPutResultCallback(
            final ProxyPassSession ppSession,
            final BasicAsyncResponseProducerGenerator value,
            final FutureCallback<Void> callback)
        {
            super(ppSession.session());
            this.ppSession = ppSession;
            this.callback = callback;
            this.value = value;
        }

        @Override
        public void completed(final HttpResponse response) {
            HttpEntity entity = response.getEntity();
            try {
                DocumentBuilder documentBuilder =
                    DocumentBuilderFactory.newInstance().newDocumentBuilder();
                Document document = documentBuilder.parse(entity.getContent());
                Element root = document.getDocumentElement();
                String stid = root.getAttribute("key");
                final BasicAsyncResponseProducerGenerator responseProducer =
                    new BasicAsyncResponseProducerGenerator(
                        HttpStatus.SC_OK,
                        stid);
                ppSession.logger().severe("MDS put: obtained STID - " + stid);
                cachePut(
                    ppSession,
                    value,
                    responseProducer,
                    callback);
                tempStoreMap.remove(ppSession.cacheKey(), value);
                return;
            } catch (ParserConfigurationException | SAXException | IOException | RuntimeException e) {
                ppSession.logger().log(Level.SEVERE, XML_ERROR, e);
            }
            ppSession.logger().severe("MDS put: STID not found!");
            tempStoreMap.remove(ppSession.cacheKey(), value);
            callback.completed(null);
        }

        @Override
        public void failed(final Exception e) {
            ppSession.logger().log(Level.SEVERE, "MDS put failed!", e);
            tempStoreMap.remove(ppSession.cacheKey(), value);
            callback.failed(e);
        }
    }
}
