package ru.yandex.travel.commons.logging.ydb;

import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.stream.Collectors;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.yandex.ydb.core.Issue;
import com.yandex.ydb.core.Result;
import com.yandex.ydb.core.StatusCode;
import com.yandex.ydb.core.UnexpectedResultException;
import com.yandex.ydb.core.auth.AuthProvider;
import com.yandex.ydb.table.SessionRetryContext;
import com.yandex.ydb.table.SessionSupplier;
import com.yandex.ydb.table.TableClient;
import com.yandex.ydb.table.query.Params;
import com.yandex.ydb.table.result.ResultSetReader;
import com.yandex.ydb.table.settings.ExecuteDataQuerySettings;
import com.yandex.ydb.table.transaction.TxControl;
import com.yandex.ydb.table.values.ListType;
import com.yandex.ydb.table.values.ListValue;
import com.yandex.ydb.table.values.PrimitiveType;
import com.yandex.ydb.table.values.PrimitiveValue;
import com.yandex.ydb.table.values.StructType;
import lombok.extern.slf4j.Slf4j;

import ru.yandex.misc.lang.StringUtils;
import ru.yandex.travel.logging.ydb.TOrderLogRecord;

import static java.util.stream.Collectors.toList;

/**
 * Inspired by https://a.yandex-team.ru/arc/trunk/arcadia/kikimr/public/sdk/java/examples/src/main/java/com/yandex/ydb/examples/basic_example/BasicExampleApp.java
 */
@Slf4j
public class YdbLogTableClient implements AutoCloseable {
    public static final StructType LOG_RECORD_TYPE = StructType.of(Map.of(
            "owner_uuid", PrimitiveType.utf8(),
            "timestamp", PrimitiveType.uint64(),
            "message_id", PrimitiveType.utf8(),
            "logger", PrimitiveType.utf8(),
            "level", PrimitiveType.utf8(),
            "message", PrimitiveType.utf8(),
            "context", PrimitiveType.json()
    ));

    private final AuthProvider authProvider;
    private final TableClient tableClient;
    private final String insertLogRecordQuery;
    private final String logRecordsDeclareBase;
    private final String logRecordsSelectBase;
    private final String logRecordsCountBase;
    private final Duration timeout;
    private final int maxAttempts;
    private final Duration backoffSlot;
    private final int backoffCeiling;
    private final ObjectMapper mapper;

    public YdbLogTableClient(AuthProvider authProvider, TableClient tableClient, String table,
                             Duration timeout, int maxAttempts, Duration backoffSlot, int backoffCeiling) {
        Preconditions.checkArgument(maxAttempts > 0, "maxAttempts should be a positive number: %s", maxAttempts);
        Preconditions.checkArgument(backoffCeiling > 0, "maxAttempts should be a positive number: %s", backoffCeiling);

        this.mapper = new ObjectMapper();
        this.authProvider = authProvider;
        this.tableClient = tableClient;
        this.timeout = timeout;
        this.maxAttempts = maxAttempts;
        this.backoffSlot = backoffSlot;
        this.backoffCeiling = backoffCeiling;

        this.insertLogRecordQuery = String.format("" +
                        "DECLARE $records as \"List<Struct<\n" +
                        "    owner_uuid: Utf8,\n" +
                        "    timestamp: Uint64,\n" +
                        "    message_id: Utf8,\n" +
                        "    logger: Utf8,\n" +
                        "    level: Utf8,\n" +
                        "    message: Utf8,\n" +
                        "    context: Json\n" +
                        ">>\";\n" +
                        // simple inserts can fail on client side due to interruptions and re-tries
                        "replace into `%s`(owner_uuid, timestamp, message_id, logger, level, message, context)\n" +
                        "select owner_uuid, timestamp, message_id, logger, level, message, context from as_table($records);\n",
                table);

        this.logRecordsDeclareBase = "" +
                "DECLARE $uuidList AS \"List<Struct<owner_uuid: Utf8>>\";\n";

        this.logRecordsSelectBase = String.format("\n" +
                        "SELECT t.owner_uuid AS owner_uuid, t.timestamp AS dtm, t.message_id AS message_id, " +
                        "t.context AS context, t.level AS level, t.logger AS logger, t.message AS message " +
                        " FROM AS_TABLE($uuidList) AS l INNER JOIN `%s` AS t ON t.owner_uuid = l.owner_uuid\n",
                table);

        this.logRecordsCountBase = String.format("\n" +
                "SELECT count(t.message_id) AS logs_count " +
                "  FROM AS_TABLE($uuidList) AS l INNER JOIN `%s` AS t on t.owner_uuid = l.owner_uuid\n",
                table);
    }

