package ru.yandex.direct.useractionlog.writer.initdictionaries;

import java.io.File;
import java.io.FileNotFoundException;
import java.nio.file.Paths;
import java.sql.Connection;
import java.sql.SQLException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Scanner;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.stream.Collectors;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.jooq.Table;
import org.jooq.TableField;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.direct.binlog.reader.EnrichedRow;
import ru.yandex.direct.config.DirectConfigFactory;
import ru.yandex.direct.db.config.DbConfig;
import ru.yandex.direct.db.config.DbConfigFactory;
import ru.yandex.direct.dbutil.wrapper.DatabaseWrapper;
import ru.yandex.direct.dbutil.wrapper.DatabaseWrapperProvider;
import ru.yandex.direct.dbutil.wrapper.ShardedDb;
import ru.yandex.direct.mysql.MySQLBinlogState;
import ru.yandex.direct.mysql.MySQLServerBuilder;
import ru.yandex.direct.tracing.Trace;
import ru.yandex.direct.tracing.TraceProfile;
import ru.yandex.direct.useractionlog.TableNames;
import ru.yandex.direct.useractionlog.db.ReadActionLogTable;
import ru.yandex.direct.useractionlog.db.ReadWriteDictTable;
import ru.yandex.direct.useractionlog.db.ShardReplicaChooser;
import ru.yandex.direct.useractionlog.db.StateReaderWriter;
import ru.yandex.direct.useractionlog.db.UserActionLogStates;
import ru.yandex.direct.useractionlog.dict.CacheDictRepository;
import ru.yandex.direct.useractionlog.dict.ClickHouseDictRepository;
import ru.yandex.direct.useractionlog.dict.DictDataCategory;
import ru.yandex.direct.useractionlog.dict.DictRepository;
import ru.yandex.direct.useractionlog.dict.DictRequest;
import ru.yandex.direct.useractionlog.dict.DictRequestsFiller;
import ru.yandex.direct.useractionlog.dict.DictResponsesAccessor;
import ru.yandex.direct.useractionlog.dict.FreshDictValuesFiller;
import ru.yandex.direct.useractionlog.schema.ActionLogRecord;
import ru.yandex.direct.useractionlog.schema.RecordSource;
import ru.yandex.direct.useractionlog.writer.ActionProcessor;
import ru.yandex.direct.useractionlog.writer.BufferedDictRepository;
import ru.yandex.direct.useractionlog.writer.ErrorWrapper;
import ru.yandex.direct.useractionlog.writer.generator.BatchRowDictProcessing;
import ru.yandex.direct.useractionlog.writer.generator.DictFiller;
import ru.yandex.direct.useractionlog.writer.generator.DictFillerTableSwitch;
import ru.yandex.direct.useractionlog.writer.generator.RowProcessingDefaults;
import ru.yandex.direct.useractionlog.writer.generator.RowProcessingStrategy;
import ru.yandex.direct.utils.Checked;
import ru.yandex.direct.utils.Completer;
import ru.yandex.direct.utils.InterruptedRuntimeException;
import ru.yandex.direct.utils.Interrupts;
import ru.yandex.direct.utils.JsonUtils;
import ru.yandex.direct.utils.MonotonicTime;
import ru.yandex.direct.utils.NanoTimeClock;
import ru.yandex.direct.utils.io.FileUtils;

import static ru.yandex.direct.dbschema.ppc.Ppc.PPC;

/**
 * Единоразовое заполнение всех словарных таблиц. Эта задача может записывать только в пустую базу user_action_log.
 * Следует запускать перед первым запуском writer.
 * <p>
 * Используемые источники:
 * <ul>
 * <li>mysql-базы ppc</li>
 * </ul>
 */
