#include "controller_loop.h"

#include <infra/libs/sensors/macros.h>
#include <infra/libs/memory_lock/memory_lock.h>
#include <yp/cpp/yp/client.h>
#include <yp/cpp/yp/token.h>
#include <yp/cpp/yp/yson_interop.h>

namespace NInfra::NController {

namespace {

struct TThreadData {
    TThreadData(
        TControllerLoop* controllerLoop
        , TControllerPtr controller
        , NUpdatableProtoConfig::TAccessor<TControllerConfig> config
    )
        : ControllerLoop(controllerLoop)
        , Controller(controller)
        , Config(std::move(config))
        , ActualConfig(*Config)
    {
        Config.SubscribeForUpdate([this](const TControllerConfig& oldConfig, const TControllerConfig& newConfig, const NUpdatableProtoConfig::TWatchContext& context = {}) {
            if (context.Id == Controller->GetFactoryName() && !google::protobuf::util::MessageDifferencer::Equivalent(oldConfig, newConfig)) {
                ActualConfig = newConfig;
            }
        });
    }

    TControllerLoop* ControllerLoop;
    TControllerPtr Controller;
    NUpdatableProtoConfig::TAccessor<TControllerConfig> Config;
    TControllerConfig ActualConfig;
};

class TTransactionContext: public NYP::NClient::TRequestContext {
public:
    TTransactionContext(TLogFramePtr frame)
        : Frame_(frame)
    {
    }

    void BeforeRequestCallback(const NYP::NClient::TRequestInfo& requestInfo) override {
        const TString yson = NYP::NClient::ObjectToYson(*requestInfo.RequestPtr, NYT::NYson::EYsonFormat::Text);
        Frame_->LogEvent(ELogPriority::TLOG_DEBUG, NLogEvent::TYpClientRequestStart(requestInfo.Id, requestInfo.RequestPtr->GetTypeName(), yson, requestInfo.Address));
    }

