#include <infra/netmon/probe_scheduler.h>
#include <infra/netmon/common_ut.h>
#include <infra/netmon/settings.h>
#include <infra/netmon/topology/common_ut.h>

#include <library/cpp/testing/unittest/registar.h>

#include <util/generic/algorithm.h>

class TUniformSwitchProbeSchedulerTest: public TTestBase {
    UNIT_TEST_SUITE(TUniformSwitchProbeSchedulerTest);
    UNIT_TEST(TestIntraDcSchedule)
    UNIT_TEST(TestCrossDcSchedule)
    UNIT_TEST(TestIpv4Schedule)
    UNIT_TEST_SUITE_END();

public:
    TUniformSwitchProbeSchedulerTest()
    {
    }

private:
    bool SwitchHasHostWithBackboneInterfaces(const TSwitch& switch_) {
        return AnyOf(switch_.GetRealHosts(), [](const THost* host) -> bool {
            return host->GetBackboneInterface();
        });
    }

    bool SwitchHasHostWithIpv4Interfaces(const TSwitch& switch_) {
        return AnyOf(switch_.GetRealHosts(), [](const THost* host) -> bool {
            return host->GetIpv4Interface();
        });
    }

    inline void CheckProbesDistributionAmongHosts(const TSwitch& switch_, IProbeScheduler::TProbeSchedule& schedule, bool v4 = false) {
        int approxProbesCount = -1;
        for (const auto host : switch_.GetRealHosts()) {
            if (!v4 || host->GetIpv4Interface()) {
                const auto found(schedule.find(TTopology::THostRef(host)));
                if (!found.IsEnd()) {
                    int foundProbesCount;
                    if (v4) {
                        foundProbesCount = found->second.Ipv4IntraDc.size();
                    } else {
                        foundProbesCount = found->second.IntraDc.size();
                    }
                    if (approxProbesCount == -1) {
                        approxProbesCount = foundProbesCount; // probes count from first active host
                    }
                    // probes count from other hosts can not differ more than 1
                    UNIT_ASSERT(std::abs(static_cast<int>(foundProbesCount) - approxProbesCount) <= 1);
                }
            }
        }
    }

    inline void TestIntraDcSchedule() {
        TestIntraDcSchedule(true);
        TestIntraDcSchedule(false);
    }

    inline void TestIntraDcSchedule(bool podAsQueue) {
        TTopologySettings::Get()->SetUsePodAsQueue(podAsQueue);

        const std::size_t probesBetweenTwoSwitches = 3;
        const std::size_t probesBetweenTwoPods = 0; // disable cross-dc probes

        TTopologyUpdater topologyUpdater(false);
        TExpressionStorage expressionStorage(false);
        TGroupStorage groupStorage(false);
        auto topologyStorage_ = podAsQueue ? MakeAtomicShared<TTopologyStorage>(topologyUpdater, expressionStorage, groupStorage, false) : nullptr;

        TTopologyStorage& topologyStorage = podAsQueue ? *topologyStorage_ : TGlobalTopology::GetTopologyStorage();
        const auto topology(topologyStorage.GetTopology());

        TSettings::Get()->SetProbeScheduleDcs({"sas"});
        TSettings::Get()->SetCrossDcProbeScheduleFullMesh(false);
        TSettings::Get()->SetIpv4IntraDcScheduleEnabled(false);

        TVector<TTopology::TDatacenterRef> dcs;
        if (podAsQueue) {
            dcs.emplace_back(FindDatacenter("sas1", &topologyStorage));
            dcs.emplace_back(FindDatacenter("sas2", &topologyStorage));
        } else {
            dcs.emplace_back(FindDatacenter("sas", &topologyStorage));
        }

        auto switchFilter = [&dcs](const TSwitch& switch_) {
            for (const auto& dc : dcs) {
                if (switch_.GetDatacenter() == *dc)
                    return true;
            }
            return false;
        };

        TUniformSwitchProbeScheduler scheduler(
            topologyStorage, switchFilter,
            {{"default", probesBetweenTwoSwitches}},
            probesBetweenTwoPods
        );
        TTopologyStorage::THostSet interestedHosts;
        topology->ForEachSwitch([&](const TSwitch& switch_) {
            if (switchFilter(switch_)) {
                for (const auto host : switch_.GetRealHosts()) {
                    interestedHosts.emplace(host);
                }
            }
        });
        const TTopologyStorage::THostSet deadHosts;
        auto schedule(scheduler.Schedule(interestedHosts, deadHosts));

        // count probes between every two switches
        THashMap<TUniformSwitchProbeScheduler::TSwitchPair, std::size_t> switchProbes;
        for (const auto& hostConfig : schedule) {
            const auto source(hostConfig.first);
            for (const auto& target : hostConfig.second.IntraDc) {
                TUniformSwitchProbeScheduler::TSwitchPair pair(source->GetSwitch(), target->GetSwitch());
                ++switchProbes[pair];
            }
        }

        for (const auto& probes : switchProbes) {
            // check probes count between every two switches
            UNIT_ASSERT(probes.second == probesBetweenTwoSwitches);

            // check probes evenly distributed among hosts in every switch
            CheckProbesDistributionAmongHosts(*probes.first.First, schedule);
            CheckProbesDistributionAmongHosts(*probes.first.Second, schedule);
        }

        // check we cover all switch pairs
        std::size_t pairsCount(0);
        for (const auto& dc : dcs) {
            TVector<TTopology::TSwitchRef> switches;
            topology->ForEachSwitch([&](const TSwitch& switch_) {
                if (switch_.GetDatacenter() == dc) {
                    switches.emplace_back(switch_);
                }
            });
            for (auto firstIt = begin(switches); firstIt != end(switches); ++firstIt) {
                for (auto secondIt = std::next(firstIt); secondIt != end(switches); ++secondIt) {
                    const auto& first = **firstIt;
                    const auto& second = **secondIt;
                    // TODO: check fastbone interface existence too if needed
                    if (switchFilter(first) && switchFilter(second) &&
                        SwitchHasHostWithBackboneInterfaces(first) && SwitchHasHostWithBackboneInterfaces(second))
                    {
                        TUniformSwitchProbeScheduler::TSwitchPair pair(first, second);
                        UNIT_ASSERT(!switchProbes.find(pair).IsEnd());
                        ++pairsCount;
                    }
                };
            };
        }
        // and nothing more
        UNIT_ASSERT(switchProbes.size() == pairsCount);
    }