@ParametersAreNonnullByDefault
public class InitDictionaries implements Interrupts.InterruptibleCheckedRunnable<InterruptedException> {
    private static final Duration SAVE_STATE_OCCASIONALLY_DURATION = Duration.ofSeconds(30);
    private static final String STATE_FILE = "init_dictionaries_state.json";
    private static final Logger logger = LoggerFactory.getLogger(InitDictionaries.class);
    private final int chunkSize;
    private final DbConfigFactory dbConfigFactory;
    private final DatabaseWrapperProvider databaseWrapperProvider;
    @Nullable
    private final Integer mysqlServerId;
    private final StateReaderWriter stateTable;
    private final ReadWriteDictTable dictTable;
    private final int clickHouseThreads;
    private final DictFillerTableSwitch dictFiller;
    private final ImmutableList<ImmutableSet<TableField>> fetchTableFieldsQueue;
    private final InitDictionariesState state;
    private final ReadActionLogTable readActionLogTable;
    private final Semaphore binlogStateFetchingSemaphore;
    private final Semaphore schemaReplicaMysqlSemaphore;
    private final Duration binlogKeepAliveTimeout;
    private MonotonicTime lastSaveStateTime;

    /**
     * Создать экземпляр задачи для единоразового заполнения всех словарных таблиц.
     *
     * @param clickHouseThreads      Количество конкурентных записей в clickhouse
     * @param binlogKeepAliveTimeout
     * @param mysqlServerId          Server ID, который будет использоваться при соединении как MySQL slave
     * @param chunkSize              Сколько записей за раз считывать из MySQL и записывать в ClickHouse
     */
    public InitDictionaries(
            DbConfigFactory dbConfigFactory,
            DatabaseWrapperProvider databaseWrapperProvider,
            ShardReplicaChooser shardReplicaChooser,
            int clickHouseThreads,
            Duration binlogKeepAliveTimeout,
            @Nullable Integer mysqlServerId,
            int chunkSize) {
        this.binlogKeepAliveTimeout = binlogKeepAliveTimeout;
        if (clickHouseThreads < 1) {
            throw new IllegalArgumentException(
                    "clickHouseThreads should be positive integer, got " + clickHouseThreads);
        }
        this.dbConfigFactory = dbConfigFactory;
        this.mysqlServerId = mysqlServerId;
        this.chunkSize = chunkSize;
        this.databaseWrapperProvider = databaseWrapperProvider;
        this.clickHouseThreads = clickHouseThreads;
        this.stateTable = new StateReaderWriter(
                shardReplicaChooser::getForReading,
                shardReplicaChooser::getForWriting,
                TableNames.USER_ACTION_LOG_STATE_TABLE);
        this.dictTable = new ReadWriteDictTable(
                shardReplicaChooser::getForReading,
                shardReplicaChooser::getForWriting,
                TableNames.DICT_TABLE);
        this.readActionLogTable = new ReadActionLogTable(
                shardReplicaChooser::getForReading,
                TableNames.READ_USER_ACTION_LOG_TABLE);
        this.binlogStateFetchingSemaphore = new Semaphore(4);
        this.schemaReplicaMysqlSemaphore = new Semaphore(4);

        RecordSource recordSource = RecordSource.makeDaemonRecordSource();
        dictFiller = RowProcessingDefaults.defaultRowToActionLog(recordSource).makePureDictFiller();

        fetchTableFieldsQueue = ImmutableList.of(
                ImmutableSet.of(
                        PPC.RETARGETING_CONDITIONS.RET_COND_ID,
                        PPC.RETARGETING_CONDITIONS.CONDITION_NAME),
                ImmutableSet.of(
                        PPC.CAMPAIGNS.CID,
                        PPC.CAMPAIGNS.CLIENT_ID,
                        PPC.CAMPAIGNS.DISABLED_IPS,
                        PPC.CAMPAIGNS.DISABLED_SSP,
                        PPC.CAMPAIGNS.DONT_SHOW,
                        PPC.CAMPAIGNS.GEO,
                        PPC.CAMPAIGNS.LAST_CHANGE,
                        PPC.CAMPAIGNS.NAME,
                        PPC.CAMPAIGNS.STRATEGY_DATA,
                        PPC.CAMPAIGNS.TIME_TARGET),
                ImmutableSet.of(
                        PPC.CAMP_OPTIONS.CID,
                        PPC.CAMP_OPTIONS.MINUS_WORDS),
                ImmutableSet.of(
                        PPC.PHRASES.CID,
                        PPC.PHRASES.GEO,
                        PPC.PHRASES.GROUP_NAME,
                        PPC.PHRASES.LAST_CHANGE,
                        PPC.PHRASES.PID),
                ImmutableSet.of(
                        PPC.BANNERS.BID,
                        PPC.BANNERS.CID,
                        PPC.BANNERS.TITLE,
                        PPC.BANNERS.LAST_CHANGE,
                        PPC.BANNERS.PID),
                Arrays.stream(PPC.HIERARCHICAL_MULTIPLIERS.fields())
                        .map(field -> (TableField) field)
                        .filter(field -> field != PPC.HIERARCHICAL_MULTIPLIERS.SYNTETIC_KEY_HASH)
                        .collect(ImmutableSet.toImmutableSet()),
                Arrays.stream(PPC.DEMOGRAPHY_MULTIPLIER_VALUES.fields())
                        .map(field -> (TableField) field)
                        .collect(ImmutableSet.toImmutableSet()),
                Arrays.stream(PPC.GEO_MULTIPLIER_VALUES.fields())
                        .map(field -> (TableField) field)
                        .collect(ImmutableSet.toImmutableSet()),
                Arrays.stream(PPC.RETARGETING_MULTIPLIER_VALUES.fields())
                        .map(field -> (TableField) field)
                        .collect(ImmutableSet.toImmutableSet()));
        verifyFetchTableFieldsQueue(dictFiller, fetchTableFieldsQueue);
        InitDictionariesState stateInit;
        try (Scanner scanner = new Scanner(new File(STATE_FILE))) {
            stateInit = JsonUtils.fromJson(scanner.useDelimiter("\\Z").next(), InitDictionariesState.class);
        } catch (FileNotFoundException e) {
            stateInit = new InitDictionariesState();
        }
        this.state = stateInit;

        Set<String> supportedTables = RowProcessingDefaults.rowProcessingStrategyMap(recordSource).keySet();
        if (!Util.PRIMARY_KEY_MAP.keySet().containsAll(supportedTables)) {
            throw new IllegalStateException("not all supported tables have their primary key defined");
        }
    }

