#include "queue_reader_cache.h"

#include <travel/hotels/lib/cpp/util/profiletimer.h>

#include <library/cpp/logger/global/global.h>

#include <type_traits>

using namespace NTravelProto;

namespace NTravel {

namespace NQuery {
namespace {

const TString CREATE_VERSION = R"(
    CREATE TABLE IF NOT EXISTS version(
        version INTEGER NOT NULL
    );
)";

const TString SET_ZERO_VERSION = R"(
    INSERT INTO VERSION VALUES (0);
)";

const TString CREATE_QUEUE = R"(
    CREATE TABLE IF NOT EXISTS queue (
        ClusterName TEXT NOT NULL,
        TabletIndex INTEGER NOT NULL,
        RowIndex INTEGER NOT NULL,
        Timestamp INTEGER NOT NULL,
        MessageType TEXT NOT NULL,
        Codec INTEGER NOT NULL,
        Bytes BLOB NOT NULL,
        MessageId TEXT NOT NULL,
        PRIMARY KEY (ClusterName, TabletIndex, RowIndex)
    )
)";

const TString ALTER_QUEUE_ADD_EXPIRETIMESTAMP = R"(
    ALTER TABLE queue ADD COLUMN ExpireTimestamp INTEGER
)";


const TString CREATE_TIMESTAMP_INDEX = R"(
   CREATE INDEX IF NOT EXISTS idx_queue_Timestamp ON queue (
       Timestamp
   )
)";

const TString CREATE_EXPIRETIMESTAMP_INDEX = R"(
   CREATE INDEX IF NOT EXISTS idx_queue_ExpireTimestamp ON queue (
       ExpireTimestamp
   )
)";

const TString CREATE_INDICES = R"(
    CREATE TABLE IF NOT EXISTS indices (
        ClusterName TEXT NOT NULL,
        TabletIndex INTEGER NOT NULL,
        RowIndex INTEGER NOT NULL,
        PRIMARY KEY (ClusterName, TabletIndex)
    )
)";

const TString CREATE_TABLE_DATA_VERSION = R"(
    CREATE TABLE IF NOT EXISTS data_version (
        ConfigVersion INTEGER NOT NULL
    )
)";

const TString SET_ZERO_DATA_VERSION = R"(
    INSERT INTO data_version VALUES (0);
)";


const TVector<TString> MIGRATIONS[] = {
    // 0 -> 1
    TVector{CREATE_QUEUE, CREATE_INDICES, CREATE_TIMESTAMP_INDEX, CREATE_VERSION, SET_ZERO_VERSION},
    // 1 -> 2
    TVector{ALTER_QUEUE_ADD_EXPIRETIMESTAMP, CREATE_EXPIRETIMESTAMP_INDEX},
    // 2 -> 3
    TVector{CREATE_TABLE_DATA_VERSION, SET_ZERO_DATA_VERSION},
};

//-----------------------------------------------------------------------------------
const TString SELECT_VERSION = R"(
    SELECT version FROM version
)";

const TString UPDATE_VERSION = R"(
    UPDATE version SET version = ?
)";

const TString SELECT_BUS_ROW = R"(
    SELECT ClusterName,
        TabletIndex,
        RowIndex,
        Timestamp,
        MessageId,
        MessageType,
        Codec,
        Bytes,
        ExpireTimestamp
    FROM queue
    WHERE Timestamp > ?
    ORDER BY Timestamp
)";

const TString SELECT_TABLET_INDEX = R"(
    SELECT ClusterName,
        TabletIndex,
        RowIndex
    FROM indices
)";

const TString APPEND_BUS_ROW = R"(
    INSERT INTO queue (
        ClusterName,
        TabletIndex,
        RowIndex,
        Timestamp,
        MessageType,
        Codec,
        Bytes,
        MessageId,
        ExpireTimestamp
    ) VALUES (
        ?, ?, ?, ?, ?, ?, ?, ?, ?
    )
)";

const TString UPDATE_INDEX = R"(
    REPLACE INTO indices (
        ClusterName,
        TabletIndex,
        RowIndex
    ) VALUES (
        ?, ?, ?
    )
)";

const TString DELETE_OLD_ROWS_BY_TIMESTAMP = R"(
    DELETE
    FROM queue
    WHERE Timestamp <= ?
)";

const TString DELETE_OLD_ROWS_BY_EXPIRETIMESTAMP = R"(
    DELETE
    FROM queue
    WHERE ExpireTimestamp IS NOT NULL AND ExpireTimestamp <= ?
)";

const TString TRIM_ROWS = R"(
    DELETE
    FROM queue
    WHERE ClusterName = ? AND
        TabletIndex = ? AND
        RowIndex < ?
)";

const TString SELECT_DATA_VERSION = R"(
    SELECT ConfigVersion FROM data_version
)";


