#include "device_controller.h"

#include <yandex_io/libs/base/directives.h>
#include <yandex_io/libs/json_utils/json_utils.h>
#include <yandex_io/libs/logging/logging.h>
#include <yandex_io/libs/protobuf_utils/proto_trace.h>
#include <yandex_io/services/aliced/capabilities/alice_capability/directive_factory.h>

#include <alice/megamind/protos/common/atm.pb.h>
#include <alice/megamind/protos/common/frame.pb.h>
#include <alice/megamind/protos/common/iot.pb.h>

#include <google/protobuf/any.pb.h>

YIO_DEFINE_LOG_MODULE("device_controller");

namespace {

    Json::Value createSemanticFrame(const NAlice::TSemanticFrameRequestData& data) {
        Json::Value semanticFrame;
        semanticFrame["name"] = "@@mm_semantic_frame";
        semanticFrame["type"] = "server_action";
        semanticFrame["payload"] = quasar::convertMessageToJson(data, true).value();

        return semanticFrame;
    }

} // unnamed namespace

namespace YandexIO {

    DeviceController::DeviceController(
        std::shared_ptr<quasar::ICallbackQueue> worker,
        const quasar::AliceConfig& aliceConfig,
        std::shared_ptr<quasar::LegacyIotCapability> legacyIotCapability,
        std::unique_ptr<quasar::IBackoffRetries> backoffer)
        : worker_(std::move(worker))
        , directiveProcessor_(std::make_shared<DirectiveProcessor>())
        , aliceConfig_(aliceConfig)
        , legacyIotCapability_(std::move(legacyIotCapability))
        , remotingMessageRouter_(std::make_shared<RemotingMessageRouter>())
        , backoffer_(std::move(backoffer))
    {
        directiveProcessor_->setDirectiveFactory(std::make_shared<DirectiveFactory>());
        backoffer_->initCheckPeriod(std::chrono::milliseconds(1000), std::chrono::milliseconds(0), std::chrono::milliseconds(600000), std::chrono::milliseconds(1000));
    }

    void DeviceController::init(std::string deivceId, NAlice::TEndpoint::EEndpointType type)
    {
        endpointStorage_ = std::make_shared<EndpointStorageHost>(remotingMessageRouter_, std::move(deivceId), type);
        endpointStorage_->init();
        endpointStorage_->addListener(shared_from_this());
        endpointStorage_->getLocalEndpoint()->addListener(weak_from_this());
    }

    void DeviceController::setAliceCapability(std::weak_ptr<IAliceCapability> aliceCapability)
    {
        aliceCapability_ = std::move(aliceCapability);
    }

    const std::shared_ptr<EndpointStorageHost>& DeviceController::getEndpointStorage() const {
        return endpointStorage_;
    }

    const std::shared_ptr<DirectiveProcessor>& DeviceController::getDirectiveProcessor() const {
        return directiveProcessor_;
    }

    const std::shared_ptr<RemotingMessageRouter>& DeviceController::getRemotingMessageRouter() const {
        return remotingMessageRouter_;
    }

    void DeviceController::fireCapabilitiesConfig()
    {
        quasar::proto::CapabilityConfig config;
        config.set_iot_capability_enabled(aliceConfig_.getIotCapabilityEnabled());

        endpointStorage_->setCapabilityConfig(std::move(config));
    }

    void DeviceController::onIotCapabilityEnabledChanged()
    {
        YIO_LOG_INFO("onIotCapabilityEnabledChanged: " << aliceConfig_.getIotCapabilityEnabled());

        if (aliceConfig_.getIotCapabilityEnabled()) {
            getDirectiveProcessor()->removeDirectiveHandler(legacyIotCapability_);
            fireCapabilitiesConfig();
        } else {
            if (const auto& handler = getDirectiveProcessor()->findDirectiveHandlerByName("IotCapabilityDirectiveHandler")) {
                getDirectiveProcessor()->removeDirectiveHandler(handler);
                fireCapabilitiesConfig();
            }

            Y_VERIFY(getDirectiveProcessor()->addDirectiveHandler(legacyIotCapability_));
        }
    }

    void DeviceController::onEndpointAdded(const std::shared_ptr<IEndpoint>& endpoint)
    {
        YIO_LOG_INFO("onEndpointAdded " << endpoint.get() << ", " << endpoint->getId());

        endpoint->addListener(weak_from_this());
        for (const auto& capability : endpoint->getCapabilities()) {
            registerCapability(capability);
        }
    }

