package ru.yandex.bannerstorage.harvester.queues.automoderation;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UncheckedIOException;
import java.io.UnsupportedEncodingException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Date;
import java.util.GregorianCalendar;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import javax.validation.constraints.NotNull;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.commons.dbcp2.BasicDataSource;
import org.apache.commons.lang3.tuple.Pair;
import org.hibernate.validator.constraints.NotEmpty;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bannerstorage.harvester.infrastructure.AsyncEmailMessageSender;
import ru.yandex.bannerstorage.harvester.queues.automoderation.services.virustotal.VirusTotalClient;
import ru.yandex.bannerstorage.harvester.queues.automoderation.services.virustotal.VirusTotalScanResult;

import static java.util.Collections.singletonList;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;

/**
 * @author elwood
 */
public class VirusTotalPollQueueObserver implements AutoCloseable {
    private final BasicDataSource dataSource;
    private final VirusTotalClient virusTotalClient;
    private final AsyncEmailMessageSender emailMessageSender;
    private final String[] emailRecipients;
    private final String premoderationUrl;

    private final ScheduledThreadPoolExecutor executor;
    private final ObjectMapper objectMapper;

    /**
     * Комбинации (антивирус; вирус), которые мы пропускаем.
     */
    private final Supplier<List<Pair<String, String>>> whiteListSupplier;

    private final Logger logger = LoggerFactory.getLogger(VirusTotalPollQueueObserver.class);

    // Using time zone = MSK to avoid errors when saving datetimes via JDBC
    private static final String TIMEZONE_MSK = "GMT+3";

    public VirusTotalPollQueueObserver(@NotNull BasicDataSource dataSource,
                                       @NotNull VirusTotalClient virusTotalClient,
                                       @NotNull AsyncEmailMessageSender emailMessageSender,
                                       @NotNull String[] emailRecipients,
                                       @NotNull String premoderationUrl,
                                       int pollIntervalInMS) {
        this.dataSource = dataSource;
        this.virusTotalClient = virusTotalClient;
        this.emailMessageSender = emailMessageSender;
        this.emailRecipients = emailRecipients;
        this.premoderationUrl = premoderationUrl;
        //
        this.objectMapper = new ObjectMapper()
                .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        //
        this.whiteListSupplier = Suppliers.memoizeWithExpiration(() -> {
            try (Connection conn = dataSource.getConnection()) {
                try (PreparedStatement st =
                             conn.prepareStatement(
                                     "SELECT antivirus, virus" +
                                             " FROM t_virustotal_whitelist")) {
                    List<Pair<String, String>> pairs = new ArrayList<>();
                    try (ResultSet rs = st.executeQuery()) {
                        while (rs.next()) {
                            pairs.add(Pair.of(rs.getString("antivirus"), rs.getString("virus")));
                        }
                    }
                    return pairs;
                }
            } catch (SQLException e) {
                throw new RuntimeException(e);
            }
        }, 10, TimeUnit.MINUTES);
        //
        executor = new ScheduledThreadPoolExecutor(4, new ThreadFactoryBuilder()
                .setNameFormat("virustotal-poll-queue-observer-%d").build());
        executor.scheduleAtFixedRate(this::run, 0, pollIntervalInMS, TimeUnit.MILLISECONDS);
    }

    private void run() {
        try {
            try (Connection conn = dataSource.getConnection()) {
                conn.setAutoCommit(false);
                try {
                    work(conn);
                    conn.commit();
                } catch (Exception | AssertionError e) {
                    logger.error("Unexpected exception, transaction will be rolled back", e);
                    conn.rollback();
                }
            }
        } catch (SQLException e) {
            logger.error("Unexpected SQL exception", e);
        } catch (Throwable e) {
            // Ловим Throwable, чтобы какое-нибудь непредвиденное исключение или AssertionError не убили тред
            logger.error("Unexpected exception", e);
        }
    }

