#include <maps/wikimap/mapspro/services/autocart/libs/auto_toloker/include/auto_toloker.h>

#include <maps/wikimap/mapspro/libs/tf_inferencer/tf_inferencer.h>

#include <algorithm>

namespace maps::wiki::autocart {

using namespace tf_inferencer;

namespace {

const std::string TF_MODEL_RESOURCE
    = "/maps/autocart/auto_toloker/models/tf_model.gdef";

const std::string YES_ANSWER = "yes";
} // namespace

AutoToloker::AutoToloker()
    : classifier_(
        new TensorFlowInferencer(TensorFlowInferencer::fromResource(TF_MODEL_RESOURCE))
    )
{
    static const std::string TF_OUTPUT_CLASS_NAMES_LAYER_NAME = "class_names";

    std::vector<TString> strAnswers = tensorToVector<TString>(
        classifier_->inference(TF_OUTPUT_CLASS_NAMES_LAYER_NAME)
    );
    REQUIRE(strAnswers.size() == 2, "Auto toloker should have two answers");
    auto yesIt = std::find(
        strAnswers.begin(), strAnswers.end(),
        TString(YES_ANSWER)
    );
    REQUIRE(yesIt != strAnswers.end(), "Auto toloker shold have yes answer");
    yesAnswerIndex_ = std::distance(strAnswers.begin(), yesIt);
}

AutoToloker::~AutoToloker()
{}

float AutoToloker::classify(
    const cv::Mat& imageBGR, const cv::Mat& mask) const
{
    static const std::string TF_INPUT_IMAGE_NAME = "inference_image";
    static const std::string TF_INPUT_MASK_NAME = "inference_mask";
    static const std::string TF_OUTPUT_SOFTMAX_LAYER_NAME = "inference_softmax";

    auto outputs = classifier_->inference(
        {{TF_INPUT_IMAGE_NAME, imageBGR}, {TF_INPUT_MASK_NAME, mask}},
        std::vector<std::string>{TF_OUTPUT_SOFTMAX_LAYER_NAME}
    );

    REQUIRE(outputs.size() == 1, "Invalid output tensors number");

    auto confidences = outputs[0].matrix<float>();
    return confidences(0, yesAnswerIndex_);
}

} // namespace maps::wiki::autocart
