#include "processor.h"

#include <drive/backend/logging/logging.h>

#include <drive/library/cpp/network/data/data.h>
#include <drive/library/cpp/threading/eventlogger.h>

#include <library/cpp/regex/pcre/regexp.h>
#include <library/cpp/string_utils/base64/base64.h>
#include <library/cpp/uilangdetect/uilangdetect.h>

#include <rtline/library/geometry/coord.h>
#include <rtline/library/json/exception.h>
#include <rtline/util/events_rate_calcer.h>
#include <rtline/util/types/exception.h>

IRequestProcessor::IRequestProcessor(const IRequestProcessorConfig& config, IReplyContext::TPtr context, const IServerBase* server, const TString& typeName)
    : TBase(server, typeName)
    , Context(context)
    , HandlerName(config.GetName())
    , AccessControlAllowOrigin(config.GetAccessControlAllowOrigin())
    , RateLimitFreshness(config.GetRateLimitFreshness())
{
}

IReplyContext::TPtr IRequestProcessor::GetContext() const {
    return Context;
}

namespace {
    TRWMutex RateMeterMutex;
    TMap<TString, THolder<TEventRate<100000>>> RateMeters;
}

bool IRequestProcessor::CheckRequestsRateLimit() const {
    if (RateLimitRefresh.load(std::memory_order_acquire) + RateLimitFreshness.Seconds() < Now().Seconds()) {
        const ui32 rateLimit = BaseServer->GetSettings().GetHandlerValueDef<ui32>(TypeName, "rate_limit", BaseServer->GetSettings().GetHandlerValueDef<ui32>("default", "rate_limit", 0));
        RateLimit = rateLimit;
        RateLimitRefresh = Now().Seconds();
    }
    if (!RateLimit.load(std::memory_order_acquire)) {
        return true;
    }
    TEventRate<100000>* er = nullptr;
    {
        TReadGuard rg(RateMeterMutex);
        auto it = RateMeters.find(TypeName);
        if (it != RateMeters.end()) {
            er = it->second.Get();
        };
    }
    {
        TWriteGuard rg(RateMeterMutex);
        er = RateMeters.emplace(TypeName, MakeHolder<TEventRate<100000>>()).first->second.Get();
    }
    if (!er) {
        return true;
    }
    TInstant start;
    TInstant finish;
    ui64 eventsCount;
    er->GetInterval(TDuration::Seconds(1), start, finish, eventsCount);
    if (eventsCount < RateLimit.load(std::memory_order_acquire)) {
        er->Hit();
        return true;
    }
    return false;
}

bool IRequestProcessor::IsAndroid(EApp app) {
    switch (app) {
        case EApp::AndroidClient:
        case EApp::AndroidService:
            return true;
        default:
            return false;
    }
}

IServerReportBuilder::TCtx IRequestProcessor::GetReportContext() const {
    return {
        Context,
        AccessControlAllowOrigin,
        TVersionedKey(TypeName, Version).ToString()
    };
}

void IRequestProcessor::ProcessException(TJsonReport::TGuard& g, const TCodedException& exception) const {
    g.SetCode(exception);
}

void IRequestProcessor::Process() {
    auto report = MakeAtomicShared<TJsonReport>(GetReportContext());
    report->SetDumpEventLog(DumpEventLog);
    report->SetReportDebugInfo(ReportDebugInfo);

    auto secretVersion = Context->GetRequestData().HeaderIn("SecretVersion");
    if (secretVersion && !secretVersion->equal("0")) {
        report->SetSecretVersion(*secretVersion);
        report->SetSecretKey(Base64Decode(
            GetHandlerSetting<TString>(TStringBuilder() << "report_secret." << *secretVersion << ".data").GetOrElse(TString{})
        ));
    }

    Process(report, [this] (TJsonReport::TGuard& g) {
        R_ENSURE(CheckRequestsRateLimit(), HTTP_TOO_MANY_REQUESTS, "too many requests");
        DoProcess(g);
    });
}

