#include "asio_server.h"

#include "asio_logging.h"
#include "serialize.h"

#include <yandex_io/libs/ipc/helpers.h>

#include <yandex_io/libs/json_utils/json_utils.h>
#include <yandex_io/libs/logging/logging.h>

#include <yandex_io/protos/quasar_proto.pb.h>

#include <util/system/yassert.h>

YIO_DEFINE_LOG_MODULE("asio_ipc");

using namespace quasar::ipc;
using namespace quasar::ipc::detail::asio_ipc;
using namespace std::chrono_literals;

//
// Callbacks
//

AsioServer::Callbacks::Callbacks(std::shared_ptr<ICallbackQueue> callbackQueue)
    : queue(std::move(callbackQueue))
{
}

//
// WeakClientConnection
//

class AsioServer::WeakClientConnection: public IServer::IClientConnection, public std::enable_shared_from_this<AsioServer::WeakClientConnection> {
public:
    WeakClientConnection(std::weak_ptr<AsioChannel> weakChannel)
        : weakChannel_(std::move(weakChannel))
    {
    }

    ~WeakClientConnection() {
        YIO_LOG_TRACE("destroy weak shared client connection " << this);
    }

    std::shared_ptr<IServer::IClientConnection> share() override {
        return shared_from_this();
    }

    void send(const SharedMessage& message) override {
        if (auto channel = weakChannel_.lock()) {
            channel->asyncSend(SerializeToSharedBufferWithFraming(*message));
        }
    }

    void send(Message&& message) override {
        if (auto channel = weakChannel_.lock()) {
            channel->asyncSend(SerializeToSharedBufferWithFraming(message));
        }
    }

    void unsafeSendBytes(std::string_view data) override {
        if (auto channel = weakChannel_.lock()) {
            channel->asyncSend(MakeSharedBuffer(data));
        }
    }

    void scheduleClose() override {
        if (auto channel = weakChannel_.lock()) {
            channel->asyncShutdown();
        }
    }

private:
    std::weak_ptr<AsioChannel> weakChannel_;
};

//
// ServerChannel
//

AsioServer::ServerChannel::ServerChannel(std::shared_ptr<AsioAsyncWorker> asyncWorker, std::shared_ptr<AsioServer::Callbacks> callbacks, asio::ip::tcp::socket sock, const std::string& serviceName, OnDisconnect onDisconnect)
    : AsioChannel(asyncWorker, std::move(sock), serviceName)
    , callbacks_(std::move(callbacks))
    , onDisconnect_(std::move(onDisconnect))
{
    Y_VERIFY(callbacks_ != nullptr);
    Y_VERIFY(onDisconnect_ != nullptr);
}

void AsioServer::ServerChannel::debugPrintDescription(std::ostream& out) const {
    out << "<Server.Channel";
    if (const auto& service = serviceDescription(); !service.empty()) {
        out << ':' << service;
    }
    out << ' ' << this << '>';
}

std::shared_ptr<IServer::IClientConnection> AsioServer::ServerChannel::sharedClientConnection() {
    auto lock = std::scoped_lock{sharedClientConnectionMutex_};
    if (!sharedClientConnection_) {
        sharedClientConnection_ = std::make_shared<WeakClientConnection>(weak_from_this());
        YIO_LOG_TRACE(*this << ": create weak shared client connection " << sharedClientConnection_.get());
    }

    return sharedClientConnection_;
}

void AsioServer::ServerChannel::onIpcConnect() {
    if (!callbacks_->onClientConnected) {
        YIO_LOG_TRACE(*this << ": on connect: no handler");
        return;
    }

    YIO_LOG_TRACE(*this << ": on connect");
    auto channel = std::static_pointer_cast<ServerChannel>(shared_from_this());
    callbacks_->queue->add([channel{std::move(channel)}] {
        YIO_LOG_TRACE(*channel << ": on connect");
        auto& client = *channel->sharedClientConnection();
        channel->callbacks_->onClientConnected(client);
    });
}

void AsioServer::ServerChannel::onIpcDisconnect() {
    onDisconnect_(this);

    if (!callbacks_->onClientDisconnected) {
        YIO_LOG_TRACE(*this << ": on disconnect: no handler");
        return;
    }

    YIO_LOG_TRACE(*this << ": on disconnect");
    auto channel = std::static_pointer_cast<ServerChannel>(shared_from_this());
    callbacks_->queue->add([channel{std::move(channel)}] {
        YIO_LOG_TRACE(*channel << ": on disconnect");
        auto& client = *channel->sharedClientConnection();
        channel->callbacks_->onClientDisconnected(client);
    });
}

