#include "command_spotter_capability.h"
#include "spotter_phrases.h"

#include <yandex_io/libs/base/directives.h>
#include <yandex_io/libs/base/utils.h>
#include <yandex_io/libs/ete_metrics/ete_util.h>
#include <yandex_io/libs/json_utils/json_utils.h>
#include <yandex_io/libs/logging/logging.h>
#include <yandex_io/protos/quasar_proto.pb.h>

#include <json/json.h>

YIO_DEFINE_LOG_MODULE("command_spotter_capability");

using namespace quasar;
using namespace YandexIO;

namespace {

    constexpr auto REMOTE_OBJECT_NAME = "CommandSpotterCapability";

} // namespace

CommandSpotterCapability::CommandSpotterCapability(
    std::shared_ptr<quasar::ICallbackQueue> worker,
    const AliceConfig& aliceConfig,
    IAliceDeviceState& deviceState,
    std::weak_ptr<YandexIO::IDirectiveProcessor> directiveProcessor,
    std::shared_ptr<YandexIO::ITelemetry> telemetry,
    std::weak_ptr<YandexIO::IAliceCapability> aliceCapability,
    std::weak_ptr<YandexIO::IPlaybackControlCapability> playbackCapability,
    std::shared_ptr<QuasarVoiceDialog> voiceDialog,
    std::weak_ptr<YandexIO::IRemotingRegistry> remotingRegistry)
    : YandexIO::IRemoteObject(std::move(remotingRegistry))
    , worker_(std::move(worker))
    , directiveProcessor_(std::move(directiveProcessor))
    , telemetry_(std::move(telemetry))
    , aliceConfig_(aliceConfig)
    , deviceState_(deviceState)
    , aliceCapability_(std::move(aliceCapability))
    , playbackCapability_(std::move(playbackCapability))
    , voiceDialog_(std::move(voiceDialog))
{
    initDisabledPhrases();
    initDisabledSpotters();
}

CommandSpotterCapability::~CommandSpotterCapability() {
    if (auto remotingRegistry = getRemotingRegistry().lock()) {
        remotingRegistry->removeRemoteObject(REMOTE_OBJECT_NAME);
    }
}

void CommandSpotterCapability::init() {
    if (auto remotingRegistry = getRemotingRegistry().lock()) {
        remotingRegistry->addRemoteObject(REMOTE_OBJECT_NAME, weak_from_this());
    }
}

void CommandSpotterCapability::onAlarmStarted()
{
    onStateUpdated();
}

void CommandSpotterCapability::onAlarmStopped()
{
    onStateUpdated();
}

void CommandSpotterCapability::onAlarmEnqueued(const quasar::proto::Alarm& /*alarm*/)
{
}

void CommandSpotterCapability::onCapabilityStateChanged(const std::shared_ptr<YandexIO::ICapability>& /*capability*/, const NAlice::TCapabilityHolder& state) {
    if (!state.HasDeviceStateCapability()) {
        return;
    }
    // deviceState_ should have updated state
    onStateUpdated();
}

void CommandSpotterCapability::onCapabilityEvents(const std::shared_ptr<YandexIO::ICapability>& /*capability*/, const std::vector<NAlice::TCapabilityEvent>& /*events*/) {
}

void CommandSpotterCapability::onCommandSpotterBegin()
{
    YIO_LOG_INFO("onCommandSpotterBegin:" << spotterType_);
}

void CommandSpotterCapability::onCommandPhraseSpotted(const std::string& phrase)
{
    if (aliceConfig_.getCommandSpottersEnabled()) {
        YIO_LOG_INFO("onCommandPhraseSpotted " << phrase << ", state=" << spotterType_);
        handlePhrase(phrase);
    }
}

void CommandSpotterCapability::onCommandSpotterError(const SpeechKit::Error& error)
{
    if (error.isSpotterModelError()) {
        YIO_LOG_ERROR_EVENT("CommandSpotterCapability.CommandSpotterError", "Speechkit failed to start due to model error. Will fallback to default spotter model " << spotterType_);

        spottersModels_.erase(spotterType_);
        onModelsUpdated();

        remoteNotifyModelError();
    }
}

SpeechKit::PhraseSpotterSettings CommandSpotterCapability::getSpotterSettings() const {
    SpeechKit::PhraseSpotterSettings settings{""};

    const auto iter = spottersModels_.find(spotterType_);
    if (iter != spottersModels_.end()) {
        settings = SpeechKit::PhraseSpotterSettings{iter->second};
        settings.context = spotterType_;
        settings.resetLogsAfterTrigger = true;

        aliceConfig_.setSpotterLoggingSettings(settings);
    }

    return settings;
}

bool CommandSpotterCapability::isSpotterEnabled(const std::string& spotterType) const {
    return disabledSpotters_.find(spotterType) == disabledSpotters_.end();
}

bool CommandSpotterCapability::isPhraseEnabled(const std::string& phrase) const {
    return disabledPhrases_.find(phrase) == disabledPhrases_.cend();
}

