#include "glagol_ws_server.h"

#include <yandex_io/libs/base/retry_delay_counter.h>
#include <yandex_io/libs/device/device.h>

YIO_DEFINE_LOG_MODULE("glagol");

using namespace quasar;
using namespace std::chrono;
using namespace quasar::proto;

using std::placeholders::_1;
using std::placeholders::_2;
using std::placeholders::_3;
using std::placeholders::_4;

namespace {
    void trimSquareBrackets(std::string& str) {
        if (str.size() >= 2 && str.front() == '[' && str.back() == ']') {
            str.resize(str.size() - 1);
            str.erase(0, 1);
        }
    }
} // namespace

GlagolWsServer::ConnectionDetails::ConnectionDetails(std::string h)
    : host(std::move(h))
          {};

bool GlagolWsServer::ConnectionDetails::changeType(proto::GlagoldState::ConnectionType newType) {
    if (newType != type) {
        type = newType;
        return true;
    }
    return false;
}

GlagolWsServer::GlagolWsServer(std::shared_ptr<YandexIO::IDevice> device,
                               std::shared_ptr<ipc::IIpcFactory> ipcFactory,
                               std::shared_ptr<IAuthProvider> authProvider,
                               std::shared_ptr<IDeviceStateProvider> deviceStateProvider)
    : device_(std::move(device))
    , ipcFactory_(std::move(ipcFactory))
    , authProvider_(std::move(authProvider))
    , deviceStateProvider_(std::move(deviceStateProvider))
    , backendUrl_(device_->configuration()->getServiceConfig("common")["backendUrl"].asString())
    , backendApi_(authProvider_, device_)
{
    deviceId_ = glagol::DeviceId{device_->deviceId(), device_->configuration()->getDeviceType()};
    auto config = device_->configuration()->getServiceConfig("glagold");
    if (config.isMember("externalPort")) {
        configuredPort_ = config["externalPort"].asInt();
    }
}

void GlagolWsServer::notifyGlagoldStatus() {
    std::lock_guard<std::mutex> lock(mutex_);
    notifyGlagoldStatusNoLock();
}

// NB: this function must be called with lock on mutex_;
void GlagolWsServer::notifyGlagoldStatusNoLock() {
    GlagoldState state;

    for (auto const& [hdl, details] : verifiedConnections_) {
        state.mutable_connections()->Add()->set_type(details.type);
    }

    if (onConnectionsChanged_) {
        auto callback = [onConnectionsChanged = onConnectionsChanged_, state]() {
            onConnectionsChanged(state);
        };
        /* Call user callback in separate thread (without any mutex locks) */
        userCallbackQueue_.add(callback);
    }
}

GlagolWsServer::~GlagolWsServer() {
    lifetime_.die();
    {
        std::lock_guard<std::mutex> lock(mutex_);
        quit_ = true;
        toServerThreadCV_.notify_all();
    }
    if (serverThread_.joinable()) {
        serverThread_.join();
    }
}

void GlagolWsServer::start() {
    authProvider_->ownerAuthInfo().connect([this](std::shared_ptr<const AuthInfo2> authInfo) {
        handleAuthInfo(std::move(authInfo));
    }, lifetime_);

    deviceStateProvider_->deviceState().connect([this](std::shared_ptr<const DeviceState> deviceState) {
        handleDeviceState(std::move(deviceState));
    }, lifetime_);

    serverThread_ = std::thread(&GlagolWsServer::serverThreadFunc, this);
}

bool GlagolWsServer::myCertsChanged(const glagol::BackendApi::DevicesMap& oldDevices) const {
    auto oldSettings = oldDevices.find(deviceId_);
    auto newSettings = accountDevices_.find(deviceId_);
    if (oldSettings == oldDevices.end()) {
        return newSettings != accountDevices_.end();
    };
    if (newSettings == accountDevices_.end()) {
        return true; // maybe useless, we cannot start server in this case
    }
    return newSettings->second != oldSettings->second;
}

void GlagolWsServer::updateAccountDevices(glagol::BackendApi::DevicesMap newDevices) {
    const bool certsChanged = [this, &newDevices] {
        std::scoped_lock lock(accountDevicesMutex_);
        accountDevices_.swap(newDevices);
        return myCertsChanged(newDevices);
    }();

    if (certsChanged) {
        YIO_LOG_INFO("My certs changed. Notifying server thread.");
        toServerThreadCV_.notify_all();
    }
}