    private void work(Connection conn) throws SQLException {
        // Используем ROWLOCK+UPDLOCK+READPAST для атомарного получения следующего набора объектов для проверки
        // Когда транзакция получит строчки, она их заблокирует, и другие транзакции уже не будут её видеть
        List<VirustotalScanInfo> allScans = new ArrayList<>();
        try (PreparedStatement st = conn.prepareStatement("SELECT s.*" +
                " FROM t_creative_version_virustotal_scan s WITH ( ROWLOCK, UPDLOCK, READPAST )" +
                " WHERE s.creative_version_nmb IN (" +
                "  SELECT TOP 1 creative_version_nmb" +
                "  FROM t_creative_version_virustotal_scan s WITH ( ROWLOCK, UPDLOCK, READPAST )" +
                "  WHERE virustotal_scan_finished IS NULL AND next_attempt_datetime < CURRENT_TIMESTAMP" +
                "  ORDER BY next_attempt_datetime)"
        )) {
            try (ResultSet rs = st.executeQuery()) {
                while (rs.next()) {
                    VirustotalScanInfo scanInfo = new VirustotalScanInfo();
                    scanInfo.setId(rs.getInt("nmb"));
                    scanInfo.setCreativeVersionId(rs.getInt("creative_version_nmb"));
                    scanInfo.setFileInstanceId(rs.getInt("file_instance_nmb"));
                    scanInfo.setSendScanDatetime(rs.getTimestamp("send_scan_datetime"));
                    scanInfo.setNextAttemptDatetime(rs.getTimestamp("next_attempt_datetime"));
                    scanInfo.setVirustotalSendScanResult(rs.getString("virustotal_send_scan_result"));
                    scanInfo.setVirustotalResourceId(rs.getString("virustotal_resource_id"));
                    scanInfo.setVirustotalPermalink(rs.getString("virustotal_permalink"));
                    scanInfo.setVirustotalScanResult(rs.getString("virustotal_scan_result"));
                    scanInfo.setVirustotalLastError(rs.getString("virustotal_last_error"));
                    scanInfo.setVirustotalScanFinished(rs.getTimestamp("virustotal_scan_finished"));
                    scanInfo.setVirustotalAttempts(rs.getInt("virustotal_attempts"));
                    allScans.add(scanInfo);
                }
            }
        }
        if (allScans.isEmpty()) {
            return;
        }

        int creativeVersionId = allScans.iterator().next().getCreativeVersionId();

        Map<String, VirustotalScanInfo> scansToSave = new HashMap<>();
        try {
            // Находим файлы по которым еще не получен результат
            Map<String, VirustotalScanInfo> incompletedScans =
                    allScans.stream()
                            .filter(s -> s.getVirustotalScanFinished() == null)
                            .collect(toMap(VirustotalScanInfo::getVirustotalResourceId, identity()));

            logger.info("Polling {} incompleted scans (of {} total) for creative version {}",
                    incompletedScans.size(), allScans.size(), creativeVersionId);

            // Запрашиваем результаты сканирования данных файлов
            Map<String, VirusTotalScanResult> scanResults = virusTotalClient.getScanResult(
                    new ArrayList<>(incompletedScans.keySet()));

            // В зависимости от результатов проставляем флаг окончания процесса (если isSuccess)
            // или переносим повторный забор данных на потом (если ошибка или ещё не готовы результаты)
            for (String resourceId : incompletedScans.keySet()) {
                VirustotalScanInfo scanInfo = incompletedScans.get(resourceId);
                if (scanResults.containsKey(resourceId)) {
                    VirusTotalScanResult scanResult = scanResults.get(resourceId);
                    if (scanResult.isSuccess()) {
                        scansToSave.put(resourceId, markComplete(scanInfo, scanResult));
                    } else if (scanResult.isError()) {
                        scansToSave.putAll(markWithError(singletonList(scanInfo), "isError flag is set in VirusTotal response", scanResult));
                    } else if (scanResult.isNotFound()) {
                        scansToSave.putAll(markWithError(singletonList(scanInfo), "isNotFound flag is set in VirusTotal response", scanResult));
                    } else {
                        scansToSave.put(resourceId, markIncomplete(scanInfo, scanResult));
                    }
                } else {
                    scansToSave.putAll(markWithError(singletonList(scanInfo), "Scan result was not found in VirusTotal response"));
                }
            }
        } catch (Exception e) {
            logger.error("Got error when retrieving VirusTotal scan results", e);
            scansToSave.putAll(markWithError(allScans, getExceptionTrace(e)));
        }

        // Обновляем соответствующим образом изменённые записи
        for (VirustotalScanInfo scanInfo : scansToSave.values()) {
            save(conn, scanInfo);
        }

        // Если все ожидаемые проверки были успешно завершены, проверяем, были ли найдены вирусы
        // и отправляем соответствующие уведомления
        boolean allPendingScansCompleted =
                scansToSave.values().stream()
                        .allMatch(s -> s.getVirustotalScanFinished() != null);
        if (allPendingScansCompleted) {
            // Собираем актуальные данные по всем проверкам
            List<VirustotalScanInfo> actualScans = new ArrayList<>();
            actualScans.addAll(allScans.stream().filter(s -> s.getVirustotalScanFinished() != null).collect(toList()));
            actualScans.addAll(scansToSave.values());

            // Ищем в них вирусы
            Map<VirustotalScanInfo, List<Pair<String, String>>> detectionsMap =
                    getViruses(objectMapper, actualScans, whiteListSupplier.get());

            // Если были найдены вирусы, собираем текстовое сообщение с ссылками на креатив и файлы
            // и отправляем e-mail на рассылку
            if (!detectionsMap.isEmpty()) {
                logger.info("Found viruses in creative version {}: {}, email will be sent",
                        creativeVersionId, detectionsMap.values());
                sendWarningEmail(conn, detectionsMap);
            }
        } else {
            List<VirustotalScanInfo> pendingScans =
                    scansToSave.values().stream()
                            .filter(s -> s.getVirustotalScanFinished() == null)
                            .collect(toList());
            logger.info(
                    "Not all scans are completed now (only {} of {}) for creative version {}. Next check will be at {}",
                    pendingScans.size(), allScans.size(), creativeVersionId,
                    pendingScans.iterator().next().getNextAttemptDatetime());
        }
    }

