package ru.yandex.webmaster3.core.worker.client;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.joda.JodaModule;
import com.fasterxml.jackson.module.paramnames.ParameterNamesModule;
import com.google.common.base.Strings;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpStatus;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Required;
import org.springframework.web.util.UriComponentsBuilder;
import ru.yandex.webmaster3.core.http.ActionStatus;
import ru.yandex.webmaster3.core.http.HttpConstants;
import ru.yandex.webmaster3.core.http.WebmasterJsonModule;
import ru.yandex.webmaster3.core.util.conductor.ConductorClient;
import ru.yandex.webmaster3.core.util.conductor.ConductorHostInfo;
import ru.yandex.webmaster3.core.worker.task.*;

import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

/**
 * @author aherman
 */
public class RemoteWorkerClient implements WorkerClient {
    public static final int TASK_BATCH_SIZE = 100;

    private static final Logger log = LoggerFactory.getLogger(RemoteWorkerClient.class);

    private static final ObjectMapper OM = new ObjectMapper()
            .registerModule(new ParameterNamesModule())
            .registerModules(new JodaModule(), new WebmasterJsonModule(false));

    private static final String WORKER_OVERRIDE_PROPERTY = "WORKER_OVERRIDE_HOST";
    private static final int DEFAULT_WORKER_PORT = 80;
    private static final Duration PING_PERIOD = Duration.ofSeconds(5);

    private static final int socketTimeoutMs = 10_000;

    private CloseableHttpClient httpClient;
    private String applicationHostname;
    private String applicationName;
    private String applicationVersion;
    private String workerConductorGroupName;
    private String fallbackHostsString;
    private ConductorClient conductorClient;
    private List<WorkerHostInfo> workerHosts;
    private Timer pingerTimer = new Timer("worker-client-pinger", true);
    private String localDcName;

    private ExecutorService sendQueueExecutor;
    private final BlockingQueue<WorkerTaskData> sendQueue = new ArrayBlockingQueue<>(50000);

    public void init() {
        RequestConfig requestConfig = RequestConfig.custom()
                .setConnectionRequestTimeout(HttpConstants.DEFAULT_CONNECTION_REQUEST_TIMEOUT)
                .setConnectTimeout(HttpConstants.DEFAULT_CONNECT_TIMEOUT)
                .setSocketTimeout(socketTimeoutMs)
                .build();

        httpClient = HttpClients.custom()
                .setMaxConnTotal(60)
                .setMaxConnPerRoute(20)
                .setDefaultRequestConfig(requestConfig)
                .build();

        sendQueueExecutor = Executors.newSingleThreadExecutor(
                new ThreadFactoryBuilder()
                        .setNameFormat("worker-client-queue-%d")
                        .setDaemon(true)
                        .build()
        );
        workerHosts = getWorkerHosts();
        pingerTimer.scheduleAtFixedRate(new PingerTask(), PING_PERIOD.toMillis(), PING_PERIOD.toMillis());
        sendQueueExecutor.execute(new SendWorker());
    }

    private List<WorkerHostInfo> getWorkerHosts() {
        List<WorkerHostInfo> hostInfos = new ArrayList<>();
        boolean isFallback = false;
        boolean isOverride = false;

        String workerOverrideHost = System.getenv(WORKER_OVERRIDE_PROPERTY);
        if (!Strings.isNullOrEmpty(workerOverrideHost)) {
            hostInfos.add(new WorkerHostInfo(URI.create(workerOverrideHost)));
            isOverride = true;
        } else {
            try {
                if (!StringUtils.isEmpty(workerConductorGroupName)) {
                    List<ConductorHostInfo> hosts = conductorClient.listHostsInGroup(workerConductorGroupName);
                    for (ConductorHostInfo hostInfo : hosts) {
                        hostInfos.add(new WorkerHostInfo(URI.create("http://" + hostInfo.getHostName() + ":" + DEFAULT_WORKER_PORT)));
                    }
                }
            } catch (Exception e) {
                log.error("Failed to get worker hosts from conductor", e);
            }
            if (hostInfos.isEmpty()) {
                isFallback = true;
                String[] hosts = fallbackHostsString.split(",");
                for (String host : hosts) {
                    String uri;
                    if (host.contains(":")) {
                        uri = "http://" + host;
                    } else {
                        uri = "http://" + host + ":" + DEFAULT_WORKER_PORT;
                    }
                    hostInfos.add(new WorkerHostInfo(URI.create(uri)));
                }
            }
        }

        log.info("Discovered worker hosts: {}, from override = {}, from fallback = {}",
                hostInfos.stream().map(h -> h.workerAddress.toString()).collect(Collectors.joining(",")),
                isOverride, isFallback
        );

        return hostInfos;
    }

