#include "service.h"
#include "handler.h"

#include <infra/netmon/agent/common/generator.h>

#include <util/generic/xrange.h>

namespace NNetmon {
    class TTcpPollerService::TImpl {
    private:
        class TProbeGenerator : public TSingleSocketProbeGenerator<TTcpIOHandler>, public TIntrusiveHashItem<TProbeGenerator>, public TObjectFromPool<TProbeGenerator> {
        public:
            using TRef = THolder<TProbeGenerator>;
            using TMapType = TIntrusiveHashWithAutoDelete<TProbeGenerator, TOps>;

            template <typename... Args>
            static inline TRef Make(TPool& pool, Args&&... args) {
                return TRef(new (&pool) TProbeGenerator(std::forward<Args>(args)...));
            }

            inline TProbeGenerator(const TImpl* parent, const TProbeConfig& config, TLimitedLogger& logger, TContExecutor* executor, TTcpIOHandler& handler, ui8 probeId)
                : TSingleSocketProbeGenerator(config, parent->Protocol(), logger, executor, handler, probeId)
                , Parent_(parent)
                , MinPort_(handler.MinPort())
                , MaxPort_(handler.MaxPort())
            {
            }

            bool ValidatePacket(const TTcpPacket& received, TTcpPacket& sent) noexcept override {
                if (received.Truncated()) {
                    return false;
                }

                Y_VERIFY(received.Stats().Seqno == sent.Stats().Seqno);

                if (received.Stats() != sent.Stats()) {
                    return false;
                }

                TPacket::TStats::Copy(received.Stats(), sent.Stats());

                return true;
            }

            TTcpPacket::TRef CreatePacket(ui16 seqno) noexcept override {
                ui16 dstPort = ExtractPort(Config_.TargetAddr);
                ui16 srcPort = ExtractPort(Config_.SourceAddr);

                if (!srcPort) {
                    ui16 range = MaxPort_ - MinPort_;
                    srcPort = MinPort_;

                    if (range) {
                        srcPort += RandomNumber<ui16>(range);
                    }
                }

                auto targetIpAddr = SetPort(*Config_.TargetAddr, 0);
                TTcpPacket::TRef packet(TTcpPacket::Make(
                    Parent_->PacketPool_,
                    Parent_->BufferPool_,
                    *targetIpAddr,
                    dstPort,
                    srcPort,
                    Config_.TypeOfService,
                    Config_.TimeToLive
                ));
                packet->Stats().Signature = 0x0;
                packet->Stats().ProbeId = Key_.second;
                packet->Stats().Seqno = seqno;
                packet->Stats().SourceSentTime = TInstant::Now().MicroSeconds();

                packet->Size() = Min(Config_.PacketSize, MAX_PACKET_LENGTH);
                return packet;
            }

        private:
            const TImpl* const Parent_;
            ui16 MinPort_;
            ui16 MaxPort_;
        };

    class TTcpSocket : public TIntrusiveHashItem<TTcpSocket> {
    public:
        using TRef = THolder<TTcpSocket>;

        struct TOps: public TAddrIntrHashOps {
            static inline const NAddr::IRemoteAddr& ExtractKey(const TTcpSocket& obj) noexcept {
                return *obj.IOHandler_.Addr();
            }
        };

        using TMapType = TIntrusiveHashWithAutoDelete<TTcpSocket, TOps>;

        static TRef Make(const TImpl* parent, const NAddr::IRemoteAddrRef& addr, ui16 minPort, ui16 maxPort) noexcept {
            try {
                return MakeHolder<TTcpSocket>(parent, addr, minPort, maxPort);
            } catch(...) {
                static TLimitedLogger logger(parent->Logger_);
                logger << TLOG_WARNING << CurrentExceptionMessage() << Endl;
            }
            return nullptr;
        }

        TTcpSocket(const TImpl* parent, const NAddr::IRemoteAddrRef& addr, ui16 minPort, ui16 maxPort)
            : Parent_(parent)
            , IOHandler_(Parent_->Logger_, addr, minPort, maxPort)
            , Cont_(nullptr)
            , Loop_(Parent_->Executor_, this, "tcp_socket")
        {
            Loop_.Start();
        }