    public static Map<VirustotalScanInfo, List<Pair<String, String>>> getViruses(
            ObjectMapper objectMapper,
            List<VirustotalScanInfo> scans,
            List<Pair<String, String>> whiteList) {
        Map<VirustotalScanInfo, List<Pair<String, String>>> detectionsMap = new HashMap<>();
        for (VirustotalScanInfo scanInfo : scans) {
            VirusTotalScanResult scanResult = fromJson(objectMapper, scanInfo.getVirustotalScanResult(), VirusTotalScanResult.class);
            if (scanResult.getPositiveCount() > 0) {
                // Обнаруженные пары (Antivirus, Virus), которых нет в списках исключений
                List<Pair<String, String>> detections =
                        scanResult.getScans().entrySet().stream()
                                .filter(e -> e.getValue().isDetected())
                                .map(e -> Pair.of(e.getKey(), e.getValue().getResult()))
                                .filter(pair -> whiteList.stream().noneMatch(
                                        wl -> pair.getLeft().contains(wl.getKey())
                                                && pair.getRight().contains(wl.getValue())))
                                .collect(toList());
                if (!detections.isEmpty()) {
                    detectionsMap.put(scanInfo, detections);
                }
            }
        }
        return detectionsMap;
    }

    private void sendWarningEmail(@NotNull Connection conn,
                                  @NotEmpty Map<VirustotalScanInfo, List<Pair<String, String>>> detectionsMap) throws SQLException {
        // Для каждого fileInstanceId получаем пару (fileName, contentUrl)
        Map<Integer, Pair<String, String>> filesMap = new HashMap<>();
        try (PreparedStatement st =
                     conn.prepareStatement(
                             String.format(
                                     "SELECT fi.nmb, fi.file_name, f.stillage_file_url" +
                                             " FROM t_file_instance fi JOIN t_file f ON fi.file_nmb = f.nmb" +
                                             " WHERE fi.nmb IN (%s)",
                                     detectionsMap.keySet().stream()
                                             .map(VirustotalScanInfo::getFileInstanceId)
                                             .map(Object::toString)
                                             .collect(joining(","))))) {
            try (ResultSet rs = st.executeQuery()) {
                while (rs.next()) {
                    filesMap.put(rs.getInt("nmb"), Pair.of(rs.getString("file_name"), rs.getString("stillage_file_url")));
                }
            }
        }

        StringBuilder sb = new StringBuilder();
        sb.append("Внимание ! В креативе были обнаружены подозрительные файлы:\n\n");
        for (Map.Entry<VirustotalScanInfo, List<Pair<String, String>>> entry : detectionsMap.entrySet()) {
            VirustotalScanInfo scanInfo = entry.getKey();
            List<Pair<String, String>> detections = entry.getValue();
            String fileName = filesMap.get(scanInfo.getFileInstanceId()).getLeft();
            String contentUrl = filesMap.get(scanInfo.getFileInstanceId()).getRight();
            sb.append(String.format("Файл %s %s :\n", fileName, contentUrl));
            for (Pair<String, String> detection : detections) {
                sb.append(String.format(" -> найден %s (антивирус %s)\n", detection.getRight(), detection.getLeft()));
            }
        }
        int creativeVersionId = detectionsMap.keySet().iterator().next().getCreativeVersionId();
        sb.append("\nСсылка на креатив ")
                .append(premoderationUrl).append("/creative/index.jsp?nmb=")
                .append(creativeVersionId)
                .append("\n\n");
        sb.append("Креатив НЕ БУДЕТ отклонён автоматически\n");
        sb.append("Если креатив нужно отклонить, вы можете это сделать в интерфейсе модерации");

        emailMessageSender.sendMessage(
                emailRecipients,
                "Обнаружен креатив с подозрительными файлами",
                sb.toString());
    }

