#pragma once

#include "addr.h"
#include "bpf.h"
#include "ephemeral.h"
#include "sockops.h"

#include <library/cpp/coroutine/engine/impl.h>
#include <library/cpp/coroutine/engine/network.h>
#include <library/cpp/coroutine/listener/listen.h>

#include <library/cpp/digest/murmur/murmur.h>

#include <kernel/p0f/p0f.h>

#include <util/network/address.h>
#include <util/network/socket.h>

#include <util/digest/numeric.h>

#include <util/generic/hash_set.h>
#include <util/generic/ptr.h>
#include <util/generic/strbuf.h>
#include <util/generic/vector.h>
#include <util/generic/yexception.h>
#include <util/random/random.h>
#include <util/string/builder.h>

#include <cstring>
#include <utility>


namespace NSrvKernel {

    namespace {
        constexpr uint64_t P0F_TIMEOUT_SECS = 600;
    }

    class TOwnListener {
    public:
        enum class EBindMode {
            Normal,
            ZeroPort,
            Skip
        };

        struct TOptions {
            TOptions() = default;

            TOptions& SetListenQueue(size_t len) noexcept {
                ListenQueue = len;
                return *this;
            }

            TOptions& SetSockBufSize(TSockBufSize size) noexcept {
                SockBufSize = size;
                return *this;
            }

            TOptions& SetBindMode(EBindMode mode) noexcept {
                BindMode = mode;
                return *this;
            }

            TOptions& SetIgnoreBindErrors(const THashSet<TIpAddr>& ignoreBindErrors) noexcept {
                IgnoreBindErrors = ignoreBindErrors;
                return *this;
            }

            TOptions& SetP0fEnabled(bool p0fEnabled) noexcept {
                P0fEnabled = p0fEnabled;
                return *this;
            }

        public:
            // TODO(velavokr): ???
            size_t ListenQueue = Max<size_t>();
            TSockBufSize SockBufSize;
            THashSet<TIpAddr> IgnoreBindErrors;
            EBindMode BindMode = EBindMode::Normal;
            bool P0fEnabled = false;
        };

    private:
        class TSocketListener {
        public:
            TSocketListener(THolder<TEphemeralBoundSocket>&& ephemeralBoundSocket, TContListener::ICallBack* c, TOwnListener* parent)
                : Options_(TOwnListener::TOptions().SetListenQueue(64))
                , ListenSocket_(ephemeralBoundSocket->Socket.Release())
                , Cb_(c)
                , Addr_(std::move(ephemeralBoundSocket->Address))
                , Parent_(parent)
            {}

            TSocketListener(NAddr::IRemoteAddrRef addr, TContListener::ICallBack* c, const TOwnListener::TOptions& options, TOwnListener* parent)
                : Options_(options)
                , Cb_(c)
                , Addr_(addr)
                , Parent_(parent)
            {
                CreateSocket(ListenSocket_);

                if (bind(ListenSocket_, Addr_->Addr(), Addr_->Len()) != 0) {
                    ythrow TSystemError(LastSystemError()) << "bind failed for " << *Addr_;
                }
            }

            TSocketListener(const TSocketListener& rhs) = delete;
            TSocketListener& operator=(const TSocketListener& rhs) = delete;

            TSocketListener(TSocketListener&& rhs) noexcept
                : Options_(rhs.Options_)
                , Cont_(rhs.Cont_)
                , P0fCleanUpCont_(rhs.P0fCleanUpCont_)
                , Cb_(rhs.Cb_)
                , Addr_(std::move(rhs.Addr_))
                , Parent_(rhs.Parent_)
            {
                ListenSocket_.Swap(rhs.ListenSocket_);
                rhs.Cont_ = nullptr;
                rhs.P0fCleanUpCont_ = nullptr;
                rhs.Cb_ = nullptr;
            }

            TSocketListener& operator=(TSocketListener&& rhs) noexcept {
                if (this != &rhs) {
                    std::swap(Options_, rhs.Options_);
                    ListenSocket_.Swap(rhs.ListenSocket_);
                    std::swap(Cont_, rhs.Cont_);
                    std::swap(P0fCleanUpCont_, rhs.P0fCleanUpCont_);
                    std::swap(Cb_, rhs.Cb_);
                    Addr_.Swap(rhs.Addr_);
                    std::swap(Parent_, rhs.Parent_);
                }

                return *this;
            }

