#include "ip_matcher.h"

#include <library/cpp/config/sax.h>
#include <util/string/split.h>

using namespace NConfig;
using namespace NSrvKernel;

using Ipv4Range = FixedRangeSet<TIpHost>::Range;
using Ipv6Range = FixedRangeSet<TIpMatcher::TIpv6Host>::Range;

namespace {

    TIpMatcher::TIpv6Host Ipv6FromString(const char* str) {
        struct in6_addr addr;

        if (inet_pton(AF_INET6, str, &addr) != 1) {
            ythrow TSystemError() << "Failed to convert (" <<  str << ") to ipv6 address";
        }

        return TIpMatcher::TIpv6Host(addr.s6_addr);
    }

    void parseCidr(const TVector<TString>& cidrTokens,
                   TVector<Ipv4Range>& ipv4Ranges,
                   TVector<Ipv6Range>& ipv6Ranges) {
        if (cidrTokens.size() == 1) {
            if (cidrTokens[0].find(':') == TString::npos) {
                auto firstIp = InetToHost(IpFromString(cidrTokens[0].c_str()));
                ipv4Ranges.emplace_back(Ipv4Range{firstIp, firstIp});
            } else {
                auto firstIp = Ipv6FromString(cidrTokens[0].c_str()).InetToHost();
                ipv6Ranges.emplace_back(Ipv6Range{firstIp, firstIp});
            }
        } else if (cidrTokens.size() == 2) {
            if (cidrTokens[0].find(':') == TString::npos) {
                const unsigned maskBits = FromString<unsigned>(cidrTokens[1]);

                if (maskBits > 32) {
                    ythrow TConfigParseError() << "incorrect mask value " << maskBits;
                }

                auto firstIp = InetToHost(IpFromString(cidrTokens[0].c_str()));
                auto maxHost = (maskBits ? TIpHost(1) << (32 - maskBits) : TIpHost(0)) - 1;
                auto lastIp = firstIp + maxHost;
                ipv4Ranges.emplace_back(Ipv4Range{firstIp, lastIp});
            } else {
                const unsigned maskBits = FromString<unsigned>(cidrTokens[1]);

                if (maskBits > 128) {
                    ythrow TConfigParseError() << "incorrect mask value " << maskBits;
                }

                auto firstIp = Ipv6FromString(cidrTokens[0].c_str()).InetToHost();
                if (!maskBits) { // We need to avoid overflow: left shift by big value is UB
                    ipv6Ranges.emplace_back(Ipv6Range{firstIp, firstIp - 1});
                } else {
                    auto maxHostHi = maskBits <= 64 ? ui64(1) << (64 - maskBits) : 0;
                    auto maxHostLow = maskBits <= 64 ? 0 : ui64(1) << (128 - maskBits);
                    auto maxHost = TIpMatcher::TIpv6Host(maxHostHi, maxHostLow) - 1;
                    auto lastIp = firstIp + maxHost;
                    ipv6Ranges.emplace_back(Ipv6Range{firstIp, lastIp});
                }
            }
        } else {
            ythrow TConfigParseError() << "incorrect masks list";
        }
    }

    void parseRange(const TVector<TString>& rangeTokens,
                    TVector<Ipv4Range>& ipv4Ranges,
                    TVector<Ipv6Range>& ipv6Ranges) {
        if (rangeTokens.size() == 2) {
            if (rangeTokens[0].find(':') == TString::npos) {
                auto firstIp = InetToHost(IpFromString(rangeTokens[0].c_str()));
                auto lastIp = InetToHost(IpFromString(rangeTokens[1].c_str()));
                Y_ENSURE(firstIp <= lastIp);
                ipv4Ranges.emplace_back(Ipv4Range{firstIp, lastIp});
            } else {
                auto firstIp = Ipv6FromString(rangeTokens[0].c_str()).InetToHost();
                auto lastIp = Ipv6FromString(rangeTokens[1].c_str()).InetToHost();
                Y_ENSURE(firstIp <= lastIp);
                ipv6Ranges.emplace_back(Ipv6Range{firstIp, lastIp});
            }
        } else {
            ythrow TConfigParseError() << "incorrect masks list";
        }
    }

} // namespace

TIpMatcher::TIpMatcher(const TString& sources) {
    TVector<Ipv4Range> ipv4Ranges;
    TVector<Ipv6Range> ipv6Ranges;
    for (const auto& net : StringSplitter(sources).Split(',').SkipEmpty()) {
        if (net.Token().find('-') == TString::npos) {
            TVector<TString> cidrTokens;
            StringSplitter(net.Token()).Split('/').SkipEmpty().Collect(&cidrTokens);
            parseCidr(cidrTokens, ipv4Ranges, ipv6Ranges);
        } else {
            TVector<TString> rangeTokens;
            StringSplitter(net.Token()).Split('-').SkipEmpty().Collect(&rangeTokens);
            parseRange(rangeTokens, ipv4Ranges, ipv6Ranges);
        }
    }
    IpSet_v4_ = FixedRangeSet<TIpHost>(std::move(ipv4Ranges));
    IpSet_v6_ = FixedRangeSet<TIpv6Host>(std::move(ipv6Ranges));
}

bool TIpMatcher::Match(const NAddr::IRemoteAddr& addr) const noexcept {
    const sockaddr* const sa = addr.Addr();

    switch (sa->sa_family) {
    case AF_INET: {
        const TIpHost ip = const_cast<sockaddr_in*>(
            reinterpret_cast<const sockaddr_in*>(sa))->sin_addr.s_addr;

        return IpSet_v4_.Contains(InetToHost(ip));
    }
    case AF_INET6: {
        const TIpv6Host ip = const_cast<sockaddr_in6*>(
            reinterpret_cast<const sockaddr_in6*>(sa))->sin6_addr.s6_addr;

        return IpSet_v6_.Contains(ip.InetToHost());
    }
    }

    return false;
}
