#include <infra/netmon/probe_scheduler.h>
#include <infra/netmon/settings.h>
#include <infra/netmon/topology/types.h>

#include <util/generic/xrange.h>
#include <util/random/random.h>
#include <util/random/shuffle.h>
#include <util/string/join.h>

namespace NNetmon {
    namespace {
        const size_t SCHEDULED_TARGETS_WARN_THRESHOLD = 512;
        const size_t SPARSE_SCHEDULE_TARGETS_PER_SWITCH = 100;

        enum class EScheduleType {
            INTRA_DC,
            CROSS_DC,
            INTRA_DC_SPARSE,
            IPV4_INTRA_DC
        };
    };

    class TUniformSwitchProbeScheduler::TImpl {
    public:
        TImpl(const TTopologyStorage& topologyStorage,
              TSwitchFilter filter,
              THashMap<TString, std::size_t> probesBetweenTwoSwitches,
              std::size_t crossDcProbesBetweenTwoPods)
            : TopologyStorage(topologyStorage)
            , SwitchFilter(filter)
            , ProbesBetweenTwoSwitches(probesBetweenTwoSwitches)
            , CrossDcProbesBetweenTwoPods(crossDcProbesBetweenTwoPods)
            , MinHostCountPerPodThreshold(5 * crossDcProbesBetweenTwoPods)
        {
            for (const auto &[_, probes] : ProbesBetweenTwoSwitches) {
                // at least one probe in each direction
                Y_VERIFY(probes >= 2);
            }
        }

        TProbeSchedule Schedule(const TTopologyStorage::THostSet& interestedHosts,
                                const TTopologyStorage::THostSet& deadHosts,
                                const TTopologyStorage::THostSet& crossDcTargetHosts,
                                const TTopologyStorage::THostSet& mutedHosts) const
        {
            TProbeSchedule schedule;

            const auto hostFilter = [&interestedHosts, &deadHosts, &mutedHosts](const TTopology::THostRef& host) {
                return interestedHosts.contains(host) &&
                       !deadHosts.contains(host) &&
                       !mutedHosts.contains(host);
            };
            const auto otherDcHostFilter = [&deadHosts](const TTopology::THostRef& host) {
                return !deadHosts.contains(host) &&
                       host->GetBackboneInterface();
            };


            const auto& probesMap = TSettings::Get()->GetScheduledProbesBetweenTwoSwitches();
            std::size_t defaultProbeCount = probesMap.at("default");
            if (TSettings::Get()->IsCrossDcProbeScheduleFullMesh()) {
                if (probesMap.size() > 1) {
                    WARNING_LOG << "Ignoring per-dc probe density settings" << Endl;
                }
                GenerateIntraDcSchedule(SwitchFilter, hostFilter, defaultProbeCount, schedule);
            } else {
                for (const auto& dcName : TSettings::Get()->GetProbeScheduleDcs()) {
                    INFO_LOG << "Generating schedule for " << dcName << "..." << Endl;
                    auto dc = TopologyStorage.FindDatacenter(dcName);
                    if (!dc) {
                        ERROR_LOG << "Datacenter " << dcName << " not found in topology" << Endl;
                        continue;
                    }
                    auto dcFilter = [this, &dc](const TSwitch& switch_) {
                        return SwitchFilter(switch_) && switch_.GetDatacenter() == dc;
                    };
                    std::size_t probesBetweenTwoSwitches = defaultProbeCount;
                    if (probesMap.contains(dcName)) {
                        probesBetweenTwoSwitches = probesMap.at(dcName);
                    }

                    INFO_LOG << "Using " << probesBetweenTwoSwitches << " probes between every two switches" << Endl;
                    GenerateIntraDcSchedule(dcFilter, hostFilter, probesBetweenTwoSwitches, schedule);
                    if (!crossDcTargetHosts.empty()) {
                        AddCrossDcProbes(dc, hostFilter, otherDcHostFilter, crossDcTargetHosts, schedule);
                    }
                }
            }

            CheckSchedule(schedule);
            return schedule;
        }