void GlagolWsServer::handleAuthInfo(std::shared_ptr<const AuthInfo2> authInfo) {
    std::lock_guard<std::mutex> lock(mutex_);
    if (authInfo_.authToken != authInfo->authToken) {
        authInfo_ = *authInfo;
        YIO_LOG_INFO("Auth token changed. Notify server thread.");
        backendApi_.setSettings({.url = backendUrl_, .token = authInfo_.authToken});
        toServerThreadCV_.notify_all();
    } else if (authInfo_.isAuthorized()) {
        toServerThreadCV_.notify_all();
    }
}

bool GlagolWsServer::configured() const {
    return !hasCriticalUpdate_ && configurationState_ == DeviceState::Configuration::CONFIGURED;
}

void GlagolWsServer::handleDeviceState(std::shared_ptr<const DeviceState> deviceState) {
    YIO_LOG_DEBUG("Configuration state came");
    std::lock_guard<std::mutex> lock(mutex_);
    bool newHasCritical = deviceState->update == DeviceState::Update::HAS_CRITICAL;
    if (newHasCritical != hasCriticalUpdate_) {
        YIO_LOG_DEBUG("Critical updates state changed from " << hasCriticalUpdate_ << " to " << newHasCritical);
        hasCriticalUpdate_ = newHasCritical;
        toServerThreadCV_.notify_all();
    }
    const auto previousConfigurationState = std::exchange(configurationState_, deviceState->configuration);
    if (configured() && configurationState_ != previousConfigurationState) {
        toServerThreadCV_.notify_all();
    }
}

WebsocketServer::Settings GlagolWsServer::waitForWsSettings() {
    WebsocketServer::Settings settings;
    settings.logErrorAsWarn = true;
    std::unique_lock<std::mutex> lock(mutex_);
    if (configuredPort_) {
        settings.port = *configuredPort_;
    }

    toServerThreadCV_.wait(lock,
                           [this, &settings]() {
                               if (quit_) {
                                   return true;
                               }
                               if (!authInfo_.isAuthorized() || !configured()) {
                                   return false;
                               }
                               std::lock_guard<std::mutex> lock(accountDevicesMutex_);
                               auto iter = accountDevices_.find(deviceId_);
                               if (iter == accountDevices_.end() || !iter->second.glagol.security.filledForServer()) {
                                   return false;
                               }
                               settings.tls.keyPemBuffer = iter->second.glagol.security.serverPrivateKey;
                               settings.tls.crtPemBuffer = iter->second.glagol.security.serverCertificate;
                               return true;
                           });
    return settings;
};

bool GlagolWsServer::wsSettingsChanged(const WebsocketServer::Settings& oldSettings) {
    // FIXME: what about port?
    if (quit_ || !configured() || !authInfo_.isAuthorized()) { // stop to wait for auth and configure
        return true;
    }
    std::lock_guard<std::mutex> lock(accountDevicesMutex_);
    auto iter = accountDevices_.find(deviceId_);
    return (iter == accountDevices_.end() // we was started before, so if key has disappeared we should stop and wait for new key
            ||
            (oldSettings.tls.keyPemBuffer != iter->second.glagol.security.serverPrivateKey ||
             oldSettings.tls.crtPemBuffer != iter->second.glagol.security.serverCertificate));
}

bool GlagolWsServer::quit() const {
    std::lock_guard<std::mutex> lock(mutex_);
    return quit_;
}

void GlagolWsServer::serverThreadFunc() {
    while (!quit()) {
        auto settings = waitForWsSettings(); // do not start util authToken, backendUrl and certificate was set
        std::unique_lock<std::mutex> lock(mutex_);
        if (quit_) {
            break;
        }
        try {
            YIO_LOG_INFO("Restarting WS server due to new auth data. Settings.port = " << settings.port);
            server_ = std::make_unique<WebsocketServer>(settings, device_->telemetry());
            server_->setOnCloseHandler(std::bind(&GlagolWsServer::onClose, this, ::_1, ::_2));
            server_->setOnMessageHandler(std::bind(&GlagolWsServer::onMessage, this, ::_1, ::_2));
            port_ = server_->start();
            YIO_LOG_INFO("Restarting WS server is done");
            fromServerThreadCV_.notify_all();
        } catch (const std::exception& exception) {
            YIO_LOG_ERROR_EVENT("GlagolWsServer.FailedStart.UnknownError", "exception: " << exception.what());
        } catch (...) {
            YIO_LOG_ERROR_EVENT("GlagolWsServer.FailedStart.UnknownError", "unknown exception");
        };
        YIO_LOG_INFO("Waiting for ws settings changes...");
        toServerThreadCV_.wait(lock,
                               [this, &settings]() {
                                   return wsSettingsChanged(settings);
                               });

        YIO_LOG_INFO("Notified to reconfigure or stop. Close " << verifiedConnections_.size() << " clients");

        if (!verifiedConnections_.empty()) {
            for (auto& [connection, details] : verifiedConnections_) {
                closeConnection(connection, "restarting", Websocket::StatusCode::SERVICE_RESTART);
            }
            verifiedConnections_.clear();
            notifyGlagoldStatusNoLock();
        }
        port_.reset();
        auto destroyingServer = std::move(server_);
        lock.unlock(); // onClose can happen on destructing WebsocketServer
    }
    YIO_LOG_INFO("Server thread completed");
}

