#include "asio_connector.h"

#include "asio_logging.h"
#include "serialize.h"

#include <yandex_io/libs/base/utils.h>
#include <yandex_io/libs/logging/logging.h>

#include <yandex_io/protos/quasar_proto.pb.h>

#include <util/system/yassert.h>

using namespace quasar::ipc::detail::asio_ipc;
using namespace std::chrono_literals;

//
// Callbacks
//

AsioConnector::Callbacks::Callbacks(std::shared_ptr<ICallbackQueue> callbackQueue)
    : queue(std::move(callbackQueue))
{
}

//
// RequestHandlerMap
//

AsioConnector::RequestHandlerMap::RequestHandlerMap(std::shared_ptr<AsioAsyncWorker> worker)
    : AsioAsyncObject(std::move(worker))
    , strand_(ioContext().get_executor())
{
}

AsioConnector::RequestHandlerMap::~RequestHandlerMap() {
    YIO_LOG_TRACE(*this << ": destroy");
}

void AsioConnector::RequestHandlerMap::debugPrintDescription(std::ostream& out) const {
    out << "<Connector.RequestHandlerMap " << this << ">";
}

void AsioConnector::RequestHandlerMap::doAsyncStart() {
    YIO_LOG_TRACE(*this << ": start");
}

void AsioConnector::RequestHandlerMap::doAsyncShutdown() {
    auto lock = std::scoped_lock{handlersMutex_};
    YIO_LOG_TRACE(*this << ": shutdown; will cancel " << handlers_.size() << " handlers");

    for (auto& item : handlers_) {
        YIO_LOG_TRACE(*this << ": cancel: " << item.first);
        item.second.onError("shutting down");
    }
    handlers_.clear();
}

std::string AsioConnector::RequestHandlerMap::add(OnDone onDone, OnError onError, std::chrono::milliseconds timeout) {
    auto lock = std::scoped_lock{handlersMutex_};

    std::string key;
    do {
        key = makeUUID();
    } while (handlers_.find(key) != handlers_.end());
    YIO_LOG_TRACE(*this << ": add: " << key);

    auto& entry = handlers_.emplace(std::piecewise_construct, std::forward_as_tuple(key), std::forward_as_tuple(ioContext(), std::move(onDone), std::move(onError))).first->second;
    entry.timer.expires_after(timeout);
    entry.timer.async_wait(asio::bind_executor(strand_, [key, weakThis = weak_from_this()](const auto& ec) {
        if (auto this_ = weakThis.lock()) {
            this_->notifyError(key, ec ? ec.message() : "request timeout");
        }
    }));

    return key;
}

bool AsioConnector::RequestHandlerMap::notifyDone(const std::string& key, const SharedMessage& message) {
    YIO_LOG_TRACE(*this << ": notify done: " << key);

    OnDone onDone;
    {
        auto lock = std::scoped_lock{handlersMutex_};
        if (auto it = handlers_.find(key); it != handlers_.end()) {
            onDone = std::move(it->second.onDone);
            handlers_.erase(it);
        } else {
            YIO_LOG_WARN(*this << ": handler (done) not found for request id = " << key);
            return false;
        }
    }

    onDone(message);
    return true;
}

bool AsioConnector::RequestHandlerMap::notifyError(const std::string& key, const std::string& errorMessage) {
    YIO_LOG_TRACE(*this << ": notify error: " << key << "; " << errorMessage);

    OnError onError;
    {
        auto lock = std::scoped_lock{handlersMutex_};
        if (auto it = handlers_.find(key); it != handlers_.end()) {
            onError = std::move(it->second.onError);
            handlers_.erase(it);
        } else {
            YIO_LOG_WARN(*this << ": handler (error) not found for request id = " << key);
            return false;
        }
    }

    onError(errorMessage);
    return true;
}

//
// ConnectorChannel
//

AsioConnector::ConnectorChannel::ConnectorChannel(std::shared_ptr<AsioAsyncWorker> asyncWorker, std::shared_ptr<AsioConnector::Callbacks> callbacks, asio::ip::tcp::socket sock, const std::string& serviceName, OnDisconnect onDisconnect)
    : AsioChannel(std::move(asyncWorker), std::move(sock), serviceName)
    , callbacks_(std::move(callbacks))
    , onDisconnect_(std::move(onDisconnect))
{
    Y_VERIFY(callbacks_ != nullptr);
    Y_VERIFY(onDisconnect_ != nullptr);
}

void AsioConnector::ConnectorChannel::debugPrintDescription(std::ostream& out) const {
    out << "<Connector.Channel";
    if (const auto& service = serviceDescription(); !service.empty()) {
        out << ':' << service;
    }
    out << ' ' << this << '>';
}

void AsioConnector::ConnectorChannel::sendMessageImpl(const Message& message) {
    asyncSend(SerializeToSharedBufferWithFraming(message));
}

