#include "connector2.h"

#include "conversation_model.h"

#include <yandex_io/libs/base/utils.h>

YIO_DEFINE_LOG_MODULE("glagol_sdk");

using namespace std::chrono;
using namespace glagol;
using namespace glagol::ext;
using namespace glagol::model;

namespace {
    const std::string WS_PEER_L3_PROTOCOL = "glagolConnectWsPeerL3Protocol";
} // namespace

Connector::Connector(std::shared_ptr<IBackendApi> backendApi,
                     std::shared_ptr<YandexIO::ITelemetry> telemetry,
                     std::string origin)
    : metricaOrigin_(std::move(origin))
    , telemetry_(telemetry)
    , wsClient_(std::move(telemetry))
    , callbackQueue_(std::make_shared<quasar::NamedCallbackQueue>("GlagolConnector"))
    , backendApi_(std::move(backendApi))
{
    Y_VERIFY(telemetry_.get() != nullptr);
    stopped_ = false;

    wsClient_.setOnConnectHandler(std::bind(&Connector::onWsConnected, this));
    wsClient_.setOnDisconnectHandler(std::bind(&Connector::onWsDisconnected, this, std::placeholders::_1));
    wsClient_.setOnFailHandler(std::bind(&Connector::onWsDisconnected, this, std::placeholders::_1)); // Fail means bad stuff happened -- retry again
    wsClient_.setOnPongHandler(std::bind(&Connector::onWsPong, this));
    wsClient_.setOnMessageHandler(std::bind(&Connector::handleMessage, this, std::placeholders::_1));

    wsSettings_.tls.verifyHostname = false;
    wsSettings_.reconnect.enabled = false;
    wsSettings_.ping.enabled = true;
}

Connector::~Connector() {
    if (destination_) {
        destination_->stopDiscovery();
    }
    lifetime_.die();
    stopped_ = true;
    /* Disconnect in destructor, so wsClient_ won't call any callbacks that captured 'this' */
    wsClient_.disconnectSyncWithTimeout(1000 /*ms*/);
}

void Connector::connect(DiscoveredItemPtr device, Settings settings) {
    retryDelayCounter_.setSettings(settings.backend.retryDelay);

    callbackQueue_->add([this, device{std::move(device)}, settings{settings}]() mutable {
        destination_ = std::move(device);
        state_.set(State::DISCOVERING);
        settings_ = settings;
        metricaContextJson_ = getConnectMetricaContext();
        metricaContext_ = quasar::jsonToString(metricaContextJson_);
        wsSettings_.ping.interval = settings_.websocket.pingTimeout;
        destination_->discover([this](auto connInfo) {
            onDiscovery(connInfo);
        });
    }, lifetime_);
}

void Connector::disconnectAsync() {
    callbackQueue_->add([this]() {
        state_.set(State::DISCONNECTING);
        stopped_ = true;
        wsClient_.disconnectAsync([this]() { // TODO remove onDone
            callbackQueue_->add([this]() {
                state_.set(State::STOPPED);
            }, lifetime_);
        });
    }, lifetime_);
}

/*
 * Handling events
 */

void Connector::onDiscovery(DiscoveredItem::ConnectionData result) {
    callbackQueue_->add([this, result{std::move(result)}]() {
        YIO_LOG_DEBUG("GSDK. onDiscoveryResultChanged")
        if (state_ != State::DISCONNECTING && state_ != State::STOPPED) {
            YIO_LOG_DEBUG("GSDK. device is visible");
            const auto newUrl = "wss://" + result.uri;
            if (newUrl != wsSettings_.url || wsSettings_.tls.crtBuffer != result.tlsCertificate) {
                wsSettings_.protocol = result.protocol == DiscoveredItem::Protocol::IPV4 ? quasar::WebsocketClient::Settings::Protocol::ipv4 : quasar::WebsocketClient::Settings::Protocol::ipv6;
                wsSettings_.url = newUrl;
                wsSettings_.tls.crtBuffer = result.tlsCertificate;
                updateMetricaContext();
                if (jwtToken_.empty()) {
                    state_.set(State::REQUESTING_BACKEND);
                    getToken();
                } else {
                    state_.set(State::CONNECTING);
                    wsClient_.connectAsync(wsSettings_);
                }
            }
        } else {
            YIO_LOG_DEBUG("GSDK. State STOPPED. Ignoring DiscoveryResult");
        }
    }, lifetime_);
}
/**
 * Must be called inside callbackQueue_ thread
 */
