#include "sharding.h"

#include <saas/protos/rtyserver.pb.h>
#include <saas/util/json/json.h>

#include <library/cpp/logger/global/global.h>

#include <util/string/cast.h>
#include <util/generic/ymath.h>
#include <util/stream/output.h>
#include <util/digest/city.h>
#include <util/string/split.h>

namespace NSaas {
    namespace {
        struct TShardIntervalCallbackWrapper : IShardIntervalCallback {
            TShardIntervalCallbackWrapper(const std::function<void (size_t)>& callback)
                : Callback(callback)
            {
            }
            void OnShardInterval(size_t intervalIndex) override {
                Callback(intervalIndex);
            }
        private:
            const std::function<void (size_t)>& Callback;
        };
    }

    TShardsDispatcher::IShardingRule::IShardingRule(const TShardsDispatcher::TContext& context)
        : Context(context)
    {
        VERIFY_WITH_LOG(Context.KpsShift < 16, "Incorrect kpsShift");
    }

    bool TShardsDispatcher::IShardingRule::NeedFullCoverage() const {
        return true;
    }

    const TShardsDispatcher::TContext& TShardsDispatcher::IShardingRule::GetContext() const {
        return Context;
    }

    NSearchMapParser::TShardIndex TShardsDispatcher::IShardingRule::GetShard(const NRTYServer::TMessage& message) const {
        const auto& document = message.GetDocument();
        return GetShard(document.GetUrl(), document.GetKeyPrefix());
    }

    bool TShardsDispatcher::IShardingRule::CheckInterval(NSearchMapParser::TShardIndex shard, const TInterval<NSearchMapParser::TShardIndex>& interval) const {
        return interval.Check(shard);
    }

    void TShardsDispatcher::EnumerateIntervals(TStringBuf url, TKeyPrefix kps, const TShardIntervals& sortedIntervals,
                                               const std::function<void (size_t intervalIndex)>& callback) const {
        TShardIntervalCallbackWrapper callbackWrapper(callback);
        Impl->EnumerateIntervals(url, kps, sortedIntervals, callbackWrapper);
    }

    void TShardsDispatcher::EnumerateIntervals(TStringBuf url, TKeyPrefix kps, const TShardIntervals& sortedIntervals,
                                               IShardIntervalCallback& callback) const {
        Impl->EnumerateIntervals(url, kps, sortedIntervals, callback);
    }

    TShardsDispatcher::TContext::TContext(ShardingType type, ShardsCount shards, ui32 kpsShift)
        : Type(type)
        , Shards(shards)
        , KpsShift(kpsShift)
    {
        VERIFY_WITH_LOG(KpsShift < 16, "Incorrect kpsShift");
    }

    TShardsDispatcher::TContext::TContext(ShardingType type, ui32 kpsShift)
        : TContext(type, ShardsCount::Legacy, kpsShift)
    {
    }

    TShardsDispatcher::TContext::TContext(ShardingType type)
        : TContext(type, 0)
    {
    }

    NSearchMapParser::TShardIndex TShardsDispatcher::TContext::GetShardsMax() const {
        return NSaas::GetShards(Shards);
    }

    TString TShardsDispatcher::TContext::ToString() const {
        TStringStream ss;
        ss << Type;
        if (KpsShift != 0 || Shards != ShardsCount::Legacy) {
            ss << "-" << KpsShift;
        }
        if (Shards != ShardsCount::Legacy) {
            ss << "-" << Shards;
        }
        return ss.Str();
    }

    TShardsDispatcher::TContext TShardsDispatcher::TContext::FromString(const TString& str) {
        TContext result;
        TVector<TString> parts;
        StringSplitter(str).Split('-').SkipEmpty().Collect(&parts);
        if (parts.size() > 3 || parts.size() < 1) {
            ythrow yexception() << "Incorrect sharding context string: " << str;
        }

        if (!TryFromString(parts[0], result.Type)) {
            ythrow yexception() << "Incorrect sharding context string: incorrect type " << parts[0];
        }

        if (parts.size() == 2) {
            if (!TryFromString(parts[1], result.KpsShift)) {
                ythrow yexception() << "Incorrect sharding context string: incorrect shift " << parts[1];
            }
        }

        if (parts.size() == 3) {
            if (!TryFromString(parts[2], result.Shards)) {
                ythrow yexception() << "Incorrect sharding context string: incorrect shards count " << parts[2];
            }
        }

        if (result.KpsShift >= 16) {
            ythrow yexception() << "KpsShift too big";
        }
        return result;
    }

