#define _GLIBCXX_USE_NANOSLEEP 1

#include "tcp_connector.h"

#include <yandex_io/libs/base/utils.h>
#include <yandex_io/libs/errno/errno_exception.h>
#include <yandex_io/libs/logging/logging.h>

#include <yandex_io/external_libs/datacratic/soa/service/passive_endpoint.h>

#include <util/system/byteorder.h>

#include <memory>

#include <fcntl.h>
#include <netinet/in.h>

YIO_DEFINE_LOG_MODULE("datacratic_ipc");

using namespace quasar;
using namespace Datacratic;
using namespace quasar::ipc::detail::datacratic;

namespace {

    bool isIpv6(const std::string& ip)
    {
        bool hasColon = false;
        for (auto ch : ip) {
            if ((ch >= '0' && ch <= '9') ||
                (ch >= 'a' && ch <= 'f') ||
                (ch >= 'A' && ch <= 'F') ||
                ch == ':')
            {
                hasColon = hasColon || ch == ':';
                continue;
            } else {
                return false;
            }
        }
        return hasColon;
    }

} // namespace

namespace quasar::ipc::detail::datacratic {

    struct ConnectManager: public ConnectionHandler {
        ConnectManager()
            : success_(false)
        {
        }
        void doError(const std::string& error) override {
            Y_VERIFY(dynamic_cast<TCPConnector*>(transport().get_endpoint()) != nullptr);
            TCPConnector* tcpConnector = static_cast<TCPConnector*>(transport().get_endpoint());
            YIO_LOG_DEBUG("ConnectManager::doError: " << error << " (" << tcpConnector->hostname() << ":" << tcpConnector->port() << ")");
            if (tcpConnector->onConnectionError_) {
                tcpConnector->onConnectionError_(error);
            }
        }

        void handleOutput() override {
            Y_VERIFY(false == success_); // We can be here only once

            transport().cancelTimer();
            stopWriting();
            // Connection finished or has an error; check which one
            int error = 0;
            socklen_t error_len = sizeof(int);
            int res = getsockopt(getHandle(), SOL_SOCKET, SO_ERROR, &error,
                                 &error_len);
            if (res == -1 || error_len != sizeof(int)) {
                throw std::runtime_error(std::string("error getting connect message: ") + strError(errno));
            }

            if (error != 0)
            {
                handleError(strError(error));
                return;
            }

            success_ = true;
            //        transport().hasConnection();

            Y_VERIFY(dynamic_cast<TCPConnector*>(transport().get_endpoint()) != nullptr);
            TCPConnector* endpoint = static_cast<TCPConnector*>(transport().get_endpoint());
            endpoint->onConnect(transport());
        }

    private:
        bool success_;
    };

} // namespace quasar::ipc::detail::datacratic

TCPConnector::TCPConnector()
    : TCPConnector("TCPConnector")
{
}

TCPConnector::TCPConnector(const std::string& name)
    : EndpointBase(name)
    , testName_(name)

{
    getTokenizer = [=]() {
        return std::make_shared<LineTokenizer>();
    };
}

void TCPConnector::setConnectHandler(OnConnect handler)
{
    connectHandler_ = std::move(handler);
}

void TCPConnector::setDisconnectHandler(OnDisconnect handler)
{
    disconnectHandler_ = std::move(handler);
}

void TCPConnector::setConnectionErrorHandler(OnConnectionError handler)
{
    onConnectionError_ = std::move(handler);
}

void TCPConnector::init(std::string host, int port, const std::string& threadName)
{
    if (port_ != -1 || !host_.empty()) {
        throw std::runtime_error(
            "TCPConnector is already initialized. port: " + std::to_string(this->port()) + " host: '" + hostname() + "'");
    }
    if (port < 0) {
        throw std::runtime_error("TCPConnector init failed due invalid port number " + std::to_string(port));
    }

    host_ = std::move(host);
    port_ = port;

    spinup(1, true, threadName);
    startConnecting();
}

void TCPConnector::waitUntilConnected() const {
    std::unique_lock lock(connectionLock_);

    while (nullptr == connection_) {
        connectionCond_.wait(lock);
    }
}

void TCPConnector::waitUntilDisconnected() const {
    std::unique_lock lock(connectionLock_);

    while (connection_ != nullptr) {
        connectionCond_.wait(lock);
    }
}

void TCPConnector::closePeer()
{
    stopReconnectingThread(true);
}

bool TCPConnector::isConnected() const {
    std::lock_guard lock(connectionLock_);

    return connection_ != nullptr;
}

