#include <maps/wikimap/mapspro/services/mrc/libs/classifiers/include/rotation_classifier.h>

#include <maps/libs/log8/include/log8.h>
#include <maps/libs/http/include/http.h>
#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>

#include <util/stream/input.h>
#include <library/cpp/deprecated/atomic/atomic.h>
#include <util/thread/pool.h>
#include <util/thread/lfqueue.h>


#include <opencv2/opencv.hpp>
#include <opencv2/imgcodecs/imgcodecs_c.h>

#include <fstream>
#include <iostream>
#include <sstream>
#include <vector>

using namespace maps::mrc::classifiers;
using namespace maps::mrc::common;

namespace {

constexpr std::chrono::milliseconds THREAD_WAIT_TIMEOUT(1);

void rotateImageCCW(cv::Mat &image) {
    cv::transpose(image, image);
    cv::flip(image, image, 0);
}

cv::Mat downloadImage(const std::string &imageURL) {
    constexpr int RETRY_NUMBER = 10;
    constexpr std::chrono::seconds RETRY_TIMEOUT(3);

    maps::http::Client  client;
    maps::http::URL     url(imageURL);
    maps::http::Request request(client, maps::http::GET, url);
    for (int i = 0; i < RETRY_NUMBER; i++) {
        if (0 != i)
            std::this_thread::sleep_for(RETRY_TIMEOUT);
        try {
            maps::http::Response response = request.perform();
            if (200 == response.status()) {
                std::string data = response.readBody();
                return cv::imdecode(cv::Mat(1, data.size(), CV_8UC1,
                                    const_cast<char*>(data.data())),
                                    cv::IMREAD_COLOR | cv::IMREAD_IGNORE_ORIENTATION);
            }
        } catch (const maps::Exception& e) {
            WARN() << e;
        }
    }
    throw maps::RuntimeError("Failed to download image");
}

struct ImageData {
    std::string url;
    cv::Mat image;
};

struct ImageDataCounter {
    void IncCount(const ImageData&) {
        AtomicIncrement(Counter);
    }

    void DecCount(const ImageData&) {
        AtomicDecrement(Counter);
    }

    TAtomic Counter = 0;
};

typedef TLockFreeQueue<ImageData, ImageDataCounter> ImageQueue;
typedef TLockFreeQueue<std::string> LinesQueue;

class ObjectInQueueWithData
    : public IObjectInQueue {
public:
    ObjectInQueueWithData() {
        AtomicSet(waitData_, 1);
    }
    void DataEnded() {
        AtomicSet(waitData_, 0);
    }
    bool isWaitData() {
        return (0 != AtomicGet(waitData_));
    }
protected:
    TAtomic waitData_;
};

class Classifier
    : public ObjectInQueueWithData {

public:
    Classifier(ImageQueue* inpQueue,
               LinesQueue* outQueue,
               RotationClassifier *classifier)
        : inpQueue_(inpQueue)
        , outQueue_(outQueue)
        , classifier_(classifier)
        , running_(true) {
    }
    void Process(void* /*threadSpecificResource*/) override {
        for (;;) {
            ImageData data;
            if (inpQueue_->Dequeue(&data)) {
                std::string line = processImage(data.url, data.image);
                outQueue_->Enqueue(line);
            }
            else if (!isWaitData())
                break;
            else
                std::this_thread::sleep_for(THREAD_WAIT_TIMEOUT);
        }
        running_ = false;
    }
    bool isRunning() const {
        return running_;
    }
private:
    ImageQueue* inpQueue_;
    LinesQueue* outQueue_;
    RotationClassifier *classifier_;
    bool running_;

    std::string processImage(const std::string &url, cv::Mat &image) {
        std::stringstream result;
        result << url;
        for (int i = 0; i < 4; i++) {
            ImageOrientation orientation = classifier_->detectImageOrientation(image);
            result << " " << (int)orientation.rotation() / 90;
            if (3 == i)
                break;
            rotateImageCCW(image);
        }
        return result.str();
    }
};

void downloadImagesInQueue(const std::string &inputPath, ImageQueue &imgQueue, int prefetch) {
    std::ifstream ifs(inputPath);
    REQUIRE(ifs.is_open(), "Unable to open input file");
    int processItems = 0;
    for (; !ifs.eof();) {
        std::string inpline; std::getline(ifs, inpline);
        if (inpline.empty() || inpline[0] == '#')
            continue;
        ImageData data;
        data.url = inpline;
        data.image = downloadImage(inpline);
        imgQueue.Enqueue(data);
        processItems++;
        if (0 == (processItems % 1000))
            INFO() << "Process: " << processItems << " items";
        for (;;) {
            ImageDataCounter counter = imgQueue.GetCounter();
            if (prefetch > (int)AtomicGet(counter.Counter))
                break;
            std::this_thread::sleep_for(THREAD_WAIT_TIMEOUT);
        }
    }
}

void dataEnded(std::vector<TAutoPtr<Classifier>> &classifiers) {
    for (size_t i = 0; i < classifiers.size(); i++) {
        classifiers[i]->DataEnded();
    }
}

void waitQueueEmpty(ImageQueue &imgQueue) {
    for (;!imgQueue.IsEmpty();) {
        std::this_thread::sleep_for(THREAD_WAIT_TIMEOUT);
    }
}

void waitClassifiers(const std::vector<TAutoPtr<Classifier>> &classifiers) {
    bool waitClassifier = true;
    for (;waitClassifier;) {
        waitClassifier = false;
        for (size_t i = 0; i < classifiers.size(); i++) {
            waitClassifier |= classifiers[i]->isRunning();
        }
        std::this_thread::sleep_for(THREAD_WAIT_TIMEOUT);
    }
}

void saveResults(const std::string &outputPath, LinesQueue &linesQueue) {
    std::ofstream ofs(outputPath);
    REQUIRE(ofs.is_open(), "Unable to open output file");
    for (;;) {
        std::string line;
        if (!linesQueue.Dequeue(&line))
            break;
        ofs << line << "\n";
    }
}

}// namespace

int main(int argc, const char** argv) try {
    maps::cmdline::Parser parser("Detect traffic signs on the images from tfrecord");

    maps::cmdline::Option<std::string> inputPath = parser.string("input")
        .required()
        .help("Path to input file with list of URLs");

    maps::cmdline::Option<std::string> outputPath = parser.string("output")
        .required()
        .help("Path to output file");

    maps::cmdline::Option<int> threads = parser.num("threads")
        .defaultValue(24)
        .help("Threads count");

    parser.parse(argc, const_cast<char**>(argv));

    int threadsCount = threads;
    if (threadsCount <= 0)
        threadsCount = 1;


    TAutoPtr<IThreadPool> mtpQueue = CreateThreadPool(threadsCount);
    ImageQueue imgQueue;
    LinesQueue linesQueue;
    std::vector<TAutoPtr<Classifier>> classifiers;
    RotationClassifier rotClassifier;
    for (int i = 0; i < threadsCount; i++) {
        TAutoPtr<Classifier> classifier(new Classifier(&imgQueue, &linesQueue, &rotClassifier));
        mtpQueue->SafeAdd(classifier.Get());
        classifiers.push_back(classifier);
    }

    downloadImagesInQueue(inputPath, imgQueue, threadsCount);
    dataEnded(classifiers);
    waitQueueEmpty(imgQueue);
    waitClassifiers(classifiers);
    saveResults(outputPath, linesQueue);

    mtpQueue->Stop();
    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    FATAL() << "Application failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Application failed: " << e.what();
    return EXIT_FAILURE;
}
