#pragma once

#include <balancer/kernel/helpers/errors.h>

#include <library/cpp/digest/md5/md5.h>
#include <library/cpp/string_utils/base64/base64.h>

#include <util/generic/strbuf.h>
#include <util/datetime/base.h>
#include <util/stream/mem.h>
#include <util/ysaveload.h>

namespace NModAntiDDOS {
    using namespace NSrvKernel;

    const size_t DIGEST_SIZE = 16;
    static char CGI_PREFIX[] = "&antiddos=";

    enum EProtectionLevel {
        PROTECTION_BYPASS,
        PROTECTION_HTTP_REDIRECT,
        PROTECTION_HTML_REDIRECT,
        PROTECTION_JAVASCRIPT_REDIRECT,
        //PROTECTION_SHOW_CAPTCHA,
        PROTECTION_DROP,
        PROTECTION_LEVELS_NUMBER
    };

    // Альтернатива: вместо EProtectionLevel, передавать максимальный rate,
    // с которым ещё запускается текущая защита. Плохо то, что если у нас
    // есть несколько балансеров с разными мощностями, то на них могут быть
    // выставлены разные пороги для одних и тех же rate'ов, в этом случае
    // определение уровня защиты по rate'у работать не будет.
    struct __TSignatureInfo {
        ui8 ProtLevel_;
        TInstant BestBefore_;
        char Digest_[DIGEST_SIZE];
    };

    class TSignatureInfo : protected __TSignatureInfo {
    public:
        static const size_t DecodedSize =
            sizeof((*(__TSignatureInfo*)nullptr).ProtLevel_) +
            sizeof((*(__TSignatureInfo*)nullptr).BestBefore_) +
            sizeof((*(__TSignatureInfo*)nullptr).Digest_);

        static const size_t EncodedSize = Base64EncodeBufSize(DecodedSize) - 1;

        TStringBuf Save(char* signInfo) const {
            char data[DecodedSize];
            TMemoryOutput stream(data, sizeof(data));
            ::Save(&stream, ProtLevel_);
            ::Save(&stream, BestBefore_);
            SaveArray(&stream, Digest_, sizeof(Digest_));
            return Base64EncodeUrl(TStringBuf(data, sizeof(data)), signInfo);
        }

        TErrorOr<bool> Load(TStringInput& signInfo) {
            char data[Base64DecodeBufSize(EncodedSize)];

            try{
                char encoded[EncodedSize];
                if (signInfo.Load(encoded, EncodedSize) != EncodedSize) {
                    return Y_MAKE_ERROR(yexception{} << "not enough data in stream to load signature info");
                }
                if (Base64Decode(data, encoded, encoded + EncodedSize) != DecodedSize) {
                    return false;
                }

                TMemoryInput stream(data, DecodedSize);
                ::Load(&stream, ProtLevel_);
                ::Load(&stream, BestBefore_);
                LoadArray(&stream, Digest_, sizeof(Digest_));
            } Y_TRY_STORE(yexception);
            return true;
        }

        TErrorOr<bool> Load(const TString& target) noexcept {
            if (target.length() < CgiLength) {
                return false;
            }

            TStringBuf signature(target);
            signature = signature.substr(0, signature.length() - CgiLength);
            if (!signature.StartsWith(TStringBuf(CGI_PREFIX, Y_ARRAY_SIZE(CGI_PREFIX) - 1))) {
                return false;
            }

            signature = signature.substr(0, signature.length() - TSignatureInfo::EncodedSize);

            auto stringIn = TString(target);
            TStringInput stream{ stringIn };
            return Load(stream);
        }

        bool Verify(const TString& target, TStringBuf remoteHost, TInstant requestTime, const TString& salt) noexcept {
            if (requestTime > BestBefore_) {
                return false;
            }

            char digest[DIGEST_SIZE];
            CalculateHash(target, remoteHost, salt, digest);

            return Digest() == TStringBuf(digest, sizeof(digest));
        }

        void CalculateHash(const TString& target, TStringBuf remoteHost,
                           const TString& salt, char digest[DIGEST_SIZE]) const noexcept
        {
            MD5 hash;
            hash.Update(target.data(), target.length());
            hash.Update(&ProtLevel_, sizeof(ProtLevel_));
            hash.Update(&BestBefore_, sizeof(BestBefore_));
            hash.Update(remoteHost.data(), remoteHost.size());
            hash.Update(salt.c_str(), salt.length());
            hash.Final((unsigned char*)digest);
        }

        EProtectionLevel ProtectionLevel() const noexcept {
            return (EProtectionLevel)ProtLevel_;
        }

        static void CutOffSignature(TString& target) noexcept {
            target.resize(target.length() - CgiLength);
        }

    protected:
        static const size_t CgiLength = Y_ARRAY_SIZE(CGI_PREFIX) - 1 + EncodedSize;

        TStringBuf Digest() const noexcept {
            return TStringBuf(Digest_, sizeof(Digest_));
        }
    };

    class TSignedTarget : public TSignatureInfo {
    public:
        TSignedTarget(TString target, TStringBuf remoteHost, EProtectionLevel protLevel, TInstant bestBefore, const TString& salt) noexcept {
            Reset(std::move(target), remoteHost, protLevel, bestBefore, salt);
        }

        void Reset(TString target, TStringBuf remoteHost, EProtectionLevel protLevel, TInstant bestBefore, const TString& salt) noexcept {
            Target_ = std::move(target);
            ProtLevel_ = (ui8) protLevel;
            BestBefore_ = bestBefore;

            CalculateHash(Target_, remoteHost, salt, Digest_);
        }

        size_t Length() const noexcept {
            return Target_.length() + CgiLength;
        }

        void Write(IOutputStream& out) const {
            char signInfo[TSignatureInfo::EncodedSize + 1];

            out << Target_
                << TStringBuf(CGI_PREFIX, sizeof(CGI_PREFIX) - 1)
                << Save(signInfo);
        }

    protected:
        TString Target_;
    };

    inline IOutputStream& operator<<(IOutputStream& out, const TSignedTarget& target) {
        target.Write(out);
        return out;
    }
}