void AsioServer::ServerChannel::onIpcConnectionError(std::string errorMessage [[maybe_unused]]) {
    YIO_LOG_TRACE(*this << ": skip on connection error: " << errorMessage);
}

bool AsioServer::ServerChannel::onIpcMessage(asio::const_buffer buffer) {
    if (!callbacks_->onMessage) {
        YIO_LOG_TRACE(*this << ": on message: no handler");
        return false;
    }

    YIO_LOG_TRACE(*this << ": on message -> offload to user callback queue");
    auto payload = std::string(static_cast<const char*>(buffer.data()), buffer.size());
    callbacks_->queue->add([weakChannel = weak_from_this(), payload = std::move(payload)] {
        if (auto channel = std::static_pointer_cast<ServerChannel>(weakChannel.lock())) {
            YIO_LOG_TRACE(*channel << ": on message");
            channel->handleIncomingPayload(payload);
        }
    });

    return true;
}

void AsioServer::ServerChannel::handleIncomingPayload(std::string_view payload) {
    // parse and handle message
    YIO_LOG_TRACE(*this << ": begin parse message (size = " << payload.size() << ")");
    auto message = UniqueMessage::create();
    if (message->ParseFromArray(payload.data(), payload.size())) {
        YIO_LOG_TRACE(*this << ": parse message: ok");
        if (callbacks_->onMessage) {
            auto& client = *sharedClientConnection();
            callbacks_->onMessage(std::move(message), client);
        }
    } else {
        YIO_LOG_WARN(*this << ": parse message: fail");
        // TODO: notify and break connection
    }
}

//
// AsioServer
//

AsioServer::AsioServer(std::string serviceName, AsioTcpAcceptor::Address address, std::shared_ptr<AsioAsyncWorker> worker, std::shared_ptr<Callbacks> callbacks)
    : AsioAsyncObject(std::move(worker))
    , serviceName_(std::move(serviceName))
    , address_(address)
    , callbacks_(std::move(callbacks))
{
    Y_VERIFY(callbacks_ != nullptr);
}

void AsioServer::doAsyncStart() {
    YIO_LOG_TRACE(*this << ": start");
    doStartAcceptor();
}

void AsioServer::doAsyncShutdown() {
    auto lock = std::scoped_lock{acceptorMutex_, channelsMutex_};
    YIO_LOG_TRACE(*this << ": shutdown; acceptor = " << acceptor_.get() << "; channel count = " << channels_.size());
    if (acceptor_ != nullptr) {
        acceptor_->asyncShutdown();
    }
    for (const auto& channel : channels_) {
        channel->asyncShutdown();
    }
}

void AsioServer::debugPrintDescription(std::ostream& out) const {
    out << "<Server " << this << ">";
}

int AsioServer::getConnectedClientCount() const {
    auto lock = std::scoped_lock{channelsMutex_};
    return channels_.size();
}

int AsioServer::port() const {
    return address_.port;
}

bool AsioServer::waitConnectionsAtLeast(size_t count, std::chrono::milliseconds timeout) const {
    YIO_LOG_TRACE(*this << ": wait for connections at least: begin(count = " << count << ", timeout = " << timeout.count() << "ms)");
    auto lock = std::unique_lock{channelsMutex_};
    const bool result = channelsCV_.wait_for(lock, timeout, [this, count] {
        const auto currentCount = channels_.size();
        YIO_LOG_TRACE(*this << ": wait for connections at least: check(" << currentCount << " >= " << count << ")");
        return currentCount >= count;
    });
    YIO_LOG_TRACE(*this << ": wait for connections at least: end(result = " << (result ? "true" : "false") << ")");
    return result;
}

bool AsioServer::waitConnectionsAtMost(size_t count, std::chrono::milliseconds timeout) const {
    YIO_LOG_TRACE(*this << ": wait for connections at most: begin(count = " << count << ", timeout = " << timeout.count() << "ms)");
    auto lock = std::unique_lock{channelsMutex_};
    const bool result = channelsCV_.wait_for(lock, timeout, [this, count] {
        const auto currentCount = channels_.size();
        YIO_LOG_TRACE(*this << ": wait for connections at most: check(" << currentCount << " <= " << count << ")");
        return currentCount <= count;
    });
    YIO_LOG_TRACE(*this << ": wait for connections at most: end(result = " << (result ? "true" : "false") << ")");
    return result;
}

