#include "cache_invalidation_service.h"

#include <travel/hotels/lib/cpp/util/sizes.h>

#include <library/cpp/logger/global/global.h>

#include <util/digest/multi.h>

namespace NTravel {
    void TCacheInvalidationService::TCounters::QueryCounters(NMonitor::TCounterTable* ct) const {
        ct->insert(MAKE_COUNTER_PAIR(NInvalidInvalidationRules));
        ct->insert(MAKE_COUNTER_PAIR(NInvalidationRulesWithUnknownSource));
        ct->insert(MAKE_COUNTER_PAIR(NInvalidationRulesWithBannedSource));
        ct->insert(MAKE_COUNTER_PAIR(NDataBytes));
        ct->insert(MAKE_COUNTER_PAIR(NInvalidationRules));
        ct->insert(MAKE_COUNTER_PAIR(NHotelsWithRulesSizeSoftLimitViolation));
        ct->insert(MAKE_COUNTER_PAIR(NHotelsWithRulesSizeHardLimitViolation));
        ct->insert(MAKE_COUNTER_PAIR(MemorySoftLimitViolation));
        ct->insert(MAKE_COUNTER_PAIR(MemoryHardLimitViolation));
    }

    TCacheInvalidationService::TCacheInvalidationService(const NTravelProto::NAppConfig::TConfigCacheInvalidationService& config)
        : Enabled(config.GetEnabled())
        , CleanupDelay(TDuration::Seconds(config.GetCleanupDelaySec()))
        , InvalidationEventLifetime(TDuration::Seconds(config.GetInvalidationEventLifetimeSec()))
        , DropRecordsWithUnknownInvalidationSource(config.GetDropRecordsWithUnknownInvalidationSource())
        , RulesSizeSoftLimit(config.GetRulesSizeSoftLimit())
        , RulesSizeHardLimit(config.GetRulesSizeHardLimit())
        , MemorySoftLimitBytes(config.GetMemorySoftLimitBytes())
        , MemoryHardLimitBytes(config.GetMemoryHardLimitBytes())
    {
        for (size_t i = 0; i < config.BannedInvalidationSourcesSize(); i++) {
            BannedInvalidationSources.insert(config.GetBannedInvalidationSources(i));
        }
    }

    void TCacheInvalidationService::RegisterCounters(NMonitor::TCounterSource& counters) {
        counters.RegisterSource(&Counters, "CacheInvalidationService");
    }

    void TCacheInvalidationService::Start() {
        if (!Enabled) {
            return;
        }
        if (Started.TrySet()) {
            CleanupThread = SystemThreadFactory()->Run([this]() { RunCleanup(); });
        } else {
            WARNING_LOG << "Duplicate call of TCacheInvalidationService::Start() is ignored" << Endl;
        }
    }

    void TCacheInvalidationService::Stop() {
        if (!Enabled) {
            return;
        }
        if (!Started) {
            WARNING_LOG << "Call of TCacheInvalidationService::Stop() before Start()" << Endl;
            return;
        }
        if (!Stopping.TrySet()) {
            return;
        }
        if (CleanupThread) {
            StopEvent.Signal();
            CleanupThread->Join();
            CleanupThread = nullptr;
        }
    }

    bool TCacheInvalidationService::IsReady() const {
        return true;
    }

