package ru.yandex.bannerstorage.harvester.queues.automoderation.services.virustotal;

import java.net.MalformedURLException;
import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;

import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.UrlResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.client.ResourceAccessException;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;

import ru.yandex.bannerstorage.harvester.infrastructure.PassportAuthenticator;
import ru.yandex.bannerstorage.harvester.infrastructure.PassportSession;
import ru.yandex.bannerstorage.harvester.queues.automoderation.services.virustotal.exceptions.InvalidResponseException;
import ru.yandex.bannerstorage.harvester.queues.automoderation.services.virustotal.exceptions.NetworkErrorException;
import ru.yandex.bannerstorage.harvester.queues.automoderation.services.virustotal.exceptions.RequestRateExceededException;

/**
 * @author egorovmv
 */
public final class VirusTotalClientFactory {
    public static VirusTotalClient newInstance(
            @NotNull String serviceUrl,
            @NotNull PassportAuthenticator passportAuthenticator,
            @NotNull String robotLogin,
            @NotNull String robotPassword,
            int renewSessionTimeoutInMin,
            int connectTimeoutInMs,
            int readTimeoutInMs) {
        return new HttpVirusTotalClient(
                serviceUrl,
                passportAuthenticator,
                robotLogin,
                robotPassword,
                renewSessionTimeoutInMin,
                connectTimeoutInMs,
                readTimeoutInMs);
    }

    private static class HttpVirusTotalClient implements VirusTotalClient {
        private static final int MAX_RECOURCES_IN_BATCH = 4;

        private static final Logger logger = LoggerFactory.getLogger(HttpVirusTotalClient.class);

        private final RestTemplate restTemplate;
        private final URI serviceUrl;
        private final Supplier<PassportSession> robotAuthenticator;

        public HttpVirusTotalClient(
                @NotNull String serviceUrl,
                @NotNull PassportAuthenticator passportAuthenticator,
                @NotNull String robotLogin,
                @NotNull String robotPassword,
                int renewSessionTimeoutInMin,
                int connectTimeoutInMs,
                int readTimeoutInMs) {
            Objects.requireNonNull(serviceUrl, "serviceUrl");
            Objects.requireNonNull(passportAuthenticator, "passportAuthenticator");
            Objects.requireNonNull(robotLogin, "robotLogin");
            Objects.requireNonNull(robotPassword, "robotPassword");
            if (renewSessionTimeoutInMin <= 0)
                throw new IllegalArgumentException("renewSessionTimeoutInMin: " + renewSessionTimeoutInMin);
            if (connectTimeoutInMs <= 0)
                throw new IllegalArgumentException("connectTimeoutInMs: " + connectTimeoutInMs);
            if (readTimeoutInMs <= 0)
                throw new IllegalArgumentException("readTimeoutInMs: " + readTimeoutInMs);

            HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory();
            requestFactory.setConnectTimeout(connectTimeoutInMs);
            requestFactory.setReadTimeout(readTimeoutInMs);
            this.restTemplate = new RestTemplate(requestFactory);

            this.serviceUrl = UriComponentsBuilder.fromHttpUrl(serviceUrl)
                    .pathSegment("vtapi/v2/file")
                    .build()
                    .toUri();

            robotAuthenticator = Suppliers.memoizeWithExpiration(
                    () -> passportAuthenticator.authenticate(robotLogin, robotPassword),
                    renewSessionTimeoutInMin,
                    TimeUnit.MINUTES);
        }

        private static MultiValueMap<String, String> createRequest(@NotNull String resources) {
            MultiValueMap<String, String> result = new LinkedMultiValueMap<>();
            result.add("resource", resources);
            return result;
        }

        private static MultiValueMap<String, Object> createRequest(
                @NotNull String fileName,
                @NotNull URI fileUrl) throws MalformedURLException {
            MultiValueMap<String, Object> result = new LinkedMultiValueMap<>();
            result.add(
                    "file",
                    new UrlResource(fileUrl) {
                        @Override
                        public String getFilename() {
                            return fileName;
                        }
                    });
            return result;
        }

