#include "abc_resolver.h"

#include <saas/searchproxy/core/abc_resolver.pb.h>

#include <library/cpp/http/client/client.h>
#include <library/cpp/logger/global/global.h>
#include <library/cpp/json/json_reader.h>

#include <util/datetime/base.h>
#include <util/generic/hash.h>
#include <util/generic/hash_set.h>
#include <util/random/random.h>
#include <util/stream/file.h>
#include <util/string/builder.h>
#include <util/string/join.h>
#include <util/system/condvar.h>
#include <util/system/fs.h>
#include <util/system/rwlock.h>
#include <util/system/thread.h>
#include <util/thread/factory.h>

namespace NSearchProxy {

    class TAbcResolver::TImpl {
    private:
        struct TGroup {
            TVector<TUid> Uids;
            TInstant Mtime;
            TInstant LastErrorTime;
        };

        using TMapping = THashMap<TAbcId, TGroup>;
        using TMappingPtr = TAtomicSharedPtr<TMapping>;

    public:
        TImpl(const TParams& params)
            : ApiUrl(params.ApiUrl)
            , ApiTvmId(params.ApiTvmId)
            , TvmClient(params.TvmClient)
            , UpdatePeriod(params.UpdatePeriod)
            , MaxErrorCountPerUpdate(params.MaxErrorCountPerUpdate)
            , FetchConfig(params.FetchConfig)
        {
            Y_ENSURE(TvmClient);

            if (params.CacheDir) {
                CacheFile = params.CacheDir + "/abc_resolver";
            }

            {
                THashSet<TAbcId> groups(params.GroupsToResolve.begin(), params.GroupsToResolve.end());
                GroupsToResolve.insert(GroupsToResolve.end(), groups.begin(), groups.end());
            }

            {
                TWriteGuard g(MappingLock);
                Mapping = new TMapping();
            }

            TryLoadFromCache();
        }

        void Start() {
            FetchThread = SystemThreadFactory()->Run([this] { FetchLoop(); });
        }

        ~TImpl() {
            Stop();
        }

        void TryLoadFromCache() {
            if (!CacheFile || !NFs::Exists(CacheFile)) {
                return;
            }

            TMappingPtr mapping;
            try {
                mapping = LoadFromCache(CacheFile);
            } catch (...) {
                ERROR_LOG << "Failed to load ABC cache from " << CacheFile
                    << ", error: " << CurrentExceptionMessage() << Endl;
                return;
            }

            TWriteGuard g(MappingLock);
            Mapping = mapping;
        }

        TMappingPtr ProtoToMapping(const NSearchProxy::TAbcMappingProto& proto) const {
            TMappingPtr ret = new TMapping();
            for (const auto& groupProto : proto.GetGroups()) {
                auto& group = (*ret)[groupProto.GetId()];
                group.Mtime = TInstant::Seconds(groupProto.GetMtime());
                group.LastErrorTime = TInstant::Seconds(groupProto.GetLastErrorTime());
                group.Uids.insert(group.Uids.end(), groupProto.GetUids().begin(), groupProto.GetUids().end());
            }
            return ret;
        }

        TMappingPtr LoadFromCache(const TString& file) const {
            NSearchProxy::TAbcMappingProto proto;
            TFileInput in(file);
            Y_ENSURE(proto.ParseFromArcadiaStream(&in));
            return ProtoToMapping(proto);
        }

        NSearchProxy::TAbcMappingProto MappingToProto(const TMapping& mapping) const {
            NSearchProxy::TAbcMappingProto proto;
            for (const auto& [groupId, group] : mapping) {
                auto& groupProto = *proto.MutableGroups()->Add();
                groupProto.SetId(groupId);
                groupProto.SetMtime(group.Mtime.Seconds());
                groupProto.SetLastErrorTime(group.LastErrorTime.Seconds());
                std::copy(group.Uids.begin(), group.Uids.end(),
                    google::protobuf::RepeatedFieldBackInserter(groupProto.MutableUids()));
            }
            return proto;
        }

        void SaveToCache(const TMapping& mapping, const TString& file) const {
            auto proto = MappingToProto(mapping);
            const TString tmpFile = file + ".tmp";
            {
                TFileOutput out(tmpFile);
                Y_ENSURE(proto.SerializeToArcadiaStream(&out));
                out.Flush();
            }
            Y_ENSURE(NFs::Rename(tmpFile, file));

            {
                TFileOutput out(file + ".dump.txt");
                out << proto.DebugString();
                out.Flush();
            }
        }

        void FetchLoop() {
            TThread::SetCurrentThreadName("AbcResolver");

            UpdateMapping();

            while (true) {
                with_lock (StopLock) {
                    const auto deadline = TDuration::Seconds(UpdatePeriod.Seconds() * (0.7 + 0.6 * RandomNumber<float>())).ToDeadLine();
                    INFO_LOG << "Sleep until " << deadline.ToStringLocalUpToSeconds() << Endl;
                    while (!ShouldStop) {
                        if (!StopCondVar.WaitD(StopLock, deadline)) {
                            break;
                        }
                    }
                    if (ShouldStop) {
                        return;
                    }
                }

                try {
                    INFO_LOG << "Start UpdateMapping" << Endl;
                    UpdateMapping();
                    INFO_LOG << "Finish UpdateMapping" << Endl;
                } catch (...) {
                    ERROR_LOG << "UpdateMapping failed: " << CurrentExceptionMessage() << Endl;
                }
            }
        }

