#include "limiter.h"

#include <library/cpp/cache/cache.h>

#include <util/digest/numeric.h>
#include <util/generic/hash.h>
#include <util/generic/string.h>
#include <util/stream/format.h>
#include <util/string/builder.h>

////////////////////////////////////////////////////////////////////////////////

namespace NUnbound {

namespace {

ui64 IpHash(const NAddr::IRemoteAddr& addr) noexcept {
    switch (addr.Addr()->sa_family) {
        case AF_INET: {
            const sockaddr_in* ipAddr = (const sockaddr_in*)addr.Addr();
            return CityHash64((const char*)&ipAddr->sin_addr.s_addr, sizeof(ipAddr->sin_addr.s_addr));
        }
        case AF_INET6: {
            const sockaddr_in6* ipAddr6 = (const sockaddr_in6*)addr.Addr();
            return CityHash64((const char*)&ipAddr6->sin6_addr.s6_addr, sizeof(ipAddr6->sin6_addr.s6_addr));
        }
        default:
            return ComputeHash(PrintHost(addr));
    }
}

} // anonymous namespace

TProjectId::TProjectId(ui32 id)
    : Id_(id)
{
}

ui32 TProjectId::Id() const noexcept {
    return Id_;
}

TString TProjectId::ToString() const {
    TString projectIdHex = TStringBuilder() << Hex(Id_, 0);
    projectIdHex.to_lower();
    return projectIdHex;
}

IOutputStream& operator<<(IOutputStream& o, const TProjectId& projectId) {
    o << projectId.ToString();
    return o;
}

} // namespace NUnbound

template <>
struct THash<NUnbound::TProjectId> {
    inline size_t operator()(const NUnbound::TProjectId& key) const noexcept {
        return static_cast<size_t>(key.Id());
    }
};

////////////////////////////////////////////////////////////////////////////////

namespace NUnbound {
namespace {

struct TCacheKey {
    TProjectId ProjectId;
    NAddr::IRemoteAddrRef NameserverAddress;
    size_t Hash;

    TCacheKey()
        : TCacheKey(TProjectId(0), MakeAtomicShared<NAddr::TOpaqueAddr>())
    {
    }

    TCacheKey(TProjectId projectId, NAddr::IRemoteAddrRef nameserverAddress)
        : ProjectId(std::move(projectId))
        , NameserverAddress(nameserverAddress)
        , Hash(CombineHashes(
            THash<TProjectId>()(ProjectId),
            IpHash(*NameserverAddress)))
    {
    }

    bool operator==(const TCacheKey& rhs) const {
        return Hash == rhs.Hash &&
               ProjectId == rhs.ProjectId &&
               NAddr::IsSame(*NameserverAddress, *rhs.NameserverAddress);
    }

    void Save(IOutputStream* s) const {
        ::Save(s, ProjectId);
        ::Save(s, NameserverAddress->Len());
        ::SaveArray(s, reinterpret_cast<const ui8*>(NameserverAddress->Addr()), static_cast<size_t>(NameserverAddress->Len()));
        ::Save(s, Hash);
    }

    void Load(IInputStream* s) {
        ::Load(s, ProjectId);
        socklen_t addrLen;
        ::Load(s, addrLen);
        sockaddr_storage addr;
        ::LoadArray(s, reinterpret_cast<ui8*>(&addr), static_cast<size_t>(addrLen));
        NameserverAddress = MakeAtomicShared<NAddr::TOpaqueAddr>(reinterpret_cast<sockaddr*>(&addr));
        ::Load(s, Hash);
    }
};

struct TCacheValue {
    TString NameserverName;
    ui64 Requests = 0;
    TInstant LastTimestamp;

    TCacheValue() = default;

    TCacheValue(TString nameserverName, ui64 requests, TInstant timestamp)
        : NameserverName(std::move(nameserverName))
        , Requests(requests)
        , LastTimestamp(timestamp)
    {
    }

    TCacheValue(TString nameserverName, ui64 requests, time_t timestamp)
        : TCacheValue(
            std::move(nameserverName),
            requests,
            TInstant::Seconds(static_cast<ui64>(timestamp)))
    {
    }

    Y_SAVELOAD_DEFINE(NameserverName, Requests, LastTimestamp);
};

using TCache = TLRUCache<TCacheKey, TCacheValue>;

} // namespace
} // anonymous namespace

template <>
struct THash<NUnbound::TCacheKey> {
    inline size_t operator()(const NUnbound::TCacheKey& key) const noexcept {
        return key.Hash;
    }
};

////////////////////////////////////////////////////////////////////////////////

namespace NUnbound {

////////////////////////////////////////////////////////////////////////////////

class TProjectOutboundLimiter::TImpl {
public:
    TImpl(TProjectOutboundLimiterConfig config)
        : Config_(std::move(config))
        , Cache_(Config_.MaxCacheSize)
    {
    }

    TProjectOutboundLimiter::EStatus OnRequest(
        NAddr::IRemoteAddrPtr clientSubnetAddress,
        NAddr::IRemoteAddrRef nameserverAddress,
        TString nameserverName,
        time_t now
    ) {
        if (!clientSubnetAddress) {
            return TProjectOutboundLimiter::OK;
        }

        TMaybe<TProjectId> projectId = TProjectId::FromClientSubnetAddress(*clientSubnetAddress);
        if (!projectId.Defined()) {
            return TProjectOutboundLimiter::OK;
        }

        TCacheKey key{std::move(*projectId), nameserverAddress};
        TCacheValue value{std::move(nameserverName), 1, now};
        if (!Cache_.Insert(key, value)) {
            TCache::TIterator iterator = Cache_.Find(key);
            ++iterator->Requests;
            iterator->LastTimestamp = TInstant::Seconds(static_cast<ui64>(now));
        }

        return TProjectOutboundLimiter::OK;
    }

