#include "sd.h"

#include <balancer/kernel/process/sd/sd.h>

#include <util/generic/adaptor.h>
#include <util/stream/file.h>
#include <util/system/backtrace.h>

#include <atomic>

namespace NBalancerSD {
    using namespace NSrvKernel;

    static TIntrusivePtr<TNotReadyBackendSet> MakeNotReadySet(const TVector<TEndpointSetBackends>& allBackends) {
        auto notReadySet = MakeIntrusive<TNotReadyBackendSet>();

        for (const auto& eps : allBackends) {
            for (const auto& backend : eps.Backends) {
                if (!backend.Ready) {
                    notReadySet->emplace(MakeUniqueBackendName(eps.EndpointSetKey, backend));
                }
            }
        }

        return notReadySet;
    }

    NYP::NServiceDiscovery::TEndpointSetKey TBackendsProvider::MakeEndpointSetKey(const TEndpointSet& eps) {
        if (eps.ServiceId) {
            return NYP::NServiceDiscovery::TEndpointSetKey(eps.ClusterId, eps.ServiceId);
        } else {
            return NYP::NServiceDiscovery::TEndpointSetKey(eps.BackendsFile);
        }
    };

    class TBackendsProvider::TYpEndpointsProvider
        : public NYP::NServiceDiscovery::IEndpointSetProvider
    {
    public:
        explicit TYpEndpointsProvider(TBackendsProvider& backendsProvider)
            : BackendsProvider_(backendsProvider)
        {
        }

        ~TYpEndpointsProvider() {
            Exchange(nullptr);
        }

        void Update(const NYP::NServiceDiscovery::TEndpointSetEx& endpointSet) override {
            Exchange(MakeHolder<NYP::NServiceDiscovery::TEndpointSetEx>(endpointSet));
            BackendsProvider_.OnUpdate(false);
        }

        void UpdateReadiness(const NYP::NServiceDiscovery::TEndpointSetEx& endpointSet) override {
            Exchange(MakeHolder<NYP::NServiceDiscovery::TEndpointSetEx>(endpointSet));
            BackendsProvider_.OnUpdate(true);
        }

        void Subscribe(NYP::NServiceDiscovery::TEndpointSetManager* sdManager, const NYP::NServiceDiscovery::TEndpointSetKey& key, bool useUnistat) {
            NYP::NServiceDiscovery::TEndpointSetOptions opts;
            opts.SetUseUnistat(useUnistat);
            SubscribeRef_ = sdManager->Subscribe(key, *this, opts);
            ReloadCurrentBackends(SubscribeRef_->GetEndpointSet());
        }

        void Unsubscribe() {
            SubscribeRef_.Reset();
        }

        bool ReloadCurrentBackends() {
            THolder<NYP::NServiceDiscovery::TEndpointSetEx> nextEndpoints = Exchange(nullptr);

            if (!nextEndpoints) {
                return false;
            }

            return ReloadCurrentBackends(*nextEndpoints);
        }

        bool ReloadCurrentBackends(const NYP::NServiceDiscovery::TEndpointSetEx& nextEndpoints) {
            TVector<TBackend> backends;

            for (ssize_t i = 0; i < nextEndpoints.endpoints_size(); ++i) {
                const auto& endpoint = nextEndpoints.endpoints(i);
                TBackend backend;

                backend.Host = endpoint.fqdn();
                backend.Port = endpoint.port();

                if (endpoint.ip6_address()) {
                    backend.Ip = endpoint.ip6_address();
                } else {
                    backend.Ip = endpoint.ip4_address();
                }

                backend.Ready = endpoint.ready();

                backends.push_back(std::move(backend));
            }

            CurrentBackends_ = std::make_pair(backends, nextEndpoints.Info);

            return true;
        }

        const auto& GetCurrentBackends() const {
            return CurrentBackends_;
        }

        bool AllowUpdate() override {
            return AtomicGet(AllowUpdate_);
        }

        void EnableUpdate() {
            AtomicSet(AllowUpdate_, 1);
        }
    private:
        THolder<NYP::NServiceDiscovery::TEndpointSetEx> Exchange(THolder<NYP::NServiceDiscovery::TEndpointSetEx> nextEndpoints) {
            return THolder<NYP::NServiceDiscovery::TEndpointSetEx>(std::atomic_exchange(&Next_, nextEndpoints.Release()));
        }
    private:
        std::atomic<NYP::NServiceDiscovery::TEndpointSetEx*> Next_ = nullptr;
        std::pair<TVector<TBackend>, NYP::NServiceDiscovery::TEndpointSetInfo> CurrentBackends_;

        TBackendsProvider& BackendsProvider_;

        TAtomic AllowUpdate_ = 0;
        NYP::NServiceDiscovery::IEndpointSetSubscriberRef SubscribeRef_;
    };