    bool TCacheInvalidationService::ProcessCacheInvalidationMessage(const NTravelProto::NOfferBus::TOfferInvalidationMessage& cacheInvalidationMessage) {
        if (!Enabled) {
            return true;
        }
        const auto& event = cacheInvalidationMessage.GetEvent();
        if (event.GetOfferInvalidationSource() == NTravelProto::NOfferInvalidation::EOfferInvalidationSource::OIS_UNKNOWN) {
            Counters.NInvalidationRulesWithUnknownSource.Inc();
            WARNING_LOG << "Found invalidation event with unknown invalidation source. HotelId: " << event.GetHotelId() << Endl;
            if (DropRecordsWithUnknownInvalidationSource) {
                return true;
            }
        }
        if (BannedInvalidationSources.contains(event.GetOfferInvalidationSource())) {
            Counters.NInvalidationRulesWithBannedSource.Inc();
            WARNING_LOG << "Found invalidation event with banned invalidation source (" << event.GetOfferInvalidationSource() << "). HotelId: " << event.GetHotelId() << Endl;
            return true;
        }
        TPreKey preKey{THotelId::FromProto(event.GetHotelId()), event.GetCurrency()};
        auto& bucket = Invalidations.GetBucketForKey(preKey);
        TWriteGuard g(bucket.GetMutex());
        auto [it, inserted] = bucket.GetMap().emplace(preKey, TInvalidationInfo());
        if (inserted) {
            Counters.NDataBytes += TTotalByteSize<decltype(Invalidations)::TActualMap::value_type>()(*it);
        }
        auto& invalidations = it->second;
        for (const auto& filter : event.GetFilters()) {
            if (!filter.HasCheckInCheckOutFilter() && !filter.HasTargetIntervalFilter()) {
                Counters.NInvalidInvalidationRules.Inc();
                ERROR_LOG << "Found invalidation event with empty filter. HotelId: " << event.GetHotelId() << Endl;
                continue;
            }
            if (invalidations.GetRulesCount() >= RulesSizeHardLimit) {
                // TODO (mpivko): log this case properly, now it creates too many log records, so it's debug level (or get rid of this case?)
                DEBUG_LOG << "Hotel already has " << RulesSizeHardLimit << " invalidation rules, dropping new rule. HotelId: " << event.GetHotelId() << Endl;
                continue;
            }
            Counters.MemorySoftLimitViolation = static_cast<size_t>(Counters.NDataBytes) >= MemorySoftLimitBytes;
            Counters.MemoryHardLimitViolation = static_cast<size_t>(Counters.NDataBytes) >= MemoryHardLimitBytes;
            if (static_cast<size_t>(Counters.NDataBytes) >= MemoryHardLimitBytes) {
                ERROR_LOG << "Memory limit violation, dropping new rule. HotelId: " << event.GetHotelId() << Endl;
            }
            auto oldSz = invalidations.GetAllocSize();
            if (filter.HasCheckInCheckOutFilter()) {
                const auto& currFilter = filter.GetCheckInCheckOutFilter();
                Counters.NInvalidationRules += invalidations.AddRule(TCacheInvalidationRuleByCheckInOut(
                    currFilter.GetCheckInDate().Empty() ? TMaybe<NOrdinalDate::TOrdinalDate>() : NOrdinalDate::FromString(currFilter.GetCheckInDate()),
                    currFilter.GetCheckOutDate().Empty() ? TMaybe<NOrdinalDate::TOrdinalDate>() : NOrdinalDate::FromString(currFilter.GetCheckOutDate()),
                    TInstant::Seconds(event.GetTimestamp().seconds())));

            } else if (filter.HasTargetIntervalFilter()) {
                const auto& currFilter = filter.GetTargetIntervalFilter();
                Counters.NInvalidationRules += invalidations.AddRule(TCacheInvalidationRuleByTargetDates(
                    currFilter.GetDateFromInclusive().Empty() ? TMaybe<NOrdinalDate::TOrdinalDate>() : NOrdinalDate::FromString(currFilter.GetDateFromInclusive()),
                    currFilter.GetDateToInclusive().Empty() ? TMaybe<NOrdinalDate::TOrdinalDate>() : NOrdinalDate::FromString(currFilter.GetDateToInclusive()),
                    TInstant::Seconds(event.GetTimestamp().seconds())));
            }
            Counters.NDataBytes += static_cast<i64>(invalidations.GetAllocSize()) - oldSz;
            if (invalidations.GetRulesCount() == RulesSizeSoftLimit) {
                WARNING_LOG << "Hotel already has " << RulesSizeSoftLimit << " invalidation rules. HotelId: " << event.GetHotelId() << Endl;
                Counters.NHotelsWithRulesSizeSoftLimitViolation.Inc();
            }
            if (invalidations.GetRulesCount() == RulesSizeHardLimit) {
                ERROR_LOG << "Hotel already has " << RulesSizeHardLimit << " invalidation rules, will drop further rules. HotelId: " << event.GetHotelId() << Endl;
                Counters.NHotelsWithRulesSizeHardLimitViolation.Inc();
            }
        }
        return true;
    }
    bool TCacheInvalidationService::IsInvalidated(const TPreKey& preKey, NOrdinalDate::TOrdinalDate checkIn, NOrdinalDate::TOrdinalDate checkOut, TInstant timestamp) const {
        if (!Enabled) {
            return false;
        }
        return GetInvalidationTimestamp(preKey, checkIn, checkOut) >= timestamp;
    }