        TBaseProbeGenerator<TTcpIOHandler>* Attach(const TProbeConfig& config) {
            static TLimitedLogger logger(Parent_->Logger_);

            for (const auto probeId : xrange(TTcpIOHandler::MaxProbeId())) {
                TProbeGenerator::TRef generator(
                    TProbeGenerator::Make(Parent_->GeneratorPool_, Parent_, config, logger, Parent_->Executor_, IOHandler_, probeId)
                );
                if (!Generators_.Has(generator->Key())) {
                    generator->Generator().Subscribe([this] (TBaseProbeGenerator<TTcpIOHandler>* generator) noexcept {
                        Generators_.Erase(TProbeGenerator::TOps::ExtractKey(*generator));
                    });
                    Generators_.Push(generator.Get());
                    generator->Generator().Start();
                    return generator.Release();
                }
            }

            // only one generator with given address and probe id can exists
            return nullptr;
        }

    private:
        inline void Run(TCont* cont) noexcept {
            Cont_ = cont;

            while (!Cont_->Cancelled()) {
                try {
                    OneShot();
                } catch(...) {
                    static TLimitedLogger logger(Parent_->Logger_);
                    logger << TLOG_WARNING << CurrentExceptionMessage() << Endl;
                }
            }

            // wait for running generators
            WaitForGenerators(Generators_, Cont_);

            Cont_ = nullptr;
        }

        inline void OneShot() {
            while (!Cont_->Cancelled()) {
                TTcpPacket::TRef packet(TTcpPacket::Make(Parent_->PacketPool_, Parent_->BufferPool_));

                ui16 opFilter = 0;

                auto ioStatus(IOHandler_.Read(*packet));
                if (ioStatus) {
                    packet->Size() = ioStatus->Checked();
                    if (!packet->Size()) {
                        ythrow yexception() << TStringBuf("no data was read");
                    } else {
                        if (packet->Type() == TTcpPacket::EType::Request && !packet->FromErrorQueue()) {
                            WriteQueue_.PushBack(packet.Release());
                        } else {
                            DispatchPacket(*packet);
                        }
                    }
                } else {
                    opFilter |= CONT_POLL_READ;
                }

                while (!WriteQueue_.Empty()) {
                    TTcpPacket::TRef packet(WriteQueue_.PopFront());

                    if (!ReplyPacket(*packet)) {
                        WriteQueue_.PushFront(packet.Release());
                        opFilter |= CONT_POLL_WRITE;
                        break;
                    }
                }

                if (opFilter) {
                    IOHandler_.PollT(Cont_, opFilter, TDuration::Max());
                } else {
                    Cont_->SleepT(TDuration::Zero());
                }
            }
        }

        inline bool ReplyPacket(TTcpPacket& packet) {
            packet.Type() = TTcpPacket::EType::Reply;

            std::swap(packet.SrcPort(), packet.DstPort());

            packet.Stats().TargetReceivedTime = packet.Stats().SourceReceivedTime;
            packet.Stats().TargetSentTime = TInstant::Now().MicroSeconds();

            SetPort(packet.ClientAddrData(), 0);

            auto ioStatus(IOHandler_.Write(packet));
            if (ioStatus) {
                if (ioStatus->Checked() != packet.Size()) {
                    return true;
                }

                TUnistat::Instance().PushSignalUnsafe(ESignals::PacketSent, 1);
            } else {
                return false;
            }

            return true;
        }

        inline void DispatchPacket(TTcpPacket& packet) {
            auto client_addr = MakeAtomicShared<NAddr::TOpaqueAddr>(packet.ClientAddrData());

            if (packet.FromErrorQueue()) {
                // Linux reported us remote client addr along with the error,
                // but the packet contents (and port values) remained intact.
                // Let's override src port with remote host designated one to
                // find corresponding generator.

                SetPort(client_addr->MutableAddr(), packet.DstPort());
            }


            const TProbeGenerator::TKey key{
                client_addr,
                packet.Stats().ProbeId
            };

            auto it(Generators_.Find(key));
            if (it != Generators_.End()) {
                it->Generator().OnPacketAvailable(packet);
            } else {
                PushSignal(EPushSignals::PacketMartian, 1);
                ythrow yexception() << TStringBuf("Unknown packet from ") << NAddr::PrintHostAndPort(*key.first)
                    << TStringBuf(" received by ") << NAddr::PrintHostAndPort(*IOHandler_.Addr())
                    << TStringBuf(", probe id ") << (ui32)key.second;
            }
        }