void Connector::getToken() {
    bool success = false;
    try {
        jwtToken_ = backendApi_->getToken(destination_->getDeviceId());
        success = true;
    } catch (const IBackendApi::Non200ResponseCodeException& exception) {
        YIO_LOG_ERROR_EVENT("GSDKConnector2.GetToken.Non200ResponseCode", "GSDK. getToken" << exception.what() << " code: " << exception.getResponseCode());
        telemetry_->reportEvent("gsdkConnectBackendConversationTokenFailure", metricaContext_);
    } catch (const IBackendApi::Exception& exception) {
        telemetry_->reportError("gsdkConnectBackendConversationTokenError");
        YIO_LOG_ERROR_EVENT("GSDKConnector2.GetToken.Exception", "GSDK. getToken" << exception.what());
    }
    if (success) {
        YIO_LOG_DEBUG("GSDK. Successfully got token");
        /*
         * We reset retryDelayCounter_ only after successful websocket connection and first message
         */
        state_.set(State::CONNECTING);
        wsClient_.connectAsync(wsSettings_);
    } else {
        YIO_LOG_INFO("GSDK. Scheduled retry for backend request in " << milliseconds(retryDelayCounter_.get()).count());
        callbackQueue_->addDelayed([this]() {
            if (state_ != State::DISCONNECTING && state_ != State::STOPPED) {
                getToken();
            }
        }, retryDelayCounter_.get(), lifetime_);
        retryDelayCounter_.increase();
    }
}

void Connector::onWsConnected() {
    callbackQueue_->add([this]() {
        telemetry_->reportEvent("gsdkConnectWsOpen", metricaContext_);
        if (state_ == State::CONNECTING) {
            state_.set(State::CONNECTED);
        } else {
            YIO_LOG_WARN("GSDK. websocket connected with state " << std::to_string(state_.getState()) << " wtf?!");
        }
    }, lifetime_);
}

Json::Value Connector::getConnectMetricaContext() const {
    Json::Value context = Json::objectValue;
    context["glagolsdk"] = "glagolsdk-cpp";
    if (!metricaOrigin_.empty()) {
        context["origin"] = metricaOrigin_;
    }
    const auto& deviceId = destination_->getDeviceId();
    context["deviceId"] = deviceId.id;
    context["platform"] = deviceId.platform;
    return context;
}

void Connector::updateMetricaContext() {
    if (wsSettings_.protocol == quasar::WebsocketClient::Settings::Protocol::ipv4) {
        metricaContextJson_[WS_PEER_L3_PROTOCOL] = "ipv4";
    } else {
        metricaContextJson_[WS_PEER_L3_PROTOCOL] = "ipv6";
    }
    metricaContext_ = quasar::jsonToString(metricaContextJson_);
}

