#include <maps/wikimap/mapspro/libs/tf_inferencer/tf_inferencer.h>
#include <tensorflow/core/common_runtime/dma_helper.h>

#include <maps/libs/http/include/http.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>

#include <util/stream/input.h>
#include <library/cpp/deprecated/atomic/atomic.h>
#include <util/thread/pool.h>
#include <util/thread/lfqueue.h>

#include <opencv2/opencv.hpp>

#include <fstream>
#include <iostream>
#include <sstream>
#include <list>

namespace tf = tensorflow;
namespace tfi = maps::wiki::tf_inferencer;

namespace {

const std::string TF_MODEL_RESOURCE = "/maps/mrc/facedetect/models/tf_model.gdef";
constexpr std::chrono::milliseconds THREAD_WAIT_TIMEOUT(1);

struct ImageData {
    std::string featureID;
    std::string url;
    cv::Mat image;
};

struct ImageDataCounter {
    void IncCount(const ImageData&) {
        AtomicIncrement(Counter);
    }

    void DecCount(const ImageData&) {
        AtomicDecrement(Counter);
    }

    TAtomic Counter = 0;
};

typedef TLockFreeQueue<ImageData, ImageDataCounter> ImageQueue;

cv::Mat downloadImage(const std::string &imageURL) {
    constexpr int RETRY_NUMBER = 10;
    constexpr std::chrono::seconds RETRY_TIMEOUT(3);

    maps::http::Client  client;
    maps::http::URL     url(imageURL);
    maps::http::Request request(client, maps::http::GET, url);
    for (int i = 0; i < RETRY_NUMBER; i++) {
        if (0 != i)
            std::this_thread::sleep_for(RETRY_TIMEOUT);
        try {
            maps::http::Response response = request.perform();
            if (200 == response.status()) {
                std::string data = response.readBody();
                return cv::imdecode(cv::Mat(1, data.size(), CV_8UC1,
                                    const_cast<char*>(data.data())),
                                    cv::IMREAD_COLOR | cv::IMREAD_IGNORE_ORIENTATION);
            }
        } catch (const maps::Exception& e) {
            WARN() << e;
        }
    }
    throw maps::RuntimeError("Failed to download image");
}

void downloadImagesInQueue(const std::string &inputPath, ImageQueue &imgQueue, int prefetch) {
    std::ifstream ifs(inputPath);
    REQUIRE(ifs.is_open(), "Unable to open input file");
    int processItems = 0;
    for (; !ifs.eof();) {
        std::string inpline; std::getline(ifs, inpline);
        if (inpline.empty() || inpline[0] == '#')
            continue;
        std::stringstream ss(inpline);
        ImageData data;
        ss >> data.featureID >> data.url;
        data.image = downloadImage(data.url);
        imgQueue.Enqueue(data);
        processItems++;
        if (0 == (processItems % 1000))
            INFO() << "Process: " << processItems << " items";
        for (;;) {
            ImageDataCounter counter = imgQueue.GetCounter();
            if (prefetch > (int)AtomicGet(counter.Counter))
                break;
            std::this_thread::sleep_for(THREAD_WAIT_TIMEOUT);
        }
    }
}

struct Object {
    cv::Rect rc;
    float score;
};

class FaceDetector {
public:
    FaceDetector()
        : inferencer_(tfi::TensorFlowInferencer::fromResource(TF_MODEL_RESOURCE))
    {}
    ~FaceDetector(){}

    void detect(const cv::Mat &image, std::list<Object> &objects) const {
        constexpr float NMS_THRESHOLD = 0.1f;

        REQUIRE(!image.empty(), "Unable to load image");
        std::vector<float> scales;
        calcScales(image.size(), scales);
        for (size_t i = 0; i < scales.size(); i++) {
            inference(image, scales[i], objects);
        }
        suppressObjects(objects, NMS_THRESHOLD);
    }
private:
    tfi::TensorFlowInferencer inferencer_;

    void calcScales(const cv::Size &imageSize, std::vector<float> &scales) const {
        const cv::Size INPUT_MIN_SIZES = cv::Size(82, 110);

        REQUIRE(imageSize.area() > 0, "Sizes of image must be great of zero");

        const float minScale = std::min( floor(log2f((float)INPUT_MIN_SIZES.width  / (float)imageSize.width)),
                                         floor(log2f((float)INPUT_MIN_SIZES.height / (float)imageSize.height)));

        for (float scale = minScale; scale < FLT_EPSILON; scale += 1.f){
            scales.emplace_back(powf(2.f, scale));
        }
        scales.emplace_back(sqrtf(2.f));
        scales.emplace_back(2.f);
    }

    void inference(const cv::Mat &image, const float scale, std::list<Object> &objects) const {
        const std::string TF_LAYER_IMAGE_NAME          = "inference_input";
        const std::string TF_LAYER_SMALL_SCALE_RESULTS = "small_scale_results:0";
        const std::string TF_LAYER_BIG_SCALE_RESULTS   = "big_scale_results:0";

        cv::Rect imgRect(cv::Point(0, 0), image.size());

        cv::Mat scaledImage;
        cv::resize(image, scaledImage, cv::Size(), scale, scale);

        tf::Tensor result =
            inferencer_.inference(TF_LAYER_IMAGE_NAME,
                                  scaledImage,
                                  (scale < 1.f) ? TF_LAYER_SMALL_SCALE_RESULTS : TF_LAYER_BIG_SCALE_RESULTS);

        const size_t objCnt = result.dim_size(0);
        if (0 == objCnt)
            return;
        const float *pResults = static_cast<const float*>(tf::DMAHelper::base(&result));
        REQUIRE(0 != pResults,  "Invalid results");
        for (size_t i = 0; i < objCnt; i++, pResults += 5) {
            Object object;
            object.rc = cv::Rect(pResults[0] / scale, pResults[1] / scale, (pResults[2] - pResults[0] + 1) / scale, (pResults[3] - pResults[1] + 1) / scale);
            object.rc = object.rc & imgRect;

            if (object.rc.area() <= 0)
                continue;

            object.score = pResults[4];
            objects.emplace_back(object);
        }
    }

