#pragma once

#include <infra/netmon/agent/common/probe.h>
#include <infra/netmon/agent/common/loop.h>
#include <infra/netmon/agent/common/metrics.h>
#include <infra/netmon/agent/common/utils.h>

#include <library/cpp/coroutine/engine/events.h>

#include <util/random/random.h>
#include <util/generic/xrange.h>

namespace NNetmon {
    template <class T>
    class TPacketGenerator;

    template <class T>
    class TBaseProbeGenerator {
    public:
        using TProbeIdType = typename T::TProbeIdType;
        using TKey = std::pair<NAddr::IRemoteAddrRef, TProbeIdType>;

        struct TOps: public ::TCommonIntrHashOps {
            static inline bool EqualTo(const TKey& lhs, const TKey& rhs) {
                return (
                    TSa(lhs.first->Addr()) == TSa(rhs.first->Addr())
                    && lhs.second == rhs.second
                );
            }

            static inline size_t Hash(const TKey& key) noexcept {
                return MultiHash(TSa(key.first->Addr()).hash(), key.second);
            }

            static inline const TKey& ExtractKey(const TBaseProbeGenerator<T>& obj) noexcept {
                return obj.Key_;
            }
        };

        struct THash {
            inline size_t operator() (const TKey& key) noexcept {
                return TOps::Hash(key);
            }
        };

        struct TEqualKey {
            inline bool operator() (const TKey& lhs, const TKey& rhs) {
                return TOps::EqualTo(lhs, rhs);
            }
        };

        TBaseProbeGenerator(const TProbeConfig& config, EProtocol protocol, TLimitedLogger& logger,
                            TContExecutor* executor, TProbeIdType probeId)
            : Config_(config)
            , Protocol_(protocol)
            , Logger_(logger)
            , ProbeId_(probeId)
            , Key_(Config_.TargetAddr, ProbeId_)
            , Statistics_(Config_.Histogram)
            , Generator_(this, Logger_, executor)
        {
        }

        virtual ~TBaseProbeGenerator() = default;

        inline const TProbeConfig& Config() const noexcept {
            return Config_;
        }

        inline EProtocol Protocol() const noexcept {
            return Protocol_;
        }

        inline const TKey& Key() const noexcept {
            return Key_;
        }

        inline TProbeStatistics& Statistics() noexcept {
            return Statistics_;
        }

        inline TPacketGenerator<T>& Generator() noexcept {
            return Generator_;
        }

        void Report(TMaybe<TProbeReport>& report) noexcept {
            Statistics_.Report(Config_, Protocol_, report);
        }

        using TPacket = typename T::TPacket;

        virtual bool ValidatePacket(const TPacket& received, TPacket& sent) noexcept = 0;
        virtual typename TPacket::TRef CreatePacket(ui16 seqno) noexcept = 0;
        virtual T& GetIOHandler() noexcept = 0;

    protected:
        const TProbeConfig Config_;
        EProtocol Protocol_;
        TLimitedLogger& Logger_;
        const TProbeIdType ProbeId_;
        const TKey Key_;
        TProbeStatistics Statistics_;
        TPacketGenerator<T> Generator_;
    };

    template <class T>
    class TSingleSocketProbeGenerator : public TBaseProbeGenerator<T> {
    public:
        using typename TBaseProbeGenerator<T>::TProbeIdType;
        using typename TBaseProbeGenerator<T>::TKey;

        TSingleSocketProbeGenerator(const TProbeConfig& config, EProtocol protocol, TLimitedLogger& logger,
                                    TContExecutor* executor, T& handler, TProbeIdType probeId)
            : TBaseProbeGenerator<T>(config, protocol, logger, executor, probeId)
            , IOHandler_(handler)
        {
        }

        T& GetIOHandler() noexcept override {
            return IOHandler_;
        }
    private:
        T& IOHandler_;
    };

    template <class TProbeGenerator, class TOps>
    void WaitForGenerators(TIntrusiveHashWithAutoDelete<TProbeGenerator, TOps>& generators, TCont* waiter) {
        if (waiter) {
            while (!generators.empty()) {
                auto generator = generators.Begin();
                if (generator->Generator().Cancelled()) {
                    generators.Erase(generator->Key());
                    continue;
                }

                generator->Generator().Wait(waiter);
            }
        }
    }

    template <class T>
    class TPacketGenerator {
    public:
        using TPacket = typename T::TPacket;

        TPacketGenerator(TBaseProbeGenerator<T>* parent, TLimitedLogger& logger, TContExecutor* executor)
            : Parent_(parent)
            , Statistics_(Parent_->Statistics())
            , Logger_(logger)
            , Executor_(executor)
            , Cont_(nullptr)
            , WakeupEvent_(Executor_)
            , Loop_(Executor_, this, "packet_generator")
        {
            Callbacks_.reserve(2);
        }

