#include "websocket_server.h"

#include <yandex_io/libs/base/utils.h>
#include <yandex_io/libs/logging/logging.h>

#include <list>
#include <memory>

YIO_DEFINE_LOG_MODULE("websocket");

using namespace quasar;
using namespace quasar::Websocket;
using namespace std::chrono;

using websocketpp::lib::bind;
using websocketpp::lib::placeholders::_1;
using websocketpp::lib::placeholders::_2;

const char WebsocketServer::dh_params_[] = "-----BEGIN DH PARAMETERS-----\n"
                                           "MIGHAoGBAN8oCk/aFMGRVTaN3NuSxGt/BtFVWoCM4GORDIahPmBf8ReosFrUWkfI\n"
                                           "Nuw0PecXhgMmr3leEdCA8oL26MqyOETZCv5qeARvegzHL/2N0WX46Hj9wFSV2Y1r\n"
                                           "jLtrO4ffF6dF2OM5ex5EP6ZLIRfbf1Kc+guqjKQbIoJ5zYjJ9XkjAgEC\n"
                                           "-----END DH PARAMETERS-----";

const websocketpp::lib::asio::const_buffer WebsocketServer::dh_temp_params_(dh_params_, strlen(dh_params_));

const std::string WebsocketServer::ciphers_ = "kEECDH+AESGCM+AES128:kEECDH+AES128:kRSA+AESGCM+AES128:kRSA+AES128:DES-CBC3-SHA:!RC4:!aNULL:!eNULL:!MD5:!EXPORT:!LOW:!SEED:!CAMELLIA:!IDEA:!PSK:!SRP:!SSLv2;";

WebsocketServer::WebsocketServer(Settings settings, std::shared_ptr<YandexIO::ITelemetry> telemetry)
    : telemetry_(std::move(telemetry))
{
    settings_ = std::move(settings);
    settings_.ensureCorrect();
}

WebsocketServer::~WebsocketServer() {
    std::unique_lock<std::mutex> lock(mutex_);
    if (serverThread_.joinable()) {
        clientsLastPongs_.clear();
        try {
            server_.stop();
        } catch (const std::exception& exception) {
            YIO_LOG_ERROR_EVENT("WebsocketServer.ServerStop.Exception", "WebsocketServer can't stop in destructor: " << exception.what());
            telemetry_->reportKeyValues("WSServerException", {{"destructor_stop", exception.what()}});
        }
        lock.unlock();
        serverThread_.join();
    }
}

void WebsocketServer::setOnMessageHandler(WebsocketServer::OnMessage onMessage) {
    std::lock_guard<std::mutex> lock(mutex_);
    onMessage_ = std::move(onMessage);
}

void WebsocketServer::setOnOpenHandler(WebsocketServer::OnOpen onOpen) {
    std::lock_guard<std::mutex> lock(mutex_);
    onOpen_ = std::move(onOpen);
}

void WebsocketServer::setOnCloseHandler(WebsocketServer::OnClose onClose) {
    std::lock_guard<std::mutex> lock(mutex_);
    onClose_ = std::move(onClose);
}

int WebsocketServer::start() {
    std::lock_guard<std::mutex> lock(mutex_);
    try {
        server_.set_access_channels(websocketpp::log::alevel::all);

        // dont log frame headers and bodies
        server_.clear_access_channels(websocketpp::log::alevel::frame_payload);
        server_.clear_access_channels(websocketpp::log::alevel::frame_header);
        server_.get_elog().setErrorAsWarn(settings_.logErrorAsWarn);

        server_.init_asio();

        server_.set_message_handler(bind(&WebsocketServer::onMessage, this, ::_1, ::_2));
        server_.set_open_handler(bind(&WebsocketServer::onOpen, this, ::_1));
        server_.set_close_handler(bind(&WebsocketServer::onClose, this, ::_1));
        server_.set_tls_init_handler(bind(&WebsocketServer::onTlsInit, this, ::_1));
        server_.set_reuse_addr(true); // sets SO_REUSEADDR so we can restart with no delay

        server_.set_ping_handler(std::bind(&WebsocketServer::onPing, this, ::_1, ::_2));
        server_.set_pong_handler(bind(&WebsocketServer::onPong, this, ::_1, ::_2));
    } catch (const std::exception& exception) {
        handleFatalException(exception, "init");
    }
    if (settings_.ping.enabled) {
        lastPongCheckerPtr_ = std::make_unique<PeriodicExecutor>(
            [this]() {
                try {
                    checkClientsLastPongs();
                } catch (const std::exception& exception) {
                    YIO_LOG_ERROR_EVENT("WebsocketServer.Pong.Exception", "PONG unhandled exception: " << exception.what());
                } catch (...) {
                    YIO_LOG_ERROR_EVENT("WebsocketServer.Pong.UnknownError", "PONG unknown unhandled exception");
                }
            },
            settings_.ping.interval);
    }

    int port;
    if (settings_.port) {
        port = *settings_.port;
        try {
            server_.listen(port);
        } catch (const std::exception& exception) {
            handleFatalException(exception, "listen");
        }
    } else {
        /*
         * Testing purposes
         */
        port = findFirstFreePortAndListen();
    }
    YIO_LOG_INFO("Starting WebsocketServer on port; " << port);

    serverThread_ = std::thread(&WebsocketServer::run, this);
    return port;
}