    void DeviceController::onEndpointRemoved(const std::shared_ptr<IEndpoint>& endpoint)
    {
        YIO_LOG_INFO("onEndpointRemoved " << endpoint.get() << ", " << endpoint->getId());

        endpoint->removeListener(weak_from_this());
        for (const auto& capability : endpoint->getCapabilities()) {
            unregisterCapability(capability);
        }
    }

    void DeviceController::onCapabilityConfigChanged(const quasar::proto::CapabilityConfig& /*config*/)
    {
        // do nothing
    }

    void DeviceController::onCapabilityStateChanged(const std::shared_ptr<ICapability>& changedCapability, const NAlice::TCapabilityHolder& /*state*/)
    {
        for (const auto& endpoint : endpointStorage_->getEndpoints()) {
            for (const auto& capability : endpoint->getCapabilities()) {
                if (capability != changedCapability) {
                    continue;
                }

                const auto& message = capability->getState();
                const auto reportable = getReportable(message);

                if (!reportable) {
                    YIO_LOG_WARN("Unable to get Reportable field from TCapabilityHolder");
                }

                if (*reportable) {
                    scheduleEndpointUpdate(endpoint, changedCapability);
                }
                return;
            }
        }
    }

    void DeviceController::onCapabilityEvents(const std::shared_ptr<ICapability>& sourceCapability, const std::vector<NAlice::TCapabilityEvent>& events) {
        for (const auto& endpoint : endpointStorage_->getEndpoints()) {
            for (const auto& capability : endpoint->getCapabilities()) {
                if (capability == sourceCapability) {
                    sendCapabilityEvents(endpoint, events);
                    return;
                }
            }
        }
    }

    std::optional<bool> DeviceController::getReportable(const NAlice::TCapabilityHolder& message) {
        static const auto reflection = NAlice::TCapabilityHolder::GetReflection();
        static const auto oneofDescriptor = NAlice::TCapabilityHolder::GetDescriptor()->FindOneofByName("Capability");
        if (!oneofDescriptor) {
            return std::nullopt;
        }

        const auto capabilityDescriptor = reflection->GetOneofFieldDescriptor(message, oneofDescriptor);
        if (!capabilityDescriptor) {
            return std::nullopt;
        }

        const auto& m = reflection->GetMessage(message, capabilityDescriptor);
        const auto metaReflection = m.GetReflection();
        const auto metaDescriptor = m.GetDescriptor()->FindFieldByName("Meta");
        if (!metaDescriptor) {
            return std::nullopt;
        }
        const auto& mm = metaReflection->GetMessage(m, metaDescriptor);
        const auto reportableDescriptor = mm.GetDescriptor()->FindFieldByName("Reportable");
        if (!reportableDescriptor) {
            return std::nullopt;
        }
        const auto mmetaReflection = mm.GetReflection();
        return mmetaReflection->GetBool(mm, reportableDescriptor);
    }

    void DeviceController::onCapabilityAdded(
        const std::shared_ptr<IEndpoint>& /*enpdoint*/,
        const std::shared_ptr<ICapability>& capability)
    {
        YIO_LOG_INFO("onCapabilityAdded " << capability->getId());
        registerCapability(capability);
    }

    void DeviceController::onCapabilityRemoved(
        const std::shared_ptr<IEndpoint>& /*enpdoint*/,
        const std::shared_ptr<ICapability>& capability)
    {
        YIO_LOG_INFO("onCapabilityRemoved " << capability->getId());
        unregisterCapability(capability);
    }

    void DeviceController::onEndpointStateChanged(
        const std::shared_ptr<IEndpoint>& endpoint)
    {
        YIO_LOG_DEBUG("onEndpointStateChanged " << endpoint->getId());
        sendCapabilityEvents(endpoint, {});
    }