    public void destroy() {
        sendQueueExecutor.shutdownNow();
        try {
            sendQueueExecutor.awaitTermination(20, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            // Ignore
        }
        IOUtils.closeQuietly(httpClient);
    }

    @Override
    public <TD extends WorkerTaskData> void enqueueTask(TD taskData){
        UUID taskId = taskData.getTaskId();
        WorkerTaskType taskType = taskData.getTaskType();
        WorkerTaskPriority taskPriority = taskData.getTaskPriority();

        log.trace("Send task: id={} type={} {} queueSize={} taskPriority={}", taskId, taskType, taskData.getShortDescription(),
                sendQueue.size(), taskPriority);
        boolean enqueued;
        try {
            enqueued = sendQueue.offer(taskData, 1, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            enqueued = false;
        }
        if (!enqueued) {
            log.error("Unable to add task: id={} type={} {}", taskId, taskType, taskData.getShortDescription());
        }
    }

    @Override
    public <TD extends WorkerTaskData> boolean checkedEnqueueTask(TD taskData) {
        UUID taskId = taskData.getTaskId();
        WorkerTaskType taskType = taskData.getTaskType();
        WorkerTaskPriority taskPriority = taskData.getTaskPriority();

        log.trace("Send task: id={} type={} {} queueSize={} taskPriority={}", taskId, taskType, taskData.getShortDescription(),
                sendQueue.size(), taskPriority);
        return internalEnqueue(Collections.singletonList(taskData));
    }

    @Override
    public <TD extends WorkerTaskData> boolean checkedEnqueueBatch(Collection<TD> batch) {
        return internalEnqueue(batch);
    }

    private class SendWorker implements Runnable {
        @Override
        public void run() {
            log.info("Start task");
            while (!Thread.interrupted()) {
                WorkerTaskData taskData;
                try {
                    taskData = sendQueue.take();
                } catch (InterruptedException e) {
                    break;
                }
                List<WorkerTaskData> taskDataList = new ArrayList<>(TASK_BATCH_SIZE);
                taskDataList.add(taskData);
                // затягиваем до 99 элементов (1 уже есть)
                sendQueue.drainTo(taskDataList, TASK_BATCH_SIZE - 1);
                internalEnqueue(taskDataList);

            }
            log.info("Stop task");
        }
    }

    private boolean internalEnqueue(Collection<? extends WorkerTaskData> taskDataList) {
        log.trace("Send {} tasks: {}", taskDataList.size(), taskDataList);
        if (taskDataList.isEmpty()) {
            return true;
        }
        boolean shouldRunLocally = taskDataList.iterator().next().shouldRunLocally();
        List<WorkerHostInfo> candidates = workerHosts
                .stream()
                .filter(WorkerHostInfo::isAlive)
                .collect(Collectors.toList());
        if (shouldRunLocally) {
            List<WorkerHostInfo> localCandidates = candidates.stream()
                    .filter(host -> Objects.equals(host.localDc.get(), localDcName))
                    .collect(Collectors.toList());
            if (!localCandidates.isEmpty()) {
                candidates = localCandidates;
            }
        }
        if (candidates.isEmpty()) {
            candidates = workerHosts;
        }
        WorkerHostInfo workerHost = candidates.get(ThreadLocalRandom.current().nextInt(candidates.size()));
        URI hostUri = workerHost.workerAddress;
        HttpPost post;
        try {
            post = toPost(taskDataList, hostUri);
        } catch (Exception e) {
            log.error("{} tasks send failed", taskDataList.size(), e);
            return false;
        }

        try (CloseableHttpResponse httpResponse = httpClient.execute(post)) {
            String status = null;
            if (httpResponse.getStatusLine().getStatusCode() == 200) {
                JsonNode node = OM.readTree(httpResponse.getEntity().getContent());
                status = node.get("status").textValue();
                if (!status.equals(ActionStatus.SUCCESS.name())) {
                    log.error("Failed to enqueue {} tasks to {}, status: {}\n", taskDataList.size(), hostUri, httpResponse.getStatusLine());
                    return false;
                }
            }
            log.info("Tasks send result:status={} actionStatus={}",
                    httpResponse.getStatusLine().getStatusCode(), status);
            return true;
        } catch (IOException e) {
            workerHost.alive = false;
            log.error("Send {} tasks failed: {}", taskDataList.size(), taskDataList, e);
            return false;
        }
    }

    HttpPost toPost(Collection<? extends WorkerTaskData> taskDataList, URI hostUrl) throws JsonProcessingException {

        URI requestUri = UriComponentsBuilder.fromUri(hostUrl)
                .path("/task/enqueueList.json")
                .queryParam("applicationHostname", applicationHostname)
                .queryParam("applicationName", applicationName)
                .queryParam("applicationVersion", applicationVersion)
                .build()
                .toUri();

        HttpPost post = new HttpPost(requestUri);
        String dataString = OM.writeValueAsString(
                taskDataList.stream().map(WorkerTaskDataWrapper::fromWorkerTaskData).collect(Collectors.toList()));
        post.setEntity(new StringEntity("{\"tasks\":" + dataString + "}", ContentType.APPLICATION_JSON));
        return post;
    }

    private class PingerTask extends TimerTask {
        @Override
        public void run() {
            for (WorkerHostInfo hostInfo : workerHosts) {
                try (CloseableHttpResponse response = httpClient.execute(new HttpGet(hostInfo.workerAddress + "/ping?showInfo=true"))) {
                    if (response.getStatusLine().getStatusCode() != 200) {
                        failure(hostInfo, null);
                    } else {
                        String workerLocalDcName = null;
                        try {
                            JsonNode root = OM.readTree(response.getEntity().getContent());
                            if (root.has("localDC")) {
                                workerLocalDcName = root.get("localDC").textValue();
                            }
                        } catch (Exception e) {
                            log.error("Failed to parse ping response", e);
                        }
                        success(hostInfo, workerLocalDcName);
                    }
                } catch (Exception e) {
                    failure(hostInfo, e);
                }
            }
        }

        private void success(WorkerHostInfo hostInfo, String localDc) {
            if (!hostInfo.alive) {
                hostInfo.alive = true;
                log.info("Worker host {} is alive now", hostInfo.workerAddress);
            }
            localDc = localDc == null ? null : localDc.intern();
            String prevDc = hostInfo.localDc.get();
            // использую !=, потому что имя dc тут всегда будет intern'нутым
            if (prevDc != localDc && hostInfo.localDc.compareAndSet(prevDc, localDc)) {
                log.info("Worker host {} changed local dc {} => {}", hostInfo.workerAddress, prevDc, localDc);
            }
        }

        private void failure(WorkerHostInfo hostInfo, Exception e) {
            if (hostInfo.alive) {
                log.warn("Worker host " + hostInfo.workerAddress + " didn't answer on ping, marking it dead", e);
                hostInfo.alive = false;
            }
        }
    }

    private static class WorkerHostInfo {
        private final URI workerAddress;
        private volatile boolean alive;
        private AtomicReference<String> localDc = new AtomicReference<>();

        public WorkerHostInfo(URI workerAddress) {
            this.workerAddress = workerAddress;
        }

        public boolean isAlive() {
            return alive;
        }
    }

    @Required
    public void setWorkerConductorGroupName(String workerConductorGroupName) {
        this.workerConductorGroupName = workerConductorGroupName;
    }

    @Required
    public void setFallbackHostsString(String fallbackHostsString) {
        this.fallbackHostsString = fallbackHostsString;
    }

    @Required
    public void setApplicationHostname(String applicationHostname) {
        this.applicationHostname = applicationHostname;
    }

    @Required
    public void setApplicationName(String applicationName) {
        this.applicationName = applicationName;
    }

    @Required
    public void setApplicationVersion(String applicationVersion) {
        this.applicationVersion = applicationVersion;
    }

    @Required
    public void setConductorClient(ConductorClient conductorClient) {
        this.conductorClient = conductorClient;
    }

    @Required
    public void setLocalDcName(String localDcName) {
        this.localDcName = localDcName.intern();
    }
}