        NHttpFetcher::TResultRef Fetch(const TString& url, const TString& tvmTicket) {
            NHttp::TFetchOptions options;

            options.ConnectTimeout = FetchConfig.ConnectTimeout;
            options.Timeout = FetchConfig.Timeout;
            options.RetryCount = FetchConfig.RetryCount;
            options.RetryDelay = FetchConfig.RetryDelay;

            TVector<TString> headers;
            if (tvmTicket) {
                headers.push_back("X-Ya-Service-Ticket: " + tvmTicket);
            }

            NHttp::TFetchQuery query(url, headers, options);

            return NHttp::Fetch(query);
        }

        TVector<TUid> ParsePersonsJson(const TString& data) {
            NJson::TJsonValue json;
            Y_ENSURE(NJson::ReadJsonTree(data, &json), "Malformed JSON");
            TVector<TUid> uids;
            for (const auto& person : json.GetMap().at("results").GetArray()) {
                auto uid = person.GetMapSafe().at("person").GetMapSafe().at("uid").GetIntegerRobust();
                Y_ENSURE(uid);
                uids.push_back(uid);
            }
            return uids;
        }

        TVector<TAbcId> SelectAndRankCandidates(const TMapping& mapping) {
            struct TUpdateCandidate {
                TAbcId GroupId{};
                TInstant LastTry;
            };

            TVector<TUpdateCandidate> candidates;
            for (const auto& groupId : GroupsToResolve) {
                if (auto* group = mapping.FindPtr(groupId)) {
                    candidates.push_back({groupId, std::max(group->Mtime, group->LastErrorTime)});
                } else {
                    candidates.push_back({groupId, TInstant::Zero()});
                }
            }

            Sort(candidates.begin(), candidates.end(), [](const auto& lhs, const auto& rhs) {
                return lhs.LastTry < rhs.LastTry;
            });

            TVector<TAbcId> sorted;
            for (const auto& cand : candidates) {
                sorted.push_back(cand.GroupId);
            }

            return sorted;
        }

        void FetchAbcGroups(TMapping& mapping, const TVector<TAbcId>& candidates) {
            size_t errorBudget = MaxErrorCountPerUpdate;

            for (const auto& groupId : candidates) {
                const auto fetchTime = TInstant::Now();
                INFO_LOG << "Updating group " << groupId << Endl;

                const TString url = TStringBuilder()
                    << ApiUrl << "/api/v4/services/members/?service=" << groupId
                    << "&fields=person.uid,person.login";

                auto& group = mapping[groupId];
                try {
                    auto resp = Fetch(url, TvmClient->GetServiceTicketFor(ApiTvmId));
                    Y_ENSURE(resp->Code == 200);
                    TVector<TUid> uids = ParsePersonsJson(resp->Data);
                    group.Mtime = fetchTime;
                    group.Uids = std::move(uids);
                    errorBudget = MaxErrorCountPerUpdate;
                } catch (...) {
                    ERROR_LOG << "Failed to handle request '" << url << "': " << CurrentExceptionMessage() << Endl;
                    group.LastErrorTime = fetchTime;
                    if (!errorBudget) {
                        break;
                    }
                    --errorBudget;
                }
            }
        }

        void UpdateMapping() {
            TVector<TAbcId> candidates;
            TMappingPtr mapping = new TMapping();
            {
                TMappingPtr prevMapping;
                {
                    TReadGuard g(MappingLock);
                    prevMapping = Mapping;
                }

                candidates = SelectAndRankCandidates(*prevMapping);

                for (const auto& groupId : GroupsToResolve) {
                    if (const TGroup* group = prevMapping->FindPtr(groupId)) {
                        (*mapping)[groupId] = *group;
                    }
                }
            }

            FetchAbcGroups(*mapping, candidates);

            if (CacheFile) {
                try {
                    SaveToCache(*mapping, CacheFile);
                } catch (...) {
                    ERROR_LOG << "Failed to write cache to " << CacheFile << ": " << CurrentExceptionMessage() << Endl;
                }
            }

            {
                TWriteGuard g(MappingLock);
                Mapping.Swap(mapping);
            }
        }

        void Stop() {
            with_lock (StopLock) {
                ShouldStop = true;
                StopCondVar.Signal();
            }
            FetchThread->Join();
        }

        TVector<TUid> Resolve(const TVector<TAbcId>& groups) const {
            TMappingPtr mapping;
            {
                TReadGuard g(MappingLock);
                mapping = Mapping;
            }
            TVector<TUid> ret;
            for (const TAbcId group : groups) {
                if (auto ptr = mapping->FindPtr(group)) {
                    auto& uids = ptr->Uids;
                    ret.insert(ret.end(), uids.begin(), uids.end());
                } else {
                    WARNING_LOG << "Requested ABC group is not resolved id=" << group << Endl;
                }
            }
            return ret;
        }

    private:
        TString ApiUrl;
        ui32 ApiTvmId{};
        TVector<TAbcId> GroupsToResolve;
        TTvmClientPtr TvmClient;
        TDuration UpdatePeriod;
        size_t MaxErrorCountPerUpdate = 0;
        const TParams::TFetchConfig FetchConfig;
        TString CacheFile;

        THolder<IThreadFactory::IThread> FetchThread;
        TMappingPtr Mapping;
        TRWMutex MappingLock;
        TMutex StopLock;
        TCondVar StopCondVar;
        bool ShouldStop = false;
    };

    TAbcResolver::TAbcResolver(const TParams& params)
        : Impl(new TImpl(params))
    {
    }

    TAbcResolver::~TAbcResolver() {
    }

    void TAbcResolver::Start() {
        Impl->Start();
    }

    void TAbcResolver::Stop() {
        Impl->Stop();
    }

    TVector<TAbcResolver::TUid> TAbcResolver::Resolve(const TVector<TAbcId>& abcGroups) const {
        return Impl->Resolve(abcGroups);
    }
}
