#pragma once

#include "ipacl_utils.h"

#include <passport/infra/libs/cpp/utils/ipaddr.h>
#include <passport/infra/libs/cpp/utils/log/global.h>

#include <util/string/cast.h>

#include <map>
#include <unordered_map>

namespace NPassport::NBb {
    //
    // Container, associating IP or IP range with a peer, NOT thread-safe!
    //
    template <typename T>
    class TIpAclMap {
    public:
        TIpAclMap() {
            TrypoNets.reserve(1000);
            SingleIpMap.reserve(1000);
        }

        ~TIpAclMap() = default;

        enum class EStorage {
            Range,
            Trypo,
            Single
        };

        // Return a Peer or nullptr
        std::shared_ptr<T> Find(const NUtils::TIpAddr& ip, EStorage* storage = nullptr) const {
            auto it = SingleIpMap.find(ip);
            if (it != SingleIpMap.end()) {
                if (storage) {
                    *storage = EStorage::Single;
                }
                return it->second;
            }

            // try to find as trypo net by project id
            std::shared_ptr<T> value = FindValueInTrypo(ip);
            if (value) {
                if (storage) {
                    *storage = EStorage::Trypo;
                }
                return value;
            }

            // if we didn't succeed, look in old style network-based ip range map
            if (storage) {
                *storage = EStorage::Range;
            }
            return FindValueInMap(RangeMap, ip);
        }

        void ParseAndAddEntry(TStringBuf curip, std::shared_ptr<T> value) {
            // Strip off whitespace
            TIpAclUtils::SkipSpace(curip);
            TIpAclUtils::SkipNonSpace(curip);

            if (!curip) {
                return;
            }

            NUtils::TIpAddr addr1;
            NUtils::TIpAddr addr2;

            const TString curipStr(curip);
            TStringBuf::size_type at = curip.find('@');

            if (at != TString::npos) {
                if (at == curip.size() - 1) {
                    throw yexception() << TIpAclUtils::Errmsg(curip);
                }
                TStringBuf projectId = curip.substr(0, at);
                if (ParseIpRange(curipStr.substr(at + 1), &addr1, &addr2)) {
                    AddTrypoNet(projectId, addr1, addr2, value);
                } else {
                    throw yexception() << TIpAclUtils::Errmsg(curip);
                }
            } else if (IsSingleIp(curip)) {
                if (!addr1.Parse(curipStr)) {
                    throw yexception() << TIpAclUtils::Errmsg(curip);
                }
                SingleIpMap.insert({addr1, std::move(value)});
            } else if (ParseIpRange(curipStr, &addr1, &addr2)) {
                AddValueRange(RangeMap, addr1, addr2, value);
            } else {
                throw yexception() << TIpAclUtils::Errmsg(curip);
            }
        }

        using TValuePrintFunc = std::function<TString(const T&)>;

        void Print(const TString& header, const TString& entry, TValuePrintFunc printFunc) const {
            TLog::Info("%s (%lu)", header.c_str(), RangeMap.size());

            // Print out resulting ACL
            PrintMap(RangeMap, entry, printFunc);

            if (!TrypoNets.empty()) {
                TLog::Info("TRYPO Nets (%lu):", TrypoNets.size());

                for (const auto& [projecId, ipRange] : TrypoNets) {
                    TLog::Info("Project Id: 0x%X", projecId);
                    PrintMap(ipRange, entry, printFunc);
                }
            }

            if (!SingleIpMap.empty()) {
                TLog::Info("Standalone ip entities (%lu):", SingleIpMap.size());

                for (const auto& [ip, value] : SingleIpMap) {
                    TStringStream fmt;
                    fmt << entry << " single IP=" << ip;
                    if (printFunc) {
                        fmt << printFunc(*value);
                    }
                    TLog::Info("%s", fmt.Str().c_str());
                }
            }
        }

        time_t GetMTime() const {
            return Mtime_;
        }

        void SetMTime(time_t time) {
            Mtime_ = time;
        }

    protected:
        struct TAclEntry {
            bool RangeStart;
            std::shared_ptr<NUtils::TIpAddr> Size;
            std::shared_ptr<T> Value;

            TAclEntry(bool start,
                      std::shared_ptr<NUtils::TIpAddr> sz,
                      std::shared_ptr<T> p)
                : RangeStart(start)
                , Size(sz)
                , Value(p)
            {
            }