    TBackendsProvider::TBackendsProvider(const NSrvKernel::TModuleParams& mp,TBackendsInitializationParams::TConfigMaker configMaker)
        : BackendsInitialization_(new TBackendsInitializationParams)
    {
        BackendsInitialization_->ModuleParams_ = mp;
        BackendsInitialization_->ConfigMaker_ = configMaker;
        BackendsInitialization_->ModuleParams_.Config->ForEach(this);
        BackendsInitialization_->ModuleParams_.Config = nullptr;
    }

    void TBackendsProvider::FinishInitialization(
        NSrvKernel::TPolicyFeatures policyFeatures,
        const NSrvKernel::TBackendCheckParameters& checkParams
    ) {
        BackendsInitialization_->BackendsConfiguration_.PolicyFeatures = policyFeatures;
        BackendsInitialization_->BackendsConfiguration_.CheckParameters = checkParams;

        if (BackendsFromConfig_.size() > 0 && EndpointSets_.size() > 0) {
            ythrow NConfig::TConfigParseError() << "try to setup both discovered and config backends";
        }

        if (BackendsFromConfig_.size() > 0) {
            BackendsHolder_ = CreateBackendsHolder(BackendsFromConfig_);
        } else {
            NYP::NServiceDiscovery::TEndpointSetManager* sdManager = BackendsInitialization_->ModuleParams_.Control->GetSDManager();
            Y_ENSURE_EX(sdManager, NConfig::TConfigParseError() << "try to use yp backends without service discovery");

            TVector<TEndpointSetBackends> allBackends;
            TVector<NYP::NServiceDiscovery::TEndpointSetInfo> allInfo;

            for (auto& eps : EndpointSets_) {
                const NYP::NServiceDiscovery::TEndpointSetKey key = MakeEndpointSetKey(eps);

                eps.YpEndpointsProvider = MakeHolder<TYpEndpointsProvider>(*this);
                eps.YpEndpointsProvider->Subscribe(sdManager, key, UseUnistat_);

                const auto& [backends, info] = eps.YpEndpointsProvider->GetCurrentBackends();

                // Copy of the backends is intended here.
                allBackends.emplace_back(TEndpointSetBackends{key.ToString(), backends});
                allInfo.push_back(info);
            }

            BackendsHolder_ = CreateBackendsHolder(allBackends);

            for (size_t i = 0; i < EndpointSets_.size(); ++i) {
                EndpointSets_[i].YpEndpointsProvider->SetActiveEndpointSetInfo(allInfo[i]);
            }
        }
    }

    TBackendsProvider::~TBackendsProvider() {
        for (const auto& eps : EndpointSets_) {
            if (eps.YpEndpointsProvider) {
                eps.YpEndpointsProvider->Unsubscribe();
            }
        }
    }

    TGlobalBackendsHolderPtr TBackendsProvider::CreateBackendsHolder(const TVector<TEndpointSetBackends>& instances) {
        auto holder = MakeShared<TGlobalBackendsHolder, TAtomicCounter>(BackendsInitialization_, instances);
        holder->Backends().ChangeNotReadySet(MakeNotReadySet(instances));
        return holder;
    }

    void TBackendsProvider::UpdateBackendsHolder(const TVector<TEndpointSetBackends>& backends) {
        TGlobalBackendsHolderPtr next = BackendsHolder_->Update(backends);
        if (next) {
            next->Backends().ChangeNotReadySet(MakeNotReadySet(backends));
        } else {
            next = CreateBackendsHolder(backends);
        }
        with_lock (Lock_) {
            BackendsHolder_ = next;
        }
    }

