#include "distribution_manager.h"

#include <infra/libs/yp_dns/dynamic_zones/helpers/yt/yt.h>
#include <infra/libs/sensors/macros.h>

#include <mapreduce/yt/util/ypath_join.h>

#include <util/generic/adaptor.h>

namespace NInfra::NController::NShardMaster {

namespace {

TString BuildLivenessDirectory(const TString& directory, size_t shardId) {
    const TString shardPart = TStringBuilder() << "shard_id_" << shardId;
    return NYT::JoinYPaths(directory, shardPart, "alive_hosts");
}

TString BuildGivenJobNode(const TString& directory, size_t shardId) {
    const TString shardPart = TStringBuilder() << "shard_id_" << shardId;
    return NYT::JoinYPaths(directory, shardPart, "given_job");
}

} // anonymous namespace

////////////////////////////////////////////////////////////////////////////////

TString TDistributionManager::GetLeadingInvaderName() const {
    return LeadingInvaderConfig_.GetPath();
}

TExpected<void, NLeadingInvader::TError> TDistributionManager::EnsureLeading() {
    TReadGuard guard(LeadingInvaderMutex_);
    if (!LeadingInvader_) {
        return NLeadingInvader::TError{TString("LeadingInvader has been destroyed")};
    }
    return LeadingInvader_->EnsureLeading();
}

NLeadingInvader::TLeaderInfo TDistributionManager::GetLeaderInfo() const {
    TReadGuard guard(LeadingInvaderMutex_);
    if (!LeadingInvader_) {
        return NLeadingInvader::TLeaderInfo{NLeadingInvader::TLeaderInfo::EResolveLeaderStatus::FAILED, "", ""};
    }
    return LeadingInvader_->GetLeaderInfo();
}

void TDistributionManager::ResetLeadingInvader(
    const std::function<void()>& onLockAcquired
    , const std::function<void()>& onLockLost
) {
    TWriteGuard guard(LeadingInvaderMutex_);
    LeadingInvader_.Reset(NLeadingInvader::CreateLeadingInvader(LeadingInvaderConfig_, onLockAcquired, onLockLost));
}

void TDistributionManager::DestroyLeadingInvader() {
    TWriteGuard guard(LeadingInvaderMutex_);
    LeadingInvader_.Destroy();
}

////////////////////////////////////////////////////////////////////////////////

const TSensorGroup& TDistributionManager::GetSensorGroupRef() {
    return SensorGroup_;
}

void TDistributionManager::OnGlobalManagementFinish() {
    ManagementDurationSensor_.Reset();
}

TDuration TDistributionManager::GetManagementInterval() const {
    return FromString<TDuration>(ActualConfig_.GetManagementInterval());
}

TString TDistributionManager::GetName() const {
    return ServiceName_;
}

void TDistributionManager::IncrementSensor(const TStringBuf sensor, ui64 x) {
    NON_STATIC_INFRA_RATE_SENSOR_X(GetSensorGroupRef(), sensor, x);
}

////////////////////////////////////////////////////////////////////////////////

void TDistributionManager::ManageShardsDistribution(
    TLogFramePtr frame
) {
    ManagementDurationSensor_.Start();
    frame->LogEvent(NLogEvent::TManageShardsDistributionStart(GetName()));

    auto shardId2AliveHosts = DiscoverAliveHostsPerShard(frame);

    for (const auto& [shardId, aliveHosts] : shardId2AliveHosts) {
        if (aliveHosts.empty()) {
            frame->LogEvent(
                ELogPriority::TLOG_ERR,
                NLogEvent::TShardWithoutAvailableHosts(GetName(), shardId)
            );
            IncrementSensor(NSensors::SHARD_WITHOUT_AVAILABLE_HOSTS, 1);
        }
    }

    auto shardId2CurrentHost = GetShardsCurrentHosts(frame);
    auto hostName2ListOfShards = GetListOfShardsPerHost(shardId2CurrentHost, frame);

    try {
        RebalanceShards(shardId2AliveHosts, shardId2CurrentHost, hostName2ListOfShards, frame);
    } catch (...) {
        frame->LogEvent(ELogPriority::TLOG_ERR, NLogEvent::TRebalanceShardsError(GetName(), CurrentExceptionMessage()));
        throw;
    }

    ManagementDurationSensor_.Update();
}

////////////////////////////////////////////////////////////////////////////////

THashMap<size_t, THashSet<TString>> TDistributionManager::DiscoverAliveHostsPerShard(
    TLogFramePtr frame
) {
    try {
        frame->LogEvent(NLogEvent::TDiscoverAliveHostsPerShardStart(GetName()));

        THashMap<size_t, THashSet<TString>> result;
        const size_t numberOfShards = ActualConfig_.GetNumberOfShards();
        for (size_t shardId = 0; shardId < numberOfShards; ++shardId) {
            result[shardId] = NYpDns::NYtHelpers::ListLockedNodes(YtClient_, BuildLivenessDirectory(ActualConfig_.GetYtConfig().GetCypressRootPath(), shardId), frame);
        }

        frame->LogEvent(NLogEvent::TDiscoverAliveHostsPerShardSuccess(GetName()));
        return result;
    } catch (...) {
        frame->LogEvent(ELogPriority::TLOG_ERR, NLogEvent::TDiscoverAliveHostsPerShardError(GetName(), CurrentExceptionMessage()));
        throw;
    }
}

THashMap<size_t, TString> TDistributionManager::GetShardsCurrentHosts(
    TLogFramePtr frame
) {
    try {
        frame->LogEvent(NLogEvent::TGetShardsCurrentHostsStart(GetName()));

        THashMap<size_t, TString> result;
        const size_t numberOfShards = ActualConfig_.GetNumberOfShards();
        for (size_t shardId = 0; shardId < numberOfShards; ++shardId) {
            const TString hostName = NYpDns::NYtHelpers::TryGetNodeData(YtClient_, BuildGivenJobNode(ActualConfig_.GetYtConfig().GetCypressRootPath(), shardId), frame).AsString();
            result[shardId] = hostName;
        }

        frame->LogEvent(NLogEvent::TGetShardsCurrentHostsSuccess(GetName()));
        return result;
    } catch (...) {
        frame->LogEvent(ELogPriority::TLOG_ERR, NLogEvent::TGetShardsCurrentHostsError(GetName(), CurrentExceptionMessage()));
        throw;
    }
}

THashMap<TString, TVector<size_t>> TDistributionManager::GetListOfShardsPerHost(
    THashMap<size_t, TString> shardId2CurrentHost,
    TLogFramePtr frame
) {
    Y_UNUSED(frame);
    THashMap<TString, TVector<size_t>> result;
    for (const auto& [shardId, hostWithJob] : shardId2CurrentHost) {
        if (hostWithJob) {
            result[hostWithJob].emplace_back(shardId);
        }
    }

    return result;
}

void TDistributionManager::RebalanceShards(
    THashMap<size_t, THashSet<TString>>& shardId2AliveHosts,
    THashMap<size_t, TString>& shardId2CurrentHost,
    THashMap<TString, TVector<size_t>>& host2ListOfShards,
    TLogFramePtr frame
) {
    frame->LogEvent(NLogEvent::TRebalanceShardsStart(GetName()));
    TMap<size_t, THashSet<TString>> currentWeight2ListOfHosts;
    for (const auto& [host, list] : host2ListOfShards) {
        currentWeight2ListOfHosts[list.size()].insert(host);
    }

    // Rebalance inactive shards at first
    for (size_t shardId = 0; shardId < ActualConfig_.GetNumberOfShards(); ++shardId) {
        if (shardId2AliveHosts[shardId].empty()) {
            continue;
        }

        TString currentHost = shardId2CurrentHost[shardId];
        if (!currentHost || !shardId2AliveHosts[shardId].contains(currentHost)) {
            const TString bestHost = *MinElementBy(shardId2AliveHosts[shardId], [&](const TString& host) {
                return host2ListOfShards[host].size();
            });

            Y_ENSURE(!bestHost.empty()); // Impossible situation because shard has at least one active host

            frame->LogEvent(NLogEvent::TRebalanceShardsFoundFreeShard(GetName(), shardId));
            MoveShard(shardId, currentHost, bestHost, frame);
            return;
        }
    }

    // Rebalance heavy hosts
    for (const auto& [weight, hosts] : Reversed(currentWeight2ListOfHosts)) {
        if (!weight) {
            break;
        }

        for (const auto& heavyHost : hosts) {
            for (auto shardId : host2ListOfShards.at(heavyHost)) {
                const TString bestHost = *MinElementBy(shardId2AliveHosts[shardId], [&](const TString& host) {
                    return host2ListOfShards[host].size();
                });

                Y_ENSURE(!bestHost.empty()); // Impossible situation because shard has at least one active host

                if (host2ListOfShards[bestHost].size() < host2ListOfShards[heavyHost].size() - 1) {
                    MoveShard(shardId, heavyHost, bestHost, frame);
                    return;
                }
            }
        }
    }

    frame->LogEvent(NLogEvent::TOptimalDistributionReached(GetName()));
}

void TDistributionManager::MoveShard(
    size_t shardId,
    const TString& oldHost,
    const TString& newHost,
    TLogFramePtr frame
) {
    try {
        frame->LogEvent(NLogEvent::TMoveShardStart(GetName(), shardId, oldHost, newHost));
        const TString path = BuildGivenJobNode(ActualConfig_.GetYtConfig().GetCypressRootPath(), shardId);
        YtClient_->Set(path, newHost);
        frame->LogEvent(NLogEvent::TMoveShardSuccess(GetName(), shardId, oldHost, newHost));
    } catch (...) {
        frame->LogEvent(ELogPriority::TLOG_ERR, NLogEvent::TMoveShardError(GetName(), shardId, oldHost, newHost, CurrentExceptionMessage()));
        throw;
    }
}

} // namespace NInfra::NController::NShardMaster