    void AddStats(
        TProjectOutboundStats& stats
    ) {
        for (auto it = Cache_.Begin(); it != Cache_.End(); ++it) {
            stats.Add(
                it.Key().ProjectId,
                it.Key().NameserverAddress,
                it.Value().NameserverName,
                it.Value().Requests,
                it.Value().LastTimestamp
            );
        }
    }

private:
    TProjectOutboundLimiterConfig Config_;
    TCache Cache_;
};

////////////////////////////////////////////////////////////////////////////////

TProjectOutboundLimiter::TProjectOutboundLimiter(TProjectOutboundLimiterConfig config)
    : Impl_(MakeHolder<TImpl>(std::move(config)))
{
}

TProjectOutboundLimiter::~TProjectOutboundLimiter() = default;

TProjectOutboundLimiter::EStatus TProjectOutboundLimiter::OnRequest(
    NAddr::IRemoteAddrPtr clientSubnetAddress,
    NAddr::IRemoteAddrRef nameserverAddress,
    TString nameserverName,
    time_t now
) {
    return Impl_->OnRequest(
        std::move(clientSubnetAddress),
        nameserverAddress,
        std::move(nameserverName),
        now);
}

void TProjectOutboundLimiter::AddStats(
    TProjectOutboundStats& stats
) {
    Impl_->AddStats(stats);
}

////////////////////////////////////////////////////////////////////////////////

class TProjectOutboundStats::TImpl {
public:
    TImpl() = default;

    void Add(
        TProjectId projectId,
        NAddr::IRemoteAddrRef nameserverAddress,
        TString nameserverName,
        ui64 requestsNumber,
        TInstant lastTimestamp
    ) {
        TCacheKey key{std::move(projectId), nameserverAddress};
        if (auto [it, inserted] = Stats_.try_emplace(
                std::move(key),
                std::move(nameserverName), 
                requestsNumber,
                lastTimestamp);
            !inserted)
        {
            it->second.Requests += requestsNumber;
            if (it->second.LastTimestamp > lastTimestamp) {
                it->second.LastTimestamp = lastTimestamp;
            }
        }
        if (LowestTimestamp_ > lastTimestamp) {
            LowestTimestamp_ = lastTimestamp;
        }
    }

    void Add(const TImpl& other) { 
        for (const auto& [key, value] : other.Stats_) {
            if (auto [it, inserted] = Stats_.try_emplace(key, value);
                !inserted)
            {
                it->second.Requests += value.Requests;
                if (it->second.LastTimestamp > value.LastTimestamp) {
                    it->second.LastTimestamp = value.LastTimestamp;
                }
            }
        }

        if (LowestTimestamp_ > other.LowestTimestamp_) {
            LowestTimestamp_ = other.LowestTimestamp_;
        }
    }

    bool Iterate(TProjectOutboundStats::TIterateCallbackFunc&& func) const {
        for (const auto& [key, value] : Stats_) {
            if (!func(key.ProjectId,
                      key.NameserverAddress,
                      value.NameserverName,
                      value.Requests))
            {
                return false;
                break;
            }
        }
        return true;
    }

    TInstant LowestTimestamp() const noexcept {
        return LowestTimestamp_;
    }

    size_t EntriesNumber() const noexcept {
        return Stats_.size();
    }

    Y_SAVELOAD_DEFINE(Stats_, LowestTimestamp_);

private:
    THashMap<TCacheKey, TCacheValue> Stats_;
    TInstant LowestTimestamp_ = TInstant::Max();
};

////////////////////////////////////////////////////////////////////////////////

TProjectOutboundStats::TProjectOutboundStats()
    : Impl_(MakeHolder<TImpl>())
{
}

TProjectOutboundStats::~TProjectOutboundStats() = default;

void TProjectOutboundStats::Add(
    TProjectId projectId,
    NAddr::IRemoteAddrRef nameserverAddress,
    TString nameserverName,
    ui64 requestsNumber,
    TInstant lastTimestamp
) {
    Impl_->Add(
        std::move(projectId),
        nameserverAddress,
        std::move(nameserverName),
        requestsNumber,
        lastTimestamp);
}

void TProjectOutboundStats::Add(
    const TProjectOutboundStats& other
) {
    Impl_->Add(*other.Impl_);
}

bool TProjectOutboundStats::Iterate(TProjectOutboundStats::TIterateCallbackFunc&& func) const {
    return Impl_->Iterate(std::move(func));
}

TInstant TProjectOutboundStats::LowestTimestamp() const noexcept {
    return Impl_->LowestTimestamp();
}

size_t TProjectOutboundStats::EntriesNumber() const noexcept {
    return Impl_->EntriesNumber();
}

void TProjectOutboundStats::Save(IOutputStream* s) const {
    ::SaveMany(s, *Impl_);
}

void TProjectOutboundStats::Load(IInputStream* s) {
    ::LoadMany(s, *Impl_);
}

////////////////////////////////////////////////////////////////////////////////

} // namespace NUnbound