    @SafeVarargs
    private static void verifyFetchTableFieldsQueue(
            DictFillerTableSwitch dictFiller, ImmutableList<ImmutableSet<TableField>>... queues) {
        Set<String> tablesFromQueue = new HashSet<>();
        for (ImmutableList<ImmutableSet<TableField>> queue : queues) {
            for (ImmutableSet<TableField> fields : queue) {
                String tableName = fields.stream()
                        .map(TableField::getTable)
                        .map(Table::getName)
                        .reduce((table1, table2) -> {
                            if (table1.equals(table2)) {
                                return table1;
                            } else {
                                throw new IllegalStateException(
                                        "Two different tables in one list: " + table1 + " and " + table2);
                            }
                        })
                        .orElseThrow(() -> new IllegalStateException("Empty field list"));
                if (tablesFromQueue.contains(tableName)) {
                    throw new IllegalStateException("No need to query one table twice: " + tableName);
                }
                tablesFromQueue.add(tableName);
            }
        }
        if (!dictFiller.getTableNames().equals(tablesFromQueue)) {
            throw new IllegalStateException("Expected exactly these tables: "
                    + dictFiller.getTableNames().stream().sorted().collect(Collectors.toList())
                    + " but got these tables: "
                    + tablesFromQueue.stream().sorted().collect(Collectors.toList()));
        }
    }

    private void saveInitDictState() {
        // Этот метод не потокобезопасный, он должен выполняться в одном synchronized-блоке с любыми изменениями
        // состояния.
        FileUtils.atomicWrite(JsonUtils.toJson(state), Paths.get(STATE_FILE).toAbsolutePath());
        lastSaveStateTime = NanoTimeClock.now();
    }