const TString UPDATE_DATA_VERSION = R"(
    UPDATE data_version SET ConfigVersion = ?
)";

const TString CLEAR_ALL_ROWS = R"(
    DELETE FROM queue
)";

const TString CLEAR_ALL_INDECIES = R"(
    DELETE FROM indices
)";

} // static
} // namespace NQuery

#define CACHE_LOG LogPrefix_

void TYtQueueReaderCache::TCounters::QueryCounters(NMonitor::TCounterTable* ct) const {
    ct->insert(MAKE_COUNTER_PAIR(NCacheBroken));
    ct->insert(MAKE_COUNTER_PAIR(NWriteQueueMessages));
    ct->insert(MAKE_COUNTER_PAIR(NTransactions));
    ct->insert(MAKE_COUNTER_PAIR(NReadItems));
}

TYtQueueReaderCache::TYtQueueReaderCache(const NTravelProto::NAppConfig::TConfigYtQueueReader& config)
    : LogPrefix_(config.GetCacheFileName() + ": ")
    , MaxWriteQueueItems_(config.GetCacheMaxWriteQueueItems())
    , MinBatchSize_(config.GetCacheMinBatchSizeMessages())
    , MaxBatchSize_(config.GetCacheMaxBatchSizeMessages())
    , DataVersion_(config.GetCacheDataVersion())
{
    if (config.GetCacheFileName().Empty()) {
        DEBUG_LOG << CACHE_LOG << config.GetTablePath() << ": Database filepath is empty, cache disabled" << Endl;
        return;
    }
    try {
        THolder<NSQLite::TDatabase> db(new NSQLite::TDatabase(config.GetCacheFileName()));
        InitDb(db.Get());
        Db_.Reset(db.Release());
    } catch (const NSQLite::TExceptionError&) {
        ERROR_LOG << CACHE_LOG << "Failed to open database, cache disabled: " << CurrentExceptionMessage() << Endl;
        Counters_.NCacheBroken = 1;
        return;
    }
    INFO_LOG << CACHE_LOG << "Database " << config.GetCacheFileName() << " is successfully opened, cache enabled" << Endl;
    Working_.Set();
}

TYtQueueReaderCache::~TYtQueueReaderCache() {
}

void TYtQueueReaderCache::InitDb(NSQLite::TDatabase* db) {
    db->ApplyPragma("locking_mode = EXCLUSIVE");
    int version = GetDbVersion(db);
    while (version < (int)Y_ARRAY_SIZE(NQuery::MIGRATIONS)) {
        INFO_LOG << CACHE_LOG << "Migrating DB from " << version << " to " << version + 1 << Endl;
        auto tr = db->Transaction();
        for (const auto& query: NQuery::MIGRATIONS[version]) {
            DEBUG_LOG << CACHE_LOG << "Query: ------" << Endl << query << Endl << "-------------" << Endl;
            tr->Exec(query);
        }
        ++version;
        DEBUG_LOG << CACHE_LOG << "Set DB version to " << version << Endl;
        tr->Exec(NQuery::UPDATE_VERSION, {(i64)version});
        tr->Commit();
        INFO_LOG << CACHE_LOG << "Migrated DB to " << version << Endl;
    }
    INFO_LOG << CACHE_LOG << "Will work with DB structure version " << version << Endl;
}

int TYtQueueReaderCache::GetDbVersion(NSQLite::TDatabase* db) {
    auto tr = db->Transaction();
    try {
        auto q = tr->Exec(NQuery::SELECT_VERSION);
        TVector<NSQLite::TQueryParams> row;
        if (q->Fetch(&row)) {
            return NSQLite::NumericCast<int>(row[0]);
        }
        WARNING_LOG << CACHE_LOG << "DB Version not specified. assume 0" << Endl;
        return 0;
    } catch (const NSQLite::TExceptionError&) {
        WARNING_LOG << CACHE_LOG << "Failed to determine DB version, assume 0" << Endl;
        return 0;
    }
}

void TYtQueueReaderCache::RegisterCounters(NMonitor::TCounterSource& source, const TString& name) {
    source.RegisterSource(&Counters_, name);
}

void TYtQueueReaderCache::Start() {
    if (!Working_) {
        return;
    }
    with_lock (WriteLock_) {
        WriteThread_ = SystemThreadFactory()->Run([this](){ WriteThreadLoop(); });
    }
    CheckDataVersion();
}

void TYtQueueReaderCache::Stop() {
    TAutoPtr<IThreadFactory::IThread> thr;
    with_lock (WriteLock_) {
        if (!WriteThread_) {
            return;
        }
        thr = WriteThread_.Release();
    }
    StopFlag_.Set();
    WakeUpEvent_.Signal();
    thr->Join();
}