    void TShardsDispatcher::TContext::ToProto(NSaasProto::TService& result) const {
        switch (Type) {
        case KeyPrefix:
            result.SetShardBy(NSaasProto::TService::KeyPrefix);
            break;
        case NSaas::UrlHash:
            result.SetShardBy(NSaasProto::TService::UrlHash);
            break;
        case NSaas::External:
            result.SetShardBy(NSaasProto::TService::External);
            break;
        case NSaas::Broadcast:
            result.SetShardBy(NSaasProto::TService::Broadcast);
            break;
        case NSaas::Geo:
            result.SetShardBy(NSaasProto::TService::Geo);
            break;
        case NSaas::GeoRestrict:
            result.SetShardBy(NSaasProto::TService::GeoRestrict);
            break;
        case NSaas::UrlHashErasure:
            result.SetShardBy(NSaasProto::TService::UrlHashErasure);
            break;
        case NSaas::QuerySearch:
            result.SetShardBy(NSaasProto::TService::QuerySearch);
            break;
        case NSaas::UrlToLastOctothorp:
            result.SetShardBy(NSaasProto::TService::UrlToLastOctothorp);
            break;
        case NSaas::UrlToLastUnderscore:
            result.SetShardBy(NSaasProto::TService::UrlToLastUnderscore);
            break;
        default:
            ythrow yexception() << "Unknown ShardingType";
        }
        switch (Shards) {
        case ShardsCount::Legacy:
            result.SetShards(NSaasProto::TService::Legacy);
            break;
        case ShardsCount::UI32:
            result.SetShards(NSaasProto::TService::UI32);
            break;
        default:
            ythrow yexception() << "Disallowed Shards count";
        }

        result.SetKpsShift(KpsShift);
    }

    TShardsDispatcher::TContext TShardsDispatcher::TContext::FromProto(const NSaasProto::TService& service) {
        NSaas::ShardingType type;
        switch (service.GetShardBy()) {
        case NSaasProto::TService::KeyPrefix:
            type = KeyPrefix;
            break;
        case NSaasProto::TService::UrlHash:
            type = UrlHash;
            break;
        case NSaasProto::TService::External:
            type = External;
            break;
        case NSaasProto::TService::Broadcast:
            type = Broadcast;
            break;
        case NSaasProto::TService::Geo:
            type = Geo;
            break;
        case NSaasProto::TService::GeoRestrict:
            type = GeoRestrict;
            break;
        case NSaasProto::TService::UrlHashErasure:
            type = UrlHashErasure;
            break;
        case NSaasProto::TService::QuerySearch:
            type = QuerySearch;
            break;
        case NSaasProto::TService::UrlToLastOctothorp:
            type = UrlToLastOctothorp;
            break;
        case NSaasProto::TService::UrlToLastUnderscore:
            type = UrlToLastUnderscore;
            break;
        default:
            ythrow yexception() << "Unknown ShardingType";
        }
        ShardsCount shards;
        switch (service.GetShards()) {
        case NSaasProto::TService::Legacy:
            shards = ShardsCount::Legacy;
            break;
        case NSaasProto::TService::UI32:
            shards = ShardsCount::UI32;
            break;
        default:
            ythrow yexception() << "Unknown Shards count";
        }
        return TContext(type, shards, service.GetKpsShift());
    }

    TShardsDispatcher::TShardsDispatcher(const TContext& context) {
        Impl.Reset(IShardingRule::TFactory::Construct(context.Type, context));
        VERIFY_WITH_LOG(!!Impl.Get(), "Can't find rule for %s", ToString(context.Type).data());
    }

    NSearchMapParser::TShardIndex TShardsDispatcher::GetShard(const NRTYServer::TMessage& message) const {
        return Impl->GetShard(message);
    }

    NSearchMapParser::TShardIndex TShardsDispatcher::GetShard(const TStringBuf& url, TKeyPrefix kps) const {
        return Impl->GetShard(url, kps);
    }

    bool TShardsDispatcher::CheckMessage(const NRTYServer::TMessage& message, TString& error) const {
        return Impl->CheckMessage(message, error);
    }

    bool TShardsDispatcher::CheckInterval(NSearchMapParser::TShardIndex shard, const TInterval<NSearchMapParser::TShardIndex>& interval) const {
        return Impl->CheckInterval(shard, interval);
    }