    inline void TestCrossDcSchedule() {
        TestCrossDcSchedule(true);
        TestCrossDcSchedule(false);
    }

    inline void TestCrossDcSchedule(bool podAsQueue) {
        TTopologySettings::Get()->SetUsePodAsQueue(podAsQueue);

        const std::size_t probesBetweenTwoSwitches = 3;
        const std::size_t probesBetweenTwoPods = 300;

        TTopologyUpdater topologyUpdater(false);
        TExpressionStorage expressionStorage(false);
        TGroupStorage groupStorage(false);
        auto topologyStorage_ = podAsQueue ? MakeAtomicShared<TTopologyStorage>(topologyUpdater, expressionStorage, groupStorage, false) : nullptr;

        TTopologyStorage& topologyStorage = podAsQueue ? *topologyStorage_ : TGlobalTopology::GetTopologyStorage();
        const auto topology(topologyStorage.GetTopology());

        TSettings::Get()->SetProbeScheduleDcs({"sas"});
        TSettings::Get()->SetCrossDcProbeScheduleFullMesh(false);
        TSettings::Get()->SetIpv4IntraDcScheduleEnabled(false);

        TVector<TTopology::TDatacenterRef> sourceDcs;
        if (podAsQueue) {
            sourceDcs.emplace_back(FindDatacenter("sas1", &topologyStorage));
            sourceDcs.emplace_back(FindDatacenter("sas2", &topologyStorage));
        } else {
            sourceDcs.emplace_back(FindDatacenter("sas", &topologyStorage));
        }

        auto targetDcs = sourceDcs;
        targetDcs.emplace_back(FindDatacenter("man", &topologyStorage));
        targetDcs.emplace_back(FindDatacenter("vla", &topologyStorage));

        auto sourceDcSwitchFilter = [&sourceDcs](const TSwitch& switch_) {
            return AnyOf(sourceDcs, [&switch_](const TTopology::TDatacenterRef& dc) {
                return switch_.GetDatacenter() == *dc;
            });
        };

        TUniformSwitchProbeScheduler scheduler(
            topologyStorage, sourceDcSwitchFilter,
            {{"default", probesBetweenTwoSwitches}}, 
            probesBetweenTwoPods
        );
        TTopologyStorage::THostSet interestedHosts;
        TTopologyStorage::THostSet deadHosts;
        TTopologyStorage::THostSet crossDcTargetHosts;

        topology->ForEachSwitch([&](const TSwitch& switch_) {
            if (sourceDcSwitchFilter(switch_)) {
                for (const auto host : switch_.GetRealHosts()) {
                    interestedHosts.emplace(host);
                }
            }
            if (FindPtr(targetDcs, switch_.GetDatacenter())) {
                for (const auto host : switch_.GetRealHosts()) {
                    crossDcTargetHosts.emplace(host);
                }
            }
        });
        auto schedule(scheduler.Schedule(interestedHosts, deadHosts, crossDcTargetHosts));

        // count cross-dc probes between every two pods
        THashMap<std::pair<TTopology::TPodRef, TTopology::TPodRef>, std::size_t> crossDcPodProbes;
        for (const auto& hostConfig : schedule) {
            const auto source(hostConfig.first);
            for (const auto& target : hostConfig.second.CrossDc) {
                UNIT_ASSERT(source->GetDatacenter() != target->GetDatacenter());
                if (source->GetPod() && target->GetPod()) {
                    std::pair<TTopology::TPodRef, TTopology::TPodRef> pair(*source->GetPod(), *target->GetPod());
                    ++crossDcPodProbes[pair];
                }
            }
        }

        // check we cover every pod pair exactly once
        std::size_t podPairsCount(0);
        for (const auto& targetDc : targetDcs) {
            for (const auto& sourceDc : sourceDcs) {
                if (sourceDc == targetDc) {
                    continue;
                }

                for (const auto& targetPod : targetDc->GetPods()) {
                    for (const auto& sourcePod : sourceDc->GetPods()) {
                        std::pair<TTopology::TPodRef, TTopology::TPodRef> pair(*sourcePod, *targetPod);
                        UNIT_ASSERT_C(crossDcPodProbes.contains(pair), pair.first->GetName() + " - " + pair.second->GetName());
                        UNIT_ASSERT(crossDcPodProbes.at(pair) == probesBetweenTwoPods);
                        ++podPairsCount;
                    }
                }
            }
        }
        UNIT_ASSERT(crossDcPodProbes.size() == podPairsCount);
    }