            static TAclEntry MakeStart(std::shared_ptr<NUtils::TIpAddr> size, std::shared_ptr<T> value) {
                return TAclEntry(true, size, value);
            }

            static TAclEntry MakeStart(const TAclEntry& orig) {
                return TAclEntry(true, orig.Size, orig.Value);
            }

            static TAclEntry MakeEnd(std::shared_ptr<NUtils::TIpAddr> size, std::shared_ptr<T> value) {
                return TAclEntry(false, size, value);
            }

            static TAclEntry MakeEnd(const TAclEntry& orig) {
                return TAclEntry(false, orig.Size, orig.Value);
            }
        };

        using TIpRangeMap = std::multimap<NUtils::TIpAddr, TAclEntry>;
        using TProjectIdTable = std::unordered_map<ui32, TIpRangeMap>;
        using TIpPeerMap = std::unordered_map<NUtils::TIpAddr, std::shared_ptr<T>>;

        void PrintMap(const TIpRangeMap& rangeMap, const TString& entry, TValuePrintFunc printFunc) const {
            for (typename TIpRangeMap::const_iterator it = rangeMap.begin(); it != rangeMap.end(); ++it) {
                NUtils::TIpAddr ipstart = it->first;
                ++it;
                NUtils::TIpAddr ipend = it->first;

                TStringStream fmt;
                fmt << entry << " IP=" << ipstart;
                if (ipstart != ipend) {
                    fmt << "-" << ipend;
                }

                if (printFunc) {
                    fmt << printFunc(*it->second.Value);
                }

                TLog::Info("%s", fmt.Str().c_str());
            }
        }

        std::shared_ptr<T> FindValueInMap(const TIpRangeMap& rangeMap, const NUtils::TIpAddr& ip) const {
            typename TIpRangeMap::const_iterator pos = rangeMap.lower_bound(ip);
            if (pos == rangeMap.end()) {
                return std::shared_ptr<T>();
            }
            if (pos->first != ip && pos->second.RangeStart) {
                return std::shared_ptr<T>();
            }
            return pos->second.Value;
        }
        // parse ip or ip range, return true on success
        static bool ParseIpRange(const TString& iprange, NUtils::TIpAddr* addr1, NUtils::TIpAddr* addr2) {
            // Each IP-range is either a base-address/number-of-fixed-high-bits
            // or a base-address/subnet-like-mask. The part starting from the '/'
            // may not be present, designating a single address.
            //
            // Try parse it and convert to the address pair delimiting the range.
            //
            NUtils::TIpAddr tmpAddr;
            TString::size_type slash = iprange.find('/');
            TString::size_type dash = iprange.find('-');

            // take first address
            if (!tmpAddr.Parse(iprange, 0, (slash != TString::npos) ? slash : dash)) {
                return false;
            }

            // Determine mask
            if (slash != TString::npos) {
                if (slash == iprange.size() - 1) {
                    return false;
                }

                char* endp;
                unsigned long numbits = strtol(iprange.c_str() + slash + 1, &endp, 10);
                if (*endp || numbits > (tmpAddr.IsIpv4() ? 32 : 128)) {
                    return false;
                }

                *addr1 = tmpAddr.GetRangeStart(numbits);
                *addr2 = tmpAddr.GetRangeEnd(numbits);
            } else if (dash != TString::npos) {
                if (dash == iprange.size() - 1) {
                    return false;
                }

                *addr1 = tmpAddr;

                if (!tmpAddr.Parse(iprange, dash + 1, TString::npos)) {
                    return false;
                }

                *addr2 = tmpAddr;
            } else {
                *addr2 = *addr1 = tmpAddr;
            }

            return true;
        }