void Connector::onWsDisconnected(quasar::Websocket::ConnectionInfo connectionInfo) {
    callbackQueue_->add([this, connectionInfo{std::move(connectionInfo)}]() {
        if (state_ == State::CONNECTED || state_ == State::CONNECTING) {
            Json::Value context = metricaContextJson_;
            if (connectionInfo.remote.closeCode == quasar::Websocket::StatusCode::INVALID_TOKEN ||
                connectionInfo.error == quasar::Websocket::Error::TLS_HANDSHAKE_FAILED) {
                YIO_LOG_DEBUG("GSDK. Invalid token or cert! Getting new one");
                state_.set(State::REQUESTING_BACKEND); // FIXME: for tls maybe better discovery state?
                if (connectionInfo.error == quasar::Websocket::Error::TLS_HANDSHAKE_FAILED) {
                    context["error"] = "TLS_HANDSHAKE_FAILED";
                    callbackQueue_->add([this]() {
                        destination_->invalidCert();
                        destination_->discover([this](auto connInfo) {
                            onDiscovery(std::move(connInfo));
                        });
                    });
                } else {
                    backendApi_->invalidToken(destination_->getDeviceId());
                    context["error"] = "INVALID_TOKEN";
                    callbackQueue_->addDelayed([this]() {
                        getToken();
                    }, retryDelayCounter_.get(), lifetime_);
                    retryDelayCounter_.increase();
                }
                telemetry_->reportEvent("gsdkConnectBackendConversationTokenRetry", quasar::jsonToString(context));
            } else { // TODO not CONNECTING?
                const bool wasConnected = state_ == State::CONNECTED;
                state_.set(State::CONNECTING);
                if (wasConnected) {
                    context["wsCloseCode"] = int(connectionInfo.remote.closeCode);
                    telemetry_->reportEvent("gsdkConnectWsClose",
                                            quasar::jsonToString(context));
                }
                callbackQueue_->addDelayed(
                    [this, wasConnected]() {
                        if (wasConnected) {
                            telemetry_->reportEvent(
                                "gsdkConnectWsReconnect", metricaContext_);
                        }
                        wsClient_.connectAsync(wsSettings_);
                    }, settings_.websocket.retryDelay);
            }
        }
    }, lifetime_);
}

void Connector::onWsPong() {
    callbackQueue_->add([this]() {
        if (onPong_) {
            onPong_();
        }
    }, lifetime_);
}

/*
 * Handling events done
 */

void Connector::setOnStateChangedCallback(Connector::OnStateChangedCallback onStateChanged, State knownState) {
    state_.setOnChange(std::move(onStateChanged), knownState);
}

void Connector::setOnMessageCallback(std::function<void(const model::IncomingMessage&)> onMessage) {
    onMessage_ = std::move(onMessage);
}

void Connector::handleMessage(const std::string& message) {
    callbackQueue_->add([this, message]() {
        try {
            YIO_LOG_DEBUG("GSDK. Received message: " << message);
            /*
             * This means that both cert and conversationToken are correct
             */
            retryDelayCounter_.reset();
            model::IncomingMessage response(message);
            if (response.request) {
                YIO_LOG_DEBUG("GSDK. Response delay of " << response.id << ": " << duration_cast<milliseconds>(system_clock::now() - response.request->sentTime).count()
                                                         << "ms");
                if (pendingResponses_.count(response.request->id)) {
                    auto responseCallback = pendingResponses_[response.request->id];
                    YIO_LOG_DEBUG("GSDK. calling custom onResponse for " << response.request->id << " with delay " << duration_cast<milliseconds>(system_clock::now() - response.request->sentTime).count() << "ms");
                    responseCallback(response);
                    pendingResponses_.erase(response.request->id);
                }
            }

            if (onMessage_) {
                if (response.request) {
                    YIO_LOG_DEBUG("GSDK. calling onMessage for " << response.request->id << " with delay "
                                                                 << duration_cast<milliseconds>(system_clock::now() -
                                                                                                response.request->sentTime)
                                                                        .count()
                                                                 << "ms");
                }
                onMessage_(response);
            }

        } catch (const std::exception& e) {
            telemetry_->reportEvent("gsdkConnectWsError", metricaContext_);
            throw;
        }
    }, lifetime_);
}

Json::Value Connector::makeMessage(const Json::Value& value) {
    Json::Value message;
    message["payload"] = value;
    message["sentTime"] = static_cast<int64_t>(duration_cast<milliseconds>(system_clock::now().time_since_epoch()).count()); // KILL ME
    message["id"] = quasar::makeUUID();
    message["conversationToken"] = jwtToken_;
    return message;
}

void Connector::setOnPongCallback(std::function<void()> onPong)
{
    onPong_ = std::move(onPong);
}

std::string Connector::send(const Json::Value& payload) {
    Json::Value jMessage = makeMessage(payload);
    std::string message = quasar::jsonToString(jMessage);
    callbackQueue_->add([this, message{std::move(message)}]() {
        wsClient_.unsafeSend(message);
    }, lifetime_);
    return jMessage["id"].asString();
}