    private:
        using THostFilter = std::function<bool(const TTopology::THostRef&)>;

        struct THostConfig {
            THostConfig() = default;

            THostConfig(const THost* host)
                : Host(host)
            {
            }

            TTopology::THostRef Host;
            TVector<TTopology::THostRef> Targets;
        };

        class TProbeDistributor {
        public:
            TProbeDistributor() = default;

            TProbeDistributor(const TSwitch& switch_, const THostFilter& filter)
                : TProbeDistributor(ExtractHosts(switch_, filter))
            {
            }

            TProbeDistributor(const TPod& pod,
                              const THostFilter& filter,
                              size_t minHostCountThreshold)
                : TProbeDistributor(ExtractHosts(pod, filter, minHostCountThreshold))
            {
            }

            TProbeDistributor(TVector<const THost*>&& hosts) {
                ShuffleRange(hosts);
                HostConfigs.reserve(hosts.size());
                for (const auto host : hosts) {
                    HostConfigs.emplace_back(host);
                }
            }

            static void AddProbe(TProbeDistributor& from, TProbeDistributor& to) {
                from.AddOutgoingProbe(to.GetTargetHost());
            }

            const TVector<THostConfig>& GetHostConfigs() const {
                return HostConfigs;
            }

            void AddToGlobalSchedule(EScheduleType type, TProbeSchedule& schedule) const {
                for (const auto& hostConfig : HostConfigs) {
                    for (const auto& target : hostConfig.Targets) {
                        auto& hostSchedule = schedule[hostConfig.Host];
                        switch (type) {
                            case EScheduleType::INTRA_DC:
                                hostSchedule.IntraDc.emplace_back(target);
                                break;
                            case EScheduleType::CROSS_DC:
                                hostSchedule.CrossDc.emplace_back(target);
                                break;
                            case EScheduleType::INTRA_DC_SPARSE:
                                hostSchedule.IntraDcSparse.emplace_back(target);
                                break;
                            case EScheduleType::IPV4_INTRA_DC:
                                hostSchedule.Ipv4IntraDc.emplace_back(target);
                                break;
                        }
                    }
                }
            }

        private:
            TVector<const THost*> ExtractHosts(const TSwitch& switch_, const THostFilter& filter) {
                TVector<const THost*> hosts;
                hosts.reserve(switch_.GetRealHosts().size());
                for (const auto host : switch_.GetRealHosts()) {
                    TTopology::THostRef hostRef(host);
                    if (filter(hostRef)) {
                        hosts.emplace_back(host);
                    }
                }
                return hosts;
            }

            TVector<const THost*> ExtractHosts(const TPod& pod,
                                               const THostFilter& filter,
                                               size_t minHostCountThreshold) {
                TVector<const THost*> hosts;
                for (const auto& switch_ : pod.GetSwitches()) {
                    for (const auto& host : switch_->GetRealHosts()) {
                        TTopology::THostRef hostRef(host);
                        // prefer RTC hosts for cross-dc probes
                        if (filter(hostRef) && hostRef->IsRtc()) {
                            hosts.emplace_back(host);
                        }
                    }
                }
                if (hosts.size() < minHostCountThreshold) {
                    WARNING_LOG << "Non-RTC hosts will be included in schedule "
                                << "for pod " << pod.GetName() << Endl;
                    for (const auto& switch_ : pod.GetSwitches()) {
                        for (const auto& host : switch_->GetRealHosts()) {
                            TTopology::THostRef hostRef(host);
                            if (filter(hostRef) && !hostRef->IsRtc()) {
                                hosts.emplace_back(host);
                            }
                        }
                    }
                }
                return hosts;
            }

            void AddOutgoingProbe(TTopology::THostRef target) {
                HostConfigs[SourceHostIdx].Targets.emplace_back(target);
                SourceHostIdx = (SourceHostIdx + 1) % HostConfigs.size();
            }