    private Map<String, VirustotalScanInfo> markWithError(List<VirustotalScanInfo> scanInfo, String error) {
        return markWithError(scanInfo, error, null);
    }

    private Map<String, VirustotalScanInfo> markWithError(List<VirustotalScanInfo> scanInfo,
                                                          String error,
                                                          @Nullable VirusTotalScanResult scanResult) {
        return scanInfo.stream().map(s -> {
            VirustotalScanInfo copy = new VirustotalScanInfo(s);
            copy.setVirustotalLastError(error);
            if (scanResult != null) {
                copy.setVirustotalScanResult(toJson(scanResult));
            }
            GregorianCalendar calendar = new GregorianCalendar();
            calendar.add(Calendar.MINUTE, (int) Math.pow(2, s.getVirustotalAttempts()));
            copy.setNextAttemptDatetime(calendar.getTime());
            copy.setVirustotalAttempts(s.getVirustotalAttempts() + 1);
            return copy;
        }).collect(toMap(VirustotalScanInfo::getVirustotalResourceId, identity()));
    }

    private VirustotalScanInfo markIncomplete(VirustotalScanInfo scanInfo,
                                              VirusTotalScanResult scanResult) {
        VirustotalScanInfo copy = new VirustotalScanInfo(scanInfo);
        GregorianCalendar calendar = new GregorianCalendar();
        calendar.add(Calendar.MINUTE, (int) Math.pow(2, scanInfo.getVirustotalAttempts()));
        copy.setNextAttemptDatetime(calendar.getTime());
        copy.setVirustotalAttempts(scanInfo.getVirustotalAttempts() + 1);
        copy.setVirustotalLastError(null);
        copy.setVirustotalScanResult(toJson(scanResult));
        return copy;
    }

    private VirustotalScanInfo markComplete(VirustotalScanInfo scanInfo,
                                            VirusTotalScanResult scanResult) {
        VirustotalScanInfo copy = new VirustotalScanInfo(scanInfo);
        copy.setVirustotalScanFinished(new Date());
        copy.setVirustotalScanResult(toJson(scanResult));
        copy.setVirustotalLastError(null);
        return copy;
    }

    private void save(Connection conn, VirustotalScanInfo scanInfo) throws SQLException {
        GregorianCalendar calendar = new GregorianCalendar(TimeZone.getTimeZone(TIMEZONE_MSK));
        try (PreparedStatement st =
                     conn.prepareStatement(
                             "UPDATE t_creative_version_virustotal_scan" +
                                     " SET" +
                                     "  next_attempt_datetime = ?," +
                                     "  virustotal_scan_result = ?," +
                                     "  virustotal_last_error = ?," +
                                     "  virustotal_scan_finished = ?," +
                                     "  virustotal_attempts = ?" +
                                     " WHERE nmb = ?")) {
            st.setTimestamp(1,
                    new java.sql.Timestamp(scanInfo.getNextAttemptDatetime().getTime()),
                    calendar);
            st.setString(2, scanInfo.getVirustotalScanResult());
            st.setString(3, scanInfo.getVirustotalLastError());
            if (scanInfo.getVirustotalScanFinished() != null) {
                st.setTimestamp(4,
                        new java.sql.Timestamp(scanInfo.getVirustotalScanFinished().getTime()),
                        calendar);
            } else {
                st.setTimestamp(4, null);
            }
            st.setInt(5, scanInfo.getVirustotalAttempts());
            st.setInt(6, scanInfo.getId());

            st.executeUpdate();
        }
    }