    void DeviceController::sendUpdatedEndpoints()
    {
        NAlice::TSemanticFrameRequestData semanticFrame;

        auto analytics = semanticFrame.mutable_analytics();
        analytics->set_purpose("endpoint_state_updates");
        analytics->set_origin(NAlice::TAnalyticsTrackingModule::SmartSpeaker);

        auto endpointStateUpdatesTsf = semanticFrame.mutable_typedsemanticframe()->mutable_endpointstateupdatessemanticframe();

        auto requestValue = endpointStateUpdatesTsf->mutable_request()->mutable_requestvalue();

        for (const auto& [weakEndpoint, capabilities] : updatedEndpoints_) {
            auto endpoint = weakEndpoint.lock();
            if (!endpoint) {
                continue;
            }

            YIO_LOG_DEBUG("Updating endpoint " << endpoint->getId());

            auto state = endpoint->getState();
            state.mutable_capabilities()->Clear();

            for (const auto& weakCapability : capabilities) {
                auto capability = weakCapability.lock();
                if (!capability) {
                    continue;
                }

                YIO_LOG_DEBUG("Updating capability " << capability->getId() << " of endpoint " << endpoint->getId());

                try {
                    const auto capabilityState = capability->getState();
                    const auto& message = ICapability::getCapabilityFromHolder(capabilityState);
                    state.add_capabilities()->PackFrom(message);
                } catch (const std::invalid_argument& e) {
                    YIO_LOG_WARN("Failed to add capability: " << e.what());
                }
            }

            requestValue->add_endpointupdates()->CopyFrom(state);
        }

        if (auto aliceCapability = aliceCapability_.lock()) {
            auto request = VinsRequest::createEventRequest(createSemanticFrame(semanticFrame), VinsRequest::createSoftwareDirectiveEventSource());

            request->setIsParallel(aliceConfig_.getUseParallelRequests());
            request->setEnqueued(!request->getIsParallel());
            request->setIsSilent(true);
            request->setIsIgnoreCriticalUpdate(true);
            request->setVoiceSession(false);

            aliceCapability->startRequest(std::move(request), shared_from_this());
        }
    }

    void DeviceController::onAliceRequestStarted(std::shared_ptr<VinsRequest> /* request */) {
    }

    void DeviceController::onAliceRequestCompleted(std::shared_ptr<VinsRequest> request, const Json::Value& /* response */) {
        if (eventsBeingSent_.erase(request->getId()) > 0) {
            YIO_LOG_DEBUG("Events for request " << request->getId() << " were sent");
        }
        // We managed to send everything so backoff should be turned off
        if (eventsBeingSent_.empty()) {
            backoffer_->resetDelayBetweenCallsToDefault();
            YIO_LOG_DEBUG("Event sending backoff reset to default " << backoffer_->getDelayBetweenCalls().count());
        }
    }

    void DeviceController::onAliceRequestError(std::shared_ptr<VinsRequest> request, const std::string& errorCode,
                                               const std::string& errorText) {
        YIO_LOG_WARN("Unable to send request " << request->getId() << ", " << errorCode << ": " << errorText);

        if (!eventsBeingSent_.contains(request->getId())) {
            YIO_LOG_WARN("No events are known for requestId " << request->getId() << ", skipping");
            return;
        }

        backoffer_->increaseDelayBetweenCalls();
        YIO_LOG_DEBUG("Current delay between calls: " << backoffer_->getDelayBetweenCalls().count());

        const auto endpointId = eventsBeingSent_[request->getId()].endpointId;
        auto& endpointEventsToBeSent = eventsToBeSent_[endpointId];
        const auto endpointEventsBeingSent = eventsBeingSent_[request->getId()].events;
        endpointEventsToBeSent.reserve(endpointEventsToBeSent.size() + endpointEventsBeingSent.size());
        endpointEventsToBeSent.insert(endpointEventsToBeSent.end(), endpointEventsBeingSent.begin(), endpointEventsBeingSent.end());
        eventsBeingSent_.erase(request->getId());

        worker_->addDelayed([this, a = shared_from_this(), request, endpointId]() {
            const auto endpoints = endpointStorage_->getEndpoints();
            const auto endpoint = std::find_if(endpoints.begin(), endpoints.end(),
                                               [&endpointId](const std::shared_ptr<IEndpoint>& endpoint) { return endpoint->getId() == endpointId; });

            if (endpoint == endpoints.end()) {
                YIO_LOG_WARN("Trying to send events for missing endpoint " << endpointId);
                return;
            }

            if (eventsToBeSent_.contains((*endpoint)->getId())) {
                sendCapabilityEvents(*endpoint, {});
            }
        }, backoffer_->getDelayBetweenCalls());
    }

