#include "dns_cache.h"

#include <infra/netmon/agent/common/metrics.h>

#include <library/cpp/containers/intrusive_hash/intrhash.h>
#include <library/cpp/containers/intrusive_rb_tree/rb_tree.h>

#include <util/random/random.h>
#include <util/digest/multi.h>
#include <util/memory/smallobj.h>

namespace NNetmon {
    namespace {
        const TDuration MAX_JITTER = TDuration::Minutes(5);
        const TDuration CACHE_TTL = TDuration::Minutes(15);

        struct TComparator {
            template <class T>
            static inline bool Compare(const T& lhs, const T& rhs) {
                return (
                    lhs.Deadline < rhs.Deadline
                    || (
                        lhs.Deadline == rhs.Deadline
                        && &lhs < &rhs
                    )
                );
            }
        };

        struct TEntry : public TRbTreeItem<TEntry, TComparator>,
                        public TIntrusiveHashItem<TEntry>,
                        public TObjectFromPool<TEntry>
        {
            using TKey = std::tuple<TString, int, int>;
            using TRef = THolder<TEntry>;
            using TTree = TRbTree<TEntry, TComparator>;

            struct TOps: public ::TCommonIntrHashOps {
                static inline const TKey& ExtractKey(const TEntry& entry) noexcept {
                    return entry.Key;
                }

                static inline bool EqualTo(const TKey& lhs, const TKey& rhs) noexcept {
                    return lhs == rhs;
                }

                static inline size_t Hash(const TKey& key) noexcept {
                    return MultiHash(std::get<0>(key), std::get<1>(key), std::get<2>(key));
                }
            };

            using TMapType = TIntrusiveHashWithAutoDelete<TEntry, TOps>;

            template <typename... Args>
            static inline TRef Make(TPool& pool, Args&&... args) {
                return TRef(new (&pool) TEntry(std::forward<Args>(args)...));
            }

            inline TEntry(const TString& name, int port, int family)
                : Key(name, port, family)
            {
            }

            const TKey Key;
            TDnsCache::TAddrs Addrs;
            TInstant Deadline;
        };
    }

    class TDnsCache::TImpl {
    public:
        TImpl()
            : Pool_(TDefaultAllocator::Instance())
        {
        }

        inline TMaybe<TAddrs> Get(const TString& name, int port, int family) noexcept {
            TEntry::TRef entry(TEntry::Make(Pool_, name, port, family));

            const auto it(Map_.Find(entry->Key));
            if (it == Map_.End()) {
                PushSignal(EPushSignals::DnsCacheMiss, 1);
                return Nothing();
            } else {
                PushSignal(EPushSignals::DnsCacheHit, 1);
                return it->Addrs;
            }
        }

        inline void SetWithTTL(const TString& name, int port, int family, const TAddrs& addrs,
                               const TDuration& ttl) noexcept {
            TEntry::TRef entry(TEntry::Make(Pool_, name, port, family));

            const auto it(Map_.Find(entry->Key));
            if (it != Map_.End()) {
                return;
            }

            entry->Addrs = addrs;
            entry->Deadline = (TInstant::Now() + ttl);

            Map_.Push(entry.Get());
            Tree_.Insert(entry.Get());
            Y_UNUSED(entry.Release());
        }

        inline void Cleanup() noexcept {
            TInstant now(TInstant::Now());
            while (!Tree_.Empty()) {
                TEntry* entry(&(*Tree_.Begin()));
                if (entry->Deadline > now) {
                    return;
                } else {
                    Map_.Erase(TEntry::TOps::ExtractKey(*entry));
                }
            }
        }

    private:
        TEntry::TPool Pool_;
        TEntry::TMapType Map_;
        TEntry::TTree Tree_;
    };

    TDnsCache::TDnsCache()
        : Impl(MakeHolder<TImpl>())
    {
    }

    TDnsCache::~TDnsCache() {
    }

    TMaybe<TDnsCache::TAddrs> TDnsCache::Get(const TString& name, int port, int family) noexcept {
        Y_VERIFY(Impl);
        return Impl->Get(name, port, family);
    }

    void TDnsCache::SetWithTTL(const TString& name, int port, int family, const TAddrs& addrs,
                               unsigned int ttlSecs) noexcept {
        Y_VERIFY(Impl);
        return Impl->SetWithTTL(name, port, family, addrs, TDuration::Seconds(ttlSecs));
    }

    void TDnsCache::Set(const TString& name, int port, int family, const TAddrs& addrs) noexcept {
        Y_VERIFY(Impl);
        auto ttl = CACHE_TTL +
                   TDuration::Seconds(RandomNumber<ui64>() % MAX_JITTER.Seconds());
        return Impl->SetWithTTL(name, port, family, addrs, ttl);
    }

    void TDnsCache::Cleanup() noexcept {
        Y_VERIFY(Impl);
        return Impl->Cleanup();
    }
}
