#include <balancer/kernel/helpers/syscalls.h>
#include <balancer/kernel/net/ut_common/ut_sock.h>
#include <balancer/kernel/net/socket.h>
#include <balancer/kernel/net/sockops.h>

#include <library/cpp/threading/future/async.h>
#include <library/cpp/testing/unittest/registar.h>
#include <library/cpp/testing/unittest/tests_data.h>

#include <util/generic/xrange.h>
#include <util/string/join.h>


Y_UNIT_TEST_SUITE(TTestBalancerSocketIo) {
    using namespace NSrvKernel;


    TSockBufSize MinBufs() {
        return {.Rcv = 0, .Snd = 0};
    }

    auto ReadAll(TSocketHolder& sh, EPollMode pollMode) {
        TVector<TChunkList> res;
        TContExecutor exec(32000);
        auto f = [&](TCont*) {
            TSocketIo sock{&sh, &exec, TSocketOut::DefaultMaxInQueue, pollMode};
            while (true) {
                TChunkList lst;
                auto err = sock.In().Recv(lst, TInstant::Now() + TDuration::Seconds(1));
                bool eof = lst.Empty();
                res.emplace_back(std::move(lst));
                if (HasTimedOut(err)) {
                    return;
                } else if (err) {
                    err.Throw();
                } else if (eof) {
                    break;
                }
            }
        };
        exec.Execute(f);
        return res;
    }

    void DoTestRead(TString inStr, NNetUt::ETcpFinish f, EPollMode pollMode, TSockBufSize sz) {
        auto sp = NNetUt::CreateTcpSocketPair(sz);
        NNetUt::TWriteThread thr(std::move(sp.second), {inStr}, f);
        thr.Start();

        auto res = ReadAll(sp.first, pollMode);
        UNIT_ASSERT(res);
        UNIT_ASSERT(res.back().Empty());
        auto str = JoinSeq("", res);
        UNIT_ASSERT_VALUES_EQUAL(str, inStr);
    }

    void DoTestRead(TString inStr, NNetUt::ETcpFinish f, TSockBufSize sz = {}) {
        DoTestRead(inStr, f, PM_LEVEL_TRIGGERED, sz);
        DoTestRead(inStr, f, PM_EDGE_TRIGGERED, sz);
    }

    Y_UNIT_TEST(TestRead0Fin) {
        DoTestRead("", NNetUt::ETcpFinish::Fin);
    }

    Y_UNIT_TEST(TestRead0Rst) {
        DoTestRead("", NNetUt::ETcpFinish::Rst);
    }

    Y_UNIT_TEST(TestRead1Fin) {
        DoTestRead("a", NNetUt::ETcpFinish::Fin);
    }

    Y_UNIT_TEST(TestRead1Rst) {
        DoTestRead("a", NNetUt::ETcpFinish::Rst);
    }

    Y_UNIT_TEST(TestReadMaxFin) {
        DoTestRead(NNetUt::TStableRandomInput().GenString(65535), NNetUt::ETcpFinish::Fin, {.Rcv = 6543, .Snd = 6543});
    }

    Y_UNIT_TEST(TestReadMaxRst) {
        DoTestRead(NNetUt::TStableRandomInput().GenString(65535), NNetUt::ETcpFinish::Rst, {.Rcv = 6543, .Snd = 6543});
    }

}
