#include "router_impl.h"

#include <mail/ratesrv/src/common/format.h>
#include <mail/ratesrv/src/common/types.h>
#include <mail/ratesrv/src/instances/task.h>
#include <mail/ratesrv/src/logger.h>
#include <mail/ratesrv/src/scheduler/scheduler.h>

#include <yplatform/find.h>
#include <yplatform/coroutine.h>
#include <yplatform/ptree.h>

#include <yaml-cpp/yaml.h>

#include <boost/algorithm/string/join.hpp>

namespace NRateSrv::NRouter {

TRouterImpl::TRouterImpl(yplatform::reactor& reactor)
    : Reactor(reactor)
    , Ready(false)
    , HostsUpdateMethod(EHostsUpdateMethod::File)
{}

void TRouterImpl::Init(const yplatform::ptree& configuration) {
    ReadConfiguration(configuration);

    NetworkAgent = std::make_shared<TNetworkAgentImpl>(Reactor, logger(), Configuration);
    NetworkAgent->Listen();

    switch (HostsUpdateMethod) {
        case EHostsUpdateMethod::File:
            return ReadHostsConfiguration();
        case EHostsUpdateMethod::CloudApi:
            return CloudHostlist();
    }
}

void TRouterImpl::Reload(const yplatform::ptree&) {
    if (HostsUpdateMethod == EHostsUpdateMethod::File) {
        ReadHostsConfiguration();
    }
}

void TRouterImpl::Stop() {
    auto scheduler = yplatform::find<NScheduler::TScheduler>("scheduler");
    for (auto groupId : SchedulerGroupIds) {
        scheduler->RemoveGroup(groupId);
    }
}

void TRouterImpl::ReadConfiguration(const yplatform::ptree& configuration) {
    Configuration = std::make_shared<TConfiguration>();
    Configuration->SetHashSeed(configuration.get<ui64>("hash_seed", Configuration->GetHashSeed()));
    Configuration->SetBucketCount(configuration.get<size_t>("bucket_count", Configuration->GetBucketCount()));

    Configuration->SetBaseBanDuration(configuration.get("base_ban_duration", Configuration->GetBaseBanDuration()));
    Configuration->SetBanQuorum(configuration.get("ban_quorum", Configuration->GetBanQuorum()));

    Configuration->SetNodeHashSalt(
        configuration.get<std::string>("node_hash_salt", Configuration->GetNodeHashSalt()));
    if (Configuration->GetNodeHashSalt().empty()) {
        throw std::runtime_error("Node hash salt must be not empty");
    }

    Configuration->SetTimeout(configuration.get("request_timeout", Configuration->GetTimeout()));
    if (Configuration->GetTimeout() <= TConfiguration::TTimeout::zero()) {
        throw std::runtime_error("Request timeout must be greater than 0");
    }

    Configuration->SetPingTimeout(configuration.get("ping_timeout", Configuration->GetPingTimeout()));
    if (Configuration->GetPingTimeout() <= TConfiguration::TTimeout::zero()) {
        throw std::runtime_error("Ping timeout must be greater than 0");
    }

    Configuration->SetMaxAttemptsForNode(
        configuration.get("max_attempts_for_node",
        Configuration->GetMaxAttemptsForNode()));
    Configuration->SetMaxAttempts(configuration.get("max_attempts", Configuration->GetMaxAttempts()));

    auto hostsUpdateMethod = configuration.get<std::string>("hosts.method");
    if (hostsUpdateMethod == "file") {
        HostsUpdateMethod = EHostsUpdateMethod::File;
        HostsFile = configuration.get<std::string>("hosts.file");
    } else if (hostsUpdateMethod == "qloud-api") {
        HostsUpdateMethod = EHostsUpdateMethod::CloudApi;
    } else {
        throw std::runtime_error(Format("Unknown hosts update method: %1%", hostsUpdateMethod));
    }

    HostlistPeriod = configuration.get("hosts.period", HostlistPeriod);
}

void TRouterImpl::ReadHostsConfiguration() {
    std::string localhost;
    std::vector<std::string> hosts;

    try {
        auto hostsNode = YAML::LoadFile(HostsFile);
        auto hostsNodeArr = hostsNode["hosts"];

        localhost = hostsNode["localhost"].as<std::string>();
        for (size_t i = 0; i < hostsNodeArr.size(); ++i) {
            hosts.push_back(hostsNodeArr[i].as<std::string>());
        }
    } catch (const std::exception& exp) {
        YLOG_L(error) << Format("Fail to read hosts list: %1%", exp.what());
        return;
    }

    SetHosts(localhost, std::move(hosts));
}

void TRouterImpl::CloudHostlist() {
    auto callback = [this, self = shared_from_this()](auto ec, auto response) {
        if (ec) {
            YLOG_L(error) << "Fail to get cloud instances: " << ec.message();
        } else {
            RATESRV_LOG_COUNTER_NOCTX(notice, "cluster_usage_percent", response.ClusterUsagePersent);
            SetHosts(response.LocalHost, std::move(response.Hosts));
        }
    };

    NScheduler::TGroupSettings groupSettings;
    groupSettings.Duration = HostlistPeriod;
    groupSettings.Policy = NScheduler::EExecutionPolicyWhenTaskAdding::WaitInLine;

    auto scheduler = yplatform::find<NScheduler::TScheduler>("scheduler");
    auto groupId = scheduler->CreateGroup(std::move(groupSettings));
    auto taskId = scheduler->AddTask(
        std::make_unique<NInstances::TTask>(std::move(callback), *Reactor.io()), groupId);
    SchedulerGroupIds.push_back(groupId);

    YLOG_L(info) << Format("Add repository hostlist task to scheduler, group id %1%, task id %2%", groupId, taskId);
}

void TRouterImpl::SetHosts(const std::string& localhost, std::vector<std::string> hosts) {
    YLOG_L(info) << Format(
        "Hosts list is reloading: [%1%], localhost is [%2%]",
         boost::algorithm::join(hosts, ", "),
         localhost);
    CurrentHosts = std::move(hosts);
    Ready.store(!CurrentHosts.empty(), std::memory_order_release);
    NetworkAgent->InitNodeManager(localhost, CurrentHosts);
}

void TRouterImpl::CreateWorker(TTaskContextPtr ctx, TCounterRequest request, TCallback callback, bool increase) {
    auto worker = std::make_shared<TCountersWorker>(
        Reactor,
        Configuration,
        NetworkAgent,
        NetworkAgent->GetNodeManager(),
        std::move(ctx),
        std::move(request),
        increase ? ERequestMode::Increase : ERequestMode::Get,
        std::move(callback));
    yplatform::spawn(worker);
}

void TRouterImpl::AsyncGet(TTaskContextPtr ctx, TCounterRequest request, TCallback callback) {
    CreateWorker(std::move(ctx), std::move(request), std::move(callback), false);
}

void TRouterImpl::AsyncIncrease(TTaskContextPtr ctx, TCounterRequest request, TCallback callback) {
    CreateWorker(std::move(ctx), std::move(request), std::move(callback), true);
}

bool TRouterImpl::IsReady() {
    return Ready.load(std::memory_order_acquire);
}

} // namespace NRateSrv::NRouter