void TYtQueueReaderCache::DeleteOldRows(TInstant now, TDuration maxAge, bool checkExpireTimestamp) {
    WithTransaction("DeleteOldRows", [this, now, maxAge, checkExpireTimestamp](NSQLite::TTransaction& tr) {
        TProfileTimer started;
        size_t numberOfChanges = tr.Exec(NQuery::DELETE_OLD_ROWS_BY_TIMESTAMP, {i64((now - maxAge).MilliSeconds())})->Changes();
        DEBUG_LOG << CACHE_LOG << "Deleted " << numberOfChanges << " rows by Timestamp from 'queue' table in " << started.Step() << Endl;
        if (checkExpireTimestamp) {
            size_t numberOfChanges2 = tr.Exec(NQuery::DELETE_OLD_ROWS_BY_EXPIRETIMESTAMP, {i64(now.MilliSeconds())})->Changes();
            DEBUG_LOG << CACHE_LOG << "Deleted " << numberOfChanges2 << " rows by ExpireTimestamp from 'queue' table in " << started.Step() << Endl;
        }
    });
}

void TYtQueueReaderCache::Trim(const TYtQueueMessageOrigin& till) {
    WithTransaction("Trim", [this, &till](NSQLite::TTransaction& tr) {
        TProfileTimer started;
        size_t numberOfChanges = tr.Exec(NQuery::TRIM_ROWS, {till.ClusterName, i64(till.TabletIndex), till.RowIndex + 1})->Changes();
        DEBUG_LOG << CACHE_LOG << "Trimmed " << numberOfChanges << " rows from 'queue' table in " << started.Get() << Endl;
    });
}

void TYtQueueReaderCache::ReadRows(TInstant since, std::function<bool(const TYtQueueMessagePacked& message)> cb) {
    WithTransaction("ReadRows", [&since, &cb, this](NSQLite::TTransaction& tr) {
        DEBUG_LOG << CACHE_LOG << "Reading rows since " << since << Endl;
        auto q = tr.Exec(NQuery::SELECT_BUS_ROW, {i64(since.MilliSeconds())});
        TVector<NSQLite::TQueryParams> row;
        size_t count = 0;
        while (q->Fetch(&row)) {
            ++count;
            Counters_.NReadItems.Inc();
            TYtQueueMessagePacked message;
            message.Timestamp = TInstant::MilliSeconds(NSQLite::NumericCast<ui64>(row[3]));
            message.MessageId = ToString(row[4]);
            message.MessageType = ToString(row[5]);
            message.Codec = EMessageCodec(NSQLite::NumericCast<std::underlying_type_t<EMessageCodec>>(row[6]));
            message.BytesPacked = ToString(row[7]);
            if (std::holds_alternative<NSQLite::TNone>(row[8])) {
                message.ExpireTimestamp = TInstant();
            } else {
                message.ExpireTimestamp = TInstant::MilliSeconds(NSQLite::NumericCast<ui64>(row[8]));
            }
            message.Origin.ClusterName = ToString(row[0]);
            message.Origin.TabletIndex = NSQLite::NumericCast<int>(row[1]);
            message.Origin.RowIndex = NSQLite::NumericCast<i64>(row[2]);
            if (!cb(message)) {
                break;
            }
        }
        INFO_LOG << CACHE_LOG << "Successfully read " << count << " rows" << Endl;
    });
}

void TYtQueueReaderCache::ReadIndicies(std::function<void(const TYtQueueMessageOrigin& origin)> cb) {
    WithTransaction("ReadIndicies", [this, &cb](NSQLite::TTransaction& tr) {
        auto q = tr.Exec(NQuery::SELECT_TABLET_INDEX);
        TVector<NSQLite::TQueryParams> row;
        size_t count = 0;
        while (q->Fetch(&row)) {
            ++count;
            TYtQueueMessageOrigin origin;
            origin.ClusterName = ToString(row[0]);
            origin.TabletIndex = NSQLite::NumericCast<int>(row[1]);
            origin.RowIndex = NSQLite::NumericCast<i64>(row[2]);
            cb(origin);
        }
        INFO_LOG << CACHE_LOG << "Successfully got " << count << " indeces from DB" << Endl;
    });
}

