#include <passport/infra/libs/cpp/dbpool/misc/choose_policy.h>

#include <library/cpp/testing/unittest/registar.h>

using namespace NPassport::NDbPool;

Y_UNIT_TEST_SUITE(Choose) {
    Y_UNIT_TEST(RandomWithEqualWeight) {
        std::vector<TDbHost> hosts;
        size_t count = 100;

        for (size_t idx = 0; idx < count; ++idx) {
            hosts.push_back(TDbHost{.Weight = 1});
        }

        TRandomFactory factory(hosts);
        TChooserPtr chooser = factory.CreateChooser();

        std::set<size_t> results;
        for (size_t idx = 0; idx < count; ++idx) {
            IChooser::TOptionalIdx res = chooser->TryGetIdx();
            UNIT_ASSERT_C(res, idx);
            UNIT_ASSERT_C(results.insert(*res).second, "One host was got twice: " << *res);
        }

        UNIT_ASSERT(!chooser->TryGetIdx());
    }

    Y_UNIT_TEST(RandomWithSingleHost) {
        std::vector<TDbHost> hosts{
            TDbHost{.Weight = 1},
        };

        TRandomFactory factory(hosts);

        TChooserPtr chooser = factory.CreateChooser();
        IChooser::TOptionalIdx res = chooser->TryGetIdx();
        UNIT_ASSERT(res);
        UNIT_ASSERT_VALUES_EQUAL(*res, 0);

        UNIT_ASSERT(!chooser->TryGetIdx());
    }

    Y_UNIT_TEST(RandomWithDifferentWeight) {
        std::vector<TDbHost> hosts{
            TDbHost{.Weight = 10000},
            TDbHost{.Weight = 1},
            TDbHost{.Weight = 3},
        };

        TRandomFactory factory(hosts);

        std::vector<size_t> hits{0, 0, 0};
        for (size_t idx = 0; idx < 1000; ++idx) {
            TChooserPtr chooser = factory.CreateChooser();
            IChooser::TOptionalIdx res = chooser->TryGetIdx();
            UNIT_ASSERT(res);
            ++hits[*res];
        }

        UNIT_ASSERT_LT(hits[1], hits[0]);
        UNIT_ASSERT_LT(hits[2], hits[0]);
    }

    Y_UNIT_TEST(RandomTryGet) {
        auto test = [](size_t id,
                       std::vector<size_t> weights,
                       std::vector<size_t> randoms,
                       std::vector<size_t> results) {
            UNIT_ASSERT_C(weights.size() == randoms.size(), "id=" << id);
            UNIT_ASSERT_C(results.size() == randoms.size(), "id=" << id);

            std::vector<TDbHost> h;
            h.reserve(weights.size());
            for (size_t w : weights) {
                h.push_back(TDbHost{.Weight = w});
            }
            TRandomChooser c(TRandomFactory::PrepareState(h));

            for (size_t idx = 0; idx < randoms.size(); ++idx) {
                IChooser::TOptionalIdx res = c.TryGetIdx(randoms[idx]);
                UNIT_ASSERT_C(res, "id=" << id);
                UNIT_ASSERT_VALUES_EQUAL_C(*res, results[idx], "id=" << id << ". idx=" << idx);
            }

            UNIT_ASSERT(!c.TryGetIdx());
        };

        test(0,
             {3, 1},
             {3, 2},
             {1, 0});
        test(1,
             {1, 1},
             {0, 0},
             {0, 1});
        test(2,
             {1, 100},
             {1, 0},
             {1, 0});
        test(3,
             {1, 100},
             {0, 50},
             {0, 1});
        test(4,
             {1, 100},
             {50, 0},
             {1, 0});
    }
}
