#include "beacon.h"

#include "net_utils.h"

#include <yandex_io/libs/logging/logging.h>

#include <util/generic/scope.h>
#include <util/system/byteorder.h>

#include <random>

#include <arpa/inet.h>
#include <sys/eventfd.h>
#include <sys/ioctl.h>
#include <sys/poll.h>
#include <sys/socket.h>

YIO_DEFINE_LOG_MODULE("net_lib");

namespace quasar {

    Beacon::Beacon(int port, std::vector<char> responseMessage)
        : port_(port)
        , responseMessage_(std::move(responseMessage))
    {
    }

    Beacon::~Beacon()
    {
        stop();
    }

    bool Beacon::start()
    {
        std::unique_lock<std::mutex> lock1(workerMutex_);
        if (workerThread_.joinable()) {
            throw std::runtime_error("Beacon already started");
        }

        int shutdownFd = ::eventfd(0, 0);
        if (shutdownFd == -1) {
            throw std::runtime_error("Fail to start beacon");
        }
        Y_DEFER {
            ::close(shutdownFd);
            shutdownFd = -1;
        };

        int listenFd = socket(AF_INET, SOCK_DGRAM, 0);
        if (listenFd < 0) {
            throw std::runtime_error("Fail to open server socket for beacon (error=" + std::to_string(errno) + ")");
        }

        Y_DEFER {
            ::close(listenFd);
            listenFd = -1;
        };

        struct sockaddr_in listenAddr;
        std::memset(&listenAddr, 0, sizeof(listenAddr));
        listenAddr.sin_family = AF_INET;
        listenAddr.sin_port = HostToInet<uint16_t>(port_);
        if (!inet_aton("0.0.0.0", &listenAddr.sin_addr)) {
            throw std::runtime_error("Invalid interface address");
        }

        int broadcast = 1;
        int reuseaddr = 1;

        if (listenAddr.sin_addr.s_addr == INADDR_ANY) {
            if (setsockopt(listenFd, SOL_SOCKET, SO_BROADCAST, &broadcast, sizeof(broadcast)) < 0) {
                throw std::runtime_error("Fail to set SO_BROADCAST option");
            }
        }

        const int nonblocking = 1;
        if (ioctl(listenFd, FIONBIO, &nonblocking)) {
            throw std::runtime_error("Fail to set non-blocking server socket");
        }

        if (bind(listenFd, (struct sockaddr*)&listenAddr, sizeof(listenAddr))) {
            throw std::runtime_error("Fail to bind server socket");
        }

        std::vector<int> broadcastFds;
        Y_DEFER {
            for (int fd : broadcastFds) {
                ::close(fd);
            }
            broadcastFds.clear();
        };

        if (listenAddr.sin_addr.s_addr != INADDR_ANY) {
            auto bcasts = broadcastAddresses(toIpAddress(listenAddr));
            broadcastFds.reserve(bcasts.size());
            for (const auto& bcast : bcasts) {
                const auto& bcastIp = bcast.broadcast;
                int bcastFd = socket(AF_INET, SOCK_DGRAM, 0);
                if (bcastFd < 0) {
                    int e = errno;
                    YIO_LOG_DEBUG("Fail to create broadcast socket at interface " << bcast.name << ": " << std::strerror(e));
                    continue;
                }
                struct sockaddr_in bcastAddr {};
                bcastAddr.sin_family = AF_INET;
                bcastAddr.sin_port = HostToInet<uint16_t>(port_);
                if (!inet_aton(bcastIp.c_str(), &bcastAddr.sin_addr) ||
                    setsockopt(bcastFd, SOL_SOCKET, SO_BROADCAST, &broadcast, sizeof(broadcast)) < 0 ||
                    setsockopt(bcastFd, SOL_SOCKET, SO_REUSEADDR, &reuseaddr, sizeof(reuseaddr)) < 0 ||
                    ioctl(bcastFd, FIONBIO, &nonblocking) ||
                    bind(bcastFd, (struct sockaddr*)&bcastAddr, sizeof(listenAddr)))
                {
                    int e = errno;
                    YIO_LOG_DEBUG("Fail to setup broadcast socket at interface " << bcast.name << ": " << std::strerror(e));
                    ::close(bcastFd);
                    continue;
                }
                broadcastFds.push_back(bcastFd);
            }
        }

        workerThread_ = std::thread(&Beacon::worker, this);
        shutdownFd_ = shutdownFd;
        shutdownFd = -1;
        listenFd_ = listenFd;
        listenFd = -1;
        broadcastFds_ = std::move(broadcastFds);
        broadcastFds.clear();

        return true;
    }