void TYtQueueReaderCache::Write(TVector<TYtQueueMessagePacked>&& messages, const TYtQueueMessageOrigin& origin) {
    if (!Working_) {
        return;
    }
    bool overflow = false;
    with_lock (WriteLock_) {
        if (!WriteItem_) {
            WriteItem_ = new TWriteItem;
        }
        WriteItem_->Messages.reserve(WriteItem_->Messages.size() + messages.size());
        for (auto& msg: messages) {
            WriteItem_->Messages.push_back(std::move(msg));
        }
        Counters_.NWriteQueueMessages += messages.size();
        size_t cnt = WriteItem_->Messages.size();
        WriteItem_->Indecies[TOriginKey{origin.ClusterName, origin.TabletIndex}] = origin.RowIndex;
        if (cnt > MaxBatchSize_) {
            WriteQueue_.push_back(WriteItem_);
            WakeUpEvent_.Signal();
            WriteItem_ = nullptr;
            if (WriteQueue_.size() > MaxWriteQueueItems_) {
                ERROR_LOG << CACHE_LOG << "Write queue items overflow (" << WriteQueue_.size() << " items) -> disable cache" << Endl;
                overflow = true;
            }
        } else if (cnt >= MinBatchSize_ && WriteQueue_.empty()) {
            WakeUpEvent_.Signal();
        }
    }
    if (overflow) {
        OnCacheBroken();
    }
}

void TYtQueueReaderCache::CheckDataVersion() {
    WithTransaction("CheckDataVersion", [this](NSQLite::TTransaction& tr) {
        auto q = tr.Exec(NQuery::SELECT_DATA_VERSION);
        TVector<NSQLite::TQueryParams> row;
        if (!q->Fetch(&row)) {
            throw yexception() << "No data version in DB, this cannot be so";
        }
        i64 actualVersion = NSQLite::NumericCast<i64>(row[0]);
        if (DataVersion_ == actualVersion) {
            INFO_LOG << CACHE_LOG << "Data version is OK: " << DataVersion_ << Endl;
            return;
        }
        WARNING_LOG << CACHE_LOG << "Data version DOES NOT MATCH: expected " << DataVersion_ << ", got " << actualVersion << ". Clearing the cache! " << Endl;
        tr.Exec(NQuery::UPDATE_DATA_VERSION, {DataVersion_});
        tr.Exec(NQuery::CLEAR_ALL_ROWS, {});
        tr.Exec(NQuery::CLEAR_ALL_INDECIES, {});
    });
}

void TYtQueueReaderCache::WithTransaction(const TString& name, std::function<void(NSQLite::TTransaction&)> func) {
    bool failed = false;
    with_lock (DbLock_) {
        if (!Db_) {
            return;
        }
        try {
            auto tr = Db_->Transaction();
            Counters_.NTransactions.Inc();
            func(*tr.Get());
            tr->Commit();
        } catch (const NSQLite::TExceptionError&) {
            ERROR_LOG << CACHE_LOG << "Cache operation exception during '" << name << "', cache disabled: " << CurrentExceptionMessage() << Endl;
            failed = true;
        }
    }
    if (failed) {
        OnCacheBroken();
    }
}

void TYtQueueReaderCache::WriteThreadLoop() {
    while (!StopFlag_) {
        WakeUpEvent_.WaitI();
        while (true) {
            TWriteItemRef item;
            with_lock (WriteLock_) {
                if (WriteQueue_.empty()) {
                    if (WriteItem_ && WriteItem_->Messages.size() >= MinBatchSize_) {
                        item.Swap(WriteItem_);
                    }
                } else {
                    item = WriteQueue_.front();
                    WriteQueue_.pop_front();
                }
            }
            if (!item) {
                break;
            }
            TProfileTimer started;
            WithTransaction("Write", [item](NSQLite::TTransaction& tr) {
                for (const auto& msg: item->Messages) {
                    TVector<NSQLite::TQueryParams> params = {
                        msg.Origin.ClusterName, i64(msg.Origin.TabletIndex), i64(msg.Origin.RowIndex),
                        i64(msg.Timestamp.MilliSeconds()), msg.MessageType, i64(msg.Codec),
                        NSQLite::TBlob{msg.BytesPacked}, msg.MessageId,
                    };
                    if (msg.ExpireTimestamp) {
                        params.push_back(i64(msg.ExpireTimestamp.MilliSeconds()));
                    } else {
                        params.push_back(NSQLite::TNone());
                    }
                    tr.Exec(NQuery::APPEND_BUS_ROW, params);
                }
                for (auto it = item->Indecies.begin(); it != item->Indecies.end(); ++it) {
                    tr.Exec(NQuery::UPDATE_INDEX, {it->first.first, i64(it->first.second), i64(it->second)});
                }
            });
            Counters_.NWriteQueueMessages -= item->Messages.size();
            DEBUG_LOG << CACHE_LOG << "Written " << item->Messages.size() << " messages and "
                      << item->Indecies.size() << " indecies in " << started.Get() << Endl;
        }
    }
}

void TYtQueueReaderCache::OnCacheBroken() {
    with_lock (DbLock_) {
        Db_.Reset();
    }
    Counters_.NCacheBroken = 1;
    Working_.Clear();
    with_lock (WriteLock_) {
        WriteItem_.Reset();
        WriteQueue_.clear();
        Counters_.NWriteQueueMessages = 0;
    }
}

}// namespace NTravel