bool AsioServer::waitListening(bool targetStatus, std::chrono::milliseconds timeout) const {
    YIO_LOG_TRACE(*this << ": wait until listening: begin(timeout = " << timeout.count() << "ms)");
    auto lock = std::unique_lock{acceptorMutex_};
    const bool result = channelsCV_.wait_for(lock, timeout, [this, targetStatus] {
        const bool status = isListening_;
        YIO_LOG_TRACE(*this << ": wait until listening: check(listening = "
                            << (status ? "true" : "false") << "; target = "
                            << (targetStatus ? "true" : "false") << ")");
        return status == targetStatus;
    });
    YIO_LOG_TRACE(*this << ": wait until listening: end(result = " << (result ? "true" : "false") << ")");
    return result;
}

void AsioServer::sendToAll(const Message& message) {
    auto buffer = std::make_shared<TString>();
    SerializeToStringWithFraming(message, buffer.get());

    auto lock = std::scoped_lock{acceptorMutex_, channelsMutex_};
    if (acceptor_ == nullptr && channels_.empty()) {
        YIO_LOG_ERROR_EVENT("AsioServer.SendOnInactiveServer", *this << ": sendToAll called on inactive server, message lost: " << message.Utf8DebugString());
    }
    for (const auto& channel : channels_) {
        channel->asyncSend(buffer);
    }
}

void AsioServer::notifyListen() {
    {
        auto lock = std::scoped_lock{acceptorMutex_};
        isListening_ = true;
    }

    acceptorCV_.notify_all();
}

void AsioServer::notifyStopListening() {
    {
        auto lock = std::scoped_lock{acceptorMutex_};
        acceptor_ = nullptr;
        isListening_ = false;
    }

    acceptorCV_.notify_all();
}

void AsioServer::addTcpChannel(asio::ip::tcp::socket peer) {
    if (isStopped()) {
        YIO_LOG_TRACE(*this << ": stopped; drop tcp channel for " << AsioTcpEndpointsLog{peer});
        return;
    }

    YIO_LOG_TRACE(*this << ": add tcp channel for " << AsioTcpEndpointsLog{peer});
    auto channel = worker()->create<ServerChannel>(worker(), callbacks_, std::move(peer), serviceName_, [weakThis = weak_from_this()](auto* channel) {
        auto this_ = weakThis.lock();
        if (Y_UNLIKELY(!this_)) {
            YIO_LOG_TRACE(*channel << ": on disconnect: server impl is dead");
            return;
        }

        YIO_LOG_TRACE(*channel << ": on disconnect [in " << *this_ << "]");
        this_->removeChannel(channel->shared_from_this());
    });

    YIO_LOG_DEBUG(*this << ": new channel: " << *channel);
    addChannel(channel);

    channel->asyncStart();
}

void AsioServer::addChannel(std::shared_ptr<AsioChannel> channel) {
    {
        auto lock = std::scoped_lock{channelsMutex_};
        channels_.emplace(std::move(channel));
    }

    channelsCV_.notify_all();
}

void AsioServer::removeChannel(std::shared_ptr<AsioChannel> channel) {
    {
        auto lock = std::scoped_lock{channelsMutex_};
        channels_.erase(channel);
    }

    channelsCV_.notify_all();
}

void AsioServer::doStartAcceptor() {
    auto lock = std::scoped_lock{acceptorMutex_};
    Y_VERIFY(acceptor_ == nullptr);

    isListening_ = false;

    acceptor_ = createAcceptor();
    YIO_LOG_DEBUG(*this << ": start acceptor: " << *acceptor_);
    acceptor_->asyncStart();
}

std::shared_ptr<AsioTcpAcceptor> AsioServer::createAcceptor() {
    auto weakThis = weak_from_this();
    return worker()->create<AsioTcpAcceptor>(worker(), serviceName_, address_, AsioTcpAcceptor::Callbacks{.onListen = [weakThis] {
            if (auto this_ = weakThis.lock()) {
                YIO_LOG_TRACE(*this_ << ": acceptor on listen");
                this_->notifyListen();
            } }, .onStopListening = [weakThis] {
            if (auto this_ = weakThis.lock()) {
                YIO_LOG_TRACE(*this_ << ": acceptor on stop listening");
                this_->notifyStopListening();
            } }, .onConnect = [weakThis](asio::ip::tcp::socket peer) {
            if (auto this_ = weakThis.lock()) {
                YIO_LOG_TRACE(*this_ << ": acceptor on connect");
                this_->addTcpChannel(std::move(peer));
            } }});
}