            ~TSocketListener() {
                Cancel();  // Because if executor is aborted, Join() will do nothing.
                Join();
            }

            void Loop(TCont*) noexcept {
                DoLoop();
                Cont_ = nullptr;
            }

            void P0fCleanUpLoop(TCont*) noexcept {
                DoP0fCleanUpLoop();
                P0fCleanUpCont_ = nullptr;
            }

            void StartAccept(TContExecutor& executor) {
                Y_VERIFY(!Cont_);

                if (listen(ListenSocket_, (int)Min<size_t>(Max<int>(), Options_.ListenQueue)) != 0) {
                    ythrow TSystemError(LastSystemError()) << "listen failed for " << *Addr_;
                }

                Cont_ = executor.Create<TSocketListener, &TSocketListener::Loop>(this, "listener");
#ifdef _linux_
                if (Options_.P0fEnabled && !P0fCleanUpCont_) {
                    P0fCleanUpCont_ = executor.Create<TSocketListener, &TSocketListener::P0fCleanUpLoop>(this, "p0fcleanup");
                }
#endif
            }

            void Join() noexcept {
                if (Cont_) {
                    Cont_->Executor()->Running()->Join(Cont_, TInstant());
                }

                if (P0fCleanUpCont_) {
                    P0fCleanUpCont_->Executor()->Running()->Join(P0fCleanUpCont_, TInstant());
                }
            }

            void Cancel() noexcept {
                if (Cont_) {
                    Cont_->Cancel();
                }

                if (P0fCleanUpCont_) {
                    P0fCleanUpCont_->Cancel();
                }
            }

            void CancelAndClose() noexcept {
                Cancel();
                ListenSocket_.Close();
            }

            TError CloseUsingBPF() noexcept {
                return CloseSocketUsingBPF(SOCKET(ListenSocket_));
            }

        private:
            class TNonRecoverableError : public TSystemError {
            public:
                TNonRecoverableError(int err)
                    : TSystemError(err)
                {}
            };

            void DoLoop() noexcept {
                try {
                    SetNonBlock(ListenSocket_);
                } catch (...) {
                    return;
                }

                AttachP0fIfNeed();

                bool registred = false;

                while (!Cont_->Cancelled()) {
                    try {
                        if (!registred) {
                            Parent_->RegisterAcceptStart();
                            registred = true;
                        }
                        NAddr::TOpaqueAddr remote;
                        const int res = NCoro::AcceptI(Cont_, ListenSocket_, remote.MutableAddr(), remote.LenPtr());

                        if (res < 0) {
                            const int err = -res;

                            if (err != ECONNABORTED) {
                                if (err == ECANCELED) {
                                    break;
                                }
                                if (err == EMFILE || err == ENFILE) {
                                    if (ECANCELED == Cont_->SleepT(TDuration::MilliSeconds(1))) {
                                        break;
                                    }
                                } else {
                                    ythrow TNonRecoverableError(err) << "error in accept loop that might have lead to an infinite loop";
                                }

                                ythrow TSystemError(err) << "can not accept";
                            }
                        } else {
                            TSocketHolder c((SOCKET)res);

                            const ::TContListener::ICallBack::TAcceptFull acc = {
                                &c,
                                &remote,
                                Addr_.Get()
                            };

                            Cb_->OnAcceptFull(acc);
                        }
                    } catch (TNonRecoverableError& err) {
                        try {
                            Cb_->OnError();
                        } catch (...) {
                        }
                        break;
                    } catch (...) {
                        try {
                            Cb_->OnError();
                        } catch (...) {
                        }
                    }
                }

                if (registred) {
                    Parent_->RegisterAcceptStop();
                }

                try {
                    Cb_->OnStop(&ListenSocket_);
                } catch (...) {
                }
            }

            void DoP0fCleanUpLoop() noexcept {
#ifdef _linux_
                auto delay = TDuration::Seconds(RandomNumber<ui64>() % P0F_TIMEOUT_SECS).ToDeadLine();
                do {
                    P0fCleanUpCont_->SleepD(delay);
                    NP0f::CleanupMapCoroLock();
                    delay = TDuration::Seconds(P0F_TIMEOUT_SECS).ToDeadLine();
                } while (!P0fCleanUpCont_->Cancelled());
#endif
            }