    void DeviceController::handleEndpointsUpdate() {
        if (!updatedEndpoints_.empty()) {
            sendUpdatedEndpoints();
            updatedEndpoints_.clear();
        }

        hasScheduledEndpointUpdate_ = false;
    }

    void DeviceController::scheduleEndpointUpdate(const std::shared_ptr<IEndpoint>& endpoint, const std::shared_ptr<ICapability>& capability)
    {
        auto [it, _] = updatedEndpoints_.emplace(endpoint, UpdatedCapabilities{});

        if (capability != nullptr) {
            it->second.insert(capability);
        }

        if (!hasScheduledEndpointUpdate_) {
            hasScheduledEndpointUpdate_ = true;

            worker_->addDelayed(
                [this]() {
                    handleEndpointsUpdate();
                },
                aliceConfig_.getEndpointsUpdateTimeout());

            YIO_LOG_INFO("EndpointUpdate scheduled");
        }
    }

    void DeviceController::registerCapability(const std::shared_ptr<ICapability>& capability)
    {
        capability->addListener(weak_from_this());
        if (const auto& directiveHandler = capability->getDirectiveHandler()) {
            if (!directiveProcessor_->addDirectiveHandler(directiveHandler)) {
                std::stringstream ss;
                ss << "Failed to register DirectiveHandler "
                   << directiveHandler->getHandlerName()
                   << " for capability " << quasar::shortUtf8DebugString(capability->getState());
                YIO_LOG_ERROR_EVENT("DeviceController.FailedToRegisterCapabilityDirectiveHandler", ss.str());
            }
        }
    }

    void DeviceController::unregisterCapability(const std::shared_ptr<ICapability>& capability)
    {
        capability->removeListener(weak_from_this());
        if (const auto& directiveHandler = capability->getDirectiveHandler()) {
            directiveProcessor_->removeDirectiveHandler(directiveHandler);
        }
    }

    void DeviceController::sendCapabilityEvents(const std::shared_ptr<IEndpoint>& endpoint, const std::vector<NAlice::TCapabilityEvent>& events) {
        NAlice::TSemanticFrameRequestData semanticFrame;

        auto analytics = semanticFrame.mutable_analytics();
        analytics->set_purpose("endpoint_events_batch");
        analytics->set_origin(NAlice::TAnalyticsTrackingModule::SmartSpeaker);

        auto endpointEventsBatchTsf = semanticFrame.mutable_typedsemanticframe()->mutable_endpointeventsbatchsemanticframe();

        auto endpointEvents = endpointEventsBatchTsf->mutable_batch()->mutable_batchvalue()->add_batch();

        const std::string endpointId = endpoint->getId();

        // Check if there are any additional events that are waiting on a timeout to be sent
        if (eventsToBeSent_[endpointId].size() > 0) {
            YIO_LOG_DEBUG("Events to be retried for endpoint " << endpointId << ": " << eventsToBeSent_[endpointId].size());
        }

        eventsToBeSent_[endpointId].reserve(eventsToBeSent_[endpointId].size() + events.size());
        eventsToBeSent_[endpointId].insert(eventsToBeSent_[endpointId].end(), events.begin(), events.end());

        endpointEvents->set_endpointid(TString{endpointId});
        endpointEvents->mutable_endpointstatus()->CopyFrom(endpoint->getStatus());

        for (const auto& event : eventsToBeSent_[endpointId]) {
            const auto json = quasar::convertMessageToJson(event, true);
            endpointEvents->add_capabilityevents()->CopyFrom(event);
        }

        if (auto aliceCapability = aliceCapability_.lock()) {
            auto request = VinsRequest::createEventRequest(createSemanticFrame(semanticFrame), VinsRequest::createSoftwareDirectiveEventSource());

            request->setIsParallel(aliceConfig_.getUseParallelRequests());
            request->setEnqueued(!request->getIsParallel());
            request->setIsSilent(true);
            request->setIsIgnoreCriticalUpdate(true);
            request->setVoiceSession(false);

            eventsBeingSent_[request->getId()] = {.endpointId = endpointId, .events = eventsToBeSent_[endpointId]};
            eventsToBeSent_.erase(endpointId);

            aliceCapability->startRequest(std::move(request), shared_from_this());
        }
    }

} // namespace YandexIO