    void TBackendsProvider::EnableUpdate() {
        for (const auto& eps : EndpointSets_) {
            Y_ASSERT(eps.YpEndpointsProvider);
            eps.YpEndpointsProvider->EnableUpdate();
        }
    }

    void TBackendsProvider::UpdateBackends(bool onlyReadinessChanged) {
        bool hasChanges = false;
        for (auto& eps : EndpointSets_) {
            Y_ASSERT(eps.YpEndpointsProvider);
            if (eps.YpEndpointsProvider->ReloadCurrentBackends()) {
                hasChanges = true;
            }
        }

        if (!hasChanges) {
            return;
        }

        TVector<TEndpointSetBackends> newBackends;
        TVector<NYP::NServiceDiscovery::TEndpointSetInfo> newInfo;
        for (const auto& eps : EndpointSets_) {
            const auto& [backends, info] = eps.YpEndpointsProvider->GetCurrentBackends();
            newBackends.emplace_back(TEndpointSetBackends{MakeEndpointSetKey(eps).ToString(), backends});
            newInfo.push_back(info);
        }

        if (!onlyReadinessChanged) {
            UpdateBackendsHolder(newBackends);
        } else {
            if (BackendsHolder_) {
                // No need for locking here, because this routine is the only place, where BackendsHolder_ is write-accessed.
                BackendsHolder_->Backends().ChangeNotReadySet(MakeNotReadySet(newBackends));
            } else {
                Y_ASSERT(false);
            }
        }

        for (size_t i = 0; i < EndpointSets_.size(); ++i) {
            EndpointSets_[i].YpEndpointsProvider->SetActiveEndpointSetInfo(newInfo[i]);
        }
    }

    TGlobalBackendsHolderPtr TBackendsProvider::GetBackendsHolder() {
        with_lock (Lock_) {
            return BackendsHolder_;
        }
    }

    THolder<NConfig::IConfig> TGlobalBackendsHolder::CreateConfig() {
        return InitializationParams_->ConfigMaker_(
                InitializationParams_->BackendsConfiguration_.BackendsOptions,
                InitializationParams_->BackendsConfiguration_.ProxyOptions,
                InitializationParams_->BackendsConfiguration_.ProxyWrapper,
                InitializationParams_->BackendsConfiguration_.Port,
                InitializationParams_->BackendsConfiguration_.PortOffset,
                BackendDescriptions_
        );
    }

    TGlobalBackendsHolder::TGlobalBackendsHolder(
        TBackendsInitializationParams::TPtr initializationParams,
        const TVector<TEndpointSetBackends>& backends
    )
        : InitializationParams_(initializationParams)
        , BackendDescriptions_(backends) {
        THolder<NConfig::IConfig> unwrapped = CreateConfig();
        NSrvKernel::TModuleParams mp = InitializationParams_->ModuleParams_.Copy(unwrapped.Get());
        mp.Modules = nullptr;

        Backends_ = NSrvKernel::CommonBackends()->Load(
            InitializationParams_->BackendsConfiguration_.BackendsType,
            mp,
            InitializationParams_->BackendsConfiguration_.UID
        );

        Y_ENSURE(Backends_);

        Backends_->ProcessPolicyFeatures(InitializationParams_->BackendsConfiguration_.PolicyFeatures);
        Backends_->SetCheckParameters(InitializationParams_->BackendsConfiguration_.CheckParameters);
    }

    TGlobalBackendsHolder::TGlobalBackendsHolder(
        TBackendsInitializationParams::TPtr initializationParams,
        const TVector<TEndpointSetBackends>& instances,
        const NSrvKernel::IBackends& backends
    )
        : InitializationParams_(initializationParams)
        , BackendDescriptions_(instances) {
        THolder<NConfig::IConfig> unwrapped = CreateConfig();
        NSrvKernel::TModuleParams mp = InitializationParams_->ModuleParams_.Copy(unwrapped.Get());
        mp.Modules = nullptr;

        Backends_ = backends.Update(mp);

        Y_ENSURE(Backends_);
    }