    private String getExceptionTrace(@NotNull Exception e) {
        try {
            try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
                try (PrintWriter printWriter = new PrintWriter(outputStream)) {
                    e.printStackTrace(printWriter);
                }
                try {
                    return outputStream.toString("UTF-8");
                } catch (UnsupportedEncodingException exc) {
                    throw new IllegalStateException(exc);
                }
            }
        } catch (IOException exc) {
            throw new UncheckedIOException(exc);
        }
    }

    @Override
    public void close() throws Exception {
        executor.shutdown();
    }

    public static class VirustotalScanInfo {
        private int id;
        private int creativeVersionId;
        private int fileInstanceId;
        private Date sendScanDatetime;
        private Date nextAttemptDatetime;
        private String virustotalSendScanResult;
        private String virustotalResourceId;
        private String virustotalPermalink;
        private String virustotalScanResult;
        private String virustotalLastError;
        private Date virustotalScanFinished;
        private int virustotalAttempts;

        public int getId() {
            return id;
        }

        public void setId(int id) {
            this.id = id;
        }

        public int getCreativeVersionId() {
            return creativeVersionId;
        }

        public void setCreativeVersionId(int creativeVersionId) {
            this.creativeVersionId = creativeVersionId;
        }

        public int getFileInstanceId() {
            return fileInstanceId;
        }

        public void setFileInstanceId(int fileInstanceId) {
            this.fileInstanceId = fileInstanceId;
        }

        public Date getSendScanDatetime() {
            return sendScanDatetime;
        }

        public void setSendScanDatetime(Date sendScanDatetime) {
            this.sendScanDatetime = sendScanDatetime;
        }

        public Date getNextAttemptDatetime() {
            return nextAttemptDatetime;
        }

        public void setNextAttemptDatetime(Date nextAttemptDatetime) {
            this.nextAttemptDatetime = nextAttemptDatetime;
        }

        public String getVirustotalSendScanResult() {
            return virustotalSendScanResult;
        }

        public void setVirustotalSendScanResult(String virustotalSendScanResult) {
            this.virustotalSendScanResult = virustotalSendScanResult;
        }

        public String getVirustotalResourceId() {
            return virustotalResourceId;
        }

        public void setVirustotalResourceId(String virustotalResourceId) {
            this.virustotalResourceId = virustotalResourceId;
        }

        public String getVirustotalPermalink() {
            return virustotalPermalink;
        }

        public void setVirustotalPermalink(String virustotalPermalink) {
            this.virustotalPermalink = virustotalPermalink;
        }

        public String getVirustotalScanResult() {
            return virustotalScanResult;
        }

        public void setVirustotalScanResult(String virustotalScanResult) {
            this.virustotalScanResult = virustotalScanResult;
        }

        public String getVirustotalLastError() {
            return virustotalLastError;
        }

        public void setVirustotalLastError(String virustotalLastError) {
            this.virustotalLastError = virustotalLastError;
        }

        public Date getVirustotalScanFinished() {
            return virustotalScanFinished;
        }

        public void setVirustotalScanFinished(Date virustotalScanFinished) {
            this.virustotalScanFinished = virustotalScanFinished;
        }

        public int getVirustotalAttempts() {
            return virustotalAttempts;
        }

        public void setVirustotalAttempts(int virustotalAttempts) {
            this.virustotalAttempts = virustotalAttempts;
        }

        public VirustotalScanInfo() {
        }

        public VirustotalScanInfo(@NotNull VirustotalScanInfo o) {
            this.id = o.id;
            this.creativeVersionId = o.creativeVersionId;
            this.fileInstanceId = o.fileInstanceId;
            this.sendScanDatetime = o.sendScanDatetime;
            this.nextAttemptDatetime = o.nextAttemptDatetime;
            this.virustotalSendScanResult = o.virustotalSendScanResult;
            this.virustotalResourceId = o.virustotalResourceId;
            this.virustotalPermalink = o.virustotalPermalink;
            this.virustotalScanResult = o.virustotalScanResult;
            this.virustotalLastError = o.virustotalLastError;
            this.virustotalScanFinished = o.virustotalScanFinished;
            this.virustotalAttempts = o.virustotalAttempts;
        }
    }

    private <T> T fromJson(String json, Class<T> clazz) {
        return fromJson(objectMapper, json, clazz);
    }

    private static <T> T fromJson(ObjectMapper objectMapper, String json, Class<T> clazz) {
        try {
            return objectMapper.readValue(json, clazz);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private String toJson(Object o) {
        try {
            return objectMapper.writeValueAsString(o);
        } catch (JsonProcessingException e) {
            throw new UncheckedIOException(e);
        }
    }
}