        void Wait(TCont* waiter) {
            if (Cont_) {
                waiter->Join(Cont_);
            } else {
                // prevent busy loop when generator is ended
                waiter->SleepT(TDuration::Zero());
            }
        }

        // callbacks shouldn't throw exceptions
        void Subscribe(const std::function<void(TBaseProbeGenerator<T>*)>& cb) {
            Y_VERIFY(!Cont_);
            Callbacks_.emplace_back([this, cb] (TCont*) noexcept {
                cb(Parent_);
            });
        }

        inline void Start() {
            Loop_.Start();
        }

        inline bool Cancelled() const noexcept {
            if (Cont_) {
                return Cont_->Cancelled();
            }
            return true;
        }

        inline void OnPacketAvailable(const TPacket& packet) noexcept {
            // find that packet and destroy it
            typename TPacket::TRef inFlightPacket(FindPacket(packet));

            if (inFlightPacket && Parent_->ValidatePacket(packet, *inFlightPacket)) {
                inFlightPacket->Unlink();

                // check that packet sizes differs
                Statistics_.Truncated_ = Statistics_.Truncated_ ||
                        packet.Size() != inFlightPacket->Size() ||
                        packet.Truncated();

                if (packet.FromErrorQueue()) {
                    if (!packet.Offender().Empty()) {
                        Statistics_.Offender_ = packet.Offender();
                    }
                    OnPacketLost(*inFlightPacket);
                } else {
                    OnPacketReceived(packet);
                    if (inFlightPacket->TypeOfService() &&
                        !TPacket::SameTypeOfService(inFlightPacket->TypeOfService(),
                                                    packet.TypeOfService()))
                    {
                        OnPacketTosChanged(packet);
                    }
                }
                // signal for generator that some message arrives
                WakeupEvent_.BroadCast();
            } else {
                PushSignal(EPushSignals::PacketMartian, 1);
            }
        }

    private:
        inline void Run(TCont* cont) noexcept {
            Cont_ = cont;

            if (Parent_->Config().StartDelay) {
                const auto delay(TDuration::MicroSeconds(
                    RandomNumber<ui64>(Parent_->Config().StartDelay.MicroSeconds())
                ));
                Cont_->SleepT(delay);
            }

            auto packetDelay = Min(Parent_->Config().Delay, MAX_PACKET_DELAY);

            for (const auto idx : xrange(Parent_->Config().PacketCount)) {
                if (Cancelled()) {
                    break;
                }

                try {
                    OneShot();
                } catch(...) {
                    Statistics_.Failed_ = true;
                    Statistics_.Error_ = CurrentExceptionMessage();
                    Logger_ << TLOG_WARNING << TStringBuf("Error occured while sending probe to ")
                                      << NAddr::PrintHostAndPort(*Parent_->Config().TargetAddr) << TStringBuf(": ")
                                      << CurrentExceptionMessage() << Endl;
                    break;
                }

                // wait for delay
                if (packetDelay && idx != Parent_->Config().PacketCount - 1) {
                    Cont_->SleepT(packetDelay);
                }
            }

            const auto deadline(Min(Parent_->Config().Timeout, MAX_PACKET_DELAY).ToDeadLine());
            while (!Cancelled() && !InFlightPackets_.Empty() && deadline > TInstant::Now()) {
                WakeupEvent_.WaitD(deadline);
            }

            while (!Cancelled() && !InFlightPackets_.Empty()) {
                typename TPacket::TRef lostPacket(InFlightPackets_.PopFront());
                OnPacketLost(*lostPacket);
            }

            for (auto it(Callbacks_.rbegin()); it != Callbacks_.rend(); ++it) {
                Executor_->Create(*it, "packet_generator_callback");
            }

            Cont_ = nullptr;
        }

        inline void OneShot() {
            // sent another one packet
            typename TPacket::TRef inFlightPacket(Parent_->CreatePacket(++Statistics_.LastSeqno_));
            auto& ioHandler(Parent_->GetIOHandler());

            // write or wait
            while (!Cancelled()) {
                auto ioStatus(ioHandler.Write(*inFlightPacket));
                if (ioStatus) {
                    if (ioStatus->Checked() != inFlightPacket->Size()) {
                        Statistics_.Truncated_ = true;
                    }
                    inFlightPacket->DetachBuffer();
                    InFlightPackets_.PushBack(inFlightPacket.Release());
                    TUnistat::Instance().PushSignalUnsafe(ESignals::PacketSent, 1);
                    break;
                } else {
                    ioHandler.PollT(Cont_, CONT_POLL_WRITE, TDuration::Max());
                }
            }
        }

