package ru.yandex.webmaster3.storage.util.ydb;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.*;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Stream;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.yandex.ydb.core.Result;
import com.yandex.ydb.core.Status;
import com.yandex.ydb.core.StatusCode;
import com.yandex.ydb.table.Session;
import com.yandex.ydb.table.SessionRetryContext;
import com.yandex.ydb.table.TableClient;
import com.yandex.ydb.table.query.DataQuery;
import com.yandex.ydb.table.query.DataQueryResult;
import com.yandex.ydb.table.query.Params;
import com.yandex.ydb.table.settings.PrepareDataQuerySettings;
import com.yandex.ydb.table.settings.ReadTableSettings;
import com.yandex.ydb.table.transaction.Transaction;
import com.yandex.ydb.table.transaction.TransactionMode;
import com.yandex.ydb.table.transaction.TxControl;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;

import ru.yandex.webmaster3.core.solomon.SolomonSensor;
import ru.yandex.webmaster3.core.solomon.metric.SolomonCounter;
import ru.yandex.webmaster3.core.solomon.metric.SolomonGauge;
import ru.yandex.webmaster3.core.solomon.metric.SolomonKey;
import ru.yandex.webmaster3.core.solomon.metric.SolomonMetricRegistry;
import ru.yandex.webmaster3.core.solomon.metric.SolomonTimer;
import ru.yandex.webmaster3.core.solomon.metric.SolomonTimerConfiguration;
import ru.yandex.webmaster3.core.tracer.YdbTracer;
import ru.yandex.webmaster3.core.util.functional.Functions;
import ru.yandex.webmaster3.storage.util.ydb.exception.WebmasterYdbException;
import ru.yandex.webmaster3.storage.util.ydb.query.PreparedStatement;
import ru.yandex.webmaster3.storage.util.ydb.querybuilder.typesafe.RowMapper;

/**
 * ishalaru
 * 03.06.2020
 **/
@Slf4j
@Service
public class ThreadLocalYdbTransactionManager {

    private static final int KEEP_ALIVE_ROWS_INTERVAL = 10000;
    private static final long PREPARE_QUERY_TIMEOUT = 5L;

    private final TableClient tableClient;
    private final ThreadLocal<TransactionData> current = new ThreadLocal<>();
    private Cache<Pair<String, String>, DataQuery> preparedStatements = CacheBuilder.newBuilder()
            .maximumSize(10_000L)
            .build();

    private final ThreadLocalYdbTransactionManager.YdbMetrics ydbMetrics;
    private final SessionRetryContext sessionRetryContext;
    private final Semaphore mutex;
    private final Duration waitAcquireDuration;
    private final Duration maxWaitingTimeout;

    public ThreadLocalYdbTransactionManager(TableClient tableClient,
                                            SessionRetryContext ydbSessionRetryContext,
                                            Semaphore mutex,
                                            @Qualifier("waitAcquireDuration") Duration waitAcquireDuration,
                                            @Qualifier("maxWaitingTimeout") Duration maxWaitingTimeout,
                                            SolomonMetricRegistry solomonMetricRegistry) {
        this.tableClient = tableClient;
        this.sessionRetryContext = ydbSessionRetryContext;
        this.mutex = mutex;
        this.waitAcquireDuration = waitAcquireDuration;
        this.maxWaitingTimeout = maxWaitingTimeout;
        this.ydbMetrics = new YdbMetrics(solomonMetricRegistry);
    }