void WebsocketServer::setOnPingHandler(WebsocketServer::OnPing pingHandler)
{
    std::lock_guard<std::mutex> lock(mutex_);
    onPing_ = std::move(pingHandler);
}

int WebsocketServer::findFirstFreePortAndListen() {
    int port = 15000;
    bool portFound = false;
    while (!portFound && port < 16000) {
        try {
            server_.listen(port);
            portFound = true;
        } catch (std::exception& e) {
            YIO_LOG_DEBUG(e.what());
            port++;
        }
    }
    if (!portFound) {
        throw std::runtime_error("Free port not found in [1500, 16000) range");
    }
    return port;
}

void WebsocketServer::run() {
    try {
        server_.start_accept();
        server_.run();
    } catch (const std::exception& exception) {
        handleFatalException(exception, "start_accept_run");
    }
}

WebsocketServer::ContextPtr WebsocketServer::onTlsInit(ConnectionHdl /* hdl */) {
    namespace asio = websocketpp::lib::asio;

    ContextPtr ctx = websocketpp::lib::make_shared<asio::ssl::context>(asio::ssl::context::tlsv12);

    ctx->set_options(asio::ssl::context::default_workarounds |
                     asio::ssl::context::single_dh_use);

    ctx->use_tmp_dh(dh_temp_params_);
    ctx->use_certificate_chain(asio::const_buffer(settings_.tls.crtPemBuffer.data(), settings_.tls.crtPemBuffer.size()));
    ctx->use_private_key(asio::const_buffer(settings_.tls.keyPemBuffer.data(), settings_.tls.keyPemBuffer.size()),
                         asio::ssl::context::pem);

    if (SSL_CTX_set_cipher_list(ctx->native_handle(), ciphers_.c_str()) != 1) {
        YIO_LOG_ERROR_EVENT("WebsocketServer.SetCipherError", "Error setting cipher list");
    }

    return ctx;
}

void WebsocketServer::onMessage(ConnectionHdl hdl, MessagePtr msg) {
    YIO_LOG_DEBUG("onMessage called with hdl: " << hdl.lock().get() << " and message with size {" << msg->get_payload().length() << "} bytes");
    if (onMessage_) {
        onMessage_(hdl, msg->get_payload());
    }
}

void WebsocketServer::onOpen(ConnectionHdl hdl) {
    YIO_LOG_DEBUG("New WS connection: " << hdl.lock().get());
    std::unique_lock<std::mutex> lock(mutex_);
    clientsLastPongs_[hdl] = steady_clock::now();
    if (onOpen_) {
        auto onOpenCopy = onOpen_;
        lock.unlock();
        onOpenCopy(hdl);
    }
}

void WebsocketServer::onClose(ConnectionHdl hdl) {
    ConnectionInfo connectionInfo;
    try {
        connectionInfo = getConnectionInfo(server_.get_con_from_hdl(hdl), hdl);
    } catch (const std::exception& exception) {
        connectionInfo.local.closeReason = exception.what();
        connectionInfo.remote.closeReason = exception.what();
        YIO_LOG_ERROR_EVENT("WebsocketServer.OnClose.Exception", "WebsocketServer cant get connection info" << exception.what());
        telemetry_->reportKeyValues("WSServerException", {{"get_con_from_hdl", exception.what()}});
    }

    YIO_LOG_DEBUG("Closed WS connection: " << hdl.lock().get() << " " << connectionInfo.toString());
    std::unique_lock<std::mutex> lock(mutex_);
    clientsLastPongs_.erase(hdl);
    if (onClose_) {
        auto onCloseCopy = onClose_;
        lock.unlock();
        onCloseCopy(hdl, connectionInfo);
    }
}

bool WebsocketServer::onPing(ConnectionHdl hdl, const std::string& payload) {
    YIO_LOG_TRACE("PING");
    if (onPing_) {
        return onPing_(hdl, payload);
    } else {
        // Just send pong back
        return settings_.pong.enabled;
    }
}