    public CompletableFuture<Void> insertLogRecords(List<TOrderLogRecord> records) {
        SessionRetryContext ctx = createContext();

        TxControl<?> txControl = TxControl.serializableRw().setCommitTx(true);
        Params params = makeParams(records);
        ExecuteDataQuerySettings settings = new ExecuteDataQuerySettings().keepInQueryCache();

        return ctx.supplyResult(session -> session.executeDataQuery(insertLogRecordQuery, txControl, params, settings))
                .thenApply(r -> r.expect("failed to insert the messages"))
                .thenApply(r -> null);
    }

    public CompletableFuture<List<TOrderLogRecord>> getLogRecords(List<String> listOfOwnerIds, String level, String logger,
                                                                  Integer offset, Integer limit, String searchText) {
        SessionRetryContext ctx = createContext();
        TxControl<?> txControl = TxControl.staleRo();

        StructType listStructType = StructType.of(Map.of("owner_uuid", PrimitiveType.utf8()));
        ListValue uuidList = ListType.of(listStructType).newValue(listOfOwnerIds.stream()
                .map(r -> listStructType.newValue(Map.of("owner_uuid", PrimitiveValue.utf8(r))))
                .collect(Collectors.toUnmodifiableList()));
        PrimitiveValue limitValue = limit == null ? PrimitiveValue.uint32(100) : PrimitiveValue.uint32(limit);
        PrimitiveValue offsetValue = offset == null ? PrimitiveValue.uint32(0) : PrimitiveValue.uint32(offset);
        Params params = Params.create();
        params.put("$uuidList", uuidList);
        params.put("$limit", limitValue);
        params.put("$offset", offsetValue);

        StringBuilder declareBuilder = new StringBuilder(this.logRecordsDeclareBase);
        declareBuilder.append("DECLARE $limit AS Uint32;\n");
        declareBuilder.append("DECLARE $offset AS Uint32;\n");
        StringBuilder queryBuilder = new StringBuilder(this.logRecordsSelectBase);

        addFiltersToQuery(level, logger, searchText, params, declareBuilder, queryBuilder);

        String logRecordsSelectQuery = "" + declareBuilder +
                queryBuilder +
                "    ORDER BY dtm DESC LIMIT $limit OFFSET $offset;";
        ExecuteDataQuerySettings settings = new ExecuteDataQuerySettings().keepInQueryCache();
        return ctx.supplyResult(session -> session.executeDataQuery(logRecordsSelectQuery, txControl, params, settings))
                .thenApply(r -> r.expect("failed to select the messages"))
                .thenApply(r -> mapToLogRecord(r.getResultSet(0)));
    }

    public CompletableFuture<Long> countLogRecords(Set<String> setOfOwnerIds, String level, String logger, String searchText) {
        SessionRetryContext ctx = createContext();

        TxControl<?> txControl = TxControl.staleRo();
        StructType listStructType = StructType.of(Map.of("owner_uuid", PrimitiveType.utf8()));
        ListValue uuidList = ListType.of(listStructType).newValue(setOfOwnerIds.stream()
                .map(r -> listStructType.newValue(Map.of("owner_uuid", PrimitiveValue.utf8(r))))
                .collect(Collectors.toUnmodifiableList()));
        Params params = Params.create();
        params.put("$uuidList", uuidList);

        StringBuilder declareBuilder = new StringBuilder(this.logRecordsDeclareBase);
        StringBuilder queryBuilder = new StringBuilder(this.logRecordsCountBase);

        addFiltersToQuery(level, logger, searchText, params, declareBuilder, queryBuilder);

        String logRecordsCountQuery = "" + declareBuilder + queryBuilder + ";";

        ExecuteDataQuerySettings settings = new ExecuteDataQuerySettings().keepInQueryCache();
        return ctx.supplyResult(session -> session.executeDataQuery(logRecordsCountQuery, txControl, params, settings))
                .thenApply(r -> r.expect("failed to select the messages"))
                .thenApply(r -> {
                    ResultSetReader resultSet = r.getResultSet(0);
                    if (resultSet.next()) {
                        return resultSet.getColumn("logs_count").getUint64();
                    } else {
                        return 0L;
                    }
                });
    }

    private void addFiltersToQuery(String level, String logger, String searchText, Params params,
                                   StringBuilder declareBuilder, StringBuilder queryBuilder) {
        List<String> filters = new ArrayList<>();
        if (StringUtils.isNotBlank(level)) {
            declareBuilder.append("DECLARE $level AS Utf8;\n");
            filters.add("level = $level");
            params.put("$level", PrimitiveValue.utf8(level));
        }
        if (StringUtils.isNotBlank(logger)) {
            declareBuilder.append("DECLARE $logger AS Utf8;\n");
            filters.add("logger LIKE($logger)");
            params.put("$logger", PrimitiveValue.utf8("%" + logger + "%"));
        }
        if (StringUtils.isNotBlank(searchText)) {
            declareBuilder.append("DECLARE $searchText AS Utf8;\n");
            filters.add("message LIKE($searchText)");
            params.put("$searchText", PrimitiveValue.utf8("%" + searchText + "%"));
        }
        if (filters.size() != 0) {
            queryBuilder.append("  WHERE ");
            queryBuilder.append(String.join(" AND ", filters));
            queryBuilder.append("\n");
        }
    }