        // For every range we insert two entries into the multimap:
        // (range-low, (true, size, ptr)) and (range-high, (false, size, ptr)) where in each pair ptr
        // is the pointer to the data structure associated with this particular
        // range and the first field indicates whether this entry is the first one in
        // the pair. For a single-IP range we insert two entries with identical
        // keys (i.e. range-low=range-high).
        //
        // Whenever we lookup a would-be insertion point for an IP, we determine
        // whether it falls within or outside of an existing range by testing the
        // boolean value conmtained by the entry -- true designates first entry in
        // a pair.
        void AddValueRange(TIpRangeMap& rangeMap,
                           NUtils::TIpAddr low,
                           NUtils::TIpAddr high,
                           std::shared_ptr<T> value) {
            // Just a sanity check -- shall be assert() normally
            if (high < low) {
                throw yexception() << "Internal error: bad input to IpAcl::addValueRange()";
            }

            std::shared_ptr<NUtils::TIpAddr> rangeSize = std::make_shared<NUtils::TIpAddr>(high);
            *rangeSize -= low;

            // Finds out where begining of this new range belongs in
            // the current map. If it turns out to be the end - simply
            // add the new range there and return.
            typename TIpAclMap::TIpRangeMap::iterator lower = rangeMap.lower_bound(low);
            if (lower == rangeMap.end()) {
                lower = rangeMap.insert(lower, std::make_pair(high, TAclEntry::MakeEnd(rangeSize, value)));
                rangeMap.insert(lower, std::make_pair(low, TAclEntry::MakeStart(rangeSize, value)));
                return;
            }

            // New range interferes with existing ones, so we need to locate
            // the overlapping and then handle it.
            typename TIpAclMap<T>::TIpRangeMap::iterator upper = rangeMap.upper_bound(high);

            // if this range, has 0 or 1 common point with existing bounds
            // i.e. either standalone, fits completely inside another range or has 1 common bound
            if (std::distance(lower, upper) < 2) {
                // This is a standalone range or a range falling within another range, wider range.

                TAclEntry currentEntry = lower->second;
                if (upper == lower) {
                    // Even if this range lies within an outer range, the two ranges
                    // do not share neither top nor bottom.
                    // NB: if the range lies entirely in some other one it is surely more specific
                    if (!currentEntry.RangeStart) {
                        // The entry we found is the second in the pair
                        // so we're entirely within an outer range. Create high
                        // and low sub-ranges of the outer range (the middle
                        // one will be created later). Maintain lower as the
                        // hint for insertion of the middle sub-range.
                        // new lower value for higher part of existing range, keep original range size
                        lower = rangeMap.insert(lower, std::make_pair(high.Next(), TAclEntry::MakeStart(currentEntry)));
                        // new higher value for lower part of existing range
                        rangeMap.insert(lower, std::make_pair(low.Prev(), TAclEntry::MakeEnd(currentEntry)));

                        // do not report that range overriden to mimimize log noise
                        // TLog::Error("%s", rangeMessage(low, high, p->getName(), peer->getName(), true).c_str());
                    }
                    // else it is a standalone range, just added below
                } else {
                    NUtils::TIpAddr oldLow;
                    NUtils::TIpAddr oldHigh;
                    // New range aligns with either top or bottom of an
                    // existing range, so we have to narrow this existing
                    // range to make space for the new range. Maintain lower
                    // as the hint for insertion of the middle sub-range.
                    if (currentEntry.RangeStart) {
                        // aligns with another range start
                        oldLow = lower->first;
                        oldHigh = high;
                        // need to check which range is more specific
                        if (*rangeSize < *currentEntry.Size) {
                            // new entry more specific, overrides part of the existing
                            rangeMap.erase(lower);
                            lower = rangeMap.insert(upper, std::make_pair(high.Next(), TAclEntry::MakeStart(currentEntry)));
                        } else {
                            // existing entry is smaller, need to shrink current one
                            high = oldLow.Prev();
                        }
                    } else {
                        // aligns with another range end
                        oldLow = low;
                        oldHigh = lower->first;
                        // need to check which range is more specific
                        if (*rangeSize < *currentEntry.Size) {
                            // new entry more specific, overrides part of the existing
                            rangeMap.erase(lower);
                            lower = rangeMap.insert(upper, std::make_pair(low.Prev(), TAclEntry::MakeEnd(currentEntry)));
                        } else {
                            // existing entry is smaller, need to shrink current one
                            low = oldHigh.Next();
                        }
                        lower = upper;
                    }

                    // if ranges were of equal size, report the range conflict
                    if (*rangeSize == *currentEntry.Size) {
                        TLog::Debug("%s", TIpAclUtils::RangeMessage(oldLow, oldHigh, currentEntry.Value->GetName(), value->GetName(), false).c_str());
                    }
                }

                // Finally, create new range. Under all circumstances we've
                // maintained lower as the insertion hint.
                if (low <= high) {
                    lower = rangeMap.insert(lower, std::make_pair(high, TAclEntry::MakeEnd(rangeSize, value)));
                    rangeMap.insert(lower, std::make_pair(low, TAclEntry::MakeStart(rangeSize, value)));
                }
            } else {
                // This new range actually covers one or more existing
                // ranges. We need to fill gaps. Do so by advancing low
                // towards high until lower becomes equal to upper and
                // skip over existing ranges along the way.

                // skip the beginning if it fits inside another range
                if (!lower->second.RangeStart) {
                    TAclEntry currentEntry = lower->second;
                    if (*rangeSize < *lower->second.Size) {
                        // need to override this part since new one is more specific
                        typename TIpRangeMap::iterator next = lower;
                        ++next;
                        rangeMap.insert(lower, std::make_pair(low.Prev(), TAclEntry::MakeEnd(currentEntry)));
                        rangeMap.erase(lower);
                        lower = next;
                    } else {
                        if (*rangeSize == *lower->second.Size) {
                            // report equal range size conflict
                            TLog::Debug("%s", TIpAclUtils::RangeMessage(low, lower->first, lower->second.Value->GetName(), value->GetName(), false).c_str());
                        }
                        low = lower->first.Next();
                        ++lower;
                    }
                }

                while (lower != upper) {
                    // lower->first is the beginning of the range
                    // if 'low' address before the beginning, add range part
                    if (low < lower->first) {
                        rangeMap.insert(lower, std::make_pair(low, TAclEntry::MakeStart(rangeSize, value)));
                        rangeMap.insert(lower, std::make_pair(lower->first.Prev(), TAclEntry::MakeEnd(rangeSize, value)));
                    }

                    NUtils::TIpAddr oldLow = lower->first;

                    if (*rangeSize < *lower->second.Size) {
                        // we have range fragment from bigger range and need to overwrite it
                        lower->second.Size = rangeSize;
                        lower->second.Value = value;
                        ++lower;
                        if (lower == upper) {
                            rangeMap.insert(lower, std::make_pair(high.Next(), TAclEntry::MakeStart(upper->second.Size, upper->second.Value)));
                            rangeMap.insert(lower, std::make_pair(high, TAclEntry::MakeEnd(rangeSize, value)));
                        } else {
                            lower->second.Size = rangeSize;
                            lower->second.Value = value;
                        }
                    } else {
                        if (*rangeSize == *lower->second.Size) {
                            // report equal range size conflict
                            TLog::Debug("%s", TIpAclUtils::RangeMessage(oldLow, lower->first, lower->second.Value->GetName(), value->GetName(), false).c_str());
                        }
                        ++lower;
                    }

                    low = lower->first.Next();
                    // now 'lower' is at the range end, and 'low' is the next address
                    // Caution: it is possible that we already skipped 'high' address

                    if (lower == upper) {
                        break;
                    }

                    ++lower;
                }

                if (low <= high) {
                    rangeMap.insert(lower, std::make_pair(low, TAclEntry::MakeStart(rangeSize, value)));
                    rangeMap.insert(lower, std::make_pair(high, TAclEntry::MakeEnd(rangeSize, value)));
                }
            }
        }

