#include <balancer/kernel/custom_io/ut/test_common.h>
#include <balancer/kernel/custom_io/rewind.h>
#include <balancer/kernel/helpers/errors.h>
#include <library/cpp/testing/unittest/registar.h>
#include <util/generic/deque.h>
#include <util/generic/xrange.h>
#include <util/string/join.h>

Y_UNIT_TEST_SUITE(TCustomIoRewindTest) {
    using namespace NSrvKernel;

    const TString ExpData = "ABCDE";

    struct TMockInput : public NSrvKernel::IIoInput {
    public:
        explicit TMockInput(TString expData, size_t errPos)
            : ExpectedData_(std::move(expData))
            , ErrPos_(errPos)
        {
            for (auto c : ExpectedData_) {
                InputData_.emplace_back(TString(1, c));
            }
            InputData_.emplace_back();
        }

        [[nodiscard]] size_t ErrPos() const noexcept {
            return ErrPos_;
        }

        [[nodiscard]] size_t Pos() const noexcept {
            size_t bufLen = 0;
            for (auto&& c : InputData_) {
                bufLen += c.size();
            }

            return ExpectedData_.size() - bufLen;
        }

        [[nodiscard]] TString TestName() const noexcept {
            return TestName_;
        }

        [[nodiscard]] TString LastRecv() const noexcept {
            return TStringBuilder() << LastRecv_;
        }

        [[nodiscard]] TString ConsumedInput() const noexcept {
            return ExpectedData_.substr(0, Pos());
        }

        [[nodiscard]] TString UnconsumedInput() const noexcept {
            TStringBuilder res;
            for (auto&& c : InputData_) {
                res << c;
            }
            return res;
        }

        [[nodiscard]] TString TotalInput() const noexcept {
            return ExpectedData_;
        }

        [[nodiscard]] TString HealthyInput() const noexcept {
            return ExpectedData_.substr(0, ErrPos_);
        }

        [[nodiscard]] const THashSet<TInstant>& Deadlines() const noexcept {
            return Deadlines_;
        }

        [[nodiscard]] bool ErrorFired() const noexcept {
            return ErrorFired_;
        }

    protected:
        TError DoRecv(TChunkList& lst, TInstant deadline) noexcept override {
            Deadlines_.insert(deadline);
            UNIT_ASSERT_C(!InputData_.empty(), TestName_);

            if (Pos() < ErrPos_ || ErrorFired_) {
                lst = std::move(InputData_.front());
                InputData_.pop_front();
                LastRecv_ = lst.Copy();
            } else {
                ErrorFired_ = true;
                return Y_MAKE_ERROR(TSystemError{ECANCELED});
            }

            return {};
        }

    protected:
        TString TestName_;
        const TString ExpectedData_;
        THashSet<TInstant> Deadlines_;
        std::deque<TChunkList> InputData_;
        TChunkList LastRecv_;
        const size_t ErrPos_;
        bool ErrorFired_ = false;
    };


    struct TTestData {
        const TInstant Deadline = TInstant::Now();
        const TString Data;
        const size_t LimitLen;

        TMockInput Mock;
        TString Name;

    public:
        TTestData(size_t dataLen, size_t limitLen, size_t errPos = -1)
            : Data(ExpData.substr(0, dataLen))
            , LimitLen(std::min(limitLen, dataLen + 1))
            , Mock(Data, std::min(errPos, dataLen + 1))
        {
            TStringBuilder descr;
            descr << "data='";

            if (Data.size() < std::min(errPos, limitLen)) {
                descr << Data;
            } else if (errPos <= limitLen) {
                if (limitLen <= Data.size()) {
                    descr << Data.substr(0, errPos) << "%" << Data.substr(errPos, limitLen - errPos) << "|" << Data.substr(limitLen);
                } else {
                    descr << Data.substr(0, errPos) << "%" << Data.substr(errPos);
                }
            } else {
                descr << Data.substr(0, limitLen) << "|" << Data.substr(limitLen);
            }
            descr << "'";

            Name = descr;
        }

        void CheckRecv(const TChunkList& recv) {
            UNIT_ASSERT_VALUES_EQUAL_C(recv, Mock.ConsumedInput(), Name);
            UNIT_ASSERT_VALUES_EQUAL_C(Mock.Deadlines().size(), 1, Name);
            UNIT_ASSERT_C(Mock.Deadlines().contains(Deadline), Name);
        }

        void CheckRewind(TLimitedRewindableInput& in) {
            if (Mock.Pos() <= LimitLen) {
                UNIT_ASSERT_C(in.IsRewindable(), Name);

                in.Rewind();
                RecvAndCheck(in, Mock.ConsumedInput().substr(0, LimitLen));
            } else {
                UNIT_ASSERT_C(!in.IsRewindable(), Name);
            }
        }

        void CheckUnrecv(TLimitedRewindableInput& in, const TChunkList& last) {
            TChunkList last1 = last.Copy();
            in.UnRecv(std::move(last1));
            UNIT_ASSERT_C(last1.Empty(), Name);

            RecvAndCheck(in, ToString(last));
        }

        void CheckUnrecvAndRewind(TLimitedRewindableInput& in, const TChunkList& last) {
            TChunkList last1 = last.Copy();
            in.UnRecv(std::move(last1));
            UNIT_ASSERT_C(last1.Empty(), Name);

            if (Mock.Pos() <= LimitLen) {
                UNIT_ASSERT_C(in.IsRewindable(), Name);
                in.Rewind();

                RecvAndCheck(in, Mock.ConsumedInput());
            } else {
                UNIT_ASSERT_C(!in.IsRewindable(), Name);

                RecvAndCheck(in, ToString(last));
            }
        }

    private:
        void RecvAndCheck(TLimitedRewindableInput& in, const TString expected) {
            if (Mock.Pos() > 0) {
                TChunkList lst;
                auto err = in.Recv(lst, Deadline);
                UNIT_ASSERT_C(!err, Name << ", err: " << GetErrorMessage(err));
                UNIT_ASSERT_VALUES_EQUAL_C(ToString(lst), expected, Name);
            }
        }
    };

    template <class TTest>
    void DoTestNoError(TTest&& test) {
        for (auto dataLen : xrange<size_t>(0, ExpData.size())) {
            for (auto limitLen : xrange<size_t>(0, dataLen + 1)) {
                for (auto checkAfter : xrange<int>(0, dataLen + 1)) {
                    test(TTestData(dataLen, limitLen), checkAfter);
                }
            }
        }
    }


    template <class TTest>
    void DoTestWithError(TTest&& test) {
        for (auto dataLen : xrange<size_t>(0, ExpData.size())) {
            for (auto errPos : xrange<size_t>(0, dataLen + 1)) {
                for (auto limitLen : xrange<size_t>(errPos, dataLen + 1)) {
                    for (auto checkAfter : xrange<int>(0, errPos + 1)) {
                        test(TTestData(dataLen, limitLen, errPos), checkAfter);
                    }
                }
            }
        }
    }

    Y_UNIT_TEST(TestRewind_NoError_NoUnrecv_NoUpgade) {
        DoTestNoError([](auto&& test, int checkAfter) {
            TLimitedRewindableInput in(test.Mock, test.LimitLen);
            UNIT_ASSERT(in.IsRewindable());

            TChunkList recv;
            bool eof = false;
            while (!eof) {
                TChunkList lst;
                auto err = in.Recv(lst, test.Deadline);
                eof = lst.Empty();
                UNIT_ASSERT(!err);

                recv.Append(std::move(lst));

                if (checkAfter > 0) {
                    checkAfter -= 1;
                } else {
                    test.CheckRewind(in);
                }
            }

            test.CheckRecv(recv);
        });
    }

    Y_UNIT_TEST(TestRewind_NoError_WithUnrecv_NoUpgade) {
        DoTestNoError([](auto&& test, int checkAfter) {
            TLimitedRewindableInput in(test.Mock, test.LimitLen);
            UNIT_ASSERT(in.IsRewindable());

            TChunkList recv;
            bool eof = false;
            while (!eof) {
                TChunkList lst;
                auto err = in.Recv(lst, test.Deadline);
                eof = lst.Empty();
                UNIT_ASSERT(!err);

                recv.Append(std::move(lst));

                if (checkAfter > 0) {
                    checkAfter -= 1;
                } else {
                    test.CheckUnrecv(in, recv);
                    test.CheckUnrecvAndRewind(in, recv);
                    test.CheckRewind(in);
                }
            }

            test.CheckRecv(recv);
        });
    }

    Y_UNIT_TEST(TestRewind_NoError_NoUnrecv_WithUpgrade) {
        DoTestNoError([](auto&& test, int) {
            TLimitedRewindableInput in(test.Mock, test.LimitLen);
            UNIT_ASSERT(in.IsRewindable());

            in.ResetRewind();
            UNIT_ASSERT(!in.IsRewindable());

            TChunkList recv;
            bool eof = false;
            while (!eof) {
                TChunkList lst;
                auto err = in.Recv(lst, test.Deadline);
                eof = lst.Empty();
                UNIT_ASSERT(!err);
                UNIT_ASSERT(!in.IsRewindable());

                recv.Append(std::move(lst));
            }

            test.CheckRecv(recv);
        });
    }

    Y_UNIT_TEST(TestRewind_NoError_WithUnrecv_WithUpgade) {
        DoTestNoError([](auto&& test, int checkAfter) {
            TLimitedRewindableInput in(test.Mock, test.LimitLen);
            UNIT_ASSERT(in.IsRewindable());

            in.ResetRewind();
            UNIT_ASSERT(!in.IsRewindable());

            TChunkList recv;
            bool eof = false;
            while (!eof) {
                TChunkList lst;
                auto err = in.Recv(lst, test.Deadline);
                eof = lst.Empty();
                UNIT_ASSERT(!err);
                UNIT_ASSERT(!in.IsRewindable());

                recv.Append(std::move(lst));

                if (checkAfter > 0) {
                    checkAfter -= 1;
                } else {
                    test.CheckUnrecv(in, recv);
                }
            }

            test.CheckRecv(recv);
        });
    }

    Y_UNIT_TEST(TestRewind_WithError_NoUnrecv_NoUpgrade) {
        DoTestWithError([](auto&& test, int checkAfter) {
            TLimitedRewindableInput in(test.Mock, test.LimitLen);
            UNIT_ASSERT(in.IsRewindable());

            TChunkList recv;
            bool eof = false;
            bool hadErr = false;
            while (!eof) {
                TChunkList lst;
                auto err = in.Recv(lst, test.Deadline);

                if (err) {
                    hadErr = true;
                    UNIT_ASSERT(lst.Empty());
                    test.CheckRewind(in);
                } else {
                    eof = lst.Empty();
                }

                recv.Append(std::move(lst));

                if (checkAfter > 0) {
                    checkAfter -= 1;
                } else {
                    test.CheckRewind(in);
                }
            }

            test.CheckRecv(recv);
            UNIT_ASSERT(hadErr);
        });
    }

    Y_UNIT_TEST(TestRewind_WithError_WithUnrecv_NoUpgrade) {
        DoTestWithError([](auto&& test, int checkAfter) {
            TLimitedRewindableInput in(test.Mock, test.LimitLen);
            UNIT_ASSERT(in.IsRewindable());

            TChunkList recv;
            bool eof = false;
            bool hadErr = true;
            while (!eof) {
                TChunkList lst;
                auto err = in.Recv(lst, test.Deadline);

                if (err) {
                    UNIT_ASSERT(lst.Empty());
                    test.CheckRewind(in);
                } else {
                    eof = lst.Empty();
                }

                recv.Append(std::move(lst));

                if (checkAfter > 0) {
                    checkAfter -= 1;
                } else {
                    test.CheckUnrecv(in, recv);
                    test.CheckUnrecvAndRewind(in, recv);
                    test.CheckRewind(in);
                }
            }

            test.CheckRecv(recv);
            UNIT_ASSERT(hadErr);
        });
    }

    Y_UNIT_TEST(TestRewind_WithError_NoUnrecv_WithUpgrade) {
        DoTestWithError([](auto&& test, int) {
            TLimitedRewindableInput in(test.Mock, test.LimitLen);
            UNIT_ASSERT(in.IsRewindable());

            in.ResetRewind();
            UNIT_ASSERT(!in.IsRewindable());

            TChunkList recv;
            bool hadErr = false;
            while (!hadErr) {
                TChunkList lst;
                auto err = in.Recv(lst, test.Deadline);
                UNIT_ASSERT(!in.IsRewindable());

                if (err) {
                    hadErr = true;
                    UNIT_ASSERT(lst.Empty());
                }

                recv.Append(std::move(lst));
            }

            test.CheckRecv(recv);
        });
    }

    Y_UNIT_TEST(TestRewind_WithError_WithUnrecv_WithUpgrade) {
        DoTestWithError([](auto&& test, int checkAfter) {
            TLimitedRewindableInput in(test.Mock, test.LimitLen);
            UNIT_ASSERT(in.IsRewindable());

            in.ResetRewind();
            UNIT_ASSERT(!in.IsRewindable());

            TChunkList recv;
            bool hadErr = false;
            while (!hadErr) {
                TChunkList lst;
                auto err = in.Recv(lst, test.Deadline);
                UNIT_ASSERT(!in.IsRewindable());

                if (err) {
                    hadErr = true;
                    UNIT_ASSERT(lst.Empty());
                }

                recv.Append(std::move(lst));

                if (checkAfter > 0) {
                    checkAfter -= 1;
                } else {
                    test.CheckUnrecv(in, recv);
                }
            }

            test.CheckRecv(recv);
        });
    }
};