std::shared_ptr<TCPConnector::TCPConnectionHandler> TCPConnector::getConnection() const {
    std::lock_guard lock(connectionLock_);
    return connection_;
}

void TCPConnector::startConnecting()
{
    Y_VERIFY(!isConnected());
    int sock = createNonBlockingTCPSocket();
    std::shared_ptr<SocketTransport> newTransport =
        std::make_shared<SocketTransport>(this);

    newTransport->peer_ = SOCK_Stream(sock);
    newTransport->hasConnection();
    auto connectManager = std::make_shared<ConnectManager>();
    std::weak_ptr<ConnectManager> connectManagerWeak = connectManager;
    newTransport->associate(connectManager);

    newTransport->doAsync(std::bind(&TCPConnector::stopReconnectingThread, this, false), "stop reconnecting");

    Datacratic::INET_Addr addr;
    int family = (isIpv6(hostname()) ? AF_INET6 : AF_INET);
    notifyNewTransport(newTransport);
    if (addr.set(port(), hostname().c_str(), family) != 0)
    {
        connectManager->doAsync([this, connectManagerWeak]() {
            if (auto connectManager = connectManagerWeak.lock()) {
                connectManager->handleError("Bad address: " + hostname());
            }
        }, "handleError");
        return;
    }

    int result = -1;

    do {
        result = ::connect(newTransport->peer().get_handle(),
                           (struct sockaddr*)addr.get_addr(), addr.get_addr_size());
    } while (-1 == result && EINTR == errno);

    if (-1 == result)
    {
        if (EAGAIN == errno || EINPROGRESS == errno)
        {
            std::weak_ptr<SocketTransport> newTransportWeak = newTransport;
            auto finishSetup = [newTransportWeak]()
            {
                auto transport = newTransportWeak.lock();
                if (transport != nullptr)
                {
                    // Ready to write
                    transport->startWriting();
                }
            };

            // Call the rest in a handler context
            newTransport->doAsync(finishSetup, "connect");
        } else {
            std::string errorMessage = std::string("connect() error: ") + strError(errno);
            connectManager->doAsync([connectManagerWeak, errorMessage]() {
                if (auto connectManager = connectManagerWeak.lock()) {
                    connectManager->handleError(errorMessage);
                }
            }, "handleError");
            return;
        }
    } else {
        connectManager->handleOutput();
    }
}

void TCPConnector::startReconnecting()
{
    stopReconnectingThread(false);

    std::unique_lock lock(reconnectThreadLock_);
    Y_VERIFY(nullptr == reconnectThread_);
    if (dontStartReconnect_)
    {
        YIO_LOG_DEBUG("TCPConnector: Don't starting reconnect because of shutdown.");
        return;
    }

    reconnectThreadStopped_.store(false);
    reconnectThread_ = std::make_unique<std::thread>([this]() {
        std::mutex predicateMutex; // does not really protect anything, since predicate is atomic anyway
        std::unique_lock lock(predicateMutex);
        bool isStopped = reconnectThreadCond_.wait_for(lock, std::chrono::seconds(1), [this] {
            return reconnectThreadStopped_.load();
        });

        if (!isStopped) {
            startConnecting();
        }
    });
}

void TCPConnector::stopReconnectingThread(bool disableReconnect)
{
    std::unique_lock lock(reconnectThreadLock_);

    if (reconnectThread_ != nullptr) {
        YIO_LOG_DEBUG("Waiting for reconnect thread to stop.");
        reconnectThreadStopped_.store(true);
        reconnectThreadCond_.notify_one();
        reconnectThread_->join();
        YIO_LOG_DEBUG("Reconnect thread has stopped");
        reconnectThread_.reset();
    }

    if (disableReconnect) {
        dontStartReconnect_ = true;
    }
}

