#include "request_rate_monitor.h"

#include <library/cpp/logger/global/global.h>

namespace NCaptchaServer {
    TCaptchaRequestRateMonitor::TCaptchaRequestRateMonitor(const TCaptchaConfig& config, TCaptchaFallbackStateAggregator& fallbackStateAggregator)
        : FallbackStateAggregator(fallbackStateAggregator)
        , CheckerThread(CheckLoop, this)
    {
        const auto& curConfig = config.GetFallback().GetRequestRateMonitor();
        Enabled = curConfig.GetEnabled();
        CheckPeriodMs = curConfig.GetCheckPeriodMs();
        FallbackActivationRateLimit = curConfig.GetFallbackActivationRateLimit();
        FallbackActivationChecksRequired = curConfig.GetFallbackActivationChecksRequired();
        FallbackDeactivationRateLimit = curConfig.GetFallbackDeactivationRateLimit();
        if (!FallbackDeactivationRateLimit) {
            FallbackDeactivationRateLimit = FallbackActivationRateLimit;
        }
        FallbackDeactivationChecksRequired = curConfig.GetFallbackDeactivationChecksRequired();

        if (!config.GetFallback().GetActive() && !config.GetFallback().GetLoaded()) {
            Enabled = false;
        }

        LastCheck = Now();

        if (Enabled) {
            CheckerThread.Start();
        }
    }

    TCaptchaRequestRateMonitor::~TCaptchaRequestRateMonitor() {
        if (Enabled) {
            StopEvent.Signal();
            CheckerThread.Join();
        }
    }

    void* TCaptchaRequestRateMonitor::CheckLoop(void* ptr) {
        TCaptchaRequestRateMonitor* thisptr = reinterpret_cast<TCaptchaRequestRateMonitor*>(ptr);
        TDuration checkPeriod = TDuration::MilliSeconds(thisptr->CheckPeriodMs);
        while (true) {
            TInstant nextCheck = thisptr->LastCheck + checkPeriod;
            if (thisptr->StopEvent.WaitD(nextCheck)) {
                return nullptr;
            }

            auto requestsSinceLastCheck = AtomicSwap(&thisptr->RequestsCounter, 0);
            switch (thisptr->State) {
                case EState::Normal:
                    thisptr->CheckNormal(requestsSinceLastCheck);
                    break;
                case EState::ConfirmingFallbackActivation:
                    thisptr->CheckConfirmingFallbackActivation(requestsSinceLastCheck);
                    break;
                case EState::ConfirmingFallbackDeactivation:
                    thisptr->CheckConfirmingFallbackDeactivation(requestsSinceLastCheck);
                    break;
                case EState::FallbackActive:
                    thisptr->CheckFallbackActive(requestsSinceLastCheck);
                    break;
            };

            thisptr->LastCheck = nextCheck;
        }
    }

    void TCaptchaRequestRateMonitor::CheckNormal(TAtomicBase requests) {
        if (requests >= FallbackActivationRateLimit) {
            ActivationChecksCounter = 0;
            State = EState::ConfirmingFallbackActivation;
            CheckConfirmingFallbackActivation(requests);
        }
    }

    void TCaptchaRequestRateMonitor::CheckConfirmingFallbackActivation(TAtomicBase requests) {
        if (requests >= FallbackActivationRateLimit) {
            ActivationChecksCounter++;
            if (ActivationChecksCounter >= FallbackActivationChecksRequired) {
                NOTICE_LOG << "High request rate detected (" << requests << " requests at last check)" << Endl;
                FallbackStateAggregator.SetFallbackState(ECaptchaFallbackActivator::RequestRateMonitor, true);
                State = EState::FallbackActive;
            }
        } else {
            State = EState::Normal;
        }
    }

    void TCaptchaRequestRateMonitor::CheckFallbackActive(TAtomicBase requests) {
        if (requests < FallbackDeactivationRateLimit) {
            DeactivationChecksCounter = 0;
            State = EState::ConfirmingFallbackDeactivation;
            CheckConfirmingFallbackDeactivation(requests);
        }
    }

    void TCaptchaRequestRateMonitor::CheckConfirmingFallbackDeactivation(TAtomicBase requests) {
        if (requests < FallbackDeactivationRateLimit) {
            DeactivationChecksCounter++;
            if (DeactivationChecksCounter >= FallbackDeactivationChecksRequired) {
                NOTICE_LOG << "Request rate back to normal (" << requests << " requests at last check)" << Endl;
                FallbackStateAggregator.SetFallbackState(ECaptchaFallbackActivator::RequestRateMonitor, false);
                State = EState::Normal;
            }
        } else {
            State = EState::FallbackActive;
        }
    }
}