std::optional<int> GlagolWsServer::getConfiguredPort() const {
    std::lock_guard<std::mutex> lock(mutex_);
    return configuredPort_;
}

void GlagolWsServer::waitServerStart() const {
    getPort();
}

int GlagolWsServer::getPort() const {
    YIO_LOG_DEBUG("Waiting for start");
    std::unique_lock<std::mutex> lock(mutex_);
    fromServerThreadCV_.wait(lock, [this]() {
        return bool(port_);
    });
    YIO_LOG_DEBUG("Waiting for start .. DONE. port: " << *port_);
    return *port_;
}

void GlagolWsServer::setGuestMode(bool newValue) {
    guestMode_ = newValue;
}

void GlagolWsServer::setOnMessage(OnMessageCallback onMessage) {
    std::lock_guard<std::mutex> lock(mutex_);
    onMessage_ = std::move(onMessage);
}

void GlagolWsServer::setOnClose(OnCloseCallback cb) {
    std::lock_guard<std::mutex> lock(mutex_);
    onClose_ = std::move(cb);
}

void GlagolWsServer::setOnConnectionsChanged(std::function<void(quasar::proto::GlagoldState)> onConnectionsChanged) {
    if (onConnectionsChanged == nullptr) {
        YIO_LOG_ERROR_EVENT("GlagolWsServer.InvalidOnConnectionsChangedCallback", "Setting not valid (nullptr) callback!");
        throw std::runtime_error("Setting not valid (nullptr) callback!");
    }
    std::lock_guard<std::mutex> lock(mutex_);
    onConnectionsChanged_ = std::move(onConnectionsChanged);
}

void GlagolWsServer::onClose(WebsocketServer::ConnectionHdl hdl, Websocket::ConnectionInfo connectionInfo) {
    std::optional<ConnectionDetails> closedDetails;
    OnCloseCallback onCloseCopy;
    {
        std::lock_guard<std::mutex> guard(mutex_);
        auto iter = verifiedConnections_.find(hdl);
        if (iter != std::end(verifiedConnections_)) {
            if (onClose_) {
                closedDetails = std::move(iter->second);
                onCloseCopy = onClose_;
            }
            verifiedConnections_.erase(iter);
            notifyGlagoldStatusNoLock(); // a connection was closed, tell everybody!
        }
    }

    if (closedDetails) { // call in such way to avoid recurse lock on mutex_
        onCloseCopy(*closedDetails);
    }

    Json::Value event;
    event["message"] = connectionInfo.toString();
    device_->telemetry()->reportEvent("glagold_client_disconnected", jsonToString(event));
}

/* run under lock */
void GlagolWsServer::closeConnection(WebsocketServer::ConnectionHdl hdl, const std::string& error_string, Websocket::StatusCode error_code) {
    try {
        if (server_) {
            server_->close(hdl, error_string, error_code);
        } else {
            YIO_LOG_WARN("Cannot close. Server has disappeared");
        }
    } catch (const websocketpp::exception& exception) {
        YIO_LOG_ERROR_EVENT("GlagolWsServer.FailedClose.WebsocketError", "Cannot close: " << exception.what());
    } catch (const std::exception& exception) {
        YIO_LOG_ERROR_EVENT("GlagolWsServer.FailedClose.UnknownError", "Cannot close: " << exception.what());
    }
}