    TGlobalBackendsHolder::TPtr TGlobalBackendsHolder::Update(const TVector<TEndpointSetBackends>& backends) {
        if (Backends_->CanUpdate()) {
            TPtr updated(new TGlobalBackendsHolder(InitializationParams_, backends, Backends()));
            return updated;
        }
        return nullptr;
    }

    void TBackendsProvider::ParseEndpointSets(NConfig::IConfig& config) {
        class TParser: public NConfig::IConfig::IFunc {
        public:
            TVector<TEndpointSet> Result;
        private:
            void DoConsume(const TString& key, NConfig::IConfig::IValue* value) final {
                if (value->IsContainer()) {
                    Result.push_back({});
                    value->AsSubConfig()->ForEach(this);
                } else if (key == "endpoint_set_id") {
                    Result.back().ServiceId = value->AsString();
                } else if (key == "service_id") {
                    Result.back().ServiceId = value->AsString();
                } else if (key == "cluster_name") {
                    Result.back().ClusterId = value->AsString();
                } else if (key == "cluster_id") {
                    Result.back().ClusterId = value->AsString();
                } else if (key == "backends_file") {
                    Result.back().BackendsFile = value->AsString();
                } else {
                    ythrow NConfig::TConfigParseError() << "unknown key " << key;
                }
            }
        } p;

        config.ForEach(&p);

        EndpointSets_.swap(p.Result);
    }

    void TBackendsProvider::OnUpdate(bool onlyReadinessChanged) {
        if (GetBackendsHolder()) {
            UpdateBackends(onlyReadinessChanged);
        }

        with_lock (Lock_) {
            WorkerRefs_.ForEach([](TWorkerBackendsRef* ref) {
                ref->OnUpdate();
            });
        }
    }

    void TBackendsProvider::DoConsume(const TString& key, NConfig::IConfig::IValue* value) {
        if (key == "proxy_options") {
            BackendsInitialization_->BackendsConfiguration_.ProxyOptions = Save(*value->AsSubConfig());
            return;
        }

        if (key == "proxy_wrapper") {
            BackendsInitialization_->BackendsConfiguration_.ProxyWrapper = Save(*value->AsSubConfig());
            return;
        }

        if (key == "port") {
            BackendsInitialization_->BackendsConfiguration_.Port = FromString<ui16>(value->AsString());
            return;
        }

        if (key == "port_offset") {
            BackendsInitialization_->BackendsConfiguration_.PortOffset = FromString<i32>(value->AsString());
            return;
        }

        if (key == "backends") {
            BackendsFromConfig_ = ParseBackends(*value->AsSubConfig());
            return;
        }

        if (key == "endpoint_sets") {
            ParseEndpointSets(*value->AsSubConfig());
            return;
        }

        if (key == "termination_delay") {
            TerminationDelay_ = FromString<TDuration>(value->AsString());
            return;
        }

        if (key == "termination_deadline") {
            TerminationDeadline_ = FromString<TDuration>(value->AsString());
            return;
        }

        ON_KEY("use_unistat", UseUnistat_) {
            return;
        } 

        if (CommonBackends()->SelectHandle(key)) {
            if (BackendsInitialization_->BackendsConfiguration_.BackendsType) {
                ythrow NConfig::TConfigParseError() << "try to setup second balancing algorithm";
            }

            BackendsInitialization_->BackendsConfiguration_.BackendsType = key;
            BackendsInitialization_->BackendsConfiguration_.BackendsOptions = Save(*value->AsSubConfig());
            return;
        }

        ythrow NConfig::TConfigParseError() << "unsupported option: " << key;
    }

