#include "resolver.h"

#include <crypta/graph/rt/sklejka/cid_resolver/future/wrapper.h>

#include <ads/bsyeti/libs/ytex/client/rpc.h>

#include <library/cpp/cache/cache.h>
#include <library/cpp/iterator/enumerate.h>
#include <library/cpp/iterator/zip.h>
#include <library/cpp/json/json_reader.h>
#include <library/cpp/logger/global/global.h>

#include <stdlib.h>

namespace NCrypta {

    class TCryptaIdResolver: public ICryptaIdResolver, private TCryptaIdResolverData {
    public:
        explicit TCryptaIdResolver(TCryptaIdResolverData&& data)
            : TCryptaIdResolverData(std::move(data))
            , Loader(static_cast<TString>(Table), Timeout, false)
            , CryptaIdCache(NYT::New<TIdentificationCache<TCryptaId>>("crypta_id", MaxCacheSize, MaxCacheDuration))
        {
        }

        TChunkResponse<TCryptaId> Identify(const TChunkRequest& ids, NSFStats::TSolomonContext& ctx) override;

    private:
        template <class TFn>
        using TExtractIdChunkResult = TChunkResponse<typename std::invoke_result_t<TFn, const NIdentifiersProto::TGenericID&>::value_type>;
        template <class TFn>
        TExtractIdChunkResult<TFn> DoRawRequest(const TVector<TIdKey>& requestedKeys, ui32 attempt, NSFStats::TSolomonContext& ctx, TFn extractId);

    private:
        NCrypta::TYtCryptaIdProtosLoader Loader;
        NYT::TIntrusivePtr<TIdentificationCache<TCryptaId>> CryptaIdCache;
    };

    using TCryptaId = ICryptaIdResolver::TCryptaId;
    using TChunkRequest = ICryptaIdResolver::TChunkRequest;

    template <class T>
    using TChunkResponse = ICryptaIdResolver::TChunkResponse<T>;

    template <>
    size_t TIdentificationCacheSizeGetter<TCryptaId>::operator()(const TIdKey& key, const TMaybe<TCryptaId>& value) const {
        return sizeof(key) + key.GetValue().capacity() + sizeof(value);
    }

    template <class T>
    TChunkResponse<T> DoIdentificationRequest(
        const ICryptaIdResolver::TChunkRequest& rowsIds,
        const TResolverFun<T> rawResolver,
        const ui32 attemptsCount,
        const ui32 maxRequestBatchSize,
        TIdentificationCache<T>& cache,
        NSFStats::TSolomonContext& ctx) {
        using TSumMetric = NSFStats::TSumMetric<ui64>;

        ctx.Inc("requests_total", 1);
        ctx.Inc("rows_total", rowsIds.size());
        auto cacheHit{ctx.Get<TSumMetric>("identifiers_cache_hit")};
        auto identifiersTotal{ctx.Get<TSumMetric>("identifiers_total")};

        TChunkResponse<T> answer(rowsIds.size(), Nothing());

        TVector<ui32> srcIndex;        // requestedKeys index -> rowsIds index
        TVector<TIdKey> requestedKeys; // ids for request
        for (auto [index, simpleId] : Enumerate(rowsIds)) {
            identifiersTotal.Inc(1);
            if (const auto& cacheItem{cache.Find(simpleId)};
                cacheItem && cacheItem->GetValue().Defined()) {
                cacheHit.Inc(1);
                answer[index] = *cacheItem->GetValue();
            } else {
                srcIndex.emplace_back(index);
                requestedKeys.emplace_back(simpleId);
            }
        }

        ctx.Inc("interesting_identifiers_total", requestedKeys.size());

        if (requestedKeys.empty()) {
            return answer;
        }

        THashMap<TIdKey, TVector<ui32>> deduplicatedIndex{};
        TVector<TIdKey> deduplicatedKeys{};
        for (const auto& [key, index] : Zip(requestedKeys, srcIndex)) {
            if (deduplicatedIndex[key].empty()) {
                deduplicatedKeys.emplace_back(key);
            }
            deduplicatedIndex[key].emplace_back(index);
        }

        ctx.Inc("deduplicated_identifiers_total", deduplicatedIndex.size());
        using TIdentifyBatchSize = NSFStats::TSolomonThresholdMetric<10, 25, 50, 100, 250, 500, 750, 1000>;

        Y_ENSURE(deduplicatedIndex.size() > 0);
        ui64 keysRequested{0};
        TChunkResponse<T> deduplicatedResponse;
        deduplicatedResponse.reserve(deduplicatedKeys.size());

        while (keysRequested < deduplicatedKeys.size()) {
            auto newBatchSize = std::min(static_cast<ui64>(maxRequestBatchSize), deduplicatedKeys.size() - keysRequested);
            ctx.Get<TIdentifyBatchSize>("identify_batch_size_hist").Add(newBatchSize);

            TVector<TIdKey> keysBatch{};
            keysBatch.reserve(newBatchSize);
            std::copy_n(
                deduplicatedKeys.begin() + keysRequested,
                newBatchSize,
                std::back_inserter(keysBatch));

            TChunkResponse<T> partialResponse;
            for (ui32 attempt{0}; attempt < attemptsCount; ++attempt) {
                try {
                    partialResponse = rawResolver(keysBatch, attempt);
                } catch (...) {
                    ctx.Inc("requests_timeout", 1);
                    WARNING_LOG << "Can't identify (probably timeout) (attempt = " << attempt << "): "
                                << CurrentExceptionMessage() << "\n";
                    continue;
                }
                break;
            }

            if (attemptsCount == 0) {
                // used for skip identification debug purpose
                partialResponse.resize(deduplicatedKeys.size(), Nothing());
            }
            if (partialResponse.empty()) {
                ctx.Inc("requests_final_empty", 1);
                ERROR_LOG << "Did not get any identification response after " << attemptsCount << " attempts.";
                ythrow TCryptaIdResolver::TNoResponseException{};
            }

            keysRequested += newBatchSize;
            std::copy(partialResponse.begin(),
                      partialResponse.end(),
                      std::back_inserter(deduplicatedResponse));
        }

        Y_ENSURE(deduplicatedResponse.size() == deduplicatedKeys.size());

        auto keysWithResponse{ctx.Get<TSumMetric>("deduplicated_identifiers_identified")};
        for (auto& ans : deduplicatedResponse) {
            keysWithResponse.Inc(ans.Defined());
        }

        for (auto [key, cryptaId] : Zip(deduplicatedKeys, deduplicatedResponse)) {
            cache.Update(NIdentifiers::TGenericID{key}, cryptaId);
            for (const auto index : deduplicatedIndex[key]) {
                if (answer[index].Defined()) {
                    continue;
                } else {
                    answer[index] = cryptaId;
                }
            }
        }

        auto rowsIdentified{ctx.Get<TSumMetric>("rows_identified")};
        for (auto& ans : answer) {
            rowsIdentified.Inc(ans.Defined());
        }

        return answer;
    }