    bool Beacon::stop()
    {
        std::unique_lock<std::mutex> lock(workerMutex_);
        if (!workerThread_.joinable()) {
            return false;
        }
        int res = eventfd_write(shutdownFd_, 1);
        if (res == -1) {
            throw std::runtime_error("Fail to shutdown beacon");
        }
        workerThread_.join();
        ::close(shutdownFd_);
        ::close(listenFd_);
        for (int bcastFd : broadcastFds_) {
            ::close(bcastFd);
        }
        shutdownFd_ = -1;
        listenFd_ = -1;
        broadcastFds_.clear();

        return true;
    }

    void Beacon::worker()
    {
        try {
            bool fShutdown = false;
            std::vector<pollfd> rset;
            int shutdownFd = -1;
            int listenFd = -1;
            {
                std::lock_guard<std::mutex> lock(workerMutex_);
                rset.reserve(2 + broadcastFds_.size());
                rset.emplace_back(pollfd{shutdownFd_, POLLIN | POLLPRI | POLLERR, 0});
                rset.emplace_back(pollfd{listenFd_, POLLIN | POLLPRI | POLLERR, 0});
                for (int bcastFd : broadcastFds_) {
                    rset.emplace_back(pollfd{bcastFd, POLLIN | POLLPRI | POLLERR, 0});
                }
                shutdownFd = shutdownFd_;
                listenFd = listenFd_;
            }
            auto t0 = std::chrono::steady_clock::now();
            int counter = 0;
            do {
                int nready = ::poll(rset.data(), rset.size(), -1);
                if (nready < 0) {
                    break;
                }
                for (const auto& r : rset) {
                    if (r.revents & (POLLIN | POLLPRI)) {
                        if (r.fd == shutdownFd) {
                            fShutdown = true;
                            break;
                        } else {
                            socklen_t clientAddressLen = sizeof(struct sockaddr_in6);
                            uint8_t clientAddressRaw[clientAddressLen];
                            struct sockaddr* clientAddress = reinterpret_cast<struct sockaddr*>(clientAddressRaw);
                            uint8_t buffer[1024];
                            int len = ::recvfrom(r.fd, buffer, sizeof(buffer), 0, clientAddress, &clientAddressLen);
                            if (len < 0) {
                                if (counter == 0 || std::chrono::steady_clock::now() - t0 > std::chrono::seconds(1)) {
                                    int e = errno;
                                    YIO_LOG_DEBUG("Unexpected error on " << (r.fd != listenFd ? "broadcast" : "data") << " socket (errno=" << std::to_string(e) << ") counter=" << counter);
                                }
                                ++counter;
                                t0 = std::chrono::steady_clock::now();
                            } else if (len > 0) {
                                YIO_LOG_DEBUG("Receive beacon impulse from " << toIpAddress(*clientAddress) << ":" << ((struct sockaddr_in*)clientAddress)->sin_port);
                                sendto(listenFd, responseMessage_.data(), responseMessage_.size(), MSG_DONTWAIT, clientAddress, clientAddressLen);
                            }
                        }
                    }
                }
            } while (!fShutdown);
        } catch (const std::exception& ex) {
            YIO_LOG_ERROR_EVENT("Beacon.WorkingThreadException", "Beacon worker thread catch unexpected exception: " << ex.what());
        }
    }