int TCPConnector::createNonBlockingTCPSocket() const {
    Y_VERIFY(port_ != -1);
    int sock = -1;

    // Assign local port to the socket. Do not allow to assign the port to which we will be connecting.
    while (true)
    {
        sock = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, 0);
        if (sock < 0)
        {
            throw std::runtime_error(
                std::string("Cannot create socket: ") + strError(errno));
        }
        struct sockaddr_in sin;
        memset(&sin, 0, sizeof(sin));
        sin.sin_family = AF_INET;
        sin.sin_port = 0;
        sin.sin_addr.s_addr = INADDR_ANY;
        int res = ::bind(sock, (struct sockaddr*)&sin, sizeof(sin));
        if (-1 == res) {
            throw std::runtime_error(std::string("TCPConnector: Cannot bind address: ") + strError(errno));
        }
        socklen_t addrLen = sizeof(sin);
        res = getsockname(sock, (struct sockaddr*)&sin, &addrLen);
        if (-1 == res) {
            throw std::runtime_error(std::string("TCPConnector: Cannot get local socket port:) ") + strError(errno));
        }
        const int assignedPort = InetToHost<uint16_t>(sin.sin_port);
        if (assignedPort != port_) {
            break;
        }
        ::close(sock);
    }

    Y_VERIFY(sock > 0);
    int res = fcntl(sock, F_SETFL, O_NONBLOCK);
    if (res != 0) {
        throw std::runtime_error(
            "Cannot set socket " + std::to_string(sock) + std::string(" to nonblocking mode: ") + strError(errno));
    }

    return sock;
}

void TCPConnector::onConnect(TransportBase& transport)
{
    std::unique_lock lock(connectionLock_);

    if (connection_ != nullptr) {
        throw std::runtime_error(testName_);
    }
    Y_VERIFY(nullptr == connection_);
    connection_ = std::make_shared<TCPConnectionHandler>();
    connection_->tokenizer_ = getTokenizer();

    auto onToken = [=](const std::string& response) {
        handleMessageReceived(response);
    };

    connection_->tokenizer_->onToken = onToken;

    transport.associate(connection_);

    lock.unlock();

    if (connectHandler_) {
        connectHandler_();
    }

    connectionCond_.notify_all();
}

void TCPConnector::onDisconnectThreadUnsafe()
{
    Y_VERIFY(connection_ != nullptr);

    connection_.reset();
    connectionCond_.notify_all();

    startReconnecting();
}

std::function<void()> TCPConnector::doBeforeDisconnect(std::shared_ptr<TCPConnectionHandler> /* connection */)
{
    return std::function<void()>();
}

void TCPConnector::shutdown()
{
    closePeer();
    EndpointBase::shutdown();
}

TCPConnector::~TCPConnector()
{
    shutdown();
}

/** Tell the endpoint that a connection has been closed. */
void TCPConnector::notifyCloseTransport(
    const std::shared_ptr<Datacratic::TransportBase>& transport)
{
    EndpointBase::notifyCloseTransport(transport);

    std::unique_lock lock(connectionLock_);
    if (connection_ != nullptr)
    {
        auto doAfterDisconnect = doBeforeDisconnect(connection_);
        onDisconnectThreadUnsafe(); // It will call startReconnecting()
        Y_VERIFY(nullptr == connection_);
        lock.unlock();
        if (disconnectHandler_) {
            try {
                disconnectHandler_();
            } catch (...) {
            }
        }
        if (doAfterDisconnect) {
            doAfterDisconnect();
        }
    } else {
        startReconnecting();
    }
}

bool TCPConnector::sendMessage(std::string data)
{
    std::lock_guard lock(connectionLock_);
    return sendMessageUnlocked(std::move(data));
}

bool TCPConnector::sendMessageUnlocked(std::string data)
{
    if (nullptr == connection_) {
        return false;
    }

    connection_->send(data);
    return true;
}

void TCPConnector::TCPConnectionHandler::handleData(const std::string& data)
{
    Y_VERIFY(tokenizer_ != nullptr);
    tokenizer_->pushData(data);
}

void TCPConnector::doError(const std::string& errorMessage)
{
    std::unique_lock lock(connectionLock_);
    if (connection_ != nullptr)
    {
        connection_->doError(errorMessage);
    }
}

void TCPConnector::TCPConnectionHandler::handleError(const std::string& message)
{
    transport().assertLockedByThisThread();
    YIO_LOG_ERROR_EVENT("DatacraticTcpConnector.HandleError", "TCPConnector: Error: " << message);
    Y_VERIFY(dynamic_cast<TCPConnector*>(transport().get_endpoint()) != nullptr);
    TCPConnector* tcpConnector = static_cast<TCPConnector*>(transport().get_endpoint());
    if (tcpConnector->onConnectionError_) {
        tcpConnector->onConnectionError_(message);
    }
    closeWhenHandlerFinished();
}

void TCPConnector::TCPConnectionHandler::handleDisconnect()
{
    PassiveConnectionHandler::handleDisconnect();
}

void TCPConnector::TCPConnectionHandler::handleTimeout(Date /* time */, size_t /* cookie */)
{
    doError("Timeout occurred");
}

void TCPConnector::TCPConnectionHandler::onGotTransport()
{
    startReading();
}
