#pragma once

#include <balancer/kernel/helpers/errors.h>
#include <library/cpp/config/sax.h>
#include <util/network/address.h>

namespace NSrvKernel {

    // T should have the following methods:
    // 1. operator<
    // 2. operator<=
    // 3. operator=
    // 4. operator+ (an overflow should be well defined: x + 1 < x on overflow)
    template <class T>
    class FixedRangeSet {
    public:
        struct Range {
            T first;
            T last;

            bool operator<(const Range& y) const noexcept {
                return first == y.first ? last < y.last : first < y.first;
            }
        };

        FixedRangeSet() = default;
        FixedRangeSet(const FixedRangeSet&) = default;
        FixedRangeSet(FixedRangeSet&&) = default;
        FixedRangeSet& operator=(const FixedRangeSet&) = default;
        FixedRangeSet& operator=(FixedRangeSet&&) = default;

        FixedRangeSet(const TVector<Range>& ranges) : Ranges_(ranges) {
            Init();
        }

        FixedRangeSet(TVector<Range>&& ranges) : Ranges_(std::move(ranges)) {
            Init();
        }

        [[nodiscard]] bool Contains(const T& value) const noexcept {
            auto it = std::lower_bound(Ranges_.begin(), Ranges_.end(), value,
            [] (const Range& r, const T& x) {
                return r.last < x;
            });
            return it != Ranges_.end() && it->first <= value;
        }

    private:

        void Init() {
            if (!Ranges_.empty()) {
                std::sort(Ranges_.begin(), Ranges_.end());
                MergeOverlappingRanges();
            }
        }

        void MergeOverlappingRanges() {
            size_t write = 0;
            for (size_t read = 1; read < Ranges_.size(); ++read) {
                auto& prev = Ranges_[write];
                auto& cur = Ranges_[read];
                if (cur.last < cur.first) {
                    ythrow NConfig::TConfigParseError() << __FILE__ << ": got incorrect ip range";
                }
                if (cur.first <= std::max(prev.last, prev.last + 1)) {
                    prev.last = std::max(prev.last, cur.last);
                } else {
                    ++write;
                    if (read > write) {
                        std::swap(Ranges_[read], Ranges_[write]);
                    }
                }
            }
            Ranges_.resize(write + 1);
            Ranges_.shrink_to_fit();
        }

        TVector<Range> Ranges_;
    };


    class TIpMatcher  {
    public:
        struct TIpv6Host {
            union {
                struct {
                    ui64 Hi;
                    ui64 Low;
                };
                unsigned char Addr[16];
            };

            TIpv6Host(unsigned char addr[16]) noexcept {
                memcpy(Addr, addr, 16);
            }

            TIpv6Host(ui64 hi, ui64 low) noexcept
                : Hi(hi)
                , Low(low)
            {}

            TIpv6Host() = default;

            TIpv6Host operator&(const TIpv6Host& s) const noexcept {
                return {Hi & s.Hi, Low & s.Low};
            }

            bool operator==(const TIpv6Host& s) const noexcept {
                return Hi == s.Hi && Low == s.Low;
            }

            bool operator<(const TIpv6Host& s) const noexcept {
                return std::make_pair(Hi, Low) < std::make_pair(s.Hi, s.Low);
            }

            bool operator<=(const TIpv6Host& s) const noexcept {
                return std::make_pair(Hi, Low) <= std::make_pair(s.Hi, s.Low);
            }

            TIpv6Host InetToHost() const noexcept {
                return {::InetToHost(Hi), ::InetToHost(Low)};
            }

            TIpv6Host operator+(ui64 x) const noexcept {
                if (Low + x < Low) {
                return {Hi + 1, Low + x};
                } else {
                return {Hi, Low + x};
                }
            }

            TIpv6Host operator-(ui64 x) const noexcept {
                if (Low - x > Low) {
                return {Hi - 1, Low - x};
                } else {
                return {Hi, Low - x};
                }
            }

            TIpv6Host operator+(const TIpv6Host& x) const noexcept {
                if (Low + x.Low < Low) {
                return {Hi + x.Hi + 1, Low + x.Low};
                } else {
                return {Hi + x.Hi, Low + x.Low};
                }
            }

        };

        template <typename T>
        struct TMask {
            template <class T1, class T2>
            TMask(T1&& base, T2&& subnet)
                : Base(std::forward<T1>(base))
                , SubNet(std::forward<T2>(subnet))
            {}

        public:
            T Base;
            T SubNet;
        };

        // Takes comma separated net ranges in two formats:
        // 1) a.b.c.d/m       (ipv4 cidr)
        // 2) a.b.c.d-u.v.r.s (ipv4 range)
        // and similar for ipv6
        TIpMatcher(const TString& sources);

        [[nodiscard]] bool Match(const NAddr::IRemoteAddr& addr) const noexcept;

    private:
        FixedRangeSet<TIpHost> IpSet_v4_;
        FixedRangeSet<TIpv6Host> IpSet_v6_;
    };
}