        void AddTrypoNet(TStringBuf strid,
                         NUtils::TIpAddr low,
                         NUtils::TIpAddr high,
                         std::shared_ptr<T> value) {
            ui32 id = 0;
            if (!TryIntFromString<16>(strid, id)) {
                throw yexception() << "bad project id: " << strid;
            }

            if (TrypoNets.find(id) != TrypoNets.end()) { // duplicate entry
                TLog::Debug("TRYPO project id 0x%X already present, rule for peer %s is ignored!", id, value->GetName().c_str());
                return;
            }
            TIpRangeMap rangeMap;
            AddValueRange(rangeMap, low, high, value);
            TrypoNets.emplace(id, std::move(rangeMap));
        }

    public: // For tests
        TIpRangeMap RangeMap;
        TProjectIdTable TrypoNets;
        TIpPeerMap SingleIpMap;

    private:
        std::shared_ptr<T> FindValueInTrypo(const NUtils::TIpAddr& ip) const {
            ui32 id = ip.ProjectId();
            if (id) {
                typename TProjectIdTable::const_iterator p = TrypoNets.find(id);
                if (p != TrypoNets.end()) {
                    return FindValueInMap(p->second, ip);
                }
            }
            return {};
        }
        static bool IsSingleIp(TStringBuf ip) {
            return !ip.Contains('-') && !ip.Contains('/');
        }

    private:
        time_t Mtime_ = 0;
    };
}