    const TShardsDispatcher::TContext& TShardsDispatcher::GetContext() const {
        return Impl->GetContext();
    }

    bool TShardsDispatcher::CheckSearchInterval(TStringBuf url, TKeyPrefix kps, const TInterval<NSearchMapParser::TShardIndex>& interval) const {
        return Impl->CheckSearchInterval(url, kps, interval);
    }


    bool TShardsDispatcher::NeedFullCoverage() const {
        return Impl->NeedFullCoverage();
    }

    NSearchMapParser::TShardIndex GetShards(ShardsCount count) {
        switch (count) {
        case NSaas::ShardsCount::Legacy:
            return NSearchMapParser::SearchMapShards;
        case NSaas::ShardsCount::UI32:
            return NSearchMapParser::UI32Shards;
        default:
            FAIL_LOG("incorrect ShardsCount");
        }
    }

    TString GetShardString(const NSearchMapParser::TShardIndex shard) {
        return Sprintf("%.8d", shard);
    }
}

NJson::TJsonValue NSaas::TSlotInfo::SerializeToJson() const {
    NJson::TJsonValue json;
#define PROCESS(member) { TString name(#member); name.to_lower(); json.InsertValue(name, member); }
    PROCESS(ServiceType);
    PROCESS(CType);
    PROCESS(Service);
    PROCESS(Slot);
    PROCESS(ConfigType);
    PROCESS(ShardMin);
    PROCESS(ShardMax);
    PROCESS(DisableIndexing);
    PROCESS(DisableSearch);
    PROCESS(DisableSearchFiltration);
    PROCESS(DisableFetch);
    PROCESS(IsIntSearch);
    PROCESS(DC);
#undef  PROCESS
    return json;
}

TString NSaas::TSlotInfo::ToString() const {
    return NUtil::JsonToString(SerializeToJson());
}

bool NSaas::TSlotInfo::DeserializeFromJson(const NJson::TJsonValue& json) {
    NJson::TJsonValue jsonLoc = json;
#define PROCESS(member, type) { TString name(#member); name.to_lower(); member = jsonLoc[name].Get ## type ## Robust(); }
    PROCESS(ServiceType, String);
    PROCESS(CType, String);
    PROCESS(Service, String);
    PROCESS(Slot, String);
    PROCESS(ConfigType, String);
    PROCESS(ShardMin, UInteger);
    PROCESS(ShardMax, UInteger);
    PROCESS(DisableIndexing, Boolean);
    PROCESS(DisableSearch, Boolean);
    PROCESS(DisableSearchFiltration, Boolean);
    PROCESS(DisableFetch, Boolean);
    PROCESS(IsIntSearch, Boolean);
    if (json.Has("dc")) {
        PROCESS(DC, String);
    }
#undef PROCESS
    return true;
}

bool NSaas::TSlotInfo::FromString(const TString& string) {
    NJson::TJsonValue json;
    if (!NUtil::JsonFromString(string, json)) {
        Cerr << "cannot parse json from string" << Endl;
        return false;
    }
    return DeserializeFromJson(json);
}

NSearchMapParser::TShardsInterval NSaas::TSharding::GetInterval(ui32 index, ui32 count, ui32 shardMin, ui32 shardMax) {
    Y_ASSERT(count && index < count);
    NSearchMapParser::TShardsInterval result;
    result.SetMin(shardMin + (shardMax - shardMin) / count * index);
    if (index == count - 1)
        result.SetMax(shardMax);
    else
        result.SetMax(shardMin + (shardMax - shardMin) / count * (index + 1) - 1);
    return result;
}

NSearchMapParser::TShardsInterval NSaas::TSharding::GetInterval(ui32 index, ui32 count) {
    return GetInterval(index, count, 0, NSearchMapParser::SearchMapShards);
}

TVector<NSearchMapParser::TShardsInterval> NSaas::TSharding::SplitInterval(const NSearchMapParser::TShardsInterval& interval, ui32 parts)
{
    TVector<NSearchMapParser::TShardsInterval> ans;
    ui32 length = floor(((double)interval.GetLength() / parts));
    NSearchMapParser::TShardIndex min = interval.GetMin();
    for (ui32 i = 0; i < parts; ++i) {
        ans.push_back(NSearchMapParser::TShardsInterval(min, min + length));
        min += length + 1;
    }
    ans.back().SetMax(interval.GetMax());
    return ans;
}