    TInstant TCacheInvalidationService::GetInvalidationTimestamp(const TPreKey& preKey, NOrdinalDate::TOrdinalDate checkIn, NOrdinalDate::TOrdinalDate checkOut) const {
        if (!Enabled) {
            return TInstant::Zero();
        }
        const auto& bucket = Invalidations.GetBucketForKey(preKey);
        TReadGuard g(bucket.GetMutex());
        auto invalidationIt = bucket.GetMap().find(preKey);
        if (invalidationIt == bucket.GetMap().end()) {
            return TInstant::Zero();
        }
        return invalidationIt->second.GetInvalidationTimestamp(checkIn, checkOut);
    }

    void TCacheInvalidationService::RunCleanup() {
        while (!Stopping) {
            StopEvent.WaitT(CleanupDelay);
            auto now = Now();
            for (auto& bucket : Invalidations.Buckets) {
                if (Stopping) {
                    break;
                }
                bool needCleanup = false;
                {
                    TReadGuard g(bucket.GetMutex());
                    for (const auto& [preKey, invalidationInfo] : bucket.GetMap()) {
                        for (const auto& [key, rule] : invalidationInfo.RulesByCheckInCheckOut) {
                            if (!IsAliveRule(&rule, now)) {
                                needCleanup = true;
                                break;
                            }
                        }
                        for (const auto& [key, rule] : invalidationInfo.RulesByTargetDates) {
                            if (!IsAliveRule(&rule, now)) {
                                needCleanup = true;
                                break;
                            }
                        }
                        if (needCleanup) {
                            break;
                        }
                    }
                }
                if (needCleanup) {
                    TWriteGuard g(bucket.GetMutex());
                    for (auto& [preKey, invalidationInfo] : bucket.GetMap()) {
                        Counters.NInvalidationRules -= invalidationInfo.GetRulesCount();
                        Counters.NDataBytes -= invalidationInfo.GetAllocSize();
                        auto hasSoftLimit = invalidationInfo.GetRulesCount() >= RulesSizeSoftLimit;
                        auto hasHardLimit = invalidationInfo.GetRulesCount() >= RulesSizeHardLimit;
                        TInvalidationInfo oldInvalidationInfo;
                        std::swap(oldInvalidationInfo, invalidationInfo);
                        for (auto& [key, rule] : oldInvalidationInfo.RulesByCheckInCheckOut) {
                            if (IsAliveRule(&rule, now)) {
                                invalidationInfo.AddRule(rule);
                            }
                        }
                        for (auto& [key, rule] : oldInvalidationInfo.RulesByTargetDates) {
                            if (IsAliveRule(&rule, now)) {
                                invalidationInfo.AddRule(rule);
                            }
                        }
                        if (hasSoftLimit && invalidationInfo.GetRulesCount() < RulesSizeSoftLimit) {
                            Counters.NHotelsWithRulesSizeSoftLimitViolation.Dec();
                        }
                        if (hasHardLimit && invalidationInfo.GetRulesCount() < RulesSizeHardLimit) {
                            Counters.NHotelsWithRulesSizeHardLimitViolation.Dec();
                        }
                        Counters.NInvalidationRules += invalidationInfo.GetRulesCount();
                        Counters.NDataBytes += invalidationInfo.GetAllocSize();
                    }
                }
            }
            Counters.MemorySoftLimitViolation = static_cast<size_t>(Counters.NDataBytes) >= MemorySoftLimitBytes;
            Counters.MemoryHardLimitViolation = static_cast<size_t>(Counters.NDataBytes) >= MemoryHardLimitBytes;
        }
    }