    @Override
    public void run() throws InterruptedException {
        if (!readActionLogTable.isEmpty()) {
            throw new UnsupportedOperationException("Filling non-empty database is not supported yet.");
        }

        // Если процесс записи из бинлога ещё не запускался, то нужно переналить все словарные данные,
        // даже если таблица со словарными данными непустая, т.к. процесс наливки словарных данных мог неожиданно
        // прерваться в предыдущий раз. Поэтому нет проверки на пустоту словаря.

        List<String> shardNames = dbConfigFactory.getShardNumbers(ShardedDb.PPC.toString()).stream()
                .sorted()
                .map(ShardedDb.PPC::getDbNameForShard)
                .collect(Collectors.toList());

        List<Future<Void>> futures = new ArrayList<>(shardNames.size());
        Completer.Builder completerBuilder = new Completer.Builder(Duration.ofMinutes(5));
        Semaphore clickHouseDictWriteSemaphore = new Semaphore(clickHouseThreads);
        Trace rootTrace = Trace.current();
        for (String shardName : shardNames) {
            futures.add(completerBuilder.submitVoid(
                    "init_dictionary:sql_stage:" + shardName,
                    () -> processPpcShard(clickHouseDictWriteSemaphore, rootTrace, shardName)));
        }
        try (Completer completer = completerBuilder.build()) {
            completer.waitAll();
        }
        for (Future<Void> future : futures) {
            Checked.run(future::get);
        }
    }

    private void processPpcShard(Semaphore clickHouseDictWriteSemaphore, Trace rootTrace, String shardName)
            throws InterruptedException {
        Trace.push(rootTrace);
        logger.info("InitDictionary started for shard {}", shardName);
        String traceProfileName = "processPpcShardSqlSelectStage";
        try (TraceProfile ignored = Trace.current().profile(traceProfileName, shardName)) {
            DictRepository dictRepository = new ParallelLimitedDictRepository(
                    new ClickHouseDictRepository(shardName, dictTable),
                    clickHouseDictWriteSemaphore);
            DatabaseWrapper ppcShardWrapper = databaseWrapperProvider.get(dbConfigFactory.get(shardName).getDbName());

            InitDictionariesState.PpcShardState ppcShardState;
            synchronized (state) {
                ppcShardState = state.shardStates.computeIfAbsent(ppcShardWrapper.getDbname(),
                        k -> new InitDictionariesState.PpcShardState());
                saveInitDictState();
            }

            // Инициализация разбивается на два шага.
            // 1. Заливка грязных словарных данных sql-запросами
            // 2. Приведение словарных данных к консистентной форме путём повторения собыйтий из бинлога
            // Второй пункт полностью совпадает с новым алгоритмом синхронизации с отдельными состояниями
            // для логов и словаря. Поэтому стейт в начале грязной заливки записывается как стейт словаря,
            // а стейт по окончании грязной заливки - как стейт логов. В итоге ActionProcessor сравняет эти
            // два стейта и остановится.

            UserActionLogStates states = stateTable.read.getActualStates(shardName);
            if (states.getLog() == null) {
                if (states.getDict() == null) {
                    MySQLBinlogState startState = getPpcShardBinlogState(ppcShardWrapper);
                    stateTable.write.saveDictState(shardName, startState);
                    states = UserActionLogStates.builder()
                            .withDict(startState)
                            .build();
                }
                fillDictBySelectRequests(ppcShardWrapper, dictRepository, fetchTableFieldsQueue, ppcShardState);
                MySQLBinlogState endState = getPpcShardBinlogState(ppcShardWrapper);
                stateTable.write.saveLogState(shardName, endState);
                states = UserActionLogStates.builder()
                        .withDict(states.getDict())
                        .withLog(endState)
                        .build();
            } else {
                logger.info("Log state already written, nothing to do.");
            }

            // Пока словарные данные скачивались SELECT-запросами, в базе появились новые словарные данные. Проход по
            // бинлогу от момента начала скачивания до конца скачивания должен привести словарные данные в консистентное
            // состояние на момент окончания скачивания. В довесок стоит убедиться, что этот проход решает все те
            // необработанные словарные запросы, которые возникли на прошлом шаге.
            DbConfig ppcDbConfig = dbConfigFactory.get(shardName);
            fillJournalDictByBinlog(
                    shardName,
                    ppcDbConfig,
                    ppcShardWrapper,
                    new CacheDictRepository(dictRepository),
                    states);
            logger.info("InitDictionary finished for shard {}", shardName);
        } finally {
            Trace.pop();
        }
    }

