#include "module.h"

#include "stats.h"
#include <balancer/modules/aab_cookie_verify/stats.h_serialized.h>
#include <util/system/sanitizers.h>

#include <balancer/kernel/cookie/cookie.h>
#include <balancer/kernel/helpers/errors.h>
#include <balancer/kernel/http/parser/common_headers.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/process/thread_info.h>

#include <library/cpp/string_utils/base64/base64.h>
#include <openssl/aes.h>

#include <util/generic/is_in.h>
#include <util/generic/strbuf.h>
#include <util/generic/string.h>
#include <util/stream/file.h>
#include <util/thread/singleton.h>

using namespace NConfig;
using namespace NSrvKernel;
using namespace NModAabCookieVerify;

namespace {

    ui64 ComputeHash(TStringBuf str, ui64 i = 1) {
        ui64 result = 0;
        for (auto ch : str) {
            result += ch * i;
            ++i;
        }
        return result;
    }

#define HEADER_FSM(name, val)\
    struct name final : public TFsm, public NSrvKernel::TWithDefaultInstance<name> {\
        name() noexcept \
            : TFsm(val, TOptions().SetCaseInsensitive(true))\
        {}\
    }

    HEADER_FSM(TAcceptLanguageFsm, "Accept-Language");
    HEADER_FSM(TUserAgentFsm, "User-Agent");

class TStats {
public:
    TStats(std::array<TMaybe<TSharedCounter>, GetEnumItemsCount<ECookieVerifyResult>()>& holders, size_t workerId)
        : SharedCounters{}
        , Counters{}
    {
        for (size_t i = 0; i < holders.size(); ++i) {
            SharedCounters[i] = TSharedCounter(*holders[i], workerId);
        }
    }

    void OnEvent(ECookieVerifyResult event) {
        const ui32 index = static_cast<ui32>(event);
        ++Counters[index];
        SharedCounters[index]->Inc();
        if (event == ECookieVerifyResult::SuccessIpMismatch) {
            const ui32 succ = static_cast<ui32>(ECookieVerifyResult::Success);
            ++Counters[succ];
            SharedCounters[succ]->Inc();
        }
    }

private:
    std::array<TMaybe<TSharedCounter>, GetEnumItemsCount<ECookieVerifyResult>()> SharedCounters{};
    std::array<ui64, GetEnumItemsCount<NModAabCookieVerify::ECookieVerifyResult>()> Counters{};
};

} // namespace

Y_TLS(aab_cookie_verify) {
    TTls(std::array<TMaybe<TSharedCounter>, GetEnumItemsCount<ECookieVerifyResult>()>& holders, size_t workerId)
        : Stats(holders, workerId)
    {}

    bool AntiadblockEnabled() const {
        return !DisableAntiAdblockModule.Exists();
    }

    TStats Stats;
    TSharedFileExistsChecker DisableAntiAdblockModule;
};

MODULE_WITH_TLS_BASE(aab_cookie_verify, TModuleWithSubModule) {
public:
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
        , Default_(Submodule_)
    {
        for (const auto& e: GetEnumAllValues<ECookieVerifyResult>()) {
            SharedCounters_[static_cast<ui32>(e)] = Control->SharedStatsManager().MakeCounter("antiadblock-"  + ToString(e)).Build();
        }

        Config->ForEach(this);
        if (!Default_) {
            ythrow TConfigParseError() << "No default module in aab_cookie_verify";
        }
        if (!AntiAdblock_) {
            ythrow TConfigParseError() << "No antiadblock module in aab_cookie_verify";
        }
        if (!AesKeyPath_) {
            ythrow TConfigParseError() << "No aes_key_path in aab_cookie_verify";
        }

        const TString key = TUnbufferedFileInput(AesKeyPath_).ReadAll();
        if (key.size() < 16) {
            ythrow TConfigParseError() << "aes key length is less than 16";
        }
        if (AES_set_decrypt_key(reinterpret_cast<const ui8*>(key.data()), 16 * 8, &DecryptKey_)) {
            ythrow TConfigParseError() << "failed to load key, AES_set_decrypt_key return non zero";
        }
        HeadersFsm_ = MakeHolder<TFsm>(TCookieFsm::Instance() | TUserAgentFsm::Instance() | TAcceptLanguageFsm::Instance() | TFsm(IpHeader_, TFsm::TOptions().SetCaseInsensitive(true)));
    }

private:
    START_PARSE {
        STATS_ATTR;
        PARSE_EVENTS;
        ON_KEY("aes_key_path", AesKeyPath_) {
            return;
        }
        ON_KEY("disable_antiadblock_file", DisableAntiAdblockFile_) {
            return;
        }
        ON_KEY("cookie", Cookie_) {
            return;
        }
        ON_KEY("cookie_lifetime", CookieLifetime_) {
            return;
        }
        ON_KEY("ip_header", IpHeader_) {
            return;
        }
        if (key == "default") {
            TSubLoader(Copy(value->AsSubConfig())).Swap(Default_);
            return;
        }
        if (key == "antiadblock") {
            TSubLoader(Copy(value->AsSubConfig())).Swap(AntiAdblock_);
            return;
        }
    } END_PARSE

    struct TMatchedHeaders {
        TMaybe<TStringBuf> Cookie;
        TMaybe<TStringBuf> Ip;
        TMaybe<TStringBuf> UserAgent;
        TMaybe<TStringBuf> AcceptLanguage;
    };
    TMatchedHeaders MatchHeaders(const THeaders& headers) const;
    ECookieVerifyResult VerifyCookie(const TConnDescr& descr) const;
    ECookieVerifyResult Check(const TStringBuf cookie, const TMatchedHeaders& headers, const TAddrHolder& remoteAddr) const;

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        auto tls = MakeHolder<TTls>(SharedCounters_, process->WorkerId());
        if (DisableAntiAdblockFile_) {
            tls->DisableAntiAdblockModule = process->SharedFiles()->FileChecker(DisableAntiAdblockFile_, TDuration::Seconds(1));
        }
        return tls;
    }

    TError DoRun(const TConnDescr& descr, TTls& tls) const noexcept override {
        if (tls.AntiadblockEnabled()) {
            const ECookieVerifyResult result = VerifyCookie(descr);
            tls.Stats.OnEvent(result);
            if (!IsIn({ECookieVerifyResult::Success, ECookieVerifyResult::SuccessIpMismatch}, result)) {
                const TExtraAccessLogEntry entry(descr, "antiadblock");
                Y_DEFER {
                    descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "antiadblock");
                };
                return AntiAdblock_->Run(descr);
            }
        }
        const TExtraAccessLogEntry entry(descr, "default");
        return Default_->Run(descr);
    }


    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