void AsioConnector::ConnectorChannel::sendRequestImpl(Message& message, OnDone onDone, OnError onError, std::chrono::milliseconds timeout) {
    auto lock = std::scoped_lock{requestHandlersMutex_};
    if (!requestHandlers_) {
        requestHandlers_ = worker()->create<RequestHandlerMap>(worker());
        YIO_LOG_DEBUG(*this << ": create request handler map: " << *requestHandlers_);
        requestHandlers_->asyncStart();
    }

    auto key = requestHandlers_->add(std::move(onDone), std::move(onError), timeout);
    message.set_request_id(TString(key));

    sendMessageImpl(message);
}

void AsioConnector::ConnectorChannel::onIpcConnect() {
    if (!callbacks_->onConnect) {
        YIO_LOG_TRACE(*this << ": on connect: no handler");
        return;
    }

    YIO_LOG_TRACE(*this << ": on connect");
    callbacks_->queue->add([cb = callbacks_] {
        YIO_LOG_TRACE("connector on connect");
        cb->onConnect();
    });
}

void AsioConnector::ConnectorChannel::onIpcDisconnect() {
    if (callbacks_->onDisconnect) {
        callbacks_->queue->add([cb = callbacks_] {
            YIO_LOG_TRACE("connector on disconnect");
            cb->onDisconnect();
        });
    }
    {
        auto lock = std::scoped_lock{requestHandlersMutex_};
        if (requestHandlers_) {
            requestHandlers_->asyncShutdown();
        }
    }
    onDisconnect_(this);
}

void AsioConnector::ConnectorChannel::onIpcConnectionError(std::string errorMessage) {
    if (!callbacks_->onConnectionError) {
        YIO_LOG_TRACE(*this << ": on connection error: no handler");
        return;
    }

    YIO_LOG_TRACE(*this << ": on connection error");
    callbacks_->queue->add([cb = callbacks_, errorMessage = std::move(errorMessage)] {
        YIO_LOG_TRACE("connector on connection error (" << errorMessage << ")");
        cb->onConnectionError(errorMessage);
    });
}

bool AsioConnector::ConnectorChannel::onIpcMessage(asio::const_buffer buffer) {
    YIO_LOG_TRACE(*this << ": on message");
    auto payload = std::string(static_cast<const char*>(buffer.data()), buffer.size());
    callbacks_->queue->add([weakThis = weak_from_this(), payload = std::move(payload)] {
        if (auto this_ = std::static_pointer_cast<ConnectorChannel>(weakThis.lock())) {
            this_->handleIncomingPayload(payload);
        }
    });

    return true;
}

void AsioConnector::ConnectorChannel::handleIncomingPayload(std::string_view payload) {
    // parse and handle message
    YIO_LOG_TRACE(*this << ": begin parse message (size = " << payload.size() << ")");
    auto mutableMessage = UniqueMessage::create();
    if (mutableMessage->ParseFromArray(payload.data(), payload.size())) {
        YIO_LOG_TRACE(*this << ": parse message: ok");
        SharedMessage message = std::move(mutableMessage);
        if (message->has_request_id() && handleRequestResponse(message)) {
            return;
        }
        if (callbacks_->onMessage) {
            callbacks_->onMessage(message);
        }
    } else {
        YIO_LOG_WARN(*this << ": parse message: fail");
        // TODO: notify and break connection
        asyncShutdown();
    }
}

bool AsioConnector::ConnectorChannel::handleRequestResponse(const SharedMessage& message) {
    auto lock = std::scoped_lock{requestHandlersMutex_};

    if (!requestHandlers_) {
        return false;
    }

    YIO_LOG_TRACE(*this << ": message request id = " << message->request_id() << "; find handler");
    return requestHandlers_->notifyDone(message->request_id(), message);
}

//
// AsioConnector
//

AsioConnector::AsioConnector(std::string serviceName, AsioTcpConnector::Address address, std::shared_ptr<AsioAsyncWorker> worker, std::shared_ptr<Callbacks> callbacks)
    : AsioAsyncObject(std::move(worker))
    , serviceName_(std::move(serviceName))
    , address_(std::move(address))
    , callbacks_(std::move(callbacks))
{
    Y_VERIFY(callbacks_ != nullptr);
}

void AsioConnector::debugPrintDescription(std::ostream& out) const {
    out << "<Connector:" << serviceName_ << ' ' << this << '>';
}

bool AsioConnector::sendMessageImpl(const Message& message) {
    auto lock = std::scoped_lock{channelMutex_};
    if (channel_) {
        channel_->sendMessageImpl(message);
        return true;
    }
    return false;
}

std::shared_ptr<AsioConnector::ConnectorChannel> AsioConnector::lockChannel() {
    auto lock = std::scoped_lock{channelMutex_};
    return channel_;
}

bool AsioConnector::isConnected() const {
    auto lock = std::scoped_lock{channelMutex_};
    return channel_ != nullptr;
}