        private <Request, Response> Response executeRequest(
                @NotNull String path,
                @NotNull MediaType contentType,
                @NotNull Request requestBody,
                @NotNull Class<Response> responseClazz) {
            URI targetUrl = UriComponentsBuilder.fromUri(serviceUrl)
                    .pathSegment(path)
                    .build()
                    .toUri();

            PassportSession session = robotAuthenticator.get();

            UUID requestId = UUID.randomUUID();

            HttpHeaders headers = new HttpHeaders();
            headers.setContentType(contentType);
            headers.setAccept(
                    Collections.singletonList(
                            MediaType.valueOf(MediaType.APPLICATION_JSON_VALUE)));

            headers.add(HttpHeaders.COOKIE, PassportAuthenticator.YANDEX_UID_COOKIE + "=" + session.getYandexUid());
            headers.add(HttpHeaders.COOKIE, PassportAuthenticator.YANDEX_LOGIN_COOKIE + "=" + session.getYandexLogin());
            headers.add(HttpHeaders.COOKIE, PassportAuthenticator.YANDEX_SESSIONID_COOKIE + "=" + session.getSessionId());
            headers.add(HttpHeaders.COOKIE, PassportAuthenticator.YANDEX_SESSIONID2_COOKIE + "=" + session.getSessionId2());

            try {
                logger.debug(
                        "Sending request for scan result (RequestId: {}, TargetUrl: {}, Request: {})",
                        requestId, targetUrl, requestBody);

                ResponseEntity<Response> responseEntity = restTemplate.exchange(
                        targetUrl,
                        HttpMethod.POST,
                        new HttpEntity<>(requestBody, headers),
                        responseClazz);
                if (responseEntity.getStatusCode().equals(HttpStatus.NO_CONTENT))
                    throw new RequestRateExceededException();

                Response response = responseEntity.getBody();

                logger.debug("Got response (RequestId: {}, Response: {})", requestId, response);

                return response;
            } catch (ResourceAccessException | HttpServerErrorException e) {
                throw new NetworkErrorException(e);
            }
        }

        private <Response> Map<String, Response> executeBatch(
                @NotNull List<String> resourceIds,
                @NotNull Function<String, List<Response>> requestExecutor) {
            if (resourceIds.isEmpty())
                return Collections.emptyMap();

            Map<String, Response> result = new HashMap<>(resourceIds.size());

            String fakeResourceId = StringUtils.repeat('0', resourceIds.get(0).length());

            for (int i = 0; i < resourceIds.size(); i += MAX_RECOURCES_IN_BATCH) {
                List<String> batch = resourceIds.subList(
                        i, Math.min(i + MAX_RECOURCES_IN_BATCH, resourceIds.size()));

                String resources = batch
                        .stream()
                        .collect(Collectors.joining(", "));
                if (batch.size() < MAX_RECOURCES_IN_BATCH)
                    resources += ", " + fakeResourceId; // Мы добавляем это, чтобы в результате всегда получать []

                List<Response> response = requestExecutor.apply(resources);
                if (response.size() < batch.size())
                    throw new InvalidResponseException();

                // Результаты по файлам должны возвращаться в том же порядке, в каком они были перечисленны в запросе
                // иначе их будет невозможно с match-ить
                for (int j = 0; j < batch.size(); j++) {
                    result.putIfAbsent(batch.get(j), response.get(j));
                }
            }

            return result;
        }

        @Override
        public Map<String, VirusTotalScanResult> getScanResult(@NotNull List<String> resourceIds) {
            return executeBatch(
                    resourceIds,
                    resources -> Arrays.asList(
                            executeRequest(
                                    "report",
                                    MediaType.APPLICATION_FORM_URLENCODED,
                                    createRequest(resources),
                                    VirusTotalScanResult[].class)));
        }

        @Override
        public VirusTotalSendScanResult sendForScan(@NotNull String fileName, @NotNull URI fileUrl) {
            Objects.requireNonNull(fileName, "fileName");
            Objects.requireNonNull(fileUrl, "fileUrl");

            try {
                return executeRequest(
                        "scan",
                        MediaType.MULTIPART_FORM_DATA,
                        createRequest(fileName, fileUrl),
                        VirusTotalSendScanResult.class);
            } catch (MalformedURLException e) {
                return null;
            }
        }

        @Override
        public Map<String, VirusTotalSendScanResult> sendForRescanByMd5Hashes(@NotNull List<String> md5Hashes) {
            return executeBatch(
                    md5Hashes,
                    resources -> Arrays.asList(
                            executeRequest(
                                    "rescan",
                                    MediaType.APPLICATION_FORM_URLENCODED,
                                    createRequest(resources),
                                    VirusTotalSendScanResult[].class)));
        }

    }

}
