package ru.yandex.infra.stage.rest;

import java.io.PrintWriter;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
import java.util.function.Function;

import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.google.common.collect.Maps;
import com.google.protobuf.Message;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.infra.stage.cache.Cache;
import ru.yandex.infra.stage.cache.CacheSet;
import ru.yandex.infra.stage.cache.CacheStorage;
import ru.yandex.infra.stage.cache.CacheStorageFactory;
import ru.yandex.infra.stage.cache.CachedObjectType;

public class CacheServlet extends HttpServlet {
    private static final Logger LOG = LoggerFactory.getLogger(CacheServlet.class);

    private final RequestType requestType;
    private final CacheSet caches;
    private final Function<String, CacheStorageFactory> storageFactorySupplier;

    public enum RequestType {
        EXPORT,
        REMOVE
    }

    public CacheServlet(RequestType requestType, CacheSet caches) {
        this(requestType, caches, null);
    }

    public CacheServlet(RequestType requestType, CacheSet caches, Function<String, CacheStorageFactory> storageFactorySupplier) {
        this.requestType = requestType;
        this.caches = caches;
        this.storageFactorySupplier = storageFactorySupplier;
    }

    private void export(PrintWriter writer, String destinationStorageType) {
        CacheStorageFactory destinationStorageFactory = storageFactorySupplier.apply(destinationStorageType);
        for (CachedObjectType<?,?> objectType : CachedObjectType.ALL.values()) {
            try {
                export(writer, destinationStorageFactory, objectType).get();
            } catch (InterruptedException|ExecutionException e) {
                String msg = String.format("Failed to export %s cache records: %s", objectType.getName(), e);
                LOG.error(msg);
                writer.println(msg);
            }
        }
    }

    private <TValue, TProtoValue extends Message>
        CompletableFuture<?> export(PrintWriter writer,
                                    CacheStorageFactory destinationStorageFactory,
                                    CachedObjectType<TValue, TProtoValue> cachedObjectType) {

        var values = caches.get(cachedObjectType).getAll();
        var protoValues = Maps.transformValues(values, cachedObjectType.getToProto()::apply);

        writer.println(String.format("Exporting %d %s cache records",
                values.size(),
                cachedObjectType.getName()));

        CacheStorage<TProtoValue> exportStorage = destinationStorageFactory.createStorage(cachedObjectType);
        return exportStorage.init()
            .thenCompose(x -> exportStorage.write(protoValues))
            .thenCompose(x -> exportStorage.flush());
    }

    private void remove(String objectType, String key) {
        CachedObjectType<?, ?> type = CachedObjectType.ALL.get(objectType);
        if (type == null) {
            throw new RuntimeException("Wrong type: " + objectType);
        }

        Cache<?> cache = caches.get(type);
        cache.remove(key);
        LOG.info("Removed {} from {} cache records", key, objectType);
    }

    private void executeRequest(HttpServletRequest req, HttpServletResponse resp,
                                Consumer<PrintWriter> action) {
        PrintWriter writer = null;
        try {
            final long start = System.currentTimeMillis();
            writer = resp.getWriter();
            action.accept(writer);

            LOG.info("Processing {} '{}' request took: {} ms", req.getMethod(), req.getRequestURI(), System.currentTimeMillis() - start);
            writer.println("done");
            resp.setStatus(HttpServletResponse.SC_OK);
        } catch (Throwable e) {
            if(writer != null) {
                writer.println("failed: " + e);
            }
            LOG.info("Failed to process {} '{}': {}", req.getMethod(), req.getRequestURI(), e);
            resp.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
        }
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) {
        executeRequest(req, resp, writer -> {
            switch (requestType) {
                case EXPORT:
                    export(writer, req.getParameter("to"));
                    break;
                case REMOVE:
                    remove(req.getParameter("type"), req.getParameter("id"));
                    break;
            }
        });
    }
}
