#include "check_message.h"

#include <saas/library/attributes/attributes.h>

#include <library/cpp/testing/unittest/registar.h>

#include <util/random/random.h>


const bool MUST_THROW_EXCEPTION = true;
const bool NOT_THROW_EXCEPTION = false;

#define CHECK_EXCEPTION(MUST_THOW, COMMAND, C)                                                           \
    do {                                                                                                 \
        try {                                                                                            \
            COMMAND;                                                                                     \
        } catch (const ::NUnitTest::TAssertException&) {                                                 \
            throw;                                                                                       \
        } catch (const yexception& e) {                                                                  \
            UNIT_ASSERT_C(MUST_THOW, "Exception has been thrown, but it shouldn't have happened " << C); \
            break;                                                                                       \
        }                                                                                                \
        UNIT_ASSERT_C(!MUST_THOW, "Exception hasn't been thrown, but it should have happened " << C);    \
    } while (false)


namespace {
    TString GetRandString(ui32 length) {
        TString str;
        for (ui32 i = 0; i < length; ++i) {
            str.push_back('a' + RandomNumber<ui32>('z' - 'a' + 1));
        }
        return str;
    }
}

using namespace NRTYServer;

namespace NSaas {

void FillAttr(
    TAttribute* attr,
    const TString& name,
    const TString& value,
    TAttribute::TAttributeType type
) {
    attr->SetName(name);
    attr->SetValue(value);
    attr->SetType(type);
}

TMessage GetCorrectModifyMessage() {
    TMessage message;
    message.SetMessageType(TMessage::MODIFY_DOCUMENT);
    auto doc = message.MutableDocument();
    doc->SetUrl("some_url");
    doc->SetKeyPrefix(1);

    doc->SetModificationTimestamp(1577826000);
    doc->SetMimeType("text/plain");
    doc->SetBody("some body value");

    FillAttr(doc->AddSearchAttributes(), "s_str", "str val", TAttribute::LITERAL_ATTRIBUTE);
    FillAttr(doc->AddSearchAttributes(), "i_int", "12345", TAttribute::INTEGER_ATTRIBUTE);
    FillAttr(doc->AddGroupAttributes(), "s_gstr", "str group", TAttribute::LITERAL_ATTRIBUTE);
    FillAttr(doc->AddGroupAttributes(), "s_gint", "98765", TAttribute::INTEGER_ATTRIBUTE);

    return message;
}

void FillBadAttr(TAttribute* attr) {
    FillAttr(attr, "", "string attribut value", TAttribute::LITERAL_ATTRIBUTE);
}

TMessage GetMessageWithBadSearchAttr() {
    auto message = GetCorrectModifyMessage();
    FillBadAttr(message.MutableDocument()->AddSearchAttributes());
    return message;
}

TMessage GetMessageWithBadGroupAttr() {
    auto message = GetCorrectModifyMessage();
    FillBadAttr(message.MutableDocument()->AddGroupAttributes());
    return message;
}

TMessage GetMessageWithSize(ui32 size) {
    auto message = GetCorrectModifyMessage();
    message.MutableDocument()->SetBody("");

    message.MutableDocument()->SetBody(GetRandString(size - message.ByteSizeLong()));
    while (message.ByteSizeLong() > size) {
        message.MutableDocument()->SetBody(message.GetDocument().GetBody().substr(1));
    }

    Y_ENSURE(size == message.ByteSizeLong());
    return message;
}

void BadBaseChecks(const TMessageChecker& checker, bool throwException) {
    {
        auto m = GetCorrectModifyMessage();
        m.MutableDocument()->SetUrl(GetRandString(TMessageLimits::MAX_URL_SIZE + 1));
        CHECK_EXCEPTION(throwException, checker.Check(m), "url size more then limit");
    }
    {
        auto m = GetCorrectModifyMessage();
        m.MutableDocument()->SetKeyPrefix(0);
        CHECK_EXCEPTION(throwException, checker.Check(m), "zero KeyPrefix");
    }
    auto m = GetCorrectModifyMessage();
    m.MutableDocument()->SetModificationTimestamp((TInstant::Now() + TDuration::Minutes(5)).Seconds());
    CHECK_EXCEPTION(throwException, checker.Check(m), "modifation timestamp is more then now");
    m.MutableDocument()->ClearModificationTimestamp();
    CHECK_EXCEPTION(throwException, checker.Check(m), "modifation timestamp is not set");
    m.ClearDocument();
    CHECK_EXCEPTION(throwException, checker.Check(m), "document is not set");
}

void BadAttributesChecks(const TMessageChecker& checker, bool throwException) {
    CHECK_EXCEPTION(throwException, checker.Check(GetMessageWithBadSearchAttr()), "bad search attribute");
    CHECK_EXCEPTION(throwException, checker.Check(GetMessageWithBadSearchAttr()), "bad group attribute");
}

void BadMaxSizeCheck(const TMessageChecker& checker, bool throwException) {
    auto m = GetMessageWithSize(TMessageLimits::MAX_MESSAGE_SIZE_BYTES + 1);
    CHECK_EXCEPTION(throwException, checker.Check(m), "big message size");
}

void CheckAllBadCases(const TMessageChecker& checker, bool throwException) {
    BadBaseChecks(checker, throwException);
    BadAttributesChecks(checker, throwException);
    BadMaxSizeCheck(checker, throwException);
}


Y_UNIT_TEST_SUITE(TestCheckMessage) {
    Y_UNIT_TEST(TestCorrectMessageOK) {
        TMessageChecker checker;

        UNIT_CHECK_GENERATED_NO_EXCEPTION_C(checker.Check(GetCorrectModifyMessage()), yexception, "correct message");
    }

    Y_UNIT_TEST(TestDefaultChecks) {
        TMessageChecker checker;

        CheckAllBadCases(checker, MUST_THROW_EXCEPTION);
    }

    Y_UNIT_TEST(TestDisableAllChecks) {
        TCheckMessageSettings settings;
        TMessageChecker checker(settings.SetDisabled(true));

        CheckAllBadCases(checker, NOT_THROW_EXCEPTION);
    }

    Y_UNIT_TEST(TestDisableAttrsCheck) {
        TCheckMessageSettings settings;
        TMessageChecker checker(settings.SetCheckAttributes(false));

        BadBaseChecks(checker, MUST_THROW_EXCEPTION);
        BadAttributesChecks(checker, NOT_THROW_EXCEPTION);
        BadMaxSizeCheck(checker, MUST_THROW_EXCEPTION);
    }

    Y_UNIT_TEST(TestDisableMaxSizeCheck) {
        TCheckMessageSettings settings;
        TMessageChecker checker(settings.SetCheckMaxSizeBytes(0));

        BadBaseChecks(checker, MUST_THROW_EXCEPTION);
        BadAttributesChecks(checker, MUST_THROW_EXCEPTION);
        BadMaxSizeCheck(checker, NOT_THROW_EXCEPTION);
    }

    Y_UNIT_TEST(TestMaxSizeCheckForManualLimit) {
        TVector<ui32> sizes = {1000, 1024 * 1024, TMessageLimits::MAX_MESSAGE_SIZE_BYTES};
        for (auto size : sizes) {
            TCheckMessageSettings settings;
            TMessageChecker checker(settings.SetCheckMaxSizeBytes(size));
            UNIT_CHECK_GENERATED_NO_EXCEPTION_C(checker.Check(GetMessageWithSize(size)), yexception, "message size is equal to max size");
            UNIT_CHECK_GENERATED_EXCEPTION_C(checker.Check(GetMessageWithSize(size + 1)), yexception, "message size is greater than max size");
        }
    }

    Y_UNIT_TEST(TestSetMaxSizeLimit) {
        TCheckMessageSettings settings;
        UNIT_CHECK_GENERATED_EXCEPTION_C(settings.SetCheckMaxSizeBytes(TMessageLimits::MAX_MESSAGE_SIZE_BYTES + 1), yexception,
            "MaxSize argumnet more then MAX_MESSAGE_SIZE_BYTES"
        );
    }

    Y_UNIT_TEST(TestUrlSizeCheck) {
        TVector<ui32> sizes = {1, 100, TMessageLimits::MAX_URL_SIZE};
        TMessageChecker checker;
        auto m = GetCorrectModifyMessage();
        for (auto size : sizes) {
            m.MutableDocument()->SetUrl(GetRandString(size));
            UNIT_CHECK_GENERATED_NO_EXCEPTION_C(checker.Check(m), yexception, "correct url size");
        }
        {
            m.MutableDocument()->SetUrl(GetRandString(TMessageLimits::MAX_URL_SIZE + 1));
            UNIT_CHECK_GENERATED_EXCEPTION_C(checker.Check(m), yexception, "url size is more then limit");
        }
    }

    Y_UNIT_TEST(TestKeyPrefixCheck) {
        TVector<i64> nonZeroKps = {Min<i64>(), -123456, -1, 1, 123456, Max<i64>()};
        TCheckMessageSettings settings;
        {
            TMessageChecker checker(settings.SetKeyPrefixCheckType(EKeyPrefixCheckType::any));
            auto m = GetCorrectModifyMessage();
            {
                m.MutableDocument()->SetKeyPrefix(0);
                UNIT_CHECK_GENERATED_NO_EXCEPTION_C(checker.Check(m), yexception, "zero kps and ckeck type: any");
            }
            for (i64 kps : nonZeroKps) {
                m.MutableDocument()->SetKeyPrefix(kps);
                UNIT_CHECK_GENERATED_NO_EXCEPTION_C(checker.Check(m), yexception, "non zero kps and ckeck type: any");
            }
        }
        {
            TMessageChecker checker(settings.SetKeyPrefixCheckType(EKeyPrefixCheckType::zero));
            auto m = GetCorrectModifyMessage();
            {
                m.MutableDocument()->SetKeyPrefix(0);
                UNIT_CHECK_GENERATED_NO_EXCEPTION_C(checker.Check(m), yexception, "zero kps and ckeck type: zero");
            }
            for (i64 kps : nonZeroKps) {
                m.MutableDocument()->SetKeyPrefix(kps);
                UNIT_CHECK_GENERATED_EXCEPTION_C(checker.Check(m), yexception, "non zero kps and ckeck type: zero");
            }
        }
        {
            TMessageChecker checker(settings.SetKeyPrefixCheckType(EKeyPrefixCheckType::non_zero));
            auto m = GetCorrectModifyMessage();
            {
                m.MutableDocument()->SetKeyPrefix(0);
                UNIT_CHECK_GENERATED_EXCEPTION_C(checker.Check(m), yexception, "zero kps and ckeck type: non_zero");
            }
            for (i64 kps : nonZeroKps) {
                m.MutableDocument()->SetKeyPrefix(kps);
                UNIT_CHECK_GENERATED_NO_EXCEPTION_C(checker.Check(m), yexception, "non zero kps and ckeck type: non_zero");
            }
        }
    }
}

}