    inline void TestIpv4Schedule() {
        const std::size_t probesBetweenTwoSwitches = 3;
        const std::size_t probesBetweenTwoPods = 0; // disable cross-dc probes

        auto& topologyStorage(TGlobalTopology::GetTopologyStorage());
        const auto topology(topologyStorage.GetTopology());

        TSettings::Get()->SetProbeScheduleDcs({"myt"});
        TSettings::Get()->SetCrossDcProbeScheduleFullMesh(false);
        TSettings::Get()->SetIpv4IntraDcScheduleEnabled(true);

        auto dc(FindDatacenter("myt"));
        auto switchFilter = [&dc](const TSwitch& switch_) { return switch_.GetDatacenter() == *dc; };

        TUniformSwitchProbeScheduler scheduler(
            topologyStorage, switchFilter,
            {{"default", probesBetweenTwoSwitches}},
            probesBetweenTwoPods
        );
        TTopologyStorage::THostSet interestedHosts;
        topology->ForEachSwitch([&](const TSwitch& switch_) {
            if (switchFilter(switch_)) {
                for (const auto host : switch_.GetRealHosts()) {
                    interestedHosts.emplace(host);
                }
            }
        });
        const TTopologyStorage::THostSet deadHosts;
        auto schedule(scheduler.Schedule(interestedHosts, deadHosts));

        // count probes between every two switches
        THashMap<TUniformSwitchProbeScheduler::TSwitchPair, std::size_t> switchProbes;
        for (const auto& hostConfig : schedule) {
            const auto source(hostConfig.first);
            for (const auto& target : hostConfig.second.Ipv4IntraDc) {
                TUniformSwitchProbeScheduler::TSwitchPair pair(source->GetSwitch(), target->GetSwitch());
                ++switchProbes[pair];
            }
        }

        for (const auto& probes : switchProbes) {
            // check probes count between every two switches
            UNIT_ASSERT(probes.second == probesBetweenTwoSwitches);

            // check probes evenly distributed among hosts in every switch
            CheckProbesDistributionAmongHosts(*probes.first.First, schedule, true);
            CheckProbesDistributionAmongHosts(*probes.first.Second, schedule, true);
        }

        // check we cover all switch pairs
        std::size_t pairsCount(0);
        topology->ForEachSwitch([&](const TSwitch& first) {
            topology->ForEachSwitch([&](const TSwitch& second) {
                if (first != second &&
                    switchFilter(first) && switchFilter(second) &&
                    SwitchHasHostWithIpv4Interfaces(first) && SwitchHasHostWithIpv4Interfaces(second))
                {
                    TUniformSwitchProbeScheduler::TSwitchPair pair(first, second);
                    UNIT_ASSERT(!switchProbes.find(pair).IsEnd());
                    ++pairsCount;
                }
            });
        });
        // and nothing more
        UNIT_ASSERT(switchProbes.size() == pairsCount / 2);
    }
};

UNIT_TEST_SUITE_REGISTRATION(TUniformSwitchProbeSchedulerTest);