void GlagolWsServer::onMessage(WebsocketServer::ConnectionHdl hdl, const std::string& msgStr) {
    Json::Value msg;
    try {
        msg = parseJson(msgStr);
    } catch (const Json::Exception& exception) {
        YIO_LOG_WARN("GlagolWSServer: Can't parse glagol incoming message: " << exception.what());
        return;
    }

    auto getOnMessage = [this, &hdl](const std::string& token) {
        std::lock_guard<std::mutex> lock(mutex_);
        proto::GlagoldState::ConnectionType connType = proto::GlagoldState::GLAGOL_APP;

        auto jwtToken = processToken(token);
        if (!jwtToken) {
            YIO_LOG_DEBUG("glagol. Invalid token");
            device_->telemetry()->reportEvent("invalid_jwt_token");
            return std::make_tuple(OnMessageCallback(), connType, false, "Invalid token", Websocket::StatusCode::INVALID_TOKEN);
        }

        if (onMessage_) {
            auto [iter, inserted] = verifiedConnections_.emplace(hdl, server_->getRemoteHost(hdl));
            if (inserted) { // if new connection
                            // FIXME: change to OTHER when implementing QUASAR-3805
                            // for now no other clients exist
                notifyGlagoldStatusNoLock();
            }
            connType = iter->second.type;
        }
        return std::make_tuple(onMessage_, connType, jwtToken->isGuest(), "Reconfiguring", Websocket::StatusCode::SERVICE_RESTART);
    };

    auto [currentOnMessage, connType, guestMode, error_string, error_code] = getOnMessage(msg["conversationToken"].asString());
    if (currentOnMessage) {
        currentOnMessage(hdl, {.type = connType, .guestMode = guestMode}, msg);
    } else {
        YIO_LOG_INFO("Wrong client: " << error_string);
        std::lock_guard<std::mutex> lock(mutex_);
        closeConnection(hdl, error_string, error_code);
    }
}

void GlagolWsServer::removeExpiredTokens() {
    auto now = system_clock::now();
    auto it = jwtTokensExpire_.begin();
    auto end = jwtTokensExpire_.end();
    while (it != end && (*it)->second.getExpiration() < now) {
        verifiedJwtTokens_.erase(*it);
        it = jwtTokensExpire_.erase(it);
    }
}

std::optional<GlagolWsServer::JwtToken> GlagolWsServer::processToken(const std::string& token) {
    if (token.empty()) {
        YIO_LOG_DEBUG("glagol. Received empty token");
        return {};
    }
    removeExpiredTokens();
    {
        auto iter = verifiedJwtTokens_.find(token);
        if (iter != verifiedJwtTokens_.end()) {
            return iter->second;
        }
    }
    YIO_LOG_DEBUG("glagol. New jwt token");
    try {
        JwtToken newToken(token);
        YIO_LOG_DEBUG("glagol. correctDevice: " << newToken.isForDevice(deviceId_) << " isExpired: " << newToken.isExpired());
        if (newToken.isForDevice(deviceId_) && !newToken.isExpired()) {
            auto checkResult = backendApi_.checkToken2(token);
            if (checkResult.owner || (checkResult.guest && newToken.isGuest())) {
                auto [iter, inserted] = verifiedJwtTokens_.emplace(token, std::move(newToken));
                jwtTokensExpire_.insert(iter);
                YIO_LOG_DEBUG("glagol. Token is ok, saving token");
                if (iter->second.isGuest() && !guestMode_) { // check here to avoid requesting backend again
                    YIO_LOG_DEBUG("Guest-token in non guest mode is not accepted");
                    return {};
                }
                return iter->second;
            };
        }
    } catch (const glagol::BackendApi::Exception& exception) {
        YIO_LOG_WARN("GlagolWSServer: Can't check token: " << exception.what());
    }
    return {};
}

void GlagolWsServer::send(WebsocketServer::ConnectionHdl hdl, const std::string& msg) {
    std::lock_guard<std::mutex> lock(mutex_);
    YIO_LOG_TRACE(msg);
    sendNoLock(hdl, msg);
}

void GlagolWsServer::sendNoLock(WebsocketServer::ConnectionHdl hdl, const std::string& msg) {
    if (server_) {
        try {
            server_->send(hdl, msg);
        } catch (const websocketpp::exception& exception) {
            YIO_LOG_ERROR_EVENT("GlagolWsServer.FailedSend.WebsocketError", "glagol. cant send: " << exception.what());
        } catch (const std::exception& exception) {
            YIO_LOG_ERROR_EVENT("GlagolWsServer.FailedSend.UnknownError", "glagol. cant send: " << exception.what());
        }
    } else
        YIO_LOG_WARN("server is not initialized");
}