void IRequestProcessor::Process(IServerReportBuilder::TPtr report, std::function<void(TJsonReport::TGuard&)>&& f) {
    const auto& configHttpCodes = BaseServer->GetHttpStatusManagerConfig();
    TJsonReport::TGuard g(report);

    try {
        f(g);
    } catch (const TCodedException& e) {
        ProcessException(g, e);
    } catch (const yexception& e) {
        TCodedException ce(configHttpCodes.UnknownErrorStatus);
        ce << e.AsStrBuf();
        if (auto backtrace = e.BackTrace()) {
            ce.AddInfo("backtrace", NJson::ToJson(*backtrace));
        } else {
            auto bt = TBackTrace::FromCurrentException();
            ce.AddInfo("backtrace", NJson::ToJson(bt));
        }
        ProcessException(g, ce);
    } catch (...) {
        TCodedException ce(configHttpCodes.UnknownErrorStatus);
        auto bt = TBackTrace::FromCurrentException();
        ce.AddInfo("backtrace", NJson::ToJson(bt));
        ce.AddInfo("exception", CurrentExceptionInfo());
        ProcessException(g, ce);
    }
}

TMaybe<TGeoCoord> IRequestProcessor::GetUserLocation(const TString& cgiUserLocation) const {
    Y_ENSURE_BT(Context);
    if (cgiUserLocation) {
        const TMaybe<TGeoCoord> src = GetValue<TGeoCoord>(Context->GetCgiParameters(), cgiUserLocation, /*required=*/false);
        if (src) {
            return *src;
        }
    }
    const TServerRequestData& rd = Context->GetRequestData();
    TStringBuf lat = rd.HeaderInOrEmpty("Lat");
    TStringBuf lon = rd.HeaderInOrEmpty("Lon");
    if (!lat || !lon) {
        return {};
    }

    auto latitude = ParseValue<double>(lat);
    auto longitude = ParseValue<double>(lon);
    return TGeoCoord(longitude, latitude);
}

IRequestProcessor::EApp IRequestProcessor::GetUserApp() const {
    Y_ENSURE_BT(Context);
    const auto& rd = Context->GetRequestData();
    const auto& headers = rd.HeadersIn();
    if (headers.contains("AC_AppBuild")) {
        return EApp::AndroidClient;
    }
    if (headers.contains("AS_AppBuild")) {
        return EApp::AndroidService;
    }
    if (headers.contains("IC_AppBuild")) {
        return EApp::iOS;
    }
    if (headers.contains("BC_AppBuild")) {
        return EApp::Business;
    }
    if (headers.contains("B2C_AppBuild")) {
        return EApp::WebDesktop;
    }
    if (headers.contains("B2CM_AppBuild")) {
        return EApp::WebMobile;
    }
    return EApp::Unknown;
}

ui32 IRequestProcessor::GetAppBuild() const {
    Y_ENSURE_BT(Context);
    auto platform = GetUserApp();
    TString buildStr;
    if (platform == EApp::iOS) {
        buildStr = TString(Context->GetRequestData().HeaderInOrEmpty("IC_AppBuild"));
    } else if (platform == EApp::AndroidClient) {
        buildStr = TString(Context->GetRequestData().HeaderInOrEmpty("AC_AppBuild"));
    } else if (platform == EApp::AndroidService) {
        buildStr = TString(Context->GetRequestData().HeaderInOrEmpty("AS_AppBuild"));
    } else if (platform == EApp::Business) {
        buildStr = TString(Context->GetRequestData().HeaderInOrEmpty("BC_AppBuild"));
    } else if (platform == EApp::WebDesktop) {
        buildStr = TString(Context->GetRequestData().HeaderInOrEmpty("B2C_AppBuild"));
    } else if (platform == EApp::WebMobile) {
        buildStr = TString(Context->GetRequestData().HeaderInOrEmpty("B2CM_AppBuild"));
    }
    ui32 appBuild = 0;
    TryFromString(buildStr, appBuild);
    return appBuild;
}

ELocalization IRequestProcessor::GetLocale() const {
    if (Locale) {
        return *Locale;
    }

    if (!Locale) {
        const auto& cgi = Context->GetCgiParameters();
        const auto& lang = cgi.Get("lang");
        if (lang) {
            Locale = FromStringWithDefault<ELocalization>(lang, DefaultLocale);
        }
    }

    auto detectUserLanguage = GetHandlerSetting<bool>("detect_user_language").GetOrElse(true);
    auto language = TMaybe<ELanguage>();
    if (!Locale && detectUserLanguage) {
        const auto& rd = Context->GetRequestData();
        language = TryDetectUserLanguage(rd);
        if (language) {
            switch (*language) {
            case LANG_CZE:
                Locale = ELocalization::Cze;
                break;
            case LANG_ENG:
                Locale = ELocalization::Eng;
                break;
            case LANG_GER:
                Locale = ELocalization::Ger;
                break;
            case LANG_RUS:
                Locale = ELocalization::Rus;
                break;
            default:
                Locale = DefaultLocale;
                break;
            }
        }
    }

    if (!Locale) {
        Locale = DefaultLocale;
    }

    auto evlog = NThreading::GetEventLogger();
    if (evlog) {
        evlog->AddEvent(NJson::TMapBuilder
            ("event", "GetLocale")
            ("language", ToString(language))
            ("locale", ToString(Locale))
        );
    }

    return *Locale;
}