private:
    std::array<TMaybe<TSharedCounter>, GetEnumItemsCount<ECookieVerifyResult>()> SharedCounters_{};
    TString StatsAttr_;
    AES_KEY DecryptKey_;
    TDuration CookieLifetime_ = TDuration::Days(14);
    TString AesKeyPath_;
    TString Cookie_ = "cycada";
    TString DisableAntiAdblockFile_;
    TString IpHeader_;
    THolder<IModule> AntiAdblock_;
    THolder<IModule>& Default_;
    THolder<TFsm> HeadersFsm_;
};

TModule::TMatchedHeaders TModule::MatchHeaders(const THeaders& headers) const {
    TMatchedHeaders matched;
    for (const auto& header : headers) {
        TMatcher matcher(*HeadersFsm_);
        if (!NSrvKernel::Match(matcher, header.first.AsStringBuf()).Final()) {
            continue;
        }
        switch (*matcher.MatchedRegexps().first) {
            case 0:
                matched.Cookie = header.second.back().AsStringBuf();
                break;
            case 1:
                matched.UserAgent = header.second.back().AsStringBuf();
                break;
            case 2:
                matched.AcceptLanguage = header.second.back().AsStringBuf();
                break;
            case 3:
                matched.Ip = header.second.back().AsStringBuf();
                break;
        }
    }
    return matched;
}

ECookieVerifyResult TModule::Check(const TStringBuf cookie, const TMatchedHeaders& headers, const TAddrHolder& remoteAddr) const {
    ui64 generateTime = 0, ipHash = 0, uaHash = 0, acceptLanguageHash = 0;
    if (!StringSplitter(cookie).Split('\t').TryCollectInto(&generateTime, &ipHash, &uaHash, &acceptLanguageHash)) {
        return ECookieVerifyResult::CookieInvalid;
    }
    if (TInstant::Seconds(generateTime) < Now() - CookieLifetime_) {
        return ECookieVerifyResult::CookieExpired;
    }
    const ui64 currentIpHash = headers.Ip ? ComputeHash(*headers.Ip) : ComputeHash(remoteAddr.AddrStr());
    if (currentIpHash == ipHash && ipHash != 0) {
        return ECookieVerifyResult::Success;
    }
    if (!headers.UserAgent || !headers.AcceptLanguage || uaHash == 0 || acceptLanguageHash == 0) {
        return ECookieVerifyResult::AllMismatch;
    }
    return uaHash == ComputeHash(*headers.UserAgent) && acceptLanguageHash == ComputeHash(*headers.AcceptLanguage) ?
        ECookieVerifyResult::SuccessIpMismatch : ECookieVerifyResult::AllMismatch;
}

ECookieVerifyResult TModule::VerifyCookie(const TConnDescr& descr) const {
    const TMatchedHeaders headers = MatchHeaders(descr.Request->Headers());
    const TStringBuf value = FindCookieK(headers.Cookie.GetOrElse({}), Cookie_).GetOrElse({});
    if (!value) {
        return ECookieVerifyResult::CookieAbsent;
    }
    TString decoded;
    try {
        decoded = Base64Decode(value);
    } catch (...) {
        return ECookieVerifyResult::CookieInvalid;
    }
    const ui64 size = decoded.size();
    if (size % 16 != 0) {
        return ECookieVerifyResult::CookieInvalid;
    }
    TBuffer decrypted(size);
    for (size_t i = 0; i < size; i += 16) {
        AES_ecb_encrypt(reinterpret_cast<const ui8*>(decoded.data() + i), reinterpret_cast<ui8*>(decrypted.Data() + i), &DecryptKey_, AES_DECRYPT);
    }
    const TStringBuf data(decrypted.Data(), size);
    NSan::Unpoison(data.data(), data.size());
    return Check(StripString(data), headers, *descr.Properties->Parent.RemoteAddress);
}

IModuleHandle* NModAabCookieVerify::Handle() {
    return TModule::Handle();
}
