#include "network_agent_impl.h"
#include "messenger_request_handler.h"
#include "util.h"

#include <mail/ratesrv/src/common/format.h>

#include <ymod_messenger/module.h>

#include <yplatform/coroutine.h>
#include <yplatform/find.h>
#include <yplatform/time_traits.h>

#include <util/random/random.h>

#include <algorithm>
#include <iterator>
#include <set>

namespace NRateSrv::NRouter {

using namespace yplatform::time_traits;

TNetworkAgentImpl::TNetworkAgentImpl(
    yplatform::reactor& reactor,
    const yplatform::log::source& logger,
    TConfigurationPtr configuration
)
    : yplatform::log::contains_logger(logger)
    , Reactor(reactor)
    , Configuration(std::move(configuration))
{
    InitNodeManager({}, {});
}

void TNetworkAgentImpl::InitNodeManager(const std::string& localhost, std::vector<std::string> hosts) {
    auto messenger = yplatform::find<ymod_messenger::module>("messenger");
    auto myRealAddress = messenger->my_address();
    auto port = ExtractPort(messenger->my_address());
    if (port.empty()) {
        throw std::runtime_error(Format("Fail to extract port from address %1%", myRealAddress));
    }

    TNodeManager::TNodes nodes;
    std::set<std::string> orderedHosts;

    for (const auto& host : hosts) {
        bool isLocal = host == localhost;
        auto hostWithPort = std::move(host);
        hostWithPort += ':' + port;
        orderedHosts.insert(hostWithPort);
        nodes.emplace_back(std::move(hostWithPort), isLocal);
    }

    messenger->connect_to_cluster(orderedHosts);

    auto nodeManager = std::make_shared<TNodeManager>(std::move(nodes));
    auto guard = Guard(NodeManagerLock);
    std::swap(NodeManager, nodeManager);
}

void TNetworkAgentImpl::Listen() {
    BindRequestMessages(false);
    BindRequestMessages(true);
    BindResponseMessages();
}

void TNetworkAgentImpl::BindRequestMessages(bool increase) {
    EMessageTypes type = increase ? EMessageTypes::IncreaseRequest : EMessageTypes::GetRequest;
    auto messenger = yplatform::find<ymod_messenger::module>("messenger");
    messenger->bind_messages<TMessengerRequest>(
        [agent = shared_from_this(), this, increase, &reactor = Reactor]
            (const std::string& address, TMessengerRequest req)
    {
            auto requestId = req.RequestId;
            auto handler = std::make_shared<TMessengerRequestHandler>(
                address,
                std::move(req),
                increase ? ERequestMode::Increase : ERequestMode::Get,
                [agent = std::move(agent), this, requestId, address]
                    (NStorage::TResponse response) {
                        SendResponse(address, requestId, std::move(response));
                    });
            yplatform::spawn(reactor.io()->get_executor(), handler);
        },
        GetMessageType(type));
}

void TNetworkAgentImpl::BindResponseMessages() {
    auto messenger = yplatform::find<ymod_messenger::module>("messenger");
    messenger->bind_messages<TMessengerResponse>(
        [agent = shared_from_this(), this](const std::string&, TMessengerResponse res) {
            auto callback = CompleteRequest(res.RequestId,  true);
            if (callback) {
                callback(res.RequestId, std::move(res.StorageResponse));
            }
        },
        GetMessageType(EMessageTypes::Response));
}

ui64 TNetworkAgentImpl::SendRequest(
    TNodeManagerPtr nodeManager,
    TTaskContextPtr ctx,
    size_t nodeNum,
    NStorage::TRequest request,
    EMessageTypes type,
    TCallback callback)
{
    const auto& node = nodeManager->Get(nodeNum);
    auto requestId = CreateRequest(nodeNum, std::move(nodeManager), ctx->uniq_id(), request, type, std::move(callback));

    try {
        auto messenger = yplatform::find<ymod_messenger::module>("messenger");
        messenger->send(
            node.GetAddress(),
            TMessengerRequest{requestId, ctx->uniq_id(), std::move(request)},
            GetMessageType(type));
    } catch (const std::exception& exp) {
        YLOG_CTX_LOCAL(ctx, error) <<
            Format("Fail to send request to node %1%, error: %2%", node.GetAddress(), exp.what());
        CompleteRequest(requestId, false);
        return 0;
    }

    return requestId;
}

void TNetworkAgentImpl::SendResponse(const std::string& address, ui64 requestId, NStorage::TResponse response) {
    try {
        auto messenger = yplatform::find<ymod_messenger::module>("messenger");
        messenger->send(
            address,
            TMessengerResponse{requestId, std::move(response)},
            GetMessageType(EMessageTypes::Response));
    } catch (const std::exception& exp) {
        YLOG_L(error) << Format("Fail to send response to node %1%, error: %2%", address, exp.what());
    }
}

INetworkAgent::TCallback TNetworkAgentImpl::CompleteRequest(ui64 requestId, bool success) {
    if (requestId == 0) {
        return {};
    }

    TRequest request;
    {
        auto guard = Guard(RequestLock);
        auto it = Requests.find(requestId);
        if (it == Requests.end()) {
            return {};
        }
        request = std::move(it->second);
        Requests.erase(it);
    }

    auto& node = request.NodeManager->Get(request.NodeNum);
    if (success) {
        auto banTimes = node.ResetBan(false);
        if (banTimes > 0) {
            YLOG_L(info) << Format("Node %1% unbanned, ban times: %2%", node.GetAddress(), banTimes);
        }
    } else {
        auto duration = node.Ban(Configuration->GetBaseBanDuration(), Configuration->GetBanQuorum(), false);
        auto banSeconds = duration_cast<seconds>(duration);
        if (banSeconds > seconds::zero()) {
            YLOG_L(warning) << Format("Node %1% banned for %2% seconds", node.GetAddress(), banSeconds.count());
        }
    }

    return std::move(request.Callback);
}

TNodeManagerPtr TNetworkAgentImpl::GetNodeManager() {
    auto guard = Guard(NodeManagerLock);
    return NodeManager;
}

ui64 TNetworkAgentImpl::CreateRequest(
        size_t nodeNum,
        TNodeManagerPtr nodeManager,
        std::string contextId,
        NStorage::TRequest request,
        EMessageTypes type,
        TCallback callback) {
    ui64 requestId;
    auto guard = Guard(RequestLock);

    do {
        requestId = RandomNumber<ui64>();
    } while (requestId == 0 || Requests.count(requestId) > 0);

    Requests.emplace(requestId, TRequest{nodeNum, std::move(nodeManager), std::move(contextId), std::move(request), type, std::move(callback)});

    return requestId;
}

} // namespace NRateSrv::NRouter