    TWorkerBackendsRef::TWorkerBackendsRef(NSrvKernel::IWorkerCtl* process, TBackendsProvider& backendsProvider)
        : Provider_(&backendsProvider)
        , Process_(process)
        , WorkerBackendsHolder_(MakeWorkerBackendsHolder(backendsProvider.GetBackendsHolder(), nullptr))
    {
        Updater_ = NSrvKernel::TCoroutine{
            ECoroType::Common,
            "sd_updater",
            &process->Executor(),
            [this](IWorkerCtl* proc) {
                auto* const cont = proc->Executor().Running();

                while (!cont->Cancelled()) {
                    auto error = [&]() -> TError {
                        try {
                            int res = UpdateEvent_.WaitD(TInstant::Max(), cont);
                            if (res == 0) {
                                Sync();
                            } else if (res != ECANCELED) {
                                ythrow TSystemError{res} << "failed to wait on UpdateEvent_";
                            }
                        } Y_TRY_STORE(TSystemError, yexception)
                        return {};
                    }();

                    if (error) {
                        LastUpdateFailed_ = true;
                        Y_VERIFY_DEBUG(false, "%s", GetErrorMessage(error).c_str());
                        LastUpdateError_ = std::move(error);
                        cont->SleepT(TDuration::Seconds(5));
                    } else {
                        LastUpdateFailed_ = false;
                    }
                }
            }, process
        };

        Provider_->AddWorker(*this);
    }

    void TWorkerBackendsRef::Sync() {
        TGlobalBackendsHolderPtr next = Provider_->GetBackendsHolder();
        if (next.Get() == WorkerBackendsHolder_->GetBackendsHolder().Get()) {
            return;
        }

        auto prev = WorkerBackendsHolder_;
        WorkerBackendsHolder_ = MakeWorkerBackendsHolder(next, prev);
        prev->Invalidate(WorkerBackendsHolder_);
        prev->ScheduleTermination(Process_->Executor(), Provider_->GetTerminationDelay(), Provider_->GetTerminationDeadline());
    }

    void TWorkerBackendsRef::OnUpdate() {
        UpdateEvent_.Signal();
    }

    void TWorkerBackendsRef::Dispose() {
        if (Provider_) {
            Provider_->RemoveWorker(*this);
            Provider_ = nullptr;
        }
    }

    TWorkerBackendsRef::~TWorkerBackendsRef() {
        Dispose();
    }

    TWorkerBackendsHolder::TWorkerBackendsHolder(NSrvKernel::IWorkerCtl* process, TGlobalBackendsHolderPtr next, TWorkerBackendsHolder::TPtr prev)
        : GlobalBackendsHolder_(std::move(next))
        , Process_(process)
    {
        for (const auto& backend : GlobalBackendsHolder_->Backends().BackendDescriptors()) {
            if (prev) {
                auto it = prev->Disposers_.find(backend.Get());
                if (it != prev->Disposers_.end()) {
                    Disposers_.emplace(it->first, it->second);
                    continue;
                }
            }
            backend->Init(Process_);
            Disposers_.emplace(backend.Get(), MakeSimpleShared<TBackendDescriptorDisposer>(backend, Process_));
        }
        GlobalBackendsHolder_->Backends().Init(Process_);
    }

    void TWorkerBackendsHolder::DisposeBackends() {
        if (Disposed_) {
            return;
        }

        GlobalBackendsHolder_->Backends().Dispose(Process_);

        Disposers_.clear();

        Disposed_ = true;
    }

    void TWorkerBackendsHolder::AddRef(TBackendsRef& ref) {
        Y_VERIFY(ref.IsStored(this));
        Y_VERIFY(ref.Empty());

        Refs_.PushBack(&ref);
    }

    void TWorkerBackendsHolder::Invalidate(TWorkerBackendsHolder::TPtr next) {
        for (auto it = Refs_.begin(); it != Refs_.end();) {
            TBackendsRef& ref = *it;
            ++it;
            ref.Invalidate(next);
        }
    }