std::string Connector::send(const Json::Value& payload, std::function<void(const model::IncomingMessage&)> onResponse) {
    Json::Value jMessage = makeMessage(payload);
    std::string id = jMessage["id"].asString();
    std::string message = quasar::jsonToString(jMessage);
    callbackQueue_->add([this, onResponse, id{std::move(id)}, message{std::move(message)}]() mutable {
        pendingResponses_[id] = std::move(onResponse);
        wsClient_.unsafeSend(message);
    }, lifetime_);
    return id;
}

model::IncomingMessage Connector::sendSync(const Json::Value& payload,
                                           milliseconds timeOut) {
    auto responsePromiseStrongPtr = std::make_shared<std::promise<model::IncomingMessage>>();
    std::weak_ptr<std::promise<model::IncomingMessage>> responsePromiseWeakPtr = responsePromiseStrongPtr;
    send(payload, [responsePromiseWeakPtr](const model::IncomingMessage& message) {
        auto responsePromiseStrongPtr = responsePromiseWeakPtr.lock();
        if (responsePromiseStrongPtr) {
            responsePromiseStrongPtr->set_value(message);
        }
    });
    auto responseFuture = responsePromiseStrongPtr->get_future();
    auto status = responseFuture.wait_for(timeOut);
    if (status != std::future_status::ready) {
        model::IncomingMessage response;
        response.request = RequestInfo();
        response.request->responseStatus = model::ResponseStatus::TIMEOUT;
        return response;
    }

    return responseFuture.get();
}

/*
 * Must be called inside main thread (in any callback)
 */
const Connector::Settings& Connector::getSettings() const {
    return settings_;
}

/*
 * Must be called inside main thread (in any callback)
 */
Connector::State Connector::getState() const {
    return state_.getState();
}

// Testing purposes only
bool Connector::waitForState(Connector::State state, std::chrono::milliseconds timeOut) const {
    return state_.waitFor(state, timeOut);
}

/*
 * This function can be called with locked mutex since onChangedCallback is not a user's callback.
 * onChangedCallback is a Connector callback, which adds user's callback to callbackQueue_;
 * Look at Connector::setOnStateChangedCallback
 */
void Connector::StateWrapper::set(Connector::State state) {
    if (state != state_) {
        YIO_LOG_DEBUG("GSDK. Connector::State changed to " << std::to_string(state));
        state_ = state;
        std::unique_lock<std::mutex> lock(mutex_);
        wakeUpVar_.notify_one();
        if (!onChanged_) {
            return;
        }
        auto onChangedCopy = onChanged_;
        lock.unlock();
        onChangedCopy(state);
    }
}

void Connector::StateWrapper::setOnChange(Connector::OnStateChangedCallback cb, State knownState) {
    State curState = state_.load();
    {
        std::scoped_lock<std::mutex> lock(mutex_);
        onChanged_ = cb;
    }
    if (cb && knownState != curState) {
        cb(curState);
    }
}

bool Connector::StateWrapper::waitFor(Connector::State state, std::chrono::milliseconds timeOut) const {
    std::mutex mutex;
    std::unique_lock<std::mutex> lock(mutex);
    return wakeUpVar_.wait_for(lock, timeOut, [this, state]() { return state_ == state; });
}

bool Connector::StateWrapper::operator==(Connector::State state) const {
    return state_ == state;
}

bool Connector::StateWrapper::operator!=(Connector::State state) const {
    return state_ != state;
}

Connector::State Connector::StateWrapper::getState() const {
    return state_;
}

std::string std::to_string(glagol::ext::Connector::State state) {
    static_assert((int)glagol::ext::Connector::State::LAST_ENUM == 6, "glagol::Connector::State has changed!");
    static constexpr std::array<const char*, 6> StatesStr = {
        "STOPPED",
        "DISCOVERING",
        "REQUESTING_BACKEND",
        "CONNECTING",
        "CONNECTED",
        "DISCONNECTING",
    };
    return StatesStr.at((int)state);
}
