#include "pinger.h"

#include "icmp_header.h"
#include "ipv4_header.h"

#include <yandex_io/libs/errno/errno_exception.h>
#include <yandex_io/libs/logging/logging.h>
#include <yandex_io/libs/threading/i_callback_queue.h>

#include <util/generic/scope.h>
#include <util/system/yassert.h>

#include <sstream>
#include <stdio.h>
#include <string.h>

#include <netdb.h>
#include <poll.h>
#include <strings.h>
#include <unistd.h>
#include <netinet/in.h>

namespace quasar {

    Pinger::Pinger(
        const Duration& interval,
        const Duration& timeout,
        Listener* listener,
        ICallbackQueue* worker,
        SocketType socketType)
        : pingId_(0)
        , interval_(interval)
        , timeout_(timeout)
        , listener_(listener)
        , sequenceNumber_(0)
        , worker_(worker)
        , pid_(getpid())
        , stopped_(false)
        , socketType_(socketType)
    {
        switch (socketType_) {
            case SocketType::DGRAM:
                socket_ = socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP);
                break;
            case SocketType::RAW:
                socket_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP);
                break;
        }

        if (socket_ == -1) {
            YIO_LOG_DEBUG("Can't create ICMP socket, errno=" << errno);
            return;
        }

        socketListener_ = std::thread([this]() { waitICMPPackets(); });
    }

    std::shared_ptr<Pinger> Pinger::create(
        const Duration& interval,
        const Duration& timeout,
        Listener* listener,
        ICallbackQueue* worker,
        SocketType socketType)
    {
        return std::make_shared<Pinger>(interval, timeout, listener, worker, socketType);
    }

    Pinger::~Pinger() {
        stopped_ = true;
        wakeupFd_.signal();

        if (socketListener_.joinable()) {
            socketListener_.join();
        }
    }

    std::weak_ptr<Pinger> Pinger::weakFromThis() {
        return shared_from_this();
    }

    uint64_t Pinger::startPing(const std::string& host, const Duration& interval, const Duration& timeout) {
        const uint64_t id = ++pingId_;
        Duration thisInterval = interval == Duration(0) ? interval_ : interval;
        Duration thisTimeout = timeout == Duration(0) ? timeout_ : timeout;

        const auto weakRef = weakFromThis();

        worker_->add([this, weakRef, id, host, thisInterval, thisTimeout]() {
            if (const auto ref = weakRef.lock()) {
                pings_.insert(std::make_pair(id, Ping(host, thisInterval, thisTimeout)));
                sendPing(id);
            }
        });

        return id;
    }

    void Pinger::stopPing(uint64_t pingId) {
        const auto weakRef = weakFromThis();

        worker_->add([this, weakRef, pingId]() {
            if (const auto ref = weakRef.lock()) {
                pings_.erase(pingId);
            }
        });
    }

    void Pinger::sendPing(uint64_t pingId) {
        const auto it = pings_.find(pingId);

        if (it == pings_.end()) {
            return;
        }

        const Ping& ping = it->second;

        struct addrinfo hints;
        memset(&hints, 0, sizeof(addrinfo));
        hints.ai_family = AF_INET;
        hints.ai_socktype = SOCK_RAW;
        hints.ai_protocol = IPPROTO_ICMP;

        struct addrinfo* addrs = nullptr;
        Y_DEFER {
            if (addrs) {
                freeaddrinfo(addrs);
            }
        };

        int res = getaddrinfo(ping.host.c_str(), nullptr, &hints, &addrs);
        if (res != 0) {
            YIO_LOG_DEBUG("getaddrinfo(" << ping.host << ") failed, returncode=" << res);
            scheduleNextPing(pingId, ping.interval);
            return;
        }

        if (addrs == nullptr) {
            YIO_LOG_DEBUG("no ICMPv4 addresses found for " << ping.host);
            scheduleNextPing(pingId, ping.interval);
            return;
        }

        icmp_header echoRequest;
        std::string body;
        echoRequest.type(icmp_header::echo_request);
        echoRequest.code(0);
        if (socketType_ != SocketType::DGRAM) {
            /**
             * In case of SOCK_DGRAM kernel sets id to the local port of the socket: https://lwn.net/Articles/420800/
             * Echo replies are then multiplexed based on these id values.
             * Otherwise, we need to match echo with echo replies ourselves, so we set it to unique value of pid
             */
            echoRequest.identifier(pid_);
        }
        echoRequest.sequence_number(++sequenceNumber_);
        compute_checksum(echoRequest, body.begin(), body.end());

        std::stringstream s;
        s << echoRequest;

        const auto sendTime = Clock::now();

        int n = sendto(socket_, s.str().c_str(), s.str().size(), 0, addrs->ai_addr, addrs->ai_addrlen);
        if (n == -1) {
            YIO_LOG_DEBUG("ICMP sendto(" << ping.host << ") failed, errno=" << errno);
            scheduleNextPing(pingId, ping.interval);
            return;
        }
        listener_->onEvent(Event(EventType::PACKET_SENT, sendTime, pingId, echoRequest.sequence_number(), ping.host));

        requests_.insert(std::make_pair(echoRequest.sequence_number(), Request(pingId, sendTime)));

        scheduleFinalizePing(echoRequest.sequence_number(), ping.timeout);

        scheduleNextPing(pingId, ping.interval);
    }

    void Pinger::scheduleNextPing(uint64_t pingId, const Duration& interval) {
        const auto weakRef = weakFromThis();

        worker_->addDelayed(
            [this, weakRef, pingId]() {
                if (const auto ref = weakRef.lock()) {
                    sendPing(pingId);
                }
            },
            std::chrono::duration_cast<std::chrono::milliseconds>(interval));
    }

    void Pinger::finalizePing(uint16_t sequenceNumber) {
        const auto it1 = requests_.find(sequenceNumber);

        Y_VERIFY(it1 != requests_.end());
        const Request& request = it1->second;

        const auto it2 = pings_.find(request.pingId);
        if (it2 == pings_.end()) {
            // Ping stopped
            requests_.erase(sequenceNumber);
            return;
        }
        const Ping& ping = it2->second;

        if (request.repliesCount == 0) {
            const auto now = Clock::now();
            listener_->onEvent(Event(EventType::PACKET_LOST, now, request.pingId, sequenceNumber, ping.host, now - request.sendTime));
        }

        requests_.erase(sequenceNumber);
    }

    void Pinger::scheduleFinalizePing(uint16_t sequenceNumber, const Duration& timeout) {
        const auto weakRef = weakFromThis();

        worker_->addDelayed(
            [this, weakRef, sequenceNumber]() {
                if (const auto ref = weakRef.lock()) {
                    finalizePing(sequenceNumber);
                }
            },
            std::chrono::duration_cast<std::chrono::milliseconds>(timeout));
    }

    void Pinger::acceptICMPPacket() {
        uint8_t msg[256];
        struct sockaddr from;
        size_t fromlen = sizeof(from);

        int ret = recvfrom(socket_, msg, sizeof(msg), 0, &from, (socklen_t*)&fromlen);
        if (ret == -1) {
            YIO_LOG_DEBUG("ICMP recvfrom failed, errno=" << errno);
            return;
        }

        const auto recvTime = Clock::now();

        std::stringstream s;
        s << std::string(reinterpret_cast<char*>(msg), ret);

        if (socketType_ == SocketType::RAW) {
            ipv4_header ipv4Header;
            s >> ipv4Header;
            if (s.rdstate() & (std::ios_base::failbit | std::ios_base::eofbit)) {
                YIO_LOG_DEBUG("IP header malformed");
                return;
            }
        }

        icmp_header icmpHeader;
        s >> icmpHeader;
        if (s.rdstate() & (std::ios_base::failbit | std::ios_base::eofbit)) {
            YIO_LOG_DEBUG("ICMP header malformed");
            return;
        }

        if (icmpHeader.type() != icmp_header::echo_reply) {
            // This is not our reply
            return;
        }

        if (socketType_ != SocketType::DGRAM && icmpHeader.identifier() != pid_) {
            // This is not our reply
            return;
        }

        scheduleProcessPingReply(recvTime, icmpHeader.sequence_number());
    }

    void Pinger::processPingReply(const TimePoint& recvTime, uint16_t sequenceNumber) {
        auto it1 = requests_.find(sequenceNumber);
        if (it1 == requests_.end()) {
            // This is either not our reply or this sequenceNumber is already announced lost
            return;
        }
        Request& request = it1->second;

        const auto it2 = pings_.find(request.pingId);
        if (it2 == pings_.end()) {
            // Ping stopped
            return;
        }
        const Ping& ping = it2->second;

        const auto elapsed = recvTime - request.sendTime;

        if (request.repliesCount++ == 0) {
            listener_->onEvent(Event(EventType::PACKET_RECEIVED, recvTime, request.pingId, sequenceNumber, ping.host, elapsed));

        } else {
            listener_->onEvent(Event(EventType::PACKET_RECEIVED_DUPLICATE, recvTime, request.pingId, sequenceNumber, ping.host, elapsed));
        }
    }

    void Pinger::scheduleProcessPingReply(const TimePoint& recvTime, uint16_t sequenceNumber) {
        const auto weakRef = weakFromThis();

        worker_->add([this, weakRef, recvTime, sequenceNumber]() {
            if (const auto ref = weakRef.lock()) {
                processPingReply(recvTime, sequenceNumber);
            }
        });
    }

    void Pinger::waitICMPPackets() {
        struct pollfd pollFd[2];

        pollFd[0].fd = socket_;
        pollFd[0].events = POLLIN;

        pollFd[1].fd = wakeupFd_.fd();
        pollFd[1].events = POLLIN | POLLPRI;

        while (!stopped_) {
            pollFd[0].revents = 0;
            pollFd[1].revents = 0;

            int ret = poll(pollFd, 2, -1);

            if (ret == -1) {
                // This should never happen. Sleep to avoid busy wait
                YIO_LOG_DEBUG("poll failed, errno=" << errno);
                std::this_thread::sleep_for(std::chrono::seconds(1));
                continue;
            }

            if (ret == 0) {
                YIO_LOG_DEBUG("poll unexpected timeout");
                std::this_thread::sleep_for(std::chrono::seconds(1));
                continue;
            }

            if (pollFd[1].revents & (POLLIN | POLLPRI)) {
                YIO_LOG_INFO("polling thread got wakeup event");
                continue;
            }

            if (pollFd[0].revents & POLLIN) {
                acceptICMPPacket();
            }
        }

        YIO_LOG_INFO("polling thread stopped");
    }

} // namespace quasar