    private void fillDictBySelectRequests(
            DatabaseWrapper ppcShardWrapper,
            DictRepository dictRepository,
            ImmutableList<ImmutableSet<TableField>> fetchTableFieldsQueue,
            InitDictionariesState.PpcShardState ppcShardState
    ) throws InterruptedException {
        for (ImmutableSet<TableField> fields : fetchTableFieldsQueue) {
            try {
                String tableName = fields.iterator().next().getTable().getName();

                logger.info("Start fetching SQL table {}", tableName);

                InitDictionariesState.PpcShardState.SelectState selectState;
                synchronized (state) {
                    selectState = ppcShardState.selects.computeIfAbsent(
                            tableName,
                            k -> new InitDictionariesState.PpcShardState.SelectState());
                    saveInitDictState();
                }
                if (selectState.finished) {
                    continue;
                }

                BufferedDictRepository bufferedDictRepository =
                        new BufferedDictRepository(dictRepository);

                // Далее идёт эмуляция бинлога. Все события вставки строки в таблицу будут приходить якобы от одной из
                // реплик с инкрементальным увеличением eventId. Это нужно из-за того, что добавление зависимых
                // корректировок ставок превращается в UPDATE основной корректировки и требует записи по тому
                // же ключу.
                SqlDictFetcher sqlDictFetcher = new SqlDictFetcher(
                        fields,
                        ppcShardWrapper.getDslContext(),
                        ppcShardState.lastEventId,
                        selectState.lastReadPk,
                        chunkSize);
                try (TraceProfile ignored = Trace.current().profile(
                        "processPpcShardSqlSelectStage", ppcShardWrapper.getDbname() + " " + tableName)) {
                    sqlDictFetcher.start();
                    BlockingQueue<CompletableFuture<SqlDictFetcher.Chunk>> queue = sqlDictFetcher.getResultQueue();
                    while (true) {
                        SqlDictFetcher.Chunk chunk = Checked.get(queue.take()::get);
                        if (chunk.rows.isEmpty()) {
                            break;
                        }
                        BatchRowDictProcessing.Result result;
                        try (TraceProfile ignored2 = Trace.current().profile(
                                "BatchRowDictProcessing", ppcShardWrapper.getDbname(), chunk.rows.size())) {
                            result = BatchRowDictProcessing.handleEvents(
                                    bufferedDictRepository, dictFiller,
                                    chunk.rows,
                                    new ErrorWrapper(false));
                        }
                        ActionProcessor.logUnprocessedRequests(result.unprocessed);
                        boolean timeToFlush = lastSaveStateTime.plus(SAVE_STATE_OCCASIONALLY_DURATION)
                                .isAtOrBefore(NanoTimeClock.now());
                        if (timeToFlush
                                || bufferedDictRepository.bufferSize() > ClickHouseDictRepository.FETCH_CHUNK_SIZE) {
                            bufferedDictRepository.flush();
                            if (timeToFlush) {
                                synchronized (state) {
                                    selectState.lastReadPk = chunk.lastReadPk;
                                    ppcShardState.lastEventId = chunk.lastEventId;
                                    saveInitDictState();
                                }
                            }
                        }
                    }
                } finally {
                    sqlDictFetcher.interrupt();
                }
                bufferedDictRepository.flush();
                synchronized (state) {
                    selectState.finished = true;
                    saveInitDictState();
                }
            } catch (RuntimeException e) {
                throw new IllegalStateException("Error while fetching fields " + fields, e);
            }
        }
    }