    void suppressObjects(std::list<Object>& objects, float threshold) const {
        if (objects.empty())
            return;
        objects.sort([](const Object& a, const Object& b) {return a.score > b.score;});

        for (std::list<Object>::iterator it1 = objects.begin(); it1 != objects.end(); ++it1) {
            for (std::list<Object>::iterator it2 = std::next(it1); it2 != objects.end();) {
                const float intersection = (float)(it1->rc & it2->rc).area();
                if (intersection > threshold * (float)(it1->rc | it2->rc).area()) {
                    it2 = objects.erase(it2);
                }
                else {
                    it2++;
                }
            }
        }
    }
};

class ObjectInQueueWithData
    : public IObjectInQueue {
public:
    ObjectInQueueWithData() {
        AtomicSet(waitData_, 1);
    }
    void DataEnded() {
        AtomicSet(waitData_, 0);
    }
    bool isWaitData() {
        return (0 != AtomicGet(waitData_));
    }
protected:
    TAtomic waitData_;
};

typedef TLockFreeQueue<std::string> LinesQueue;

class Detector
    : public ObjectInQueueWithData {

public:
    Detector(ImageQueue* inpQueue,
             LinesQueue* outQueue,
             FaceDetector *detector)
        : inpQueue_(inpQueue)
        , outQueue_(outQueue)
        , detector_(detector)
        , running_(true) {
    }
    void Process(void* /*threadSpecificResource*/) override {
        for (;;) {
            ImageData data;
            if (inpQueue_->Dequeue(&data)) {
                std::string line = processImage(data);
                outQueue_->Enqueue(line);
            }
            else if (!isWaitData())
                break;
            else
                std::this_thread::sleep_for(THREAD_WAIT_TIMEOUT);
        }
        running_ = false;
    }
    bool isRunning() const {
        return running_;
    }
private:
    ImageQueue* inpQueue_;
    LinesQueue* outQueue_;
    FaceDetector *detector_;
    bool running_;

    std::string processImage(const ImageData &data) {
        std::stringstream result;
        result << data.featureID << " " << data.url;
        std::list<Object> objects;
        detector_->detect(data.image, objects);
        result << " " << objects.size();
        for (std::list<Object>::iterator it = objects.begin(); it != objects.end(); ++it) {
            result << " " << it->rc.x
                   << " " << it->rc.y
                   << " " << it->rc.width
                   << " " << it->rc.height;
        }
        return result.str();
    }
};

typedef TLockFreeQueue<std::string> LinesQueue;
void saveResults(const std::string &outputPath, LinesQueue &linesQueue) {
    std::ofstream ofs(outputPath);
    REQUIRE(ofs.is_open(), "Unable to open output file");
    for (;;) {
        std::string line;
        if (!linesQueue.Dequeue(&line))
            break;
        ofs << line << "\n";
    }
}

void dataEnded(std::vector<TAutoPtr<Detector>> &detectors) {
    for (size_t i = 0; i < detectors.size(); i++) {
        detectors[i]->DataEnded();
    }
}

void waitQueueEmpty(ImageQueue &imgQueue) {
    for (;!imgQueue.IsEmpty();) {
        std::this_thread::sleep_for(THREAD_WAIT_TIMEOUT);
    }
}

void waitDetectors(const std::vector<TAutoPtr<Detector>> &detectors) {
    bool waitDetector = true;
    for (;waitDetector;) {
        waitDetector = false;
        for (size_t i = 0; i < detectors.size(); i++) {
            waitDetector |= detectors[i]->isRunning();
        }
        std::this_thread::sleep_for(THREAD_WAIT_TIMEOUT);
    }
}

} // namespace

int main(int argc, const char** argv) try {
    maps::cmdline::Parser parser("Detect faces on the image");

    maps::cmdline::Option<std::string> inputPath = parser.string("input")
        .required()
        .help("Path to input file with list of featureIDs and URLs");

    maps::cmdline::Option<std::string> outputPath = parser.string("output")
        .required()
        .help("Path to output file");

    maps::cmdline::Option<int> threads = parser.num("threads")
        .defaultValue(24)
        .help("Threads count");

    parser.parse(argc, const_cast<char**>(argv));

    int threadsCount = threads;
    if (threadsCount <= 0)
        threadsCount = 1;

    TAutoPtr<IThreadPool> mtpQueue = CreateThreadPool(threadsCount);
    ImageQueue imgQueue;
    LinesQueue linesQueue;
    std::vector<TAutoPtr<Detector>> detectors;
    FaceDetector faceDetector;
    for (int i = 0; i < threadsCount; i++) {
        TAutoPtr<Detector> detector(new Detector(&imgQueue, &linesQueue, &faceDetector));
        mtpQueue->SafeAdd(detector.Get());
        detectors.push_back(detector);
    }

    downloadImagesInQueue(inputPath, imgQueue, threadsCount);
    dataEnded(detectors);
    waitQueueEmpty(imgQueue);
    waitDetectors(detectors);
    saveResults(outputPath, linesQueue);

    mtpQueue->Stop();
    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
