#include "unix_socket_transport.h"

#include <yandex_io/libs/errno/errno_exception.h>
#include <yandex_io/libs/logging/logging.h>

#include <thread>

#include <poll.h>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/un.h>

YIO_DEFINE_LOG_MODULE("keymaster_proxy");

using namespace quasar;
using namespace quasar::keymaster_proxy_client;

UnixSocketTransport::UnixSocketTransport(std::string socketPath)
    : socketPath_(std::move(socketPath))
{
}

UnixSocketTransport::~UnixSocketTransport() {
    close();
}

void UnixSocketTransport::connect() {
    int sock;

    if ((sock = socket(AF_UNIX, SOCK_STREAM, 0)) < 0) {
        int socketErrno = errno;
        YIO_LOG_ERROR_EVENT("UnixSocketTransport.OpenSocketFail", "Socket creation error");
        throw ErrnoException(socketErrno, "Failed to create socket");
    } else {
        YIO_LOG_DEBUG("Socket created: " << sock);
    }

    struct sockaddr_un addr;
    addr.sun_family = AF_UNIX;
    strcpy(addr.sun_path, socketPath_.c_str());

    bool success = false;
    int retries = CONNECT_TIMEOUT / WAIT_BETWEEN_CONNECT_RETRIES;
    int connectErrno;
    while (!success && retries > 0) {
        int res = ::connect(sock, (struct sockaddr*)&addr, sizeof(addr));
        connectErrno = errno;
        if (res < 0) {
            YIO_LOG_ERROR_EVENT("UnixSocketTransport.SocketConnectFail", "Failed to connect socket: " << std::string(strerror(connectErrno)));
            if (connectErrno != EINTR) {
                std::this_thread::sleep_for(WAIT_BETWEEN_CONNECT_RETRIES);
                retries--;
            }
        } else {
            sock_ = sock;
            YIO_LOG_DEBUG("Socket " << sock << " connected successfully to path " << socketPath_);
            success = true;
        }
    }
    if (!success) {
        throw ErrnoException(connectErrno, "Failed to connect to keymaster socket");
    }
}

void UnixSocketTransport::close() {
    if (sock_ > 0) {
        ::close(sock_);
        sock_ = -1;
    }
}

void UnixSocketTransport::send(const std::vector<char>& data) {
    struct pollfd pollfds[1];
    pollfds[0].fd = sock_;
    pollfds[0].events = POLLOUT | POLLPRI;
    pollfds[0].revents = 0;

    auto timeout = std::chrono::duration_cast<std::chrono::milliseconds>(WRITE_TIMEOUT);
    size_t offset = 0;
    while (timeout.count() > 0) {
        const auto start = std::chrono::steady_clock::now();
        int result = poll(pollfds, 1, timeout.count());
        if (result < 0) {
            throw ErrnoException(errno, "Poll failed");
        } else if (result > 0) {
            if (pollfds[0].revents & (POLLOUT | POLLPRI)) {
                auto res = ::send(sock_, data.begin() + offset, data.size() - offset, MSG_DONTWAIT);
                if (res < 0) {
                    int sendErrno = errno;
                    if (sendErrno != EINTR) {
                        throw ErrnoException(sendErrno, "Failed to send data to keymaster");
                    }
                } else if (res == 0) {
                    throw std::runtime_error("EOF received from socket");
                } else {
                    offset += res;
                    YIO_LOG_DEBUG("Bytes sent: " << res);
                    if (offset == data.size()) {
                        YIO_LOG_DEBUG("Data sent to keymaster successfully");
                        return;
                    }
                }
            } else if (pollfds[0].revents != 0) {
                YIO_LOG_WARN("Unexpected poll event");
            }
        }
        const auto end = std::chrono::steady_clock::now();
        const auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
        timeout -= duration;
    }

    throw std::runtime_error("Timeout exceeded while sending request");
}

std::vector<char> UnixSocketTransport::receive(int size) {
    std::vector<char> data(size);

    struct pollfd pollfds[1];
    pollfds[0].fd = sock_;
    pollfds[0].events = POLLIN | POLLPRI;
    pollfds[0].revents = 0;

    auto timeout = std::chrono::duration_cast<std::chrono::milliseconds>(READ_TIMEOUT);
    size_t offset = 0;
    while (timeout.count() > 0) {
        const auto start = std::chrono::steady_clock::now();
        int result = poll(pollfds, 1, timeout.count());
        if (result < 0) {
            throw ErrnoException(errno, "Poll failed");
        } else if (result > 0) {
            if (pollfds[0].revents & (POLLIN | POLLPRI)) {
                auto res = ::recv(sock_, data.begin() + offset, data.size() - offset, MSG_DONTWAIT);
                if (res < 0) {
                    int recvErrno = errno;
                    if (recvErrno != EINTR) {
                        throw ErrnoException(recvErrno, "Failed to received data from keymaster");
                    }
                } else if (res == 0) {
                    throw std::runtime_error("EOF received from socket");
                } else {
                    offset += res;
                    YIO_LOG_DEBUG("Bytes received: " << res);
                    if (offset == data.size()) {
                        YIO_LOG_DEBUG("Data received from keymaster successfully");
                        return data;
                    }
                }
            } else if (pollfds[0].revents != 0) {
                YIO_LOG_WARN("Unexpected poll event");
            }
        }
        const auto end = std::chrono::steady_clock::now();
        const auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
        timeout -= duration;
    }

    throw std::runtime_error("Timeout exceeded while receiving response");
}