const NDrive::ISettingGetter& IRequestProcessor::GetSettings() const {
    if (Settings) {
        return *Settings;
    }
    if (BaseServer) {
        return BaseServer->GetSettings();
    }
    return Default<TFakeSettings>();
}

namespace {
    TRWMutex Mutex;
    TSet<TString> RegisteredHandlers;
}

IRequestProcessorConfig::IRequestProcessorConfig(const TVersionedKey& procKey)
    : Name(procKey.ToString())
    , Type(procKey.GetName())
    , Key(procKey)
{
    {
        TReadGuard rg(Mutex);
        if (!RegisteredHandlers.contains(Name)) {
            rg.Release();
            TWriteGuard wg(Mutex);
            if (!RegisteredHandlers.contains(Name)) {
                NDrive::TUnistatSignals::RegisterSignals(Name);
                RegisteredHandlers.emplace(Name);
            }
        }
    }
}

IRequestProcessorConfig::~IRequestProcessorConfig() {
}

void IRequestProcessorConfig::CheckServerForProcessor(const IServerBase* /*server*/) const {
}

void IRequestProcessorConfig::ReadDefaults(const TYandexConfig::Section* section) {
    const TYandexConfig::Directives& directives = section->GetDirectives();
    AccessControlAllowOrigin = directives.Value("Access-Control-Allow-Origin", AccessControlAllowOrigin);
    AccessControlAllowOriginRegEx = AccessControlAllowOrigin ? MakeHolder<TRegExMatch>(AccessControlAllowOrigin) : nullptr;
    DumpEventLog = directives.Value("DumpEventLog", DumpEventLog);
    ReportDebugInfo = directives.Value("ReportDebugInfo", ReportDebugInfo);
    RequestTimeout = directives.Value("RequestTimeout", RequestTimeout);
}

IRequestProcessor::TPtr IRequestProcessorConfig::ConstructProcessor(IReplyContext::TPtr context, const IServerBase* server) const {
    NDrive::TUnistatSignals::OnAccess(Name);
    IRequestProcessor::TPtr result = DoConstructProcessor(context, server);
    if (result) {
        result->SetDumpEventLog(DumpEventLog);
        result->SetReportDebugInfo(ReportDebugInfo);
    }
    return result;
}

void IRequestProcessorConfig::Init(const TYandexConfig::Section* section) {
    const TYandexConfig::Directives& directives = section->GetDirectives();
    AdditionalCgi = directives.Value("AdditionalCgi", AdditionalCgi);
    DumpEventLog = directives.Value("DumpEventLog", DumpEventLog);
    OverrideCgi = directives.Value("OverrideCgi", OverrideCgi);
    OverrideCgiPart = directives.Value("OverrideCgiPart", OverrideCgiPart);
    OverridePost = directives.Value("OverridePost", OverridePost);
    AliasFor = directives.Value("AliasFor", AliasFor);
    AccessControlAllowOrigin = directives.Value("Access-Control-Allow-Origin", AccessControlAllowOrigin);
    AccessControlAllowOriginRegEx = AccessControlAllowOrigin ? MakeHolder<TRegExMatch>(AccessControlAllowOrigin) : nullptr;
    ReportDebugInfo = directives.Value("ReportDebugInfo", ReportDebugInfo);
    RequestTimeout = directives.Value("RequestTimeout", RequestTimeout);
    RateLimitFreshness = directives.Value("RateLimitFreshness", RateLimitFreshness);
    DoInit(section);
}

void IRequestProcessorConfig::ToString(IOutputStream& os) const {
    if (!!AliasFor) {
        os << "AliasFor: " << AliasFor << Endl;
    }
    os << "AdditionalCgi: " << AdditionalCgi << Endl;
    os << "DumpEventLog: " << DumpEventLog << Endl;
    os << "OverrideCgi: " << OverrideCgi << Endl;
    os << "OverrideCgiPart: " << OverrideCgiPart << Endl;
    os << "OverridePost: " << OverridePost << Endl;
    os << "ReportDebugInfo: " << ReportDebugInfo << Endl;
    os << "RequestTimeout: " << RequestTimeout << Endl;
    os << "RateLimitFreshness: " << RateLimitFreshness << Endl;
    DoToString(os);
}