    bool Beacon::impulse(
        const InterfaceConfig& iface,
        int port,
        std::chrono::milliseconds timeout,
        const std::function<void(const std::string&, const std::vector<char>&)>& responseHandler)
    {
        const char* requestMessage = "BEACONHELLO\n";

        if (!responseHandler) {
            return false;
        }

        std::vector<size_t> ipIndexPool;
        IpRange ipRange(iface.ip, iface.netmask);
        size_t hostCapacity = ipRange.hostCapacity();
        if (hostCapacity < 2) {
            return false;
        }

        const auto myIpIndex = ipRange.indexOf(iface.ip);
        if (!myIpIndex) {
            return false;
        }

        constexpr size_t maxHostCount = 254;
        const size_t rightIndex = (hostCapacity > maxHostCount ? std::min(*myIpIndex + maxHostCount / 2, hostCapacity) : hostCapacity);
        const size_t leftIndex = (hostCapacity > maxHostCount ? *myIpIndex - std::min(*myIpIndex, maxHostCount / 2) : 0);

        std::shuffle(
            ipIndexPool.begin(),
            ipIndexPool.end(),
            std::default_random_engine(std::chrono::steady_clock::now().time_since_epoch().count()));

        std::vector<pollfd> discoveryFds;
        Y_DEFER {
            for (const auto& pfd : discoveryFds) {
                if (pfd.fd != -1) {
                    ::close(pfd.fd);
                }
            }
        };
        discoveryFds.reserve(rightIndex - leftIndex - 1);

        for (size_t index = leftIndex; index < rightIndex; ++index) {
            if (myIpIndex == index) {
                continue;
            }

            struct sockaddr_in addr = ipRange.sockaddr(index);
            addr.sin_port = HostToInet<uint16_t>(port);

            int discoveryFd = -1;
            if ((discoveryFd = ::socket(AF_INET, SOCK_DGRAM, 0)) < 0) {
                YIO_LOG_DEBUG("Fail to create request socket: " << ipRange.ip(index));
                continue;
            }
            discoveryFds.push_back(pollfd{discoveryFd, POLLIN | POLLPRI | POLLERR, 0});

            const int nonblocking = 1;
            if (ioctl(discoveryFd, FIONBIO, &nonblocking)) {
                YIO_LOG_DEBUG("Fail to set socket non-blocking mode: " << ipRange.ip(index));
                ::close(discoveryFd);
                discoveryFds.pop_back();
                continue;
            }

            if (::sendto(discoveryFd, requestMessage, strlen(requestMessage), 0, (sockaddr*)&addr, sizeof(addr)) < 0) {
                YIO_LOG_DEBUG("Fail to send beacon request: " << ipRange.ip(index));
                ::close(discoveryFd);
                discoveryFds.pop_back();
                continue;
            }
        }

        if (discoveryFds.empty()) {
            YIO_LOG_DEBUG("Fail to send any message");
            return false;
        }

        constexpr size_t maxMessageSize = 1024;
        constexpr std::chrono::milliseconds timeoutMinQuantum{10};
        constexpr std::chrono::milliseconds timeoutMaxQuantum{100};
        bool result = false;
        std::vector<char> responseMessage(maxMessageSize, 0);
        auto zeroTime = std::chrono::steady_clock::now();
        auto timeoutMs = std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
        std::chrono::milliseconds timeoutQuantum = std::max(timeoutMinQuantum, std::min(timeoutMaxQuantum, timeoutMs / 10));
        do {
            std::chrono::milliseconds timeoutRest = std::chrono::duration_cast<std::chrono::milliseconds>((zeroTime + timeout) - std::chrono::steady_clock::now());
            int timeoutRestMs = std::min(timeoutQuantum, std::max(std::chrono::milliseconds{0}, timeoutRest)).count();
            int nready = ::poll(discoveryFds.data(), discoveryFds.size(), timeoutRestMs);
            if (nready < 0) {
                continue;
            }
            for (size_t i = 0; i < discoveryFds.size(); ++i) {
                const auto& rfd = discoveryFds[i];
                if (rfd.revents & (POLLIN | POLLPRI)) {
                    responseMessage.resize(maxMessageSize);
                    struct sockaddr_in addr {};
                    socklen_t len = sizeof(addr);
                    int bytes = ::recvfrom(rfd.fd, responseMessage.data(), responseMessage.size(), 0, (struct sockaddr*)&addr, &len);
                    if (bytes > 0)
                    {
                        YIO_LOG_DEBUG("Receive beacon answer from " << toIpAddress(addr));
                        responseMessage.resize(bytes);
                        responseHandler(toIpAddress(addr), responseMessage);
                        result = true;
                    }
                }
            }
        } while (std::chrono::steady_clock::now() < zeroTime + timeout);

        return result;
    }

} // namespace quasar