void CommandSpotterCapability::handlePhrase(const std::string& phrase) {
    if (!isPhraseEnabled(phrase)) {
        YIO_LOG_DEBUG("Phrase " << phrase << " is disabled.");
        return;
    }

    Json::Value event;
    event["spotted_phrase"] = phrase;
    telemetry_->reportEvent("customSpotter", jsonToString(event));
    logSpotterCommand(phrase, spotterType_, getSpotterSettings().getModel());

    if (SpotterPhrases::isTurnOn(phrase)) {
        if (auto aliceCapability = aliceCapability_.lock()) {
            aliceCapability->startVoiceInput(VinsRequest::createSoftwareSpotterEventSource());
        }
    } else if (SpotterPhrases::isContinuePlaying(phrase)) {
        if (auto playbackCapability = playbackCapability_.lock()) {
            playbackCapability->play();
        }
    } else if (SpotterPhrases::isStop(phrase)) {
        if (deviceState_.hasPlayingAlarm()) {
            pushCommandDirective(Directives::ALARM_STOP);
        } else if (aliceState_ == proto::AliceState::SPEAKING) {
            if (auto aliceCapability = aliceCapability_.lock()) {
                aliceCapability->stopConversation();
            }
        } else {
            if (auto playbackCapability = playbackCapability_.lock()) {
                playbackCapability->pause();
            }
        }
        // handle isGoForward first because it handles "дальше". isNextTrack is also "дальше"
    } else if (SpotterPhrases::isGoForward(phrase) && spotterType_ == SpotterTypes::NAVIGATION) {
        pushCommandDirective(Directives::GO_FORWARD);
    } else if (SpotterPhrases::isNextTrack(phrase)) {
        if (auto playbackCapability = playbackCapability_.lock()) {
            playbackCapability->next();
        }
    } else if (SpotterPhrases::isGoBackward(phrase)) {
        pushCommandDirective(Directives::GO_BACKWARD);
    } else if (SpotterPhrases::isGoUp(phrase)) {
        pushCommandDirective(Directives::GO_UP);
    } else if (SpotterPhrases::isGoDown(phrase)) {
        pushCommandDirective(Directives::GO_DOWN);
    } else if (SpotterPhrases::isMakeLouder(phrase)) {
        pushCommandDirective(Directives::SOUND_LOUDER);
    } else if (SpotterPhrases::isMakeQuieter(phrase)) {
        pushCommandDirective(Directives::SOUND_QUITER);
    }
}

void CommandSpotterCapability::pushCommandDirective(const std::string& directive) {
    if (const auto directiveProcessor = directiveProcessor_.lock()) {
        Directive::Data data(directive, "local_action");
        data.requestId = makeUUID();

        directiveProcessor->addDirectives({std::make_shared<Directive>(data)});
    }
}

void CommandSpotterCapability::onConfigUpdated() {
    initDisabledPhrases();
    initDisabledSpotters();

    onStateUpdated();
}

void CommandSpotterCapability::onStateUpdated() {
    static constexpr bool forceRecreate = false;
    updateSpotterState(forceRecreate);
}

void CommandSpotterCapability::onModelsUpdated() {
    static constexpr bool forceRecreate = true;
    updateSpotterState(forceRecreate);
}

void CommandSpotterCapability::startCommandSpotter(const std::string& newSpotterType, bool forceRecreate) {
    if (!forceRecreate && spotterType_ == newSpotterType) {
        return;
    }
    YIO_LOG_INFO("startCommandSpotter, spotterType=" << spotterType_ << ", newSpotterType=" << newSpotterType);

    spotterType_ = newSpotterType;

    voiceDialog_->startCommandSpotter(getSpotterSettings());
}

std::string CommandSpotterCapability::defineSpotterType() const {
    if (deviceState_.hasPlayingAlarm()) {
        return SpotterTypes::STOP;
    }

    if (aliceState_ == proto::AliceState::SPEAKING) {
        return SpotterTypes::STOP;
    }
    if (deviceState_.getVideoState().HasCurrentScreen()) {
        if (deviceState_.isLongListenerScreen()) {
            return SpotterTypes::NAVIGATION;
        }

        const auto& screenState = deviceState_.getVideoState().GetCurrentScreen();
        if (screenState == "video_player") {
            return SpotterTypes::VIDEO;
        } else if (screenState == "music_player") {
            return SpotterTypes::MUSIC;
        } else if (screenState == "radio_player") {
            return SpotterTypes::MUSIC;
        } else {
            return SpotterTypes::GENERAL;
        }
    }

    const auto& videoState = deviceState_.getVideoState();
    if (videoState.GetPlayer().HasPause() && !videoState.GetPlayer().GetPause()) {
        return SpotterTypes::VIDEO;
    }

    if (deviceState_.isMediaPlaying()) {
        return SpotterTypes::MUSIC;
    }

    return SpotterTypes::GENERAL;
}