    bool TCacheInvalidationService::IsAliveRule(const TCacheInvalidationRuleBase* rule, TInstant now) const {
        return (rule->Timestamp + InvalidationEventLifetime) > now;
    }

    bool TCacheInvalidationService::TInvalidationInfo::AddRule(TCacheInvalidationRuleByTargetDates rule) {
        auto key = rule.GetKey();
        auto it = RulesByTargetDates.find(key);
        if (it == RulesByTargetDates.end()) {
            RulesByTargetDates[key] = std::move(rule);
            return true;
        }
        if (it->second.Timestamp < rule.Timestamp) {
            RulesByTargetDates[key] = std::move(rule);
        }
        return false;
    }

    bool TCacheInvalidationService::TInvalidationInfo::AddRule(TCacheInvalidationRuleByCheckInOut rule) {
        auto key = rule.GetKey();
        auto it = RulesByCheckInCheckOut.find(key);
        if (it == RulesByCheckInCheckOut.end()) {
            RulesByCheckInCheckOut[key] = std::move(rule);
            return true;
        }
        if (it->second.Timestamp < rule.Timestamp) {
            RulesByCheckInCheckOut[key] = std::move(rule);
        }
        return false;
    }

    size_t TCacheInvalidationService::TInvalidationInfo::GetRulesCount() const {
        return RulesByTargetDates.size() + RulesByCheckInCheckOut.size();
    }

    size_t TCacheInvalidationService::TInvalidationInfo::GetAllocSize() const {
        return GetHashMapByteSizeWithoutElementAllocations(RulesByTargetDates) +
            GetHashMapByteSizeWithoutElementAllocations(RulesByCheckInCheckOut);
    }

    size_t TCacheInvalidationService::TInvalidationInfo::CalcTotalByteSize() const {
        return sizeof(TCacheInvalidationService::TInvalidationInfo) + GetAllocSize();
    }

    TInstant TCacheInvalidationService::TInvalidationInfo::GetInvalidationTimestamp(NOrdinalDate::TOrdinalDate checkIn, NOrdinalDate::TOrdinalDate checkOut) const {
        TInstant result = TInstant::Zero();
        for (const auto&[key, rule] : RulesByTargetDates) {
            result = Max(result, rule.GetInvalidationTimestamp(checkIn, checkOut));
        }
        for (const auto& ruleIt : {RulesByCheckInCheckOut.find(std::make_pair<TMaybe<NOrdinalDate::TOrdinalDate>, TMaybe<NOrdinalDate::TOrdinalDate>>(checkIn, checkOut)),
                                   RulesByCheckInCheckOut.find(std::make_pair<TMaybe<NOrdinalDate::TOrdinalDate>, TMaybe<NOrdinalDate::TOrdinalDate>>({}, checkOut)),
                                   RulesByCheckInCheckOut.find(std::make_pair<TMaybe<NOrdinalDate::TOrdinalDate>, TMaybe<NOrdinalDate::TOrdinalDate>>(checkIn, {})),
                                   RulesByCheckInCheckOut.find(std::make_pair<TMaybe<NOrdinalDate::TOrdinalDate>, TMaybe<NOrdinalDate::TOrdinalDate>>({}, {}))}) {
            if (ruleIt == RulesByCheckInCheckOut.end()) {
                continue;
            }
            result = Max(result, ruleIt->second.Timestamp);
        }
        return result;
    }
}