    private void fillJournalDictByBinlog(String shardName,
                                         DbConfig ppcDbConfig,
                                         DatabaseWrapper ppcShardWrapper,
                                         DictRepository dictRepository,
                                         UserActionLogStates states)
            throws InterruptedException {
        MySQLServerBuilder schemaReplicaMysqlBuilder = new MySQLServerBuilder()
                .setGracefulStopTimeout(Duration.ZERO)
                .addExtraArgs(Collections.singletonList("--skip-innodb-use-native-aio"))
                .withNoSync(true);

        try (TraceProfile ignored = Trace.current().profile("processPpcShard:binlog", ppcShardWrapper.getDbname())) {
            new ActionProcessor.Builder()
                    .withBinlogKeepAliveTimeout(binlogKeepAliveTimeout)
                    .withBinlogStateFetchingSemaphore(binlogStateFetchingSemaphore)
                    .withDirectConfig(DirectConfigFactory.getConfig())
                    .withDbConfig(ppcDbConfig)
                    .withDictRepository(dictRepository)
                    .withBatchDuration(Duration.ofMinutes(2))
                    .withEventBatchSize(chunkSize)
                    .withRecordBatchSize(chunkSize)
                    .withInitialServerId(mysqlServerId)
                    .withMaxBufferedEvents(chunkSize)
                    .withReadWriteStateTable(stateTable)
                    .withRowProcessingStrategy(new NoLogRowProcessingStrategy(dictFiller))
                    .withSchemaReplicaMysqlBuilder(schemaReplicaMysqlBuilder)
                    .withSchemaReplicaMysqlSemaphore(schemaReplicaMysqlSemaphore)
                    .withUntilGtidSet(Objects.requireNonNull(states.getLog()).getGtidSet())
                    .withWriteActionLogTable(null)
                    .build()
                    .run();
            // Последние несколько событий до untilSet могут не добавить никаких новых словарных данных, и в таблицу
            // будет записан стейт немного старше, чем endState. Не страшно, но ведёт к холостому прогону. Поэтому явно
            // записывается стейт для словаря такой же, как стейт для логов.
            stateTable.write.saveDictState(shardName, states.getLog());
        }
    }

    /**
     * Получение MySQLBinlogState для одного шарда
     */
    private MySQLBinlogState getPpcShardBinlogState(DatabaseWrapper wrapper) {
        logger.info("Fetching binlog state from {}", wrapper.getDbname());
        try (TraceProfile ignored = Trace.current().profile("getPpcShardBinlogState", wrapper.getDbname());
             Connection connection = wrapper.getDataSource().getConnection()) {
            return MySQLBinlogState.snapshot(connection);
        } catch (SQLException exc) {
            throw new IllegalStateException(exc);
        }
    }

    private static class NoLogRowProcessingStrategy implements RowProcessingStrategy {
        private final DictFiller dictFiller;

        private NoLogRowProcessingStrategy(DictFiller dictFiller) {
            this.dictFiller = dictFiller;
        }

        @Override
        public void fillFreshDictValues(EnrichedRow row, DictResponsesAccessor dictData,
                                        FreshDictValuesFiller freshDictValues) {
            dictFiller.fillFreshDictValues(row, dictData, freshDictValues);
        }

        @Override
        public void fillDictRequests(EnrichedRow row, DictRequestsFiller dictRequests) {
            dictFiller.fillDictRequests(row, dictRequests);
        }

        @Nonnull
        @Override
        public List<ActionLogRecord> processEvent(EnrichedRow row, DictResponsesAccessor dictResponsesAccessor) {
            throw new IllegalStateException("This method should not be called");
        }

        @Override
        public Collection<DictDataCategory> provides() {
            return dictFiller.provides();
        }

        @Override
        public DictFiller makePureDictFiller() {
            return dictFiller;
        }
    }

    /**
     * Ограничивается лишь количество конкурентных записей в словарь. Чтение не ограничивается.
     * Так как процесс может захотеть получить данные, которые только что записал, то запись идёт синхронно.
     */
    private class ParallelLimitedDictRepository implements DictRepository {
        private final DictRepository forwardDict;
        private final Semaphore semaphore;

        ParallelLimitedDictRepository(DictRepository forwardDict, Semaphore semaphore) {
            this.forwardDict = forwardDict;
            this.semaphore = semaphore;
        }

        @Nonnull
        @Override
        public Map<DictRequest, Object> getData(Collection<DictRequest> dictRequests) {
            return forwardDict.getData(dictRequests);
        }

        @Override
        public void addData(Map<DictRequest, Object> data) {
            try {
                semaphore.acquire();
                try {
                    forwardDict.addData(data);
                } finally {
                    semaphore.release();
                }
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new InterruptedRuntimeException(e);
            }
        }
    }
}