    public DataQueryResult execute(PreparedStatement st) {
        final long startProcessingTime = System.nanoTime();
        Functions.F3<Session, TxControl, DataQuery, CompletableFuture<Result<DataQueryResult>>> op =
                (session, txControl, dataQuery) -> execute(session, txControl, st, dataQuery);
        Optional<TransactionData> transactionData = getCurrent();
        try {
            if (mutex.tryAcquire(waitAcquireDuration.toMillis(), TimeUnit.MILLISECONDS)) {
                this.ydbMetrics.setMutexFreeSlots(mutex.availablePermits());
                final long startTime = System.nanoTime();
                DataQueryResult result;
                AsyncQueryWithStats stats = new AsyncQueryWithStats();
                try {
                    if (transactionData.isPresent()) {
                        TransactionData t = transactionData.get();
                        DataQuery query = prepareQuery(t.session, st);
                        stats.setQueryStarted(System.currentTimeMillis());
                        result = getResult(op.apply(t.session,
                                        TxControl.id(t.transaction).setCommitTx(false), query)
                                .whenComplete((a, b) -> {
                                    mutex.release();
                                    stats.setQueryFinished(System.currentTimeMillis());
                                }), st.getTimeout());
                    } else {
                        result = getResult(sessionRetryContext
                                .supplyResult(session -> {
                                    DataQuery query = prepareQuery(session, st);
                                    stats.setQueryStarted(System.currentTimeMillis());
                                    return op.apply(session, TxControl.serializableRw().setCommitTx(true), query);
                                }).whenComplete((a, b) -> {
                                    mutex.release();
                                    stats.setQueryFinished(System.currentTimeMillis());
                                }), st.getTimeout());
                    }
                    return result;
                } finally {
                    YdbTracer.getCurrentTrace().addQueryStats(st.getTableName(), st.getOperationType().name(), stats.getQueryTime());
                    this.ydbMetrics.success.add(1);
                    this.ydbMetrics.ydbQueryTime.add(org.joda.time.Duration.millis(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime)), 1);
                }
            } else {
                this.ydbMetrics.failed.add(1);
                throw new WebmasterYdbException("Too many query acquired.");
            }
        } catch (InterruptedException e) {
            this.ydbMetrics.failed.add(1);
            throw new WebmasterYdbException(e);
        } finally {
            this.ydbMetrics.processingTime.add(org.joda.time.Duration.millis(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startProcessingTime)), 1);
        }
    }

    public List<DataQueryResult> asyncExecute(List<PreparedStatement> list) {
        final long startProcessingTime = System.nanoTime();
        try {
            List<CompletableFuture<Result<DataQueryResult>>> queryResult = new ArrayList<>(list.size());
            List<DataQueryResult> result = new ArrayList<>(list.size());
            List<AsyncQueryWithStats> statsList = Collections.synchronizedList(new ArrayList<>(list.size()));
            for (var statement : list) {
                try {
                    if (mutex.tryAcquire(waitAcquireDuration.toMillis(), TimeUnit.MILLISECONDS)) {
                        final long startTime = System.nanoTime();
                        queryResult.add(this.asyncExecuteInTransactionOrTmpSession((session, txControl) -> {
                            var stats = new AsyncQueryWithStats();
                            statsList.add(stats);
                            DataQuery dataQuery = prepareQuery(session, statement);
                            stats.setQueryStarted(System.currentTimeMillis());

                            //todo add YdbTrace
                            return execute(session, txControl, statement, dataQuery).whenComplete((a, b) -> {
                                mutex.release();
                                stats.setQueryFinished(System.currentTimeMillis());

                                if (b == null) {
                                    this.ydbMetrics.success.add(1);
                                } else {
                                    this.ydbMetrics.failed.add(1);
                                }
                                this.ydbMetrics.ydbQueryTime.add(org.joda.time.Duration.millis(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime)), 1);

                            });
                        }));
                    } else {
                        this.ydbMetrics.failed.add(list.size());
                        throw new WebmasterYdbException("Too many query acquired.");
                    }
                } catch (InterruptedException e) {
                    this.ydbMetrics.failed.add(list.size());
                    throw new WebmasterYdbException(e);
                }
            }
            for (int i = 0; i < queryResult.size(); i++) {
                CompletableFuture<Result<DataQueryResult>> resultCompletableFuture = queryResult.get(i);
                DataQueryResult dataQueryResult = getResult(resultCompletableFuture);
                AsyncQueryWithStats stats = statsList.get(i);
                PreparedStatement preparedStatement = list.get(i);
                YdbTracer.getCurrentTrace().addQueryStats(
                        preparedStatement.getTableName(),
                        preparedStatement.getOperationType().name(),
                        stats.getQueryTime()
                );
                result.add(dataQueryResult);
            }
            return result;
        }finally {
            this.ydbMetrics.processingTime.add(org.joda.time.Duration.millis(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startProcessingTime)), list.size());
        }
    }


    public <T> void readTable(String tablePath, RowMapper<T> rowMapper, ReadTableSettings readTableSettings, Consumer<T> consumer) {
        sessionRetryContext.supplyStatus(session ->
                session.readTable(tablePath, readTableSettings, setReader -> {
                    int row = 0;
                    while (setReader.next()) {
                        final T data = rowMapper.get(setReader);
                        consumer.accept(data);
                        if (row++ % KEEP_ALIVE_ROWS_INTERVAL == 0) {
                            session.keepAlive();
                        }
                    }
                })
        ).join().expect("readTable failed");
    }

    public <T> T executeInTx(Supplier<T> op, TransactionMode mode) {
        return sessionRetryContext.supplyResult(session -> {

            Transaction transaction = session.beginTransaction(mode).join().expect("Can't begin transaction session");
            setCurrent(session, transaction);
            try {
                T result = op.get();
                commit();
                return CompletableFuture.completedFuture(Result.success(result));
            } catch (Throwable t) {
                log.error("Failed to execute transaction in session {} due to {}", session, t);
                rollback();
                throw t;
            } finally {
                clearCurrent();
            }
        }).join().expect("Can't execute transactional query");
    }

    private <T> CompletableFuture<Result<T>> asyncExecuteInTransactionOrTmpSession(Functions.F2<Session, TxControl, CompletableFuture<Result<T>>> op) {

        Optional<TransactionData> transactionData = getCurrent();
        if (transactionData.isPresent()) {
            TransactionData t = transactionData.get();
            return op.apply(t.session, TxControl.id(t.transaction).setCommitTx(false));
        } else {
            return sessionRetryContext
                    .supplyResult(session -> op.apply(session, TxControl.serializableRw().setCommitTx(true)));
        }
    }


    private CompletableFuture<Result<DataQueryResult>> execute(Session session,
                                                               TxControl txControl,
                                                               PreparedStatement statement,
                                                               DataQuery query) {
        Params params = statement.getParamsBuilder().apply(query);
        try {
            return withInvalidateMissingPreparedQuery(query.execute(txControl, params), session, statement.getQuerySql());
        } catch (Exception e) {
            throw new WebmasterYdbException(e);
        }
    }

    private CompletableFuture<Result<DataQueryResult>> withInvalidateMissingPreparedQuery(
            CompletableFuture<Result<DataQueryResult>> query, Session session, String sql) {
        return query.thenApply(queryResult -> {
            if (!queryResult.isSuccess() && queryResult.getCode() == StatusCode.NOT_FOUND) {
                invalidatePreparedQuery(session, sql);
            }
            return queryResult;
        });
    }

    private <T> T getResult(CompletableFuture<Result<T>> cf) {
        return getResult(cf, null);
    }

    private <T> T getResult(CompletableFuture<Result<T>> cf, Duration timeout) {
        try {
            timeout = timeout != null ? timeout : maxWaitingTimeout;
            return cf.get(timeout.toMillis(), TimeUnit.MILLISECONDS).expect("Cannot execute data query");
        } catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new WebmasterYdbException("Error in ydb get answer", e);
        }
    }

    public void commit() {
        try {
            Status result = getCurrent()
                    .orElseThrow(() -> new IllegalStateException("No current transaction"))
                    .transaction.commit()
                    .get();
            result.expect("Can't commit transaction");
        } catch (Exception e) {
            throw new WebmasterYdbException(e);
        } finally {
            clearCurrent();
        }
    }

    public void rollback() {
        try {
            getCurrent()
                    .orElseThrow(() -> new IllegalStateException("No current transaction"))
                    .transaction.rollback()
                    .get()
                    .expect("Can't rollback transaction");
        } catch (Exception e) {
            throw new WebmasterYdbException(e);
        } finally {
            clearCurrent();
        }
    }

    private Pair<String, String> getCacheId(Session s, String query) {
        return Pair.of(s.getId(), query);
    }

    private DataQuery prepareQuery(Session s, PreparedStatement statement) {
        String query = statement.getQuerySql();

        Pair<String, String> t2 = Pair.of(s.getId(), query);
        var prepareStart = System.currentTimeMillis();
        try {
            return preparedStatements.get(t2, () -> prepareQueryImpl(s, query));
        } catch (ExecutionException e) {
            throw new WebmasterYdbException(e);
        } finally {
            // add prepare stats to tracer
            YdbTracer.getCurrentTrace().addPrepareStats(
                    statement.getTableName(),
                    statement.getOperationType().name(),
                    System.currentTimeMillis() - prepareStart);
        }
    }

    private void invalidatePreparedQuery(Session s, String query) {
        log.debug("Invalidating query (session {}): {}", s, query);
        preparedStatements.invalidate(getCacheId(s, query));
    }

    private DataQuery prepareQueryImpl(Session s, String query) {
        try {
            log.debug("Preparing query (session {}): {}", s, query);
            return s.prepareDataQuery(query, new PrepareDataQuerySettings().keepInQueryCache().setTimeout(PREPARE_QUERY_TIMEOUT, TimeUnit.SECONDS))
                    .get(PREPARE_QUERY_TIMEOUT + 1, TimeUnit.SECONDS).expect("Can't prepare query");
        } catch (Exception e) {
            throw new WebmasterYdbException(e);
        }
    }

    private void setCurrent(Session session, Transaction transaction) {
        if (getCurrent().isPresent()) {
            throw new IllegalStateException("Transaction is already present");
        }
        current.set(new TransactionData(session, transaction));
    }

    private Optional<TransactionData> getCurrent() {
        return Optional.ofNullable(current.get());
    }

    private void clearCurrent() {
        current.remove();
    }

    public TableClient getTableClientForTests() {
        return tableClient;
    }

    @AllArgsConstructor
    @Data
    private static class TransactionData {
        private final Session session;
        private final Transaction transaction;
    }

    @Value
    private static class YdbMetrics {
        private static final String MAIN_LABEL = "ydb";
        private static final String DURATION = "duration";
        private static final String QUERY_TIME = "query";
        private static final String PROCESSING_TIME = "processing";
        private static final String MUTEX_FREE_COUNT = "query_free_slots";
        private static final String OPERATION_RESULT = "operation_result";
        SolomonTimer ydbQueryTime;
        SolomonTimer processingTime;
        SolomonCounter success;
        SolomonCounter failed;
        SolomonGauge<Long> mutexFreeSlots;

        public YdbMetrics(SolomonMetricRegistry solomonMetricRegistry) {
            final SolomonKey baseKey = SolomonKey.create(SolomonSensor.LABEL_CATEGORY, MAIN_LABEL);
            final SolomonTimerConfiguration solomonTimerConfiguration = makeSolomonTimeConfiguration();
            ydbQueryTime = solomonMetricRegistry.createTimer(solomonTimerConfiguration, baseKey.withLabel(DURATION,QUERY_TIME));
            processingTime = solomonMetricRegistry.createTimer(solomonTimerConfiguration, baseKey.withLabel(DURATION,PROCESSING_TIME));
            mutexFreeSlots = solomonMetricRegistry.createGauge(baseKey.withLabel(MUTEX_FREE_COUNT,"count"));
            success = solomonMetricRegistry.createSimpleCounter(baseKey.withLabel(OPERATION_RESULT, "success"), 1.0d);
            failed = solomonMetricRegistry.createSimpleCounter(baseKey.withLabel(OPERATION_RESULT, "failed"), 1.0d);
        }

        public void setMutexFreeSlots(int freeSlots) {
            this.mutexFreeSlots.set((long) freeSlots);
        }

        private SolomonTimerConfiguration makeSolomonTimeConfiguration() {
            SolomonTimerConfiguration solomonTimerConfiguration = new SolomonTimerConfiguration();
            solomonTimerConfiguration.setTimeUnit(TimeUnit.MILLISECONDS);
            solomonTimerConfiguration.setDurationIndicatorName("ydb_query_duration");
            solomonTimerConfiguration.setBuckets(Stream.of(4,8,16,32,128,512,1024,4096,8192,32000).map(org.joda.time.Duration::millis).toList());
            solomonTimerConfiguration.setGroupBy(List.of("duration,time_bucket", //Времена по таске, таски по времени
                    "time_bucket",//Времена по всем типам запуска
                    ""//Общее количество запросов
            ));
            solomonTimerConfiguration.setDurationGroupBy(List.of(""));

            return solomonTimerConfiguration;
        }
    }


    @Setter
    public static class AsyncQueryWithStats {
        private long queryStarted;
        private long queryFinished;

        public long getQueryTime() {
            return queryFinished - queryStarted;
        }
    }
}
