#include "road_classifier.h"

#include <library/cpp/resource/resource.h>
#include <opencv2/imgcodecs/imgcodecs_c.h>

namespace maps::mrc::classifiers {

namespace {

const std::string ROAD_CLASSIFIER_RESOURCE_KEY("/maps/mrc/classifiers/models/road_classifier.gdef");

} // anonymous namespace

RoadClassifier::RoadClassifier()
    : TensorFlowClassifier{
        wiki::tf_inferencer::TensorFlowInferencer::fromResource(
            ROAD_CLASSIFIER_RESOURCE_KEY
        )
    }
{}

double RoadClassifier::estimateRoadProbability(const cv::Mat& image) const
{
    const auto result = callClassifier(image);

    if (result.first == "not_road") {
        return 1. - result.second;
    } else if (result.first == "road") {
        return result.second;
    }

    throw Exception() << "Unknown road type: " << result.first;
}

double RoadClassifier::estimateRoadProbability(const common::Bytes& encodedImage) const
{
    return estimateRoadProbability(common::decodeImage(encodedImage));
}

double RoadClassifier::estimateRoadProbability(const common::Blob& encodedImage) const
{
    return estimateRoadProbability(common::decodeImage(encodedImage));
}

} // namespace maps::mrc::classifiers