            TTopology::THostRef GetTargetHost() {
                auto host(HostConfigs[TargetHostIdx].Host);
                TargetHostIdx = (TargetHostIdx + 1) % HostConfigs.size();
                return host;
            }

            TVector<THostConfig> HostConfigs;
            std::size_t SourceHostIdx = 0; // next probe sender, index of host in HostConfigs
            std::size_t TargetHostIdx = 0; // next target for incoming probe in this switch, index of host in HostConfigs
        };

        template <class THostFilter>
        void GenerateIntraDcSchedule(const TSwitchFilter& switchFilter,
                                     const THostFilter& hostFilter,
                                     std::size_t probesBetweenTwoSwitches,
                                     TProbeSchedule& schedule) const {
            const auto& bbHostFilter = [&hostFilter](const TTopology::THostRef& host) {
                // TODO: check fastbone interface existence too if needed
                return hostFilter(host) && host->GetBackboneInterface();
            };

            TVector<TTopology::TSwitchRef> switches;
            TopologyStorage.GetTopology()->ForEachSwitch([&switchFilter, &switches](const TSwitch& switch_) {
                if (switchFilter(switch_)) {
                    switches.emplace_back(switch_);
                }
            });

            AddFullMeshIntraDcProbes(
                switches, bbHostFilter, probesBetweenTwoSwitches, EScheduleType::INTRA_DC, schedule
            );

            if (!TSettings::Get()->GetVlanInversionSelector().Empty()) {
                const auto sparseFilter = [&hostFilter](const TTopology::THostRef& host) {
                    return hostFilter(host) && host->IsRtc();
                };
                AddSparseIntraDcProbes(switches, sparseFilter, schedule);
            }

            if (TSettings::Get()->IsIpv4IntraDcScheduleEnabled()) {
                const auto ipv4HostFilter = [&hostFilter](const TTopology::THostRef& host) {
                    return hostFilter(host) && host->GetIpv4Interface();
                };
                AddFullMeshIntraDcProbes(
                    switches, ipv4HostFilter, probesBetweenTwoSwitches, EScheduleType::IPV4_INTRA_DC, schedule
                );
            }
        }

        void AddFullMeshIntraDcProbes(const TVector<TTopology::TSwitchRef>& switches,
                                      const THostFilter& hostFilter,
                                      std::size_t probesBetweenTwoSwitches,
                                      EScheduleType scheduleType,
                                      TProbeSchedule& schedule) const {
            const auto topology(TopologyStorage.GetTopology());

            THashMap<TTopology::TSwitchRef, TProbeDistributor> switchSchedules;
            THashSet<TSwitchPair> switchPairs;

            for (auto firstIt = begin(switches); firstIt != end(switches); ++firstIt) {
                const auto& first = **firstIt;
                for (auto secondIt = std::next(firstIt); secondIt != end(switches); ++secondIt) {
                    const auto& second = **secondIt;
                    TSwitchPair pair(first, second);
                    TProbeDistributor firstSchedule(first, hostFilter);
                    TProbeDistributor secondSchedule(second, hostFilter);
                    // check each switch has at least one alive host
                    if (!firstSchedule.GetHostConfigs().empty() && !secondSchedule.GetHostConfigs().empty()) {
                        switchPairs.emplace(std::move(pair));
                        switchSchedules.emplace(first, std::move(firstSchedule));
                        switchSchedules.emplace(second, std::move(secondSchedule));
                    }
                }
            }

            for (const auto& pair : switchPairs) {
                auto& first(switchSchedules[pair.First]);
                auto& second(switchSchedules[pair.Second]);

                std::size_t probes(probesBetweenTwoSwitches);

                // add at least one probe in each direction
                TProbeDistributor::AddProbe(first, second);
                TProbeDistributor::AddProbe(second, first);
                probes -= 2;

                if (probes) {
                    // distribute probes according to hosts count
                    std::size_t firstHosts(pair.First->GetRealHosts().size());
                    std::size_t secondHosts(pair.Second->GetRealHosts().size());
                    std::size_t totalHosts(firstHosts + secondHosts);
                    double firstWeight(static_cast<double>(firstHosts) / totalHosts);
                    std::size_t firstProbes(std::round(firstWeight * probes));
                    for (std::size_t i = 0; i < firstProbes; ++i) {
                        TProbeDistributor::AddProbe(first, second);
                    }
                    for (std::size_t i = firstProbes; i < probes; ++i) {
                        TProbeDistributor::AddProbe(second, first);
                    }
                }
            }

            for (const auto& switchSchedule : switchSchedules) {
                switchSchedule.second.AddToGlobalSchedule(scheduleType, schedule);
            }
        }

