#include <passport/infra/libs/cpp/dbpool/db_pool.h>
#include <passport/infra/libs/cpp/dbpool/destination.h>
#include <passport/infra/libs/cpp/dbpool/exception.h>
#include <passport/infra/libs/cpp/dbpool/query.h>
#include <passport/infra/libs/cpp/dbpool/result.h>
#include <passport/infra/libs/cpp/dbpool/misc/db_pool.h>
#include <passport/infra/libs/cpp/dbpool/misc/handle.h>
#include <passport/infra/libs/cpp/dbpool/misc/http-driver.h>

#include <library/cpp/testing/mock_server/server.h>
#include <library/cpp/testing/unittest/registar.h>

#include <util/system/event.h>

using namespace NPassport::NDbPool;

Y_UNIT_TEST_SUITE(DbPoolHttp) {
    class TPong: public TRequestReplier {
        TDuration Sleep_;

    public:
        TPong(TDuration sleep = TDuration())
            : Sleep_(sleep)
        {
        }

        bool DoReply(const TReplyParams& params) override {
            if (TParsedHttpFull(params.Input.FirstLine()).Path != "/ping") {
                UNIT_ASSERT(false);
            }

            Sleep(Sleep_);

            ::THttpResponse resp(HTTP_OK);
            resp.OutTo(params.Output);

            return true;
        }
    };

    Y_UNIT_TEST(simpleStart) {
        TPortManager pm;
        ui16 port = pm.GetPort(80);
        NMock::TMockServer server(port, []() { return new TPong; });

        TDbPool pool(
            TDbPoolSettings{
                .Dsn = TDestination::CreateHttp(),
                .Hosts = {{"localhost", port}},
            });
        UNIT_ASSERT_NO_EXCEPTION(pool.TryPing());

        TString msg;
        UNIT_ASSERT_C(pool.IsOk(&msg), msg);
        UNIT_ASSERT_STRINGS_EQUAL(TStringBuilder() << "dr=http;db= is OK: 1/1 [1/1]", msg);
    }

    Y_UNIT_TEST(slowPong) {
        TPortManager pm;
        ui16 port = pm.GetPort(80);
        NMock::TMockServer server(port, []() { return new TPong(TDuration::MilliSeconds(500)); });

        TDbPool pool(
            TDbPoolSettings{
                .Dsn = TDestination::CreateHttp(),
                .Hosts = {{"::1", port}},
                .Size = 70,
                .GetTimeout = TDuration(),
                .ConnectionTimeout = TDuration::MilliSeconds(200),
                .QueryTimeout = TDuration::MilliSeconds(300),
                .FailThreshold = TDuration::MilliSeconds(50),
            });
        UNIT_ASSERT_EXCEPTION_CONTAINS(pool.TryPing(),
                                       TException,
                                       "dr=http;db=: DbPool failed to ping: too few alive weight: 0/1 (rate=0). Min required: 0.6");

        Sleep(TDuration::MilliSeconds(100));

        TString msg;
        UNIT_ASSERT_C(!pool.IsOk(&msg), msg);
        UNIT_ASSERT_STRING_CONTAINS(msg, "dr=http;db= is unavailable for ");
        UNIT_ASSERT_STRING_CONTAINS(msg, ": 0/1 [0/1];down hosts:[::1]");
    }

    class TAnyAction: public TRequestReplier {
    public:
        using TFunc = std::function<void()>;

        TAnyAction(TFunc func)
            : Func_(func)
        {
        }

        bool DoReply(const TReplyParams& params) override {
            if (TParsedHttpFull(params.Input.FirstLine()).Path != "/ping") {
                UNIT_ASSERT(false);
            }

            Func_();
            ::THttpResponse resp(HTTP_OK);
            resp.OutTo(params.Output);

            return true;
        }

    private:
        TFunc Func_;
    };

    Y_UNIT_TEST(dyingHandle) {
        using namespace NPassport::NDbPool;

        std::function<void()> func = []() {};
        TAutoEvent ev;

        TPortManager pm;
        ui16 port = pm.GetPort(80);
        NMock::TMockServer server(port, [&func]() { return new TAnyAction(func); });

        TCountersPtr c = std::make_shared<TCounters>("foo", 42);
        UNIT_ASSERT_VALUES_EQUAL(0, c->TotalHandles.GetCount());
        {
            THandle::TNonblockingInit nb = THandle::CreateNonblocking(
                THandleSettings{
                    .Dsn = TDestination::CreateHttp(),
                    .DbHost = {"localhost", port},
                    .ConnectionTimeout = TDuration::Minutes(1),
                    .QueryTimeout = TDuration::MilliSeconds(300),
                },
                THandleUnistatCtx{
                    .Counters = c,
                    .PoolCounters = std::make_shared<TCounters>("foo", 42),
                },
                {});
            UNIT_ASSERT(nb.InitionError.Wait(TDuration::Minutes(1)));
            UNIT_ASSERT_C(!nb.InitionError.GetValue(), *nb.InitionError.GetValue());
            THandle& h = *nb.Handle;

            UNIT_ASSERT(!h.Bad());
            UNIT_ASSERT_VALUES_EQUAL(1, c->TotalHandles.GetCount());
            UNIT_ASSERT_NO_EXCEPTION(h.Query("/ping"));
            UNIT_ASSERT(!h.Bad());

            func = [&ev]() { ev.WaitT(TDuration::Seconds(5)); };
            h.NonBlockingQuery("/ping", TDuration::Seconds(50)); // set timeout in curl
            // wait for timeout in cond var
            TInstant start = TInstant::Now();
            UNIT_ASSERT_EXCEPTION_CONTAINS(h.WaitResult(TDuration::MilliSeconds(300)),
                                           TTimeoutException,
                                           "query timeout (300 ms)");
            UNIT_ASSERT_LT(TInstant::Now() - start, TDuration::Seconds(4));
            UNIT_ASSERT(h.Bad());
        }

        UNIT_ASSERT_VALUES_EQUAL(1, c->TotalHandles.GetCount());
        ev.Signal();

        Sleep(TDuration::Seconds(1));
        UNIT_ASSERT_VALUES_EQUAL(0, c->TotalHandles.GetCount());
    }

    Y_UNIT_TEST(curlTimeout) {
        using namespace NPassport::NDbPool;

        std::function<void()> func = []() {};
        TAutoEvent ev;

        TPortManager pm;
        ui16 port = pm.GetPort(80);
        NMock::TMockServer server(port, [&func]() { return new TAnyAction(func); });

        THandle::TNonblockingInit nb = THandle::CreateNonblocking(
            THandleSettings{
                .Dsn = TDestination::CreateHttp(),
                .DbHost = {"localhost", port},
                .ConnectionTimeout = TDuration::Minutes(1),
                .QueryTimeout = TDuration::MilliSeconds(300),
            },
            THandleUnistatCtx{
                .Counters = std::make_shared<TCounters>("foo", 42),
                .PoolCounters = std::make_shared<TCounters>("foo", 42),
            },
            {});

        UNIT_ASSERT(nb.InitionError.Wait(TDuration::Minutes(1)));
        UNIT_ASSERT_C(!nb.InitionError.GetValue(), *nb.InitionError.GetValue());
        THandle& h = *nb.Handle;

        UNIT_ASSERT(!h.Bad());
        UNIT_ASSERT_NO_EXCEPTION(h.Query("/ping"));
        UNIT_ASSERT(!h.Bad());

        func = [&ev]() { ev.WaitT(TDuration::Seconds(5)); };
        h.NonBlockingQuery("/ping", TDuration::MilliSeconds(5)); // set timeout in curl
        // wait for timeout in cond var
        TInstant start = TInstant::Now();
        UNIT_ASSERT_EXCEPTION_CONTAINS(h.WaitResult(TDuration::Seconds(300)),
                                       TException,
                                       "query failure: <Timeout was reached>[28]");
        UNIT_ASSERT_LT(TInstant::Now() - start, TDuration::Seconds(4));
        UNIT_ASSERT(h.Bad());
    }

    Y_UNIT_TEST(buildBaseUrl) {
        UNIT_ASSERT_VALUES_EQUAL("http://:17", THttpDriver::BuildBaseUrl("", 17));
        UNIT_ASSERT_VALUES_EQUAL("http://", THttpDriver::BuildBaseUrl("", 0));

        UNIT_ASSERT_VALUES_EQUAL("http://localhost:17", THttpDriver::BuildBaseUrl("localhost", 17));
        UNIT_ASSERT_VALUES_EQUAL("http://localhost", THttpDriver::BuildBaseUrl("localhost", 0));
        UNIT_ASSERT_VALUES_EQUAL("http://127.0.0.1:17", THttpDriver::BuildBaseUrl("127.0.0.1", 17));
        UNIT_ASSERT_VALUES_EQUAL("http://127.0.0.1", THttpDriver::BuildBaseUrl("127.0.0.1", 0));
        UNIT_ASSERT_VALUES_EQUAL("http://[::1]:17", THttpDriver::BuildBaseUrl("::1", 17));
        UNIT_ASSERT_VALUES_EQUAL("http://[::1]", THttpDriver::BuildBaseUrl("::1", 0));

        UNIT_ASSERT_VALUES_EQUAL("http://localhost:17", THttpDriver::BuildBaseUrl("http://localhost", 17));
        UNIT_ASSERT_VALUES_EQUAL("http://localhost", THttpDriver::BuildBaseUrl("http://localhost", 0));
        UNIT_ASSERT_VALUES_EQUAL("http://127.0.0.1:17", THttpDriver::BuildBaseUrl("http://127.0.0.1", 17));
        UNIT_ASSERT_VALUES_EQUAL("http://127.0.0.1", THttpDriver::BuildBaseUrl("http://127.0.0.1", 0));
        UNIT_ASSERT_VALUES_EQUAL("http://::1:17", THttpDriver::BuildBaseUrl("http://::1", 17));
        UNIT_ASSERT_VALUES_EQUAL("http://::1", THttpDriver::BuildBaseUrl("http://::1", 0));

        UNIT_ASSERT_VALUES_EQUAL("http://[::1]:17", THttpDriver::BuildBaseUrl("[::1]", 17));
        UNIT_ASSERT_VALUES_EQUAL("http://[::1]", THttpDriver::BuildBaseUrl("[::1]", 0));
    }

    Y_UNIT_TEST(calculateState) {
        class TTestPool: public TDbPoolImpl {
        public:
            using TDbPoolImpl::CalculateState;
            using TDbPoolImpl::TDbPoolImpl;
        };

        TTestPool p(
            TDbPoolSettings{
                .Dsn = TDestination::CreateHttp(),
                .Hosts = {
                    TDbHost{.Weight = 1},
                    TDbHost{.Weight = 2},
                    TDbHost{.Weight = 3},
                },
                .GetTimeout = TDuration(),
                .FailThreshold = TDuration::MicroSeconds(1),
            },
            {});
        UNIT_ASSERT_EXCEPTION(p.TryPing(), TException);

        TTestPool::THostIdxs idxs;
        UNIT_ASSERT_VALUES_EQUAL(0, p.CalculateState(idxs, TTestPool::ECollectHosts::Alive));
        UNIT_ASSERT_VALUES_EQUAL(0, idxs.size());

        UNIT_ASSERT_VALUES_EQUAL(0, p.CalculateState(idxs, TTestPool::ECollectHosts::Bad));
        UNIT_ASSERT_VALUES_EQUAL(TTestPool::THostIdxs({0, 1, 2}), idxs);

        TString msg;
        UNIT_ASSERT_C(!p.IsOk(&msg), msg);
    }

    Y_UNIT_TEST(calculateSizes) {
        struct TTestCase {
            TString Name;
            std::vector<size_t> HostWeights;
            size_t Size = 0;
            TDbPoolImpl::THostIdxs Idxs;
            size_t WeightSum = 0;
            TDbPoolImpl::THostSizes Expected;
        };

        std::vector<TTestCase> cases = {
            {
                .Name = "down",
                .HostWeights = {1, 2, 3},
                .Size = 1024,
                .Idxs = {},
                .WeightSum = 0,
                .Expected = {},
            },
            {
                .Name = "only first host",
                .HostWeights = {1, 2, 3},
                .Size = 1024,
                .Idxs = {0},
                .WeightSum = 1,
                .Expected = {1024},
            },
            {
                .Name = "first and second host",
                .HostWeights = {1, 2, 3},
                .Size = 1024,
                .Idxs = {0, 1},
                .WeightSum = 3,
                .Expected = {341, 683},
            },
            {
                .Name = "first and second host with illegal weightSum",
                .HostWeights = {1, 2, 3},
                .Size = 1024,
                .Idxs = {0, 1},
                .WeightSum = 4,
                .Expected = {256, 768},
            },
            {
                .Name = "second and third hosts",
                .HostWeights = {1, 2, 3},
                .Size = 1024,
                .Idxs = {1, 2},
                .WeightSum = 5,
                .Expected = {409, 615},
            },
            {
                .Name = "all hosts",
                .HostWeights = {1, 2, 3},
                .Size = 1024,
                .Idxs = {0, 1, 2},
                .WeightSum = 6,
                .Expected = {170, 342, 512},
            },
            {
                .Name = "uniform distribution",
                .HostWeights = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
                .Size = 96,
                .Idxs = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26},
                .WeightSum = 27,
                .Expected = {3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 4, 3, 4, 3, 4, 3, 4, 3, 4, 4, 3, 4, 3, 4, 3, 4, 4},
            },
        };

        for (const TTestCase& c : cases) {
            TDbPoolSettings settings{
                .Dsn = TDestination::CreateHttp(),
                .Size = c.Size,
                .GetTimeout = TDuration(),
            };
            settings.Hosts.clear();
            for (auto s : c.HostWeights) {
                settings.Hosts.push_back(TDbHost{.Weight = s});
            }

            class TTestPool: public TDbPoolImpl {
            public:
                using TDbPoolImpl::CalculateSizes;
            };
            TDbPoolImpl::THostSizes sizes;
            TTestPool::CalculateSizes(settings, c.Idxs, c.WeightSum, sizes);
            UNIT_ASSERT_VALUES_EQUAL_C(c.Expected, sizes, c.Name);
        }
    }
}

template <>
void Out<TDbPoolImpl::THostIdxs>(IOutputStream& o, const TDbPoolImpl::THostIdxs& value) {
    for (auto idx : value) {
        o << idx << ",";
    }
}