    template <class TFn>
    TCryptaIdResolver::TExtractIdChunkResult<TFn> TCryptaIdResolver::DoRawRequest(
        const TVector<TIdKey>& requestedKeys,
        const ui32 attempt,
        NSFStats::TSolomonContext& ctx,
        TFn extractId) {
        Y_UNUSED(ctx);

        Y_ENSURE(YtClients.size() > 0);
        const auto response{
            Loader.LoadRowset(requestedKeys, YtClients[attempt % YtClients.size()])};

        Y_ENSURE(response.size() == requestedKeys.size());

        TExtractIdChunkResult<TFn> answer(requestedKeys.size(), Nothing());
        for (auto [ans, res] : Zip(answer, response)) {
            if (auto id{extractId(res.GetCryptaId())}) {
                ans = *std::move(id);
            }
        }
        return answer;
    }

    TChunkResponse<TCryptaId> TCryptaIdResolver::Identify(
        const ICryptaIdResolver::TChunkRequest& ids,
        NSFStats::TSolomonContext& ctx) {
        const auto rawResolver{
            [this, &ctx](const TVector<TIdKey>& requestedKeys, const ui32 attempt) {
                const auto extracTCryptaId{
                    [](const NIdentifiersProto::TGenericID& record) -> TMaybe<TCryptaId> {
                        if (record.GetType() == NIdentifiersProto::NIdType::CRYPTA_ID) {
                            return record.GetCryptaId().GetValue();
                        }
                        return Nothing();
                    }};
                return DoRawRequest(requestedKeys, attempt, ctx, extracTCryptaId);
            }};
        return DoIdentificationRequest<TCryptaId>(ids, rawResolver, AttemptsCount, MaxRequestBatchSize, *CryptaIdCache, ctx);
    }

    TCryptaIdResolverData CreateCryptaIdResolverData(const TCryptaIdResolverConfig& config) {
        TCryptaIdResolverData data{
            .Table = config.GetTable(),
            .AttemptsCount = config.GetAttemptsCount(),
            .Timeout = TDuration::MilliSeconds(config.GetTimeoutMs()),
            .MaxCacheSize = config.GetMaxCacheSize(),
            .MaxCacheDuration = TDuration::MilliSeconds(config.GetMaxCacheDurationMs()),
            .MaxRequestBatchSize = config.GetMaxRequestBatchSize(),
        };

        for (const auto& cluster : config.GetYtClusters()) {
            data.YtClients.emplace_back(NYTRpc::CreateClusterClient(cluster));
            data.YtClientAliases.emplace_back(cluster);
        }

        return data;
    }

    TCryptaIdResolverPtr CreateCryptaIdResolver(TCryptaIdResolverData&& data) {
        if (data.YtClients.size() == 0 || data.Table == "") {
            ERROR_LOG << "Invalid data for constructing TCryptaIdResolver\n";
            return {};
        }
        return MakeIntrusive<TCryptaIdResolver>(std::move(data));
    }

    TCryptaIdResolverPtr CreateCryptaIdResolver(const TCryptaIdResolverConfig& config) {
        return CreateCryptaIdResolver(CreateCryptaIdResolverData(config));
    }

}