        TPacket* FindPacket(const TPacket& packet) {
            for (auto it(InFlightPackets_.RBegin()); it != InFlightPackets_.REnd(); ++it) {
                if (it->Stats().Seqno == packet.Stats().Seqno) {
                    return &*it;
                }
            }
            return nullptr;
        }

        void OnPacketReceived(const TPacket& packet) {
            ui64 roundTripTime = 0;
            double clockSkewSys = 0;
            std::tie(roundTripTime, clockSkewSys) = packet.Stats().ComputeRoundTrip();

            if (roundTripTime) {
                if (Statistics_.Histogram_) {
                    Statistics_.Histogram_->RecordValue(roundTripTime);
                }
                Statistics_.Average_ = Statistics_.Average_ * ((double)Statistics_.ReceivedCounter_ / (Statistics_.ReceivedCounter_ + 1)) + (double)roundTripTime / (Statistics_.ReceivedCounter_ + 1);
                TUnistat::Instance().PushSignalUnsafe(ESignals::PacketRoundTripDelay, roundTripTime);
            }

            if (clockSkewSys > 0) {
                Statistics_.MaxClockSkewSys_ = Max(Statistics_.MaxClockSkewSys_, clockSkewSys);
                PushSignal(EPushSignals::ClockSkewSys, clockSkewSys);
            }

            Statistics_.ReceivedCounter_++;
            TUnistat::Instance().PushSignalUnsafe(ESignals::PacketReceived, 1);

            if (packet.CongestionEncountered()) {
                Statistics_.CongestedCounter_++;
                PushSignal(EPushSignals::PacketCongested, 1);
            }

            if (!packet.ValidateData(Parent_->Protocol() == EProtocol::TCP)) {
                Statistics_.CorruptedCounter_++;
                PushSignal(EPushSignals::PacketCorrupted, 1);
                Logger_ << TLOG_WARNING << TStringBuf("Packet corrupted") << TStringBuf(" when sending from ")
                        << NAddr::PrintHostAndPort(*Parent_->Config().SourceAddr)
                        << TStringBuf(" to ") << NAddr::PrintHostAndPort(*Parent_->Config().TargetAddr)
                        << TStringBuf(" with ProbeId = ") << packet.Stats().ProbeId << TStringBuf(" and Seqno = ") << packet.Stats().Seqno << Endl;
            }
        }

        void OnPacketLost(const TPacket&) {
            Statistics_.LostCounter_++;
            TUnistat::Instance().PushSignalUnsafe(ESignals::PacketLost, 1);
        }

        void OnPacketTosChanged(const TPacket&) {
            Statistics_.TosChangedCounter_++;
        }

        TBaseProbeGenerator<T>* Parent_;
        TProbeStatistics& Statistics_;

        TVector<std::function<void(TCont*)>> Callbacks_;

        TLimitedLogger& Logger_;
        TContExecutor* Executor_;

        TCont* Cont_;
        TContSimpleEvent WakeupEvent_;

        typename TPacket::TListType InFlightPackets_;

        TSimpleContLoop<TPacketGenerator, &TPacketGenerator::Run> Loop_;
    };

    template <class T>
    class TProbeCollector : public TNonCopyable {
    public:
        inline TProbeCollector(T* parent, TCont* cont, const TVector<TProbeConfig>& configList)
            : Cont_(cont)
            , ConfigList_(configList)
            , GeneratorCount_(0)
        {
            Reports_.resize(ConfigList_.size());

            std::size_t index(0);
            for (const auto& config : configList) {
                auto* generator(parent->Attach(config));
                if (generator) {
                    GeneratorCount_++;
                    generator->Generator().Subscribe([this, index] (auto* ptr) {
                        ptr->Report(Reports_[index]);
                        GeneratorCount_--;
                        Cont_->ReSchedule();
                    });
                } else {
                    Reports_[index].ConstructInPlace(
                        config.Type,
                        parent->Protocol(),
                        config.SourceAddr,
                        config.TargetAddr,
                        config.TypeOfService,
                        "Can't create appropriate probe"
                    );
                }

                index++;
            }
        }

        void Dump(TVector<TProbeReport>& reports) noexcept {
            while (!Cont_->Cancelled() && GeneratorCount_ > 0) {
                Cont_->SleepI();
            }

            reports.reserve(Reports_.size());
            for (const auto& report : Reports_) {
                reports.emplace_back(report.GetRef());
            }
        }

    private:
        TCont* Cont_;
        const TVector<TProbeConfig>& ConfigList_;
        std::size_t GeneratorCount_;

        TVector<TMaybe<TProbeReport>> Reports_;
    };
}