void WebsocketServer::onPong(ConnectionHdl hdl, const std::string& /* payload */) {
    YIO_LOG_TRACE("PONG");
    std::lock_guard<std::mutex> lock(mutex_);
    clientsLastPongs_[hdl] = std::chrono::steady_clock::now();
}

void WebsocketServer::send(ConnectionHdl hdl, const std::string& msg) {
    try {
        server_.send(hdl, msg, websocketpp::frame::opcode::text);
    } catch (websocketpp::exception& exception) {
        YIO_LOG_WARN("websocket exception: " << exception.what());
    } catch (const std::exception& exception) {
        handleFatalException(exception, "send");
    }
}

void WebsocketServer::sendAll(const std::string& msg) {
    std::lock_guard<std::mutex> lock(mutex_);

    for (auto const& it : clientsLastPongs_) {
        send(it.first, msg);
    }
}

void WebsocketServer::checkClientsLastPongs() {
    std::lock_guard<std::mutex> lock(mutex_);
    const auto now = std::chrono::steady_clock::now();
    std::list<ConnectionHdl> connectionsToDelete;
    for (const auto& item : clientsLastPongs_) {
        if (now - item.second > pingIntervalCoef_ * settings_.ping.interval) {
            connectionsToDelete.push_back(item.first);
        }
    }
    for (const auto& connection : connectionsToDelete) {
        YIO_LOG_WARN("WsServer, closed connection because of client pong timeout");
        try { //  FIXME[katayad] https://st.yandex-team.ru/QUASAR-4520
            server_.close(connection, (websocketpp::close::status::value)StatusCode::CLIENT_PONG_TIMEOUT, "no pong for long");
        } catch (const std::exception& exception) {
            handleFatalException(exception, "close");
        }

        clientsLastPongs_.erase(connection); // TODO check if close calls onclose handler
    }
    for (const auto& item : clientsLastPongs_) {
        YIO_LOG_TRACE("SENDING_PING");
        try {
            server_.ping(item.first, "server_ping");
        } catch (websocketpp::exception& exception) {
            YIO_LOG_ERROR_EVENT("WebsocketServer.Ping.WebsocketppException", "websocket exception: " << exception.what());
        } catch (const std::exception& exception) {
            handleFatalException(exception, "ping");
        }
    }
}

int WebsocketServer::getConnectionsNumber() {
    std::lock_guard<std::mutex> lock(mutex_);
    return clientsLastPongs_.size();
}

std::string WebsocketServer::extractAsioHost(std::string src) {
    // format [::ffff:IPV4]:PORT
    //[IPV6]:PORT

    // strip port
    const auto pos = src.rfind(']');
    if (pos == std::string::npos) { // wrong format (mostly empty src)
        return src;
    }
    src.resize(pos + 1);
    std::string_view prefix("[::ffff:");
    if (std::string_view(src.c_str(), prefix.size()) == prefix) {
        src.erase(0, prefix.size()); // strip prefix
        src.resize(src.size() - 1);  // strip trailing ]
    }
    return src;
}

std::string WebsocketServer::getRemoteHost(ConnectionHdl hdl) {
    auto connection = server_.get_con_from_hdl(hdl);
    if (connection) {
        return extractAsioHost(connection->get_remote_endpoint());
    }
    return {};
}

void WebsocketServer::close(WebsocketServer::ConnectionHdl hdl, const std::string& reason, StatusCode code) {
    std::lock_guard<std::mutex> lock(mutex_);
    if (clientsLastPongs_.count(hdl)) {
        clientsLastPongs_.erase(hdl);
        try { //  FIXME[katayad] https://st.yandex-team.ru/QUASAR-4520
            server_.close(hdl, (websocketpp::close::status::value)code, reason);
        } catch (const std::exception& exception) {
            handleFatalException(exception, "close");
        }
    }
}

void WebsocketServer::handleFatalException(const std::exception& exception, const std::string& msg) {
    YIO_LOG_ERROR_EVENT("WebsocketServer.FatalException", "WebsocketServer " << msg << ": " << exception.what());
    telemetry_->reportKeyValues("WSServerException", {{msg, exception.what()}});
    std::this_thread::sleep_for(std::chrono::seconds(5));
    throw exception;
}

void WebsocketServer::Settings::ensureCorrect() const {
    if (tls.disabled) {
        return;
    }
    if (tls.keyPemBuffer.empty() || tls.crtPemBuffer.empty()) {
        throw std::runtime_error("Incorrect tls settings for Websocket server. Key or cert is empty");
    }
    if (!port) {
        YIO_LOG_WARN("Port is not specified for WebsocketServer. Using first free one");
    }
}