    private List<TOrderLogRecord> mapToLogRecord(ResultSetReader resultReader) {
        var result = new ArrayList<TOrderLogRecord>();

        while (resultReader.next()) {
            var logRecordBuilder = TOrderLogRecord.newBuilder();

            logRecordBuilder.setOwnerId(resultReader.getColumn("owner_uuid").getUtf8());
            logRecordBuilder.setTimestamp(resultReader.getColumn("dtm").getUint64());
            logRecordBuilder.setMessageId(resultReader.getColumn("message_id").getUtf8());
            logRecordBuilder.setContext(resultReader.getColumn("context").getJson());
            logRecordBuilder.setLevel(resultReader.getColumn("level").getUtf8());
            logRecordBuilder.setLogger(resultReader.getColumn("logger").getUtf8());
            logRecordBuilder.setMessage(resultReader.getColumn("message").getUtf8());
            logRecordBuilder.setHostName(getHostNameFromContext(logRecordBuilder.getContext()));

            result.add(logRecordBuilder.build());
        }

        return result;
    }

    private String getHostNameFromContext(String context) {
        try {
            @SuppressWarnings("unchecked")
            Map<String, String> mapContext = mapper.readValue(context, Map.class);
            return mapContext.getOrDefault("HostName", "");
        } catch (IOException e) {
            log.warn("Could not deserialize context into map: {}", context);
            return "";
        }
    }

    private SessionRetryContext createContext() {
        return SessionRetryContext.create(retryableSessionSupplier())
                // max attemps = 1 first attemp + max retries
                .maxRetries(maxAttempts - 1)
                .sessionSupplyTimeout(timeout)
                .backoffSlot(backoffSlot)
                .backoffCeiling(backoffCeiling)
                .fastBackoffSlot(backoffSlot)
                .fastBackoffCeiling(backoffCeiling)
                .build();
    }

    private SessionSupplier retryableSessionSupplier() {
        return (timeout) -> tableClient.getOrCreateSession(timeout)
                .thenApply(sessionResult -> {
                    // temporary W/A for session creation re-tries:
                    // session creation errors are wrapped into CLIENT_INTERNAL_ERROR
                    // errors which aren't retryable themselves;
                    // we need to unwrap the source error code and pass it further
                    Optional<Throwable> err = sessionResult.error();
                    if (err.isPresent() && err.get() instanceof UnexpectedResultException) {
                        UnexpectedResultException ure = (UnexpectedResultException) err.get();
                        if (ure.getStatusCode() == StatusCode.CLIENT_INTERNAL_ERROR && ure.getCause() instanceof CompletionException) {
                            Throwable cause = ure.getCause();
                            while (cause instanceof CompletionException && cause.getCause() != null) {
                                cause = cause.getCause();
                            }
                            if (cause instanceof UnexpectedResultException) {
                                UnexpectedResultException rootCause = (UnexpectedResultException) cause;
                                List<Issue> issues = new ArrayList<>(Arrays.asList(rootCause.getIssues()));
                                issues.add(Issue.of(rootCause.getMessage(), Issue.Severity.INFO));
                                return Result.fail(rootCause.getStatusCode(), issues.toArray(new Issue[0]));
                            }
                        }
                    }
                    return sessionResult;
                });
    }

    private Params makeParams(List<TOrderLogRecord> records) {
        // we have to store uuid-s as strings and timestamp-s as integers
        // because the Uuid and Timestamp types can't be used in primary keys
        ListType listType = ListType.of(LOG_RECORD_TYPE);
        ListValue converted = listType.newValue(records.stream()
                .map(r -> LOG_RECORD_TYPE.newValue(Map.of(
                        "owner_uuid", PrimitiveValue.utf8(r.getOwnerId()),
                        "timestamp", PrimitiveValue.uint64(r.getTimestamp()),
                        "message_id", PrimitiveValue.utf8(r.getMessageId()),
                        "logger", PrimitiveValue.utf8(r.getLogger()),
                        "level", PrimitiveValue.utf8(r.getLevel()),
                        "message", PrimitiveValue.utf8(r.getMessage()),
                        "context", PrimitiveValue.json(r.getContext())
                )))
                .collect(toList()));
        return Params.of("$records", converted);
    }

    @Override
    public void close() {
        Collection<AutoCloseable> resources = List.of(
                tableClient,
                authProvider
        );
        Collection<Exception> failures = new ArrayList<>();
        for (AutoCloseable resource : resources) {
            try {
                resource.close();
            } catch (Exception e) {
                failures.add(e);
            }
        }
        if (!failures.isEmpty()) {
            RuntimeException e = new RuntimeException("Failed to successfully close all resources");
            for (Exception failure : failures) {
                e.addSuppressed(failure);
            }
            throw e;
        }
    }
}