        void AddSparseIntraDcProbes(const TVector<TTopology::TSwitchRef>& switches,
                                    const THostFilter& hostFilter,
                                    TProbeSchedule& schedule) const {
            const auto topology(TopologyStorage.GetTopology());
            TVector<TProbeDistributor> switchSchedules;

            for (const auto& switchRef : switches) {
                switchSchedules.emplace_back(*switchRef, hostFilter);
                if (switchSchedules.back().GetHostConfigs().empty()) {
                    switchSchedules.pop_back();
                }
            };

            if (switchSchedules.size() <= 1) {
                return;
            }
            size_t targetsPerSwitch = Min(switchSchedules.size() - 1,
                                          SPARSE_SCHEDULE_TARGETS_PER_SWITCH);
            for (size_t sourceIdx : xrange(switchSchedules.size())) {
                for (size_t iter = 0; iter < targetsPerSwitch; ++iter) {
                    size_t targetIdx = sourceIdx;
                    while (targetIdx == sourceIdx) {
                        targetIdx = RandomNumber(switchSchedules.size());
                    }
                    TProbeDistributor::AddProbe(switchSchedules[sourceIdx],
                                                switchSchedules[targetIdx]);
                }
            }

            for (const auto& switchSchedule : switchSchedules) {
                switchSchedule.AddToGlobalSchedule(EScheduleType::INTRA_DC_SPARSE, schedule);
            }
        }

        THashMap<TTopology::TPodRef, TProbeDistributor> CreatePodSchedules(const TTopologyStorage::THostSet& hosts,
                                                                           const THostFilter& hostFilter) const {
            THashMap<TTopology::TPodRef, TVector<const THost*>> hostsByPod;
            for (const auto& host : hosts) {
                if (!host->GetPod() || !hostFilter(host)) {
                    continue;
                }
                const auto& pod = *host->GetPod();

                // prefer RTC hosts for cross-dc probes
                if (host->IsRtc()) {
                    hostsByPod[pod].emplace_back(&*host);
                } else {
                    hostsByPod[pod]; // don't lose pods that have no rtc hosts at all
                }
            }

            THashSet<TTopology::TPodRef> deficientPods;
            for (const auto& pair : hostsByPod) {
                if (pair.second.size() < MinHostCountPerPodThreshold) {
                    deficientPods.emplace(pair.first);
                }
            }
            if (!deficientPods.empty()) {
                TString podNames;
                for (const auto& pod : deficientPods) {
                    podNames += (podNames.empty() ? "" : ",") + pod->GetName();
                }
                WARNING_LOG << "Non-RTC hosts will be included in schedule "
                            << "for the following target pods: " << podNames << Endl;

                for (const auto& host : hosts) {
                    if (host->GetPod() &&
                        deficientPods.contains(*host->GetPod()) &&
                        hostFilter(host) &&
                        !host->IsRtc())
                    {
                        hostsByPod[*host->GetPod()].emplace_back(&*host);
                    }
                }
            }

            THashMap<TTopology::TPodRef, TProbeDistributor> podSchedules;
            podSchedules.reserve(hostsByPod.size());
            for (auto& pair : hostsByPod) {
                podSchedules.emplace(pair.first, std::move(pair.second));
            }
            return podSchedules;
        }