bool AsioConnector::waitUntilConnected(const std::chrono::milliseconds& timeout) const {
    YIO_LOG_TRACE(*this << ": wait until connected: begin(timeout = " << timeout.count() << "ms)");
    auto lock = std::unique_lock{channelMutex_};
    const bool result = channelCV_.wait_for(lock, timeout, [this] {
        YIO_LOG_TRACE(*this << ": wait until connected: check(channel = " << channel_.get() << ")");
        return channel_ != nullptr;
    });
    YIO_LOG_TRACE(*this << ": wait until connected: end(result = " << (result ? "true" : "false") << ")");
    return result;
}

bool AsioConnector::waitUntilDisconnected(const std::chrono::milliseconds& timeout) const {
    YIO_LOG_TRACE(*this << ": wait until disconnected: begin(timeout = " << timeout.count() << "ms)");
    auto lock = std::unique_lock{channelMutex_};
    const bool result = channelCV_.wait_for(lock, timeout, [this] {
        YIO_LOG_TRACE(*this << ": wait until disconnected: check(channel = " << channel_.get() << ")");
        return channel_ == nullptr;
    });
    YIO_LOG_TRACE(*this << ": wait until disconnected: end(result = " << (result ? "true" : "false") << ")");
    return result;
}

void AsioConnector::doAsyncStart() {
    YIO_LOG_TRACE(*this << ": start");
    asyncConnect();
}

void AsioConnector::doAsyncShutdown() {
    auto lock = std::scoped_lock{channelMutex_, connectorMutex_};
    YIO_LOG_TRACE(*this << ": shutdown; connector = " << connector_.get() << "; channel = " << channel_.get());

    if (connector_) {
        connector_->asyncShutdown();
        connector_ = nullptr;
    }
    if (channel_) {
        channel_->asyncShutdown();
    }
}

void AsioConnector::asyncConnect() {
    auto lock = std::scoped_lock{connectorMutex_};
    Y_VERIFY(connector_ == nullptr);

    connector_ = createConnector();
    YIO_LOG_DEBUG(*this << ": start connector: " << *connector_);
    connector_->asyncStart();
}

std::shared_ptr<AsioTcpConnector> AsioConnector::createConnector() {
    auto weakThis = weak_from_this();
    return worker()->create<AsioTcpConnector>(worker(), serviceName_, address_, AsioTcpConnector::Callbacks{
                                                                                    .onConnect = [weakThis](asio::ip::tcp::socket peer) {
            if (auto this_ = weakThis.lock()) {
                YIO_LOG_TRACE(*this_ << ": connector on connect");
                this_->setTcpChannel(std::move(peer));
                // TODO: destroy connector after a channel is established?
            } },
                                                                                    .onConnectionFailure = [weakThis]() -> AsioTcpConnector::RetryAction {
                                                                                        if (auto this_ = weakThis.lock()) {
                                                                                            if (!this_->isStopped()) {
                                                                                                // TODO: adaptive retry policy?
                                                                                                return AsioTcpConnector::RetryAfter{1s};
                                                                                            }
                                                                                        }

                                                                                        return AsioTcpConnector::StopTrying{};
                                                                                    },
                                                                                });
}

void AsioConnector::setTcpChannel(asio::ip::tcp::socket peer) {
    YIO_LOG_TRACE(*this << ": set tcp channel for " << AsioTcpEndpointsLog{peer});

    auto channel = worker()->create<ConnectorChannel>(worker(), callbacks_, std::move(peer), serviceName_, [weakThis = weak_from_this()](auto* channel) {
        auto this_ = weakThis.lock();
        if (Y_UNLIKELY(!this_)) {
            YIO_LOG_TRACE(*channel << ": on disconnect: connector impl is dead");
            return;
        }

        YIO_LOG_TRACE(*channel << ": on disconnect [in " << *this_ << "]");
        this_->setChannel(nullptr);
    });

    YIO_LOG_DEBUG(*this << ": new channel: " << *channel);
    setChannel(channel);

    channel->asyncStart();
}

void AsioConnector::setChannel(std::shared_ptr<ConnectorChannel> channel) {
    if (channel) {
        YIO_LOG_DEBUG(*this << ": setChannel " << *channel);
    } else {
        YIO_LOG_DEBUG(*this << ": setChannel (null)");
    }

    const bool needReconnect = !channel;

    {
        auto lock = std::scoped_lock{channelMutex_, connectorMutex_};
        channel_ = std::move(channel);

        if (connector_) {
            connector_->asyncShutdown();
            connector_ = nullptr;
        }
    }

    channelCV_.notify_all();

    if (needReconnect) {
        const bool isThisRunning = isRunning();
        const bool isWorkerRunning = worker()->isRunning();
        YIO_LOG_TRACE(*this << ": channel needs reconnect? (running: "
                            << (isThisRunning ? "true" : "false") << "; worker running: "
                            << (isWorkerRunning ? "true" : "false") << ")");
        if (isThisRunning && isWorkerRunning) {
            YIO_LOG_DEBUG(*this << ": reconnect");
            asyncConnect();
        }
    }
}