    void TWorkerBackendsHolder::ScheduleTermination(TContExecutor& executor, TDuration delay, TDuration deadline) {
        if (Refs_.Empty()) {
            return;
        }

        Terminator_ = NSrvKernel::TCoroutine{
            ECoroType::Service,
            "sd_terminator",
            &executor,
            [this, delay, deadline](TContExecutor* const exec) {
                auto* const cont = exec->Running();
                if (cont->SleepT(delay) == ECANCELED) {
                    return;
                }

                size_t deadlineSeconds = Max(deadline.Seconds(), 1LU);
                size_t group = (Refs_.Size() + deadlineSeconds - 1) / deadlineSeconds;

                for (size_t i = 0; i < deadlineSeconds; ++i) {
                    if (Refs_.Empty()) {
                        break;
                    }

                    size_t j = 0;
                    for (auto& ref : Refs_) {
                        ref.CancelHoldingCoroutine();
                        if (++j >= group) {
                            break;
                        }
                    }

                    if (cont->SleepT(TDuration::Seconds(1)) == ECANCELED) {
                        return;
                    }
                }

                for (auto& ref : Refs_) {
                    ref.CancelHoldingCoroutine();
                }
            }, &executor
        };
    }

    TWorkerBackendsHolder::~TWorkerBackendsHolder() {
        Y_VERIFY(Refs_.Empty());
        DisposeBackends();
    };

    bool TBackendsRef::IsStored(const TWorkerBackendsHolder* ptr) const noexcept {
        return Get() == ptr;
    }

    NSrvKernel::IBackends* TBackendsRef::Backends() noexcept {
        CanInvalidate_ = false;
        return *this ? &Get()->Backends() : nullptr;
    }

    size_t TBackendsRef::BackendsCount() const noexcept {
        return Get()->Backends().Size();
    }

    bool TBackendsRef::IsHashing() const noexcept {
        return Get()->Backends().IsHashing();
    }

    class TBackendsRef::TInvalidateHandler : public IAlgorithm {
    public:
        TInvalidateHandler(THolder<IAlgorithm> algorithm, TBackendsRef* ref)
            : IAlgorithm(algorithm->Process_)
            , Algorithm_(std::move(algorithm))
            , BackendsRef_(ref)
        {
            BackendsRef_->SetInvalidateHandler(this);
        }

        ~TInvalidateHandler() {
            BackendsRef_->SetInvalidateHandler(nullptr);
        }

        bool OnInvalidate(TWorkerBackendsHolder::TPtr next) {
            THolder<IAlgorithm> algo = next->Backends().InvalidateAlgorithm(*Algorithm_, AccessedBackends_);
            if (algo) {
                Algorithm_ = std::move(algo);
                return true;
            }
            return false;
        }

        void RemoveSelected(IBackend* backend) override {
            Algorithm_->RemoveSelected(backend);
        }
        void Reset() override {
            Algorithm_->Reset();
        }
        IBackend* Next() override {
            IBackend* backend = Algorithm_->Next();
            Track(backend);
            return backend;
        }
        void Select(IBackend* backend) override {
            Algorithm_->Select(backend);
        }
        IBackend* SelectNext() override {
            IBackend* backend = Algorithm_->SelectNext();
            Track(backend);
            return backend;
        }
        IBackend* NextByName(TStringBuf name, bool allowZeroWeights) override {
            IBackend* backend = Algorithm_->NextByName(name, allowZeroWeights);
            Track(backend);
            return backend;
        }
        IBackend* NextByHash(IHashProcessor& processor) override {
            IBackend* backend = Algorithm_->NextByHash(processor);
            Track(backend);
            return backend;
        }
    private:
        void Track(IBackend* backend) {
            AccessedBackends_.insert(backend);
        }
        THolder<IAlgorithm> Algorithm_;
        TBackendsRef* BackendsRef_;
        THashSet<IBackend*> AccessedBackends_;
    };

    THolder<NSrvKernel::IAlgorithm> TBackendsRef::ConstructAlgorithm(const NSrvKernel::TStepParams& params) noexcept {
        Y_VERIFY(!InvalidateHandler_);
        return MakeHolder<TInvalidateHandler>(Get()->Backends().ConstructAlgorithm(params), this);
    }

    void TBackendsRef::Invalidate(TWorkerBackendsHolder::TPtr next) {
        if (!CanInvalidate_) {
            return;
        }
        if (!InvalidateHandler_ || InvalidateHandler_->OnInvalidate(next)) {
            Unlink();
            Reset(next);
            next->AddRef(*this);
        }
    }
}
