#include "net_devices_monitor.h"

#include <yandex_io/libs/net/netlink_monitor.h>

#include <mutex>
#include <thread>
#include <unordered_map>

#include <sys/socket.h>

using namespace glagol;

namespace {
    using NetlinkMonitor = quasar::net::NetlinkMonitor;

    class NetDevicesMonitorImpl: public NetDevicesMonitor {
        // entered here only when family is ipv4 or ipv6
        static Family fromSysFamily(int family) {
            return family == AF_INET ? Family::IPV4 : Family::IPV6;
        }

        struct NetlinkHandler: public NetlinkMonitor::Handler {
            using IfFlags = NetlinkMonitor::IfFlags;
            using Scope = NetlinkMonitor::Scope;
            mutable std::mutex mutex_;
            OnDeviceCallback onDeviceUpdated_;
            std::unordered_map<IfIndex, NetDevice> devices_;

            NetlinkHandler(OnDeviceCallback cb)
                : onDeviceUpdated_(std::move(cb))
                      {};

            template <typename Container>
            typename Container::iterator findAddress(Container& src, const std::string addr) {
                return std::find_if(src.begin(), src.end(), [&addr](const auto& a) {
                    return a.addr == addr;
                });
            }

            void onAddress(IfIndex idx, std::string addr, Family family, Scope scope, bool local) override {
                if (local || scope != Scope::UNIVERSAL || !(family == AF_INET || family == AF_INET6)) {
                    return;
                }
                std::scoped_lock lock(mutex_);
                auto& dev = devices_[idx];

                auto iter = findAddress(dev.addresses, addr);
                if (iter == dev.addresses.end()) {
                    dev.addresses.emplace_back(Address{.family = fromSysFamily(family), .addr = std::move(addr)});
                    onDeviceUpdated_(dev);
                }
            }

            void onAddressRemove(IfIndex idx, std::string addr, Family family, Scope /*scope*/, bool local) override {
                if (local || !(family == AF_INET || family == AF_INET6)) {
                    return;
                }
                std::scoped_lock lock(mutex_);

                auto devIter = devices_.find(idx);
                if (devIter == devices_.end()) {
                    return;
                }

                auto& dev = devIter->second;

                auto iter = findAddress(dev.addresses, addr);
                if (iter != dev.addresses.end()) {
                    dev.addresses.erase(iter);
                    onDeviceUpdated_(dev);
                }
            }

            void onLink(IfIndex idx, std::string name, IfFlags flags) override {
                std::scoped_lock lock(mutex_);
                auto [iter, inserted] = devices_.emplace(idx, NetDevice());
                auto& dev = iter->second;
                if (inserted) {
                    dev.interfaceName = std::move(name);
                    dev.running = flags.running;
                    dev.loopback = flags.loopback;
                    onDeviceUpdated_(dev);
                    return;
                }
                if (dev.running != flags.running) {
                    dev.running = flags.running;
                    onDeviceUpdated_(dev);
                }
            }

            void onLinkRemoved(IfIndex idx) override {
                std::scoped_lock lock(mutex_);
                auto iter = devices_.find(idx);
                if (iter == devices_.end()) {
                    return;
                }
                if (iter->second.running || !iter->second.addresses.empty()) {
                    iter->second.running = false;
                    iter->second.addresses.clear();
                    onDeviceUpdated_(iter->second);
                }
                devices_.erase(iter);
            }

            void onMac(IfIndex idx, std::string mac) override {
                std::scoped_lock lock(mutex_);
                devices_[idx].MAC = std::move(mac);
            }

            void eachRunningDevice(OnDeviceCallback& cb) const {
                std::scoped_lock lock(mutex_);
                for (const auto& [idx, dev] : devices_) {
                    if (dev.running && !dev.loopback) {
                        cb(dev);
                    }
                }
            }

            void onStats(IfIndex /*idx*/, const Stats& /*stats*/) override {
            }
        };

        NetlinkHandler handler_;
        std::unique_ptr<NetlinkMonitor> netlinkMonitor_;
        std::thread netlinkThread_;

    public:
        NetDevicesMonitorImpl(OnDeviceCallback cb)
            : handler_(std::move(cb))
            , netlinkMonitor_(quasar::net::makeNetlinkMonitor(handler_))
            , netlinkThread_([this] {
                netlinkMonitor_->monitor(true);
            })
        {
        }

        void stop() override {
            if (netlinkMonitor_) {
                netlinkMonitor_->stop();
                netlinkThread_.join();
                netlinkMonitor_.reset();
            };
        }

        ~NetDevicesMonitorImpl() {
            stop();
        }

        void eachRunningDevice(OnDeviceCallback cb) const override {
            handler_.eachRunningDevice(cb);
        }
    };
} // namespace

namespace glagol {
    NetDevicesMonitorPtr makeNetDevicesMonitor(NetDevicesMonitor::OnDeviceCallback onDeviceUpdated) {
        return std::make_unique<NetDevicesMonitorImpl>(std::move(onDeviceUpdated));
    }
} // namespace glagol
