#include "trie_sharding.h"

#include <saas/library/searchserver/exception.h>

#include <kernel/saas_trie/idl/saas_trie.pb.h>
#include <kernel/saas_trie/idl/trie_key.h>
#include <kernel/saas_trie/trie_url_mask_iterator.h>

#include <library/cpp/string_utils/scan/scan.h>

namespace NProxyMeta {
    using TShardedKeys = TVector<std::pair<TDeque<TString>, ui32>>;

    constexpr ui32 TrieNormalKeyBit = 1;
    constexpr ui32 TrieUrlMaskBit = 2;

    const TStringBuf TrieComponentName{"TRIE"};

    size_t GetFirstRealm(const NSaasTrie::TComplexKey& key) {
        Y_ENSURE_EX(key.KeyRealmsSize() > 0, TSearchException(HTTP_BAD_REQUEST) << "Trie key is invalid: zero key realms found");
        auto& realmName = key.GetKeyRealms(0);
        for (size_t i = 0, imax = key.AllRealmsSize(); i < imax; ++i) {
            if (key.GetAllRealms(i).GetName() == realmName) {
                return i;
            }
        }
        ythrow TSearchException(HTTP_BAD_REQUEST) << "Trie key is invalid: cannot find first realm";
    }

    TString BuildNewKey(NSaasTrie::TComplexKey& key, size_t realm, TDeque<TString>& realmKeys, ui32 keyTypes, bool packed) {
        NSaasTrie::TRealm originalRealm;
        auto realmToReplace = key.MutableAllRealms(realm);
        originalRealm.MutableKey()->Swap(realmToReplace->MutableKey());

        // replace keys with new ones
        realmToReplace->MutableKey()->Clear();
        for (auto& realmKey : realmKeys) {
            *realmToReplace->AddKey() = std::move(realmKey);
        }

        // override MainKey and UrlMaskPrefix if it's necessary
        TMaybe<TString> originalMainKey;
        TMaybe<TString> originalUrlMaskPrefix;
        if (key.HasMainKey() && !(keyTypes & TrieNormalKeyBit)) {
            originalMainKey = std::move(*key.MutableMainKey());
            key.ClearMainKey();
        }
        if (key.HasUrlMaskPrefix() && !(keyTypes & TrieUrlMaskBit)) {
            originalUrlMaskPrefix = std::move(*key.MutableUrlMaskPrefix());
            key.ClearUrlMaskPrefix();
        }

        TString newKey = NSaasTrie::SerializeToCgi(key, packed);

        // restore changes
        if (originalMainKey.Defined()) {
            key.SetMainKey(std::move(*originalMainKey));
        }
        if (originalUrlMaskPrefix.Defined()) {
            key.SetUrlMaskPrefix(std::move(*originalUrlMaskPrefix));
        }
        originalRealm.MutableKey()->Swap(realmToReplace->MutableKey());
        return newKey;
    }

    struct TTrieShardIntervalEnumerator : NSaas::IShardIntervalCallback {
        TTrieShardIntervalEnumerator(const NSaas::TShardsDispatcher& shardDispatcher,
                                     const NSaas::TShardIntervals& shardIntervals,
                                     TShardedKeys& shardedKeys)
            : ShardDispatcher(shardDispatcher)
            , ShardIntervals(shardIntervals)
            , ShardedKeys(shardedKeys)
        {
        }

        void FindShard(const TString& originalKey, const TString& shardKey, ui64 kps, ui32 typeMask) {
            OriginalKey = &originalKey;
            TypeMask = typeMask;
            FoundSource = false;
            ShardDispatcher.EnumerateIntervals(shardKey, kps, ShardIntervals, *this);
            if (!FoundSource) {
                ythrow TSearchException(HTTP_INTERNAL_SERVER_ERROR) << "No shard found for key " << shardKey << ':' << kps;
            }
        }

        void OnShardInterval(size_t interval) override {
            auto& shardContext = ShardedKeys[interval];
            shardContext.second |= TypeMask;
            auto& keys = shardContext.first;
            if (keys.empty() || keys.back().c_str() != OriginalKey->c_str()) {
                keys.push_back(*OriginalKey);
            }
            FoundSource = true;
        }

    private:
        const NSaas::TShardsDispatcher& ShardDispatcher;
        const NSaas::TShardIntervals& ShardIntervals;
        TShardedKeys& ShardedKeys;
        const TString* OriginalKey = nullptr;
        ui32 TypeMask = 0;
        bool FoundSource = false;
    };