            void CreateSocket(TSocketHolder& target) const {
                if (!Addr_) {
                    ythrow yexception() << "do not know an address to create socket";
                }

                int sock = socket(Addr_->Addr()->sa_family, SOCK_STREAM, 0);

                if (sock < 0) {
                    ythrow TSystemError(LastSystemError()) << "socket failed";
                }

                TSocketHolder retval(sock);
                FixIPv6ListenSocket(retval);
                CheckedSetSockOpt(retval, SOL_SOCKET, SO_REUSEADDR, 1, "reuse addr");
                SetReusePort(retval, true);
                FillSocketOptions(retval, Options_);
                target.Swap(retval);
            }

            void AttachP0fIfNeed() {
#ifdef _linux_
                if (Options_.P0fEnabled) {
                    NP0f::Attach((SOCKET)ListenSocket_);
                }
#endif
            }

            static void FillSocketOptions(TSocketHolder& sock, const TOwnListener::TOptions& options) {
                SetNonBlock(sock);
                TryRethrowError(SetSockBufSize(sock, options.SockBufSize));
            }

        private:
            TOwnListener::TOptions Options_;
            TSocketHolder ListenSocket_;
            TCont* Cont_ = nullptr;
            TCont* P0fCleanUpCont_ = nullptr;
            TContListener::ICallBack* Cb_ = nullptr;
            NAddr::IRemoteAddrRef Addr_;
            TOwnListener* Parent_ = nullptr;
        };

    public:
        TOwnListener(const TVector<TNetworkAddress>& addresses, TContListener::ICallBack* cb, const TOwnListener::TOptions& options) {
            if (options.BindMode == EBindMode::Skip) {
                return;
            }

            for (const auto& addr: addresses) {
                THashSet<TString> seenAddresses;

                for (auto info = addr.Begin(); info != addr.End(); ++info) {
                    auto a = MakeHolder<TSockAddr>(TSockAddr::FromSockAddr(*info->ai_addr).GetOrThrow());

                    if (options.BindMode == EBindMode::ZeroPort) {
                        a->SetPort(0);
                    }

                    if (!seenAddresses.emplace(a->ToString()).second) {
                        continue;
                    }

                    const auto ip = a->Ip();

                    try {
                        Listeners_.push_back(TSocketListener(a.Release(), cb, options, this));
                    } catch (...) {
                        if (options.IgnoreBindErrors.contains(ip)) {
                            Cerr << "IGNORED failed bind " << CurrentExceptionMessage() << Endl;
                        } else {
                            throw;
                        }
                    }
                }
            }
        }

        TOwnListener(THolder<TEphemeralBoundSocket>&& ephemeralBoundSocket, TContListener::ICallBack* cb) {
            Y_ENSURE(ephemeralBoundSocket && cb, "bad parameters passed to TOwnListener");
            Y_ENSURE(ephemeralBoundSocket->Socket != INVALID_SOCKET, "Creating listener with INVALID_SOCKET");
            Listeners_.push_back(TSocketListener(std::move(ephemeralBoundSocket), cb, this));
        }

        void Start(TContExecutor& executor) {
            Y_ASSERT(!IsStarted_);
            for (auto& listener: Listeners_) {
                listener.StartAccept(executor);
            }
        }

        void StartIfNotStarted(TContExecutor& executor) {
            if (!IsStarted_) {
                Start(executor);
                IsStarted_ = true;
            }
        }

        void CancelAndClose() noexcept {
            for (auto& listener : Listeners_) {
                listener.CancelAndClose();
            }
        }

        TError CloseUsingBPF() noexcept {
            for (auto& listener : Listeners_) {
                if (TError error = listener.CloseUsingBPF()) {
                    return error;
                }
            }
            return {};
        }

    private:
        void RegisterAcceptStart() noexcept {
            ++Accepting_;
            IsListening_ = (Listeners_.size() == Accepting_);
        }

        void RegisterAcceptStop() noexcept {
            if (Accepting_) {
                --Accepting_;
            }
            IsListening_ = false;
        }
    private:
        TVector<TSocketListener> Listeners_;
        size_t Accepting_ = 0;
        bool IsListening_ = false;
        bool IsStarted_ = false;
    };
}