    void AfterRequestCallback(const NYP::NClient::TRequestInfo& requestInfo) override {
        Frame_->LogEvent(ELogPriority::TLOG_DEBUG, NLogEvent::TYpClientRequestEnd(requestInfo.Id));
    }

private:
    TLogFramePtr Frame_;
};

} // namespace

TControllerLoop::TControllerLoop(
    NUpdatableProtoConfig::TAccessor<TControllerConfig> config
    , TLogger& logger
    , TVector<TControllerPtr> controllers
)
    : Config_(std::move(config))
    , Logger_(logger)
    , Controllers_(controllers)
    , SequentialFailsCounter_(CONFIG_SNAPSHOT_VALUE(Config_, GetController().GetMaxSequentialFailedSyncLoopsCount()))
    , HistogramSensors_(
        GenerateHistogramSensors(
            HISTOGRAMS_REQUEST_TYPES_FOR_CONTROLLER_LOOP
            , HISTOGRAMS_SENSOR_TYPES_FOR_CONTROLLER_LOOP
            , TVector<std::pair<TStringBuf, TStringBuf>>({
                {ORIGIN, CONTROLLER_LOOP}
            })
        )
    )
{
    for (const auto& controller : Controllers_) {
        MtpQueues_[controller->GetFactoryName()];
    }

    const auto initialConfig = Config_.Get();
    NMemoryLock::LockSelfMemory(initialConfig->GetMemoryLock(), Logger_.SpawnFrame(), CTL_SENSOR_GROUP);
    AtomicSet(Running_, 1);
}

void TControllerLoop::ReopenLogs() {
    Config_.RequestReopenLogs();
}

bool TControllerLoop::LockAcquired(const TMaybe<TString>& shardLeadingInvaderName) const {
    if (!shardLeadingInvaderName.Defined()) {
        if (LockAcquired_.empty()) {
            return false;
        }

        for (auto&& [_, v] : LockAcquired_) {
            if (!AtomicGet(v)) {
                return false;
            }
        }
        return true;
    }

    auto lockPtr = LockAcquired_.FindPtr(shardLeadingInvaderName.GetRef());
    return lockPtr && AtomicGet(*lockPtr);
}

void TControllerLoop::RunSyncLoop() {
    Config_.RequestUpdate(NUpdatableProtoConfig::TWatchContext());
    const TControllerConfig actualConfig = *Config_;

    TVector<TThreadData> controllerThreadsData;
    controllerThreadsData.reserve(Controllers_.size());
    for (TControllerPtr controller : Controllers_) {
        MtpQueues_.at(controller->GetFactoryName()).Start(actualConfig.GetController().GetThreadPoolSize());
        controllerThreadsData.emplace_back(this, controller, Config_);
    }

    if (actualConfig.GetController().GetAuxThreadPoolSize() == 0) {
        AuxMtpQueue_.Reset(new TFakeThreadPool());
    } else {
        AuxMtpQueue_.Reset(new TThreadPool());
    }
    AuxMtpQueue_->Start(actualConfig.GetController().GetAuxThreadPoolSize());

    TVector<TSimpleSharedPtr<TThread>> controllerThreads(controllerThreadsData.size());
    for (size_t i = 0; i < controllerThreadsData.size(); ++i) {
        controllerThreads[i] = MakeSimpleShared<TThread>(TControllerLoop::RunControllerSyncLoop, (void*)&controllerThreadsData[i]);
    }

    for (auto controllerThread : controllerThreads) {
        controllerThread->Start();
    }

    for (auto controllerThread : controllerThreads) {
        controllerThread->Join();
    }

    AuxMtpQueue_->Stop();

    for (TControllerPtr controller : Controllers_) {
        MtpQueues_.at(controller->GetFactoryName()).Stop();
        controller->OnGlobalSyncFinish();
    }
}

NLeadingInvader::TLeaderInfo TControllerLoop::GetLeaderInfo(TMaybe<TString> shardLeadingInvaderName) const {
    if (!shardLeadingInvaderName.Defined()) {
        if (Controllers_.size() > 1 || Controllers_.empty()) {
            return NLeadingInvader::TLeaderInfo{NLeadingInvader::TLeaderInfo::EResolveLeaderStatus::FAILED, "", ""};
        }

        shardLeadingInvaderName = Controllers_[0]->GetFullLeadingInvaderName();
    }

    for (const auto& controller : Controllers_) {
        if (controller->GetFullLeadingInvaderName() == shardLeadingInvaderName.GetRef()) {
            return controller->GetLeaderInfo();
        }
    }

    return NLeadingInvader::TLeaderInfo{NLeadingInvader::TLeaderInfo::EResolveLeaderStatus::FAILED, "", ""};
}

void TControllerLoop::SyncController(
    TControllerPtr controller
    , const TControllerConfig& actualConfig
    , const THashMap<TString, NYP::NClient::TClientPtr>& clients
    , const THashMap<TString, NYP::NClient::TTransactionFactoryPtr>& transactionFactories
) {
    auto frame = Logger_.SpawnFrame();
    for (const auto& [_, client] : clients) {
        client->ReconstructBalancing(
            {}
            , MakeHolder<TSensorContext>(
                frame
                , HistogramSensors_.at(RECONSTRUCT_BALANCING)
            )
        );
    }

    controller->IncrementFactorySensor("sync_cycles", 1);

    const TInstant startSyncCycle = TInstant::Now();
    const TInstant leastNeededFinishTime = startSyncCycle + FromString<TDuration>(actualConfig.GetController().GetSleepUntilSyncTime());
    frame->LogEvent(NLogEvent::TStartSyncCycle(controller->GetFactoryName()));

    controller->Sync(
        clients
        , transactionFactories
        , MtpQueues_.at(controller->GetFactoryName())
        , *AuxMtpQueue_
        , frame
        , actualConfig.GetController().GetRetryCount()
        , actualConfig.GetController().GetYpRequestsBatchSize()
        , actualConfig.GetController().GetMaxSequentialFailedObjMngrsCount()
    );

    SequentialFailsCounter_.FeedbackSuccess(controller->GetFactoryName());
    controller->IncrementFactorySensor("successful_sync_cycles", 1);

    const auto syncTime = TInstant::Now() - startSyncCycle;
    frame->LogEvent(NLogEvent::TSyncCycleSuccess(ToString(syncTime), controller->GetFactoryName(), controller->GetShardId()));

    frame->LogEvent(NLogEvent::TSleepUntilSyncTime(ToString(syncTime), actualConfig.GetController().GetSleepUntilSyncTime(), controller->GetFactoryName(), controller->GetShardId()));
    SleepUntil(leastNeededFinishTime);
}

void TControllerLoop::ControllerSyncLoop(
    TControllerPtr controller
    , TGetActualConfig getActualConfig
) {
    const auto onLockAcquired = [this, &controller]() {
        AtomicSet(LockAcquired_[controller->GetFullLeadingInvaderName()], 1);
        NON_STATIC_INFRA_INT_GAUGE_SENSOR(controller->GetSensorGroupRef(), "lock_acquired", 1);
    };

    const auto onLockLost = [this, &controller]() {
        AtomicSet(LockAcquired_[controller->GetFullLeadingInvaderName()], 0);
        NON_STATIC_INFRA_INT_GAUGE_SENSOR(controller->GetSensorGroupRef(), "lock_acquired", 0);
    };

    const auto onLivenessLockAcquired = [this, &controller]() {
        AtomicSet(LivenessLockAcquired_[controller->GetFullLeadingInvaderName()], 1);
        NON_STATIC_INFRA_INT_GAUGE_SENSOR(controller->GetSensorGroupRef(), "liveness_lock_acquired", 1);
    };

    const auto onLivenessLockLost = [this, &controller]() {
        AtomicSet(LivenessLockAcquired_[controller->GetFullLeadingInvaderName()], 0);
        NON_STATIC_INFRA_INT_GAUGE_SENSOR(controller->GetSensorGroupRef(), "liveness_lock_acquired", 0);
    };

    auto outerFrame = Logger_.SpawnFrame();
    while (AtomicGet(Running_) == 1 &&
        controller->IsManagedByMaster() &&
        !controller->RegisterLiveness(outerFrame, onLivenessLockAcquired, onLivenessLockLost)
    ) {
        // Controller could not register it's liveness. YT requests throw exceptions
        // Controller managed by master cannot do anything without registering it's liveness
       continue;
    }

    while (AtomicGet(Running_) == 1) {
        while (AtomicGet(Running_) == 1 && !SequentialFailsCounter_.IsThresholdExceeded(controller->GetFactoryName(), controller->GetNumberOfShards())) {
            auto [actualConfig, _] = getActualConfig();
            auto frame = Logger_.SpawnFrame();

            { // Controller makes sure, that Master gave hime permission to run SyncCycle. If no Master is alive, default LeadingInvader logics is applied
                if (controller->ShouldAbortTaskDueToMasterDistribution(frame)) {
                    controller->DestroyLeadingInvader();
                } else {
                    controller->ResetLeadingInvader(onLockAcquired, onLockLost, true /*ignoreIfSet*/);
                }

                {
                    if (auto result = controller->EnsureLeading(); !(bool)result) {
                        frame->LogEvent(
                            ELogPriority::TLOG_WARNING
                            , NLogEvent::TLockAcquireError(controller->GetFullLeadingInvaderName(), result.Error().Reason)
                        );

                        if (!controller->IsResponsibleForLock()) {
                            Sleep(FromString<TDuration>(actualConfig.GetController().GetFollowerLoopInterval()));
                            continue;
                        }
                    } else {
                        frame->LogEvent(NLogEvent::TLockAcquireSuccess(controller->GetFullLeadingInvaderName()));
                    }
                }
            }

            try {
                TVector<NInfra::NController::TClientConfig> configs;
                if (auto customConfigs = controller->GetYpClientConfigs(); customConfigs) {
                    configs = std::move(*customConfigs);
                } else if (actualConfig.HasYpClient()) {
                    configs.push_back(actualConfig.GetYpClient());
                } else {
                    configs.reserve(actualConfig.GetYpClients().ConfigsSize());
                    for (const auto& config : actualConfig.GetYpClients().GetConfigs()) {
                        configs.push_back(config);
                    }
                }
                Y_ENSURE(!configs.empty(), "Clients config is empty");

                THashMap<TString, NYP::NClient::TClientPtr> auxClients;
                THashSet<TString> clusterNames;
                for (const auto& ypClientConfig : configs) {
                    if (!clusterNames.insert(ypClientConfig.GetClusterName()).second) {
                        ythrow yexception() << "Cluster names are not unique in config, duplicated name: " << ypClientConfig.GetClusterName();
                    }
                    NYP::NClient::TClientOptions ypOpts;
                    auto defaultSnapshotTimestamp = ypClientConfig.GetSnapshotTimestamp();
                    ypOpts
                        .SetAddress(ypClientConfig.GetAddress())
                        .SetTimeout(FromString<TDuration>(ypClientConfig.GetTimeout()))
                        .SetToken(NYP::NClient::FindToken())
                        .SetEnableSsl(ypClientConfig.GetEnableSsl())
                        .SetMaxReceiveMessageSize(ypClientConfig.GetMaxReceiveMessageSize())
                        .SetReadOnlyMode(ypClientConfig.GetReadOnlyMode())
                        .SetSnapshotTimestamp(defaultSnapshotTimestamp)
                        .SetThreadsNum(ypClientConfig.GetThreadPoolSize())
                        .SetDefaultRequestContextGenerator([frame]() {
                            return MakeHolder<TTransactionContext>(frame);
                        });

                    auxClients[ypClientConfig.GetClusterName()] = NYP::NClient::CreateClient(ypOpts);
                }

                if (controller->ShouldAbortTaskDueToMasterDistribution(frame)) {
                    controller->DestroyLeadingInvader();
                } else {
                    controller->ResetLeadingInvader(onLockAcquired, onLockLost, true /*ignoreIfSet*/);
                }

                auto isLeader = (bool)controller->EnsureLeading();
                while (AtomicGet(Running_) == 1 && (controller->IsResponsibleForLock() || isLeader) &&
                        !SequentialFailsCounter_.IsThresholdExceeded(controller->GetFactoryName(), controller->GetNumberOfShards())) {
                    auto [_, hasUpdate] = getActualConfig();
                    if (hasUpdate) {
                        frame->LogEvent(ELogPriority::TLOG_DEBUG, NLogEvent::TInnerSyncCycleBreakAfterConfigUpdate(controller->GetFactoryName()));
                        break;
                    }

                    THashMap<TString, NYP::NClient::TClientPtr> clients;
                    THashMap<TString, NYP::NClient::TTransactionFactoryPtr> transactionFactories;

                    for (const auto& ypClientConfig : configs) {
                        const TString& clusterName = ypClientConfig.GetClusterName();
                        NYP::NClient::TClientOptions ypOpts = auxClients[clusterName]->Options();
                        auto defaultSnapshotTimestamp = ypClientConfig.GetSnapshotTimestamp();
                        if (!defaultSnapshotTimestamp) {
                            ypOpts.SetSnapshotTimestamp(auxClients[clusterName]->GenerateTimestamp(
                                MakeHolder<TSensorContext>(
                                    frame
                                    , HistogramSensors_.at(GENERATE_TIMESTAMP)
                                )
                            ).GetValue(TDuration::Seconds(60)));
                        }

                        auto client = NYP::NClient::CreateClient(ypOpts);
                        clients[clusterName] = client;
                        NYP::NClient::TTransactionOptions factoryOptions;
                        if (actualConfig.GetStartTransactionFromSnapshotTimestamp()) {
                            factoryOptions
                                .SetStartTimestamp(client->Options().SnapshotTimestamp())
                                .SetSnapshotIsolation(true);
                        }
                        transactionFactories[clusterName] = NYP::NClient::CreateTransactionFactory(*client, factoryOptions);
                    }

                    controller->SetLeadership(isLeader);

                    SyncController(
                        controller
                        , actualConfig
                        , clients
                        , transactionFactories
                    );

                    frame->LogEvent(ELogPriority::TLOG_DEBUG
                        , NLogEvent::TCacheSize(controller->GetCacheSize())
                    );

                    frame->LogEvent(ELogPriority::TLOG_DEBUG
                        , NLogEvent::TCachedMatchersSize(controller->GetCachedMatchersSize())
                    );

                    frame->LogEvent(
                        ELogPriority::TLOG_DEBUG
                        , NLogEvent::TSleep(
                            "Regular"
                            , ToString(controller->GetSyncInterval())
                            , controller->GetFactoryName()
                        )
                    );
                    Sleep(controller->GetSyncInterval());

                    if (controller->ShouldAbortTaskDueToMasterDistribution(frame)) {
                        controller->DestroyLeadingInvader();
                    } else {
                        controller->ResetLeadingInvader(onLockAcquired, onLockLost, true /*ignoreIfSet*/);
                    }
                    isLeader = (bool)controller->EnsureLeading();
                }

                if (!controller->IsResponsibleForLock() && !isLeader) {
                    SensorRegistry().Reset();
                }
            } catch (...) {
                SequentialFailsCounter_.FeedbackFailure(controller->GetFactoryName());
                controller->IncrementFactorySensor("failed_sync_cycles", 1);
                frame->LogEvent(
                    ELogPriority::TLOG_ERR
                    , NLogEvent::TSyncLoopError(CurrentExceptionMessage(), controller->GetFactoryName())
                );
            }

            if (AtomicGet(Running_) == 1) {
                Sleep(FromString<TDuration>(actualConfig.GetController().GetLeaderLoopInterval()));
            }
        }

        controller->DestroyLeadingInvader();

        if (AtomicGet(Running_) == 1) {
            // Too many errors or too long sync cycles, so release the lock and sleep
            Logger_.SpawnFrame()->LogEvent(
                ELogPriority::TLOG_ERR
                , NLogEvent::TSequentialFailsCountThresholdExceeded()
            );

            auto [actualConfig, _] = getActualConfig();
            Sleep(FromString<TDuration>(actualConfig.GetController().GetGlobalLoopInterval()));

            SequentialFailsCounter_.DoubleThresholdAndReset(controller->GetFactoryName(), controller->GetNumberOfShards());
        }
    }
}

void* TControllerLoop::RunControllerSyncLoop(
    void* data
) {
    TThreadData* threadData = (TThreadData*)data;

    TGetActualConfig getActualConfig = [&]() -> std::pair<const TControllerConfig&, bool> {
        bool hasUpdate = threadData->Config.RequestUpdate({threadData->Controller->GetFactoryName()});
        return {threadData->ActualConfig, hasUpdate};
    };

    threadData->ControllerLoop->ControllerSyncLoop(
        threadData->Controller,
        std::move(getActualConfig)
    );
    return nullptr;
}

void TControllerLoop::Shutdown() {
    AtomicSet(Running_, 0);
}

TControllerLoop::TSequentialFailsCounter::TSequentialFailsCounter(const ui32 maxFailsCount)
    : MaxFailsCount_(maxFailsCount)
{}

void TControllerLoop::TSequentialFailsCounter::FeedbackSuccess(const TString& factoryName) {
    TWriteGuardBase<TLightRWLock> guard(Lock_);
    SequentialFailsCount_[factoryName] = 0;
}

void TControllerLoop::TSequentialFailsCounter::FeedbackFailure(const TString& factoryName) {
    TWriteGuardBase<TLightRWLock> guard(Lock_);
    ++SequentialFailsCount_[factoryName];
}

bool TControllerLoop::TSequentialFailsCounter::IsThresholdExceeded(const TString& factoryName, const size_t numberOfShards) const {
    TReadGuardBase<TLightRWLock> guard(Lock_);
    if (numberOfShards > 1) {
        return SequentialFailsCount_.contains(factoryName) && SequentialFailsCount_.at(factoryName) > MaxFailsCount_;
    }

    for (auto&& [k, v] : SequentialFailsCount_) {
        if (v > MaxFailsCount_) {
            return true;
        }
    }

    return false;
}

void TControllerLoop::TSequentialFailsCounter::DoubleThresholdAndReset(const TString& factoryName, const size_t numberOfShards) {
    TWriteGuardBase<TLightRWLock> guard(Lock_);
    if (MaxFailsCount_ < Max<decltype(MaxFailsCount_)>() / 2) {
        MaxFailsCount_ *= 2;
    }

    if (numberOfShards > 1) {
        SequentialFailsCount_[factoryName] = 0;
    } else {
        SequentialFailsCount_.clear();
    }
}

} // namespace NInfra::NController
