#include <saas/searchproxy/proxy_meta/rearrange/abstract/rearrange.h>
#include <saas/searchproxy/proxy_meta/rearrange/trie_sharding.h>

#include <search/meta/scatter/source.h>

#include <kernel/saas_trie/idl/trie_key.h>

#include <library/cpp/getopt/small/last_getopt.h>

#include <util/datetime/base.h>
#include <util/generic/deque.h>
#include <util/stream/file.h>
#include <util/string/split.h>
#include <library/cpp/deprecated/atomic/atomic.h>
#include <util/thread/pool.h>

struct TConfig {
    TString KeysPath;
    TString UrlMaskPrefix;
    TString ShardingType{"urlhash"};
    ui32 ShardCount = 2;
    ui32 Iterations = 1;
    bool ValidateKeys = false;

    void Print(IOutputStream& out) {
        out << "--- Config --\n";
        out << "KeysPath: " << KeysPath << '\n';
        out << "ShardCount: " << ShardCount << '\n';
        out << "Iterations: " << Iterations << '\n';
        out << "ValidateKeys: " << ValidateKeys << '\n';
        out << "UrlMaskPrefix: " << UrlMaskPrefix << '\n';
        out << "ShardingType: " << ShardingType << '\n';
        out << "---" << Endl;
    }

    void Init(int argc, const char* argv[]) {
        auto makeOpt = [](TString name, TString help, bool required = true) {
            NLastGetopt::TOpt opt;
            opt.AddLongName(name);
            opt.Help(help);
            if (required) {
                opt.Required();
            }
            return opt;
        };
        NLastGetopt::TOpts opts;
        opts.AddOption(makeOpt("keys", "path to keys", true).StoreResult(&KeysPath));
        opts.AddOption(makeOpt("shards", "number of shards", true).StoreResult(&ShardCount));
        opts.AddOption(makeOpt("iter", "number of iterations", false).StoreResult(&Iterations));
        opts.AddOption(makeOpt("validinput", "validate input", false).StoreResult(&ValidateKeys));
        opts.AddOption(makeOpt("urlmaskprefix", "setup url mask prefix", false).StoreResult(&UrlMaskPrefix));
        opts.AddOption(makeOpt("shardrule", "sharding rule", false).StoreResult(&ShardingType));
        opts.AddHelpOption('h');
        NLastGetopt::TOptsParseResult parser(&opts, argc, argv);
    }
};

TDeque<TCgiParameters> LoadRequests(const TString& path, const TString& urlMaskPrefix, bool validateKeys) {
    TDeque<TCgiParameters> result;
    auto content = TIFStream(path).ReadAll();
    for (const auto& it : StringSplitter(content).Split('\n').SkipEmpty()) {
        TString strKey{it.Token()};
        if (validateKeys || !urlMaskPrefix.empty()) {
            NSaasTrie::TComplexKey key;
            NSaasTrie::DeserializeFromCgi(key, strKey, true);
            if (!urlMaskPrefix.empty()) {
                key.SetUrlMaskPrefix(urlMaskPrefix);
                strKey = NSaasTrie::SerializeToCgi(key, true);
            }
        }
        result.emplace_back(TCgiParameters{
            {"component", "TRIE"},
            {"comp_search", "comp:TRIE;max_docs:100500;key_type:complex_key_packed"},
            {"text", std::move(strKey)}
        });
    }
    return result;
}

template<typename T>
struct TFakePointers {
    TFakePointers(size_t numberOfPointers) {
        Pointers.resize(numberOfPointers);
        for (size_t i = 0; i < numberOfPointers; ++i) {
            Pointers[i].Reset(reinterpret_cast<T*>(i));
        }
    }

    ~TFakePointers() {
        for (auto& ptr : Pointers) {
            Y_UNUSED(ptr.Release());
        }
    }

    const TVector<THolder<T>>& Get() const {
        return Pointers;
    }

private:
    TVector<THolder<T>> Pointers;
};

int main(int argc, const char* argv[]) {
    try {
        TConfig config;
        config.Init(argc, argv);
        config.Print(Cout);

        constexpr ui32 MaxShardIndex = 65533;
        Y_ENSURE(config.ShardCount > 1 && config.ShardCount <= MaxShardIndex + 1);

        auto requests = LoadRequests(config.KeysPath, config.UrlMaskPrefix, config.ValidateKeys);
        NSaas::ShardingType shardingType = NSaas::ShardingType::UrlHash;
        if (config.ShardingType == "urlhash") {
            shardingType = NSaas::ShardingType::UrlHash;
        } else if (config.ShardingType == "broadcast") {
            shardingType = NSaas::ShardingType::Broadcast;
        } else if (config.ShardingType == "keyprefix") {
            shardingType = NSaas::ShardingType::KeyPrefix;
        }

        NSaas::TShardsDispatcher::TContext shardContext{shardingType};
        NSaas::TShardsDispatcher shardDispatcher{shardContext};

        Cout << "Shard intervals:\n";
        NSaas::TShardIntervals shardIntervals;
        ui32 from = 0;
        for (ui32 i = 0; i < config.ShardCount; ++i) {
            ui32 to = (i + 1) * (MaxShardIndex + 1) / config.ShardCount;
            shardIntervals.emplace_back(from, to - 1);
            Cout << from << " - " << (to - 1) << '\n';
            from = to;
        }

        TFakePointers<NScatter::ISource> fakeSources{config.ShardCount};

        size_t incomingSize = 0;
        for (auto& request : requests) {
            incomingSize += request.size();
        }
        size_t outgoingRequests = 0;
        size_t outgoingSize = 0;

        Cout << "--- Start searching ---" << Endl;
        auto startTime = TInstant::Now();

        for (ui32 i = 0; i < config.Iterations; ++i) {
            for (const auto& request : requests) {
                NProxyMeta::TKeysForSource keysForSource;
                NProxyMeta::DoTrieSharding(request, shardDispatcher, shardIntervals, fakeSources.Get(), keysForSource);
                outgoingRequests += keysForSource.size();
                for (auto& src : keysForSource) {
                    outgoingSize += src.second.size();
                }
            }
        }

        auto endTime = TInstant::Now();

        Cout << "Incoming requests: " << requests.size() << Endl;
        Cout << "Outgoing requests: " << (outgoingRequests / config.Iterations) << Endl;
        Cout << "RPS amplification: " << (double(outgoingRequests / config.Iterations) / requests.size()) << Endl;
        Cout << "Network amplification: " << (double(outgoingSize / config.Iterations) / incomingSize) << Endl;

        double totalTime = (endTime - startTime).MicroSeconds();
        Cout << "Time per iteration: " << size_t(totalTime / config.Iterations + 0.5) << "us" << Endl;
        Cout << "Time per request: " << size_t(totalTime / (config.Iterations * requests.size()) + 0.5) << "us" << Endl;

        return 0;

    } catch (...) {
        Cerr << CurrentExceptionMessage() << Endl;
        return 1;
    }
}