        const TImpl* const Parent_;
        TTcpIOHandler IOHandler_;
        TTcpPacket::TListType WriteQueue_;
        TCont* Cont_;

        TProbeGenerator::TMapType Generators_;

        TSimpleContLoop<TTcpSocket, &TTcpSocket::Run> Loop_;
    };

    public:
        inline TImpl(TLog& logger, TContExecutor* executor)
            : Logger_(logger)
            , Executor_(executor)
            , PacketPool_(TDefaultAllocator::Instance())
            , GeneratorPool_(TDefaultAllocator::Instance())
            , BufferPool_(TDefaultAllocator::Instance())
        {
        }

        TBaseProbeGenerator<TTcpIOHandler>* Attach(const TProbeConfig& config) {
            if (config.SourceAddr->Addr()->sa_family != config.TargetAddr->Addr()->sa_family) {
                return nullptr;
            }

            auto srcIpAddr = SetPort(*config.SourceAddr, 0);
            auto it(Sockets_.Find(*srcIpAddr));
            return it != Sockets_.End() ? it->Attach(config) : nullptr;
        }

        void ScheduleChecks(TCont* cont, const TVector<TProbeConfig>& configs, TVector<TProbeReport>& reports) {
            TProbeCollector<TImpl> collector(this, cont, configs);
            collector.Dump(reports);
        }

        void SyncAddresses(const TVector<TNetworkAddress>& addresses) {
            TOpaqueAddrSet target;
            ui16 minPort = 0xffff;
            ui16 maxPort = 0x0;

            for (auto& addr : addresses) {
                for (auto it(addr.Begin()); it != addr.End(); ++it) {
                    NAddr::TOpaqueAddr copy(it->ai_addr);
                    ui16 port = ExtractPort(copy.Addr());

                    minPort = Min(minPort, port);
                    maxPort = Max(maxPort, port);

                    SetPort(copy.MutableAddr(), 0);
                    target.emplace(&copy);
                }
            }

            TVector<NAddr::IRemoteAddrRef> keysToDelete;
            for (auto& sock : Sockets_) {
                const auto& key(TTcpSocket::TOps::ExtractKey(sock));
                auto it(target.find(key));
                if (it.IsEnd()) {
                    keysToDelete.emplace_back(CopyRemoteAddr(key));
                } else {
                    target.erase(it);
                }
            }
            for (const auto& key : keysToDelete) {
                Sockets_.Erase(*key);
            }

            bool failed = false;
            for (auto& addr : target) {
                try {
                    TTcpSocket::TRef sk(MakeHolder<TTcpSocket>(this, CopyRemoteAddr(addr), minPort, maxPort));
                    Sockets_.Push(sk.Release());
                } catch(...) {
                    static TLimitedLogger logger(Logger_);
                    logger << TLOG_WARNING << CurrentExceptionMessage() << Endl;
                    failed = true;
                }
            }

            if (failed) {
                ythrow yexception() << "can't bind to all given addresses";
            }
        }

        inline EProtocol Protocol() const noexcept {
            return EProtocol::TCP;
        }

    private:
        TLog& Logger_;
        TContExecutor* Executor_;

        mutable TTcpPacket::TPool PacketPool_;
        mutable TProbeGenerator::TPool GeneratorPool_;
        mutable TOnDemandBuffer::TPool BufferPool_;

        TTcpSocket::TMapType Sockets_;
    };

    TTcpPollerService::TTcpPollerService(TLog& logger, TContExecutor* executor)
        : Impl(MakeHolder<TImpl>(logger, executor))
    {
    }

    TTcpPollerService::~TTcpPollerService()
    {
    }

    void TTcpPollerService::Stop() noexcept {
        Y_VERIFY(Impl);
        Impl.Destroy();
    }

    void TTcpPollerService::SyncAddresses(const TVector<TNetworkAddress>& addresses) {
        Y_VERIFY(Impl);
        Impl->SyncAddresses(addresses);
    }

    void TTcpPollerService::ScheduleChecks(TCont* cont, const TVector<TProbeConfig>& configs, TVector<TProbeReport>& reports) {
        Y_VERIFY(Impl);
        Impl->ScheduleChecks(cont, configs, reports);
    }
}