    bool DoTrieSharding(const TCgiParameters& cgi,
                        const NSaas::TShardsDispatcher& shardDispatcher,
                        const NSaas::TShardIntervals& shardIntervals,
                        const TVector<THolder<NScatter::ISource>>& sources,
                        TKeysForSource& keysForSource) {
        if (cgi.Get("component") != TrieComponentName) {
            return false;
        }
        NSaasTrie::ETrieKeyType keyType = NSaasTrie::ETrieKeyType::Text;
        ScanKeyValue<false, ';', ':'>(cgi.Get("comp_search"), [&](TStringBuf key, TStringBuf value) {
            if (key == "key_type") {
                keyType = FromString<NSaasTrie::ETrieKeyType>(value);
            }
        });
        if (keyType == NSaasTrie::ETrieKeyType::Text) {
            return false;
        }
        bool packed = keyType == NSaasTrie::ETrieKeyType::ComplexKeyPacked;


        for (const auto& text: cgi.Range("text")) {
            NSaasTrie::TComplexKey key;
            try {
                NSaasTrie::DeserializeFromCgi(key, text, packed);
            } catch (...) {
                ythrow TSearchException(HTTP_BAD_REQUEST) << "Can not unpack trie key: " << CurrentExceptionMessage();
            }

            const bool hasMainKey = key.HasMainKey() || !key.HasUrlMaskPrefix();
            const bool hasUrlMaskPrefix = key.HasUrlMaskPrefix();

            const size_t firstRealm = GetFirstRealm(key);
            const auto& firstKeys = key.GetAllRealms(firstRealm).GetKey();

            TShardedKeys shardedKeys(sources.size());
            TTrieShardIntervalEnumerator shardEnumerator(shardDispatcher, shardIntervals, shardedKeys);

            for (const auto& firstKey : firstKeys) {
                if (hasMainKey) {
                    const TString shardKey = key.GetMainKey() + firstKey;
                    shardEnumerator.FindShard(firstKey, shardKey, key.GetKeyPrefix(), TrieNormalKeyBit);
                }
                if (hasUrlMaskPrefix) {
                    NSaasTrie::TUrlMaskHostIterator hostIterator;
                    if (hostIterator.SetUrl(TStringBuf{firstKey}.After('\t')) && hostIterator.CheckOwner()) {
                        do {
                            const TString shardKey = TString::Join(key.GetUrlMaskPrefix(), "\t", hostIterator.CurrentHost());
                            shardEnumerator.FindShard(firstKey, shardKey, key.GetKeyPrefix(), TrieUrlMaskBit);
                        } while (hostIterator.Next());
                    }
                }
            }
            for (size_t i = 0, imax = sources.size(); i < imax; ++i) {
                auto& shardContext = shardedKeys[i];
                if (!shardContext.first.empty()) {
                    TString newKey = BuildNewKey(key, firstRealm, shardContext.first, shardContext.second, packed);
                    keysForSource[sources[i].Get()].emplace_back(std::move(newKey));
                }
            }
        }
        return true;
    }

    class TTrieSharding : public ICustomRearrange {
    public:
        TString GetName() const final {
            return "TrieSharding";
        }

        bool DoProcessSharding(IReplyContext& context,
                               const NSaas::TShardsDispatcher& shardDispatcher,
                               const NSaas::TShardIntervals& shardIntervals,
                               const TVector<THolder<NScatter::ISource>>& sources,
                               TKeysForSource& keysForSource) const final
        {
            return DoTrieSharding(context.GetRequestData().CgiParam, shardDispatcher, shardIntervals, sources, keysForSource);
        }
    };

    class TTrieShardingFactory : public ICustomRearrangeFactory {
    public:
        TTrieShardingFactory(const TCustomRearrangeParams&)
        {}

        TString GetName() const final {
            return "TrieSharding";
        }

        ICustomRearrange::TPtr CreateRerrangeInstance() const final {
            return new TTrieSharding;
        }

    private:
        static ICustomRearrangeFactory::TFactory::TRegistrator<TTrieShardingFactory> Registrator;
    };

    ICustomRearrangeFactory::TFactory::TRegistrator<TTrieShardingFactory> TTrieShardingFactory::Registrator("TrieSharding");
}