void CommandSpotterCapability::onAliceStateChanged(proto::AliceState state)
{
    if (aliceState_ == state.state()) {
        return;
    }

    aliceState_ = state.state();
    onStateUpdated();
}

void CommandSpotterCapability::onAliceTtsCompleted()
{
}

void CommandSpotterCapability::handleRemotingMessage(
    const quasar::proto::Remoting& message,
    std::shared_ptr<YandexIO::IRemotingConnection> connection)
{
    if (!message.has_spotter_capability_method()) {
        return;
    }

    const auto& method = message.spotter_capability_method();

    if (method.method() != quasar::proto::Remoting::SpotterCapabilityMethod::SET_MODEL_PATHS) {
        return;
    }

    worker_->add([this, method, connection{std::move(connection)}]() mutable {
        connection_ = std::move(connection);

        std::set<std::string> types;
        spottersModels_.clear();
        for (const auto& modelPath : method.model_path()) {
            spottersModels_[modelPath.type()] = modelPath.path();
            types.insert(modelPath.type());
        }

        onModelsUpdated();

        remoteNotifyModelSet(types);
    });
}

void CommandSpotterCapability::logSpotterCommand(
    const std::string& command, std::string_view spotterType, const std::string& spotterModel)
{
    if (!aliceConfig_.getSendNavigationSpotterLog() && spotterType == SpotterTypes::NAVIGATION_OLD) {
        YIO_LOG_DEBUG("Skip logging navigation spotter");
        return;
    }

    if (!aliceConfig_.getSendCommandSpotterLog()) {
        YIO_LOG_DEBUG("Skip logging command spotter");
        return;
    }

    const auto eventType = (spotterType == SpotterTypes::NAVIGATION_OLD) ? "navigationCommand" : "spotterCommand";

    Json::Value args;
    args["device_state"] = deviceState_.formatJson();
    args["environment_state"] = deviceState_.getEnvironmentState().formatJson();
    args["spotter_phrase"] = command;
    args["spotter_model"] = spotterModel;
    telemetry_->reportEvent(eventType, jsonToString(args));
}

void CommandSpotterCapability::initDisabledPhrases() {
    disabledPhrases_.clear();

    for (const auto& phrase : aliceConfig_.getDisabledCommandPhrases()) {
        if (phrase.isString()) {
            disabledPhrases_.emplace(phrase.asString());
        }
    }
}

void CommandSpotterCapability::initDisabledSpotters() {
    disabledSpotters_.clear();

    for (const auto& disabledSpotter : aliceConfig_.getDisabledCommandSpotterTypes()) {
        if (disabledSpotter.isString()) {
            disabledSpotters_.emplace(disabledSpotter.asString());
        }
    }
}

void CommandSpotterCapability::stop() {
    YIO_LOG_INFO("Stop command spotter");

    isActive_ = false;
    spotterType_.clear();
    voiceDialog_->stopCommandSpotter();
}

bool CommandSpotterCapability::isConfigEnabled() {
    return aliceConfig_.getSpottersEnabled() && aliceConfig_.getCommandSpottersEnabled();
}

bool CommandSpotterCapability::isAliceInCommandState() {
    return aliceState_ == proto::AliceState::IDLE || aliceState_ == proto::AliceState::SPEAKING;
}

void CommandSpotterCapability::updateSpotterState(bool forceRecreate)
{
    if (isConfigEnabled() && isAliceInCommandState()) {
        const auto newSpotterType = defineSpotterType();
        if (isSpotterEnabled(newSpotterType)) {
            auto it = spottersModels_.find(newSpotterType);
            if (it != spottersModels_.end() && !it->second.empty()) {
                startCommandSpotter(newSpotterType, forceRecreate);
                isActive_ = true;
                return;
            }
        }
    }

    if (isActive_) {
        stop();
    }
}

void CommandSpotterCapability::remoteNotifyModelSet(const std::set<std::string>& types) {
    if (connection_ == nullptr) {
        return;
    }

    quasar::proto::Remoting remoting;
    remoting.set_remote_object_id(TString(REMOTE_OBJECT_NAME));

    auto method = remoting.mutable_spotter_capability_listener_method();
    method->set_method(quasar::proto::Remoting::SpotterCapabilityListenerMethod::ON_MODEL_SET);

    for (const auto& type : types) {
        method->add_spotter_type(TString(type));
    }

    connection_->sendMessage(remoting);
}

void CommandSpotterCapability::remoteNotifyModelError() {
    if (connection_ == nullptr) {
        return;
    }

    quasar::proto::Remoting remoting;
    remoting.set_remote_object_id(TString(REMOTE_OBJECT_NAME));

    auto method = remoting.mutable_spotter_capability_listener_method();
    method->set_method(quasar::proto::Remoting::SpotterCapabilityListenerMethod::ON_MODEL_ERROR);
    method->add_spotter_type(TString(spotterType_));

    connection_->sendMessage(remoting);
}