void GlagolWsServer::sendAll(const std::string& msg, const SendFilterPredicate& pred) {
    std::unique_lock<std::mutex> lock(mutex_);
    if (verifiedConnections_.empty()) {
        YIO_LOG_TRACE(msg);
        return;
    }

    std::vector<std::tuple<WebsocketServer::ConnectionHdl, proto::GlagoldState::ConnectionType>> hdls;
    hdls.reserve(verifiedConnections_.size());
    // send only to verified connections
    for (const auto& [connection, details] : verifiedConnections_) {
        hdls.emplace_back(connection, details.type);
    }

    if (pred) {
        lock.unlock();
        hdls.erase(std::remove_if(std::begin(hdls),
                                  std::end(hdls),
                                  [&pred](auto& connectionType) -> bool {
                                      auto& [connection, type] = connectionType;
                                      return !pred(connection, type);
                                  }),
                   std::end(hdls));
        lock.lock();
    }

    if (server_) {
        for (auto const& [connection, type] : hdls) {
            sendNoLock(connection, msg);
        }
    } else {
        YIO_LOG_WARN("server is not initialized");
    }
}

proto::GlagoldState::ConnectionType GlagolWsServer::connectionType(WebsocketServer::ConnectionHdl hdl) const {
    std::lock_guard<std::mutex> lock(mutex_);

    auto iter = verifiedConnections_.find(hdl);
    if (iter != verifiedConnections_.end()) {
        return iter->second.type;
    }
    return proto::GlagoldState::OTHER;
}

void GlagolWsServer::setConnectionType(WebsocketServer::ConnectionHdl hdl, const GlagoldState::ConnectionType& connectionType) {
    applyConnectionDetails(hdl, [this, &connectionType](ConnectionDetails& details) {
        if (details.changeType(connectionType)) {
            notifyGlagoldStatusNoLock();
        }
    });
}

std::optional<glagol::ResolveInfo> GlagolWsServer::setConnectionDetailsExt(WebsocketServer::ConnectionHdl hdl,
                                                                           const proto::GlagoldState::ConnectionType& connectionType,
                                                                           std::string deviceId) {
    std::optional<glagol::ResolveInfo> result;

    applyConnectionDetails(hdl, [&, this](ConnectionDetails& details) {
        if (details.changeType(connectionType)) {
            notifyGlagoldStatusNoLock();
        }
        details.deviceId = std::move(deviceId);
        std::string address = details.host;
        trimSquareBrackets(address);
        if (configuredPort_) {
            result = glagol::ResolveInfo{
                .address = std::move(address),
                .protocol = glagol::ResolveInfo::protoByAddress(details.host),
                .port = *configuredPort_,
                .cluster = true,
            };
        }
    });
    return result;
}

Json::Value GlagolWsServer::connectedDevicesTelemetry(Json::Value result) const {
    result["stats"]["websocket_is_active"] = bool(server_);
    auto telemetryConnectionName = [](proto::GlagoldState::ConnectionType type) {
        return type == proto::GlagoldState::YANDEXIO_DEVICE ? "pilot_host" : "cluster_host";
    };
    std::lock_guard<std::mutex> lock(mutex_);
    for (auto [hdl, details] : verifiedConnections_) {
        if (details.deviceId) {
            result[*details.deviceId][telemetryConnectionName(details.type)] = details.host;
        }
    }
    return result;
}

/****************************************************************
 *               JwtToken code
 ****************************************************************/

const std::string GlagolWsServer::JwtToken::DEVICE_ID_GRANT = "sub";
const std::string GlagolWsServer::JwtToken::DEVICE_PLATFORM_GRANT = "plt";
const std::string GlagolWsServer::JwtToken::TOKEN_EXPIRATION_GRANT = "exp";
const std::string GlagolWsServer::JwtToken::GUEST_GRANT = "gst";

GlagolWsServer::JwtToken::JwtToken(const std::string& token)
    : JwtToken(decodeJWT(token))
          {};

GlagolWsServer::JwtToken::JwtToken(JwtPtr jwt)
    : deviceId_{
          getStringGrantFromJWT(jwt.get(), DEVICE_ID_GRANT),
          getStringGrantFromJWT(jwt.get(), DEVICE_PLATFORM_GRANT)}
    , guest_(getBoolOrFalseGrantFromJWT(jwt.get(), GUEST_GRANT))
    , expirationTimePoint_(system_clock::from_time_t(getLongGrantFromJWT(jwt.get(), TOKEN_EXPIRATION_GRANT)))
{
}

bool GlagolWsServer::JwtToken::isForDevice(const glagol::DeviceId& deviceId) const {
    return deviceId_ == deviceId;
}

bool GlagolWsServer::JwtToken::isExpired() const {
    return expirationTimePoint_ < system_clock::now();
}

bool GlagolWsServer::JwtToken::isGuest() const {
    return guest_;
}

std::chrono::system_clock::time_point GlagolWsServer::JwtToken::getExpiration() const {
    return expirationTimePoint_;
}