        void AddCrossDcProbes(const TTopology::TDatacenterRef& sourceDc,
                              const THostFilter& sourceHostFilter,
                              const THostFilter& targetHostFilter,
                              const TTopologyStorage::THostSet& targetHosts,
                              TProbeSchedule& schedule) const {
            auto targetPodSchedules = CreatePodSchedules(targetHosts, targetHostFilter);
            for (const auto& pod : sourceDc->GetPods()) {
                TProbeDistributor sourcePodSchedule(*pod, sourceHostFilter, MinHostCountPerPodThreshold);
                if (sourcePodSchedule.GetHostConfigs().empty()) {
                    ERROR_LOG << "Pod " << pod->GetName() << " has no available hosts" << Endl;
                    continue;
                }
                for (auto& pair : targetPodSchedules) {
                    if (pair.first->GetDatacenter() == sourceDc) {
                        continue;
                    }
                    for (std::size_t i = 0; i < CrossDcProbesBetweenTwoPods; ++i) {
                        TProbeDistributor::AddProbe(sourcePodSchedule, pair.second);
                    }
                }
                sourcePodSchedule.AddToGlobalSchedule(EScheduleType::CROSS_DC, schedule);
            }
        }

        void CheckSchedule(const TProbeSchedule& schedule) const {
            size_t overloadedHosts = 0;
            THashSet<TStringBuf> overloadedSwitches;
            for (const auto& hostSchedule : schedule) {
                auto scheduleSize = hostSchedule.second.IntraDc.size() + hostSchedule.second.CrossDc.size();
                if (scheduleSize > SCHEDULED_TARGETS_WARN_THRESHOLD) {
                    ++overloadedHosts;
                    overloadedSwitches.emplace(hostSchedule.first->GetSwitch().GetName());
                }
            }
            if (overloadedHosts) {
                WARNING_LOG << overloadedHosts << " hosts from " << overloadedSwitches.size()
                            << " switches have more than " << SCHEDULED_TARGETS_WARN_THRESHOLD
                            << " targets in schedule. Affected switches: "
                            << JoinSeq(TStringBuf(","), overloadedSwitches) << Endl;
            }
        }

        const TTopologyStorage& TopologyStorage;
        TSwitchFilter SwitchFilter;
        THashMap<TString, std::size_t> ProbesBetweenTwoSwitches;
        std::size_t CrossDcProbesBetweenTwoPods;
        std::size_t MinHostCountPerPodThreshold;
    };

    TUniformSwitchProbeScheduler::TUniformSwitchProbeScheduler(const TTopologyStorage& topologyStorage,
                                                               TSwitchFilter filter,
                                                               THashMap<TString, std::size_t> probesBetweenTwoSwitches,
                                                               std::size_t crossDcProbesBetweenTwoPods)
        : Impl(MakeHolder<TImpl>(topologyStorage, filter, probesBetweenTwoSwitches, crossDcProbesBetweenTwoPods))
    {
    }

    TUniformSwitchProbeScheduler::~TUniformSwitchProbeScheduler() = default;

    IProbeScheduler::TProbeSchedule TUniformSwitchProbeScheduler::Schedule(const TTopologyStorage::THostSet& interestedHosts,
                                                                           const TTopologyStorage::THostSet& deadHosts,
                                                                           const TTopologyStorage::THostSet& crossDcTargetHosts,
                                                                           const TTopologyStorage::THostSet& mutedHosts) const
    {
        return Impl->Schedule(interestedHosts, deadHosts, crossDcTargetHosts, mutedHosts);
    }
}

template <>
void Out<NNetmon::TUniformSwitchProbeScheduler::TSwitchPair>(IOutputStream& stream, TTypeTraits<NNetmon::TUniformSwitchProbeScheduler::TSwitchPair>::TFuncParam pair) {
    stream << pair.First->GetName() << " <-> " << pair.Second->GetName();
}
