#include <maps/wikimap/mapspro/services/mrc/libs/number_recognizer/include/number_recognizer.h>

#include <maps/wikimap/mapspro/libs/tf_inferencer/tf_inferencer.h>
#include <maps/libs/common/include/exception.h>

#include <tensorflow/core/common_runtime/dma_helper.h>
#include <opencv2/opencv.hpp>
#include <library/cpp/resource/resource.h>

#include <utility>
#include <codecvt>

namespace maps::mrc::number_recognizer {

using namespace wiki::tf_inferencer;

NumberRecognizer::NumberRecognizer(const std::string& resourceName)
    : tfInferencer_(
        new TensorFlowInferencer(TensorFlowInferencer::fromResource(resourceName))
    ) {
    evalSupportedSymbols(resourceName);
}

NumberRecognizer::~NumberRecognizer() = default;

const std::vector<char16_t>& NumberRecognizer::getSupportedSymbols() const {
    return supportedSymbols_;
}

void NumberRecognizer::evalSupportedSymbols(const std::string& resourceName) {
    static const std::string TF_OUTPUT_SYMBOLS_LAYER_NAME = "symbols:0";
    wiki::tf_inferencer::TensorFlowInferencer tfInferencerSymbols =
         wiki::tf_inferencer::TensorFlowInferencer::fromResource(resourceName);

    tensorflow::Tensor symbolCodes = tfInferencerSymbols.inference(TF_OUTPUT_SYMBOLS_LAYER_NAME);
    REQUIRE((1 == symbolCodes.dims()), "Invalid codes of symbols tensor size");
    const int symbolsCount = symbolCodes.dim_size(0);
    const char16_t *pCodes = static_cast<const char16_t*>(tensorflow::DMAHelper::base(&symbolCodes));
    supportedSymbols_.resize(symbolsCount);
    for(int i = 0; i < symbolsCount; i++) {
        supportedSymbols_[i] = pCodes[i];
    }
}

std::pair<std::string, float> NumberRecognizer::recognize(const cv::Mat& image) const
{
    static const std::string TF_INPUT_LAYER = "inference_input";
    static const std::string TF_OUTPUT_LAYER = "inference_softmax";
    static const int DUMMY_SYMBOL_INDEX = (int)supportedSymbols_.size();

    tensorflow::Tensor outputTensor = tfInferencer_->inference(TF_INPUT_LAYER, image, TF_OUTPUT_LAYER);

    REQUIRE((3 == outputTensor.dims()), "Invalid output tensor size");
    REQUIRE((1 == outputTensor.dim_size(0)), "Invalid output tensor batch size");
    const int seqLength = outputTensor.dim_size(1);
    const int symbolsCount = outputTensor.dim_size(2);
    REQUIRE((DUMMY_SYMBOL_INDEX + 1 == symbolsCount), "Invalid symbols count in output tensor");

    const float *pConfidence = static_cast<const float*>(tensorflow::DMAHelper::base(&outputTensor));

    std::pair<std::string, float> result;
    result.second = 1.f;

    std::u16string resultString;
    for (int i = 0; i < seqLength; i++) {
        float confidence = 0.f;
        int symbol = 0;
        for (int j = 0; j < symbolsCount; j++) {
            if (pConfidence[j + symbolsCount * i] > confidence) {
                confidence = pConfidence[j + symbolsCount * i];
                symbol = j;
            }
        }
        result.second *= confidence;
        if (symbol == DUMMY_SYMBOL_INDEX) {
            continue;
        }
        resultString += supportedSymbols_.at(symbol);
    }
    result.first = std::wstring_convert<std::codecvt_utf8<char16_t>, char16_t>().to_bytes(resultString);
    if (result.first.empty())
        result.second = 0.f;
    return result;
}

} // namespace maps::mrc::number_recognizer
