#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/log8/include/log8.h>

#include <opencv2/opencv.hpp>

#include <algorithm>
#include <fstream>
#include <string>
#include <utility>
#include <list>
#include <sstream>
#include <vector>
#include <iterator>

namespace {

typedef std::pair<std::string, std::string> StringPair;

enum class AnnType : uint8_t
{
    FINEST = 0,
    COARSE = 1,
    WIRE   = 2,
    VERT   = 3,
};

struct AnnClassInfo
{
    uint8_t label_;
    AnnType annType_;
};
typedef std::map<std::string, AnnClassInfo> AnnClassMap;


struct Object
{
    std::string className_;
    std::vector<cv::Point2i> polygon_;
};

typedef std::vector<Object> Markup;

struct MatchTemplateResult
{
    double val_;
    cv::Point2f shift_;
};
typedef std::vector<MatchTemplateResult> MatchTemplateResults;

void loadFileList(const std::string& inpListPath,
                   std::list<StringPair>& listFiles)
{
    std::ifstream ifs(inpListPath);
    if (!ifs.is_open())
        return;

    for (; !ifs.eof();) {
        std::string line; std::getline(ifs, line);
        if (line.empty())
            continue;
        std::stringstream ss(line);
        StringPair pairFiles;
        ss >> pairFiles.first >> pairFiles.second;
        listFiles.emplace_back(pairFiles);
    }
}

void loadAnnMap(const std::string& annMapPath,
                AnnClassMap& annMap)
{
    std::ifstream ifs(annMapPath);
    if (!ifs.is_open())
        return;

    std::string className;
    AnnClassInfo info;
    for (; !ifs.eof();) {
        std::string line; std::getline(ifs, line);
        if (line.empty())
            continue;
        std::stringstream ss(line);
        int temp;
        ss >> className >> temp;
        info.label_ = (uint8_t)temp;
        ss >> temp;
        info.annType_ = (AnnType)temp;
        annMap[className] = info;
    }
}

void loadMarkup(const std::string &markup_path, Markup &markup)
{
    std::ifstream ifs(markup_path);
    if (!ifs.is_open())
        return;

    for (; !ifs.eof();) {
        std::string line; std::getline(ifs, line);
        if (line.empty())
            continue;
        std::stringstream ss(line);
        Object object; int ptsCnt;
        ss >> object.className_ >> ptsCnt;
        object.polygon_.resize(ptsCnt + 1);
        for (int i = 0; i < ptsCnt; i++) {
            ss >> object.polygon_[i].x >> object.polygon_[i].y;
        }
        object.polygon_[ptsCnt] = object.polygon_[0];
        markup.emplace_back(std::move(object));
    }
}

void drawAnnotation(const cv::Size &imgSize, const Markup &markup, const AnnClassMap &annMap, cv::Mat &ann, int coarseRadius = 0, int verticesSize = 0, int wireThickness = 1)
{
    ann = cv::Mat::zeros(imgSize, CV_8UC1);
    for (const auto& object : markup) {
        const auto it = annMap.find(object.className_);
        if (it == annMap.end())
            continue;
        const auto& annClassInfo = it->second;
        cv::Scalar color = annClassInfo.label_;
        switch (annClassInfo.annType_) {
        case AnnType::FINEST: {
            std::vector< std::vector<cv::Point2i> > polys(1, object.polygon_);
            cv::fillPoly(ann, polys, color);
        }
        break;
        case AnnType::COARSE: {
            cv::Point2f mean(0.f, 0.f);
            for (size_t j = 0; j < object.polygon_.size() - 1; j++)
                mean = mean + (cv::Point2f)object.polygon_[j];
            mean /= (float)(object.polygon_.size() - 1);
            cv::circle(ann, mean, coarseRadius, color, -1);
        }
        break;
        case AnnType::WIRE: {
            for (size_t j = 0; j < object.polygon_.size() - 1; j++)
                cv::line(ann, object.polygon_[j], object.polygon_[j + 1], color, wireThickness);
        }
        break;
        case AnnType::VERT: {
            cv::Point pt(verticesSize / 2, verticesSize / 2);
            for (size_t j = 0; j < object.polygon_.size() - 1; j++)
                cv::rectangle(ann, object.polygon_[j] - pt, object.polygon_[j] + pt, color, -1);
        }
        break;
        }
    }
}

void sobelMask(const cv::Mat &img, cv::Mat &sobel, double sobelThreshold = 250.)
{
    cv::Mat imgGray; cv::cvtColor(img, imgGray, cv::COLOR_BGR2GRAY);
    cv::equalizeHist(imgGray, imgGray);

    cv::Mat gradX, gradY, gradMagn;
    cv::Sobel(imgGray, gradX, CV_32F, 1, 0);
    cv::Sobel(imgGray, gradY, CV_32F, 0, 1);
    cv::sqrt(gradX.mul(gradX) + gradY.mul(gradY), gradMagn);

    cv::threshold(gradMagn, gradMagn, sobelThreshold, 255, cv::THRESH_BINARY);
    gradMagn.convertTo(sobel, CV_8U);
}

std::vector<MatchTemplateResults> calcMarkupShift(const cv::Mat &img,
                     const Markup &markup,
                     const AnnClassMap &annMap,
                     int border = 50,
                     const cv::Size &cellSize = cv::Size(512, 512)
                     )
{
    std::vector<MatchTemplateResults> mtrList(markup.size());

    cv::Mat markupBW;
    drawAnnotation(img.size(), markup, annMap, markupBW);

    cv::Mat sobelBordered = cv::Mat::zeros(img.rows + 2 * border, img.cols + 2 * border, CV_8UC1);
    cv::Mat temp = sobelBordered(cv::Rect(border, border, img.cols, img.rows));
    sobelMask(img, temp);

    int cellH = (markupBW.cols + cellSize.width - 1) / cellSize.width;
    int cellV = (markupBW.rows + cellSize.height - 1) / cellSize.height;
    for (int cellY = 0; cellY < cellV; cellY++) {
        cv::Rect roi;
        roi.y = cellY * cellSize.height;
        roi.height = std::min((cellY + 1) * cellSize.height, markupBW.rows) - roi.y;
        for (int cellX = 0; cellX < cellH; cellX++) {
            roi.x = cellX * cellSize.width;
            roi.width = std::min((cellX + 1) * cellSize.width, markupBW.cols) - roi.x;

            cv::Mat sobelBorderedROI = sobelBordered(cv::Rect(roi.x, roi.y, roi.width + 2 * border, roi.height + 2 * border));
            cv::Mat markupBWROI = markupBW(roi);

            cv::Mat mt; cv::matchTemplate(sobelBorderedROI, markupBWROI, mt, cv::TM_CCORR_NORMED);
            double minVal, maxVal;
            cv::Point minLoc; cv::Point maxLoc;
            cv::minMaxLoc(mt, &minVal, &maxVal, &minLoc, &maxLoc);

            cv::Point shift = maxLoc - cv::Point(border, border);
            for (size_t i = 0; i < markup.size(); i++) {
                auto &object = markup[i];
                const auto it = annMap.find(object.className_);
                if (it == annMap.end())
                    continue;
                bool inCell = false;
                for (size_t j = 0; j < object.polygon_.size(); j++) {
                    if (object.polygon_[j].inside(roi)) {
                        inCell = true;
                        break;
                    }
                }
                if (!inCell)
                    continue;
                MatchTemplateResult mtr;
                mtr.val_ = maxVal;
                mtr.shift_ = shift;
                mtrList[i].push_back(mtr);
            }
        }
    }
    return mtrList;
}

void refineMarkup(const cv::Mat &img, Markup &markup, const std::vector<std::string> &classNames)
{
    INFO() << "refine markup classes:";
    AnnClassMap annMap;
    for (auto clName : classNames) {
        INFO() << clName;
        annMap[clName].label_ = 255;
        annMap[clName].annType_ =  AnnType::WIRE;
    }

    std::vector<MatchTemplateResults> mtrList = calcMarkupShift(img, markup, annMap);

    for (size_t i = 0; i < markup.size(); i++) {
        auto &object = markup[i];
        const auto it = annMap.find(object.className_);
        if (it == annMap.end())
            continue;
        if (0 == mtrList[i].size())
            continue;
        cv::Point shift = mtrList[i][0].shift_;
        if (1 < mtrList[i].size()) {
            double sumWeights = 0.f;
            for (auto mtr : mtrList[i]) {
                sumWeights += mtr.val_;
            }
            CV_Assert(fabs(sumWeights) > FLT_EPSILON);

            cv::Point2f objShift(0.f, 0.f);
            for (auto mtr : mtrList[i]) {
                objShift += mtr.shift_ * (mtr.val_ / sumWeights);
            }
            shift.x = (int)cvRound(objShift.x);
            shift.y = (int)cvRound(objShift.y);
        }

        for (size_t j = 0; j < object.polygon_.size(); j++) {
            object.polygon_[j] += shift;
        }
    }

}

void dumpAnnotationCells(const cv::Mat &data, const cv::Size &cellSize, std::ofstream &ofsDump, std::vector<cv::Point> &cells, bool check_not_empty = true)
{
    const int elemSize = (int)data.elemSize();
    const int horzCells = data.cols / cellSize.width;
    const int vertCells = data.rows / cellSize.height;
    const int cellDataSize = elemSize * cellSize.width * cellSize.height;

    cv::Point shift((data.cols - horzCells * cellSize.width) / 2, (data.rows - vertCells * cellSize.height) / 2);
    cv::Mat cell(cellSize, data.type());
    cv::Rect roiCell(shift, cellSize);

    for (int row = 0; row < vertCells; row++, roiCell.y += cellSize.height) {
        roiCell.x = shift.x;
        for (int col = 0; col < horzCells; col++, roiCell.x += cellSize.width) {
            data(roiCell).copyTo(cell);
            CV_Assert(cell.isContinuous());
            int nonZero = 1;
            if (check_not_empty)
                nonZero = cv::countNonZero(cell);
            if (0 != nonZero) {
                ofsDump.write(cell.ptr<char>(0), cellDataSize);
                cells.push_back(cv::Point(col, row));
            }
        }
    }
}

void equalizeHistBGR(const cv::Mat &src, cv::Mat &dst)
{
    cv::Mat imgHSV;
    cv::cvtColor(src, imgHSV, cv::COLOR_BGR2HSV);
    std::vector<cv::Mat> channels;
    cv::split(imgHSV, channels);
    cv::equalizeHist(channels[2], channels[2]);
    cv::merge(channels, imgHSV);
    cv::cvtColor(imgHSV, dst, cv::COLOR_HSV2BGR);
}

void dumpImageCells(const cv::Mat &data, const cv::Size &cellSize, std::ofstream &ofsDump, std::vector<cv::Point> &cells, bool eqHist)
{
    const int elemSize = (int)data.elemSize();
    const int horzCells = data.cols / cellSize.width;
    const int vertCells = data.rows / cellSize.height;
    const int cellDataSize = elemSize * cellSize.width * cellSize.height;

    cv::Point shift((data.cols - horzCells * cellSize.width) / 2, (data.rows - vertCells * cellSize.height) / 2);
    cv::Mat cell(cellSize, data.type());
    cv::Rect roiCell(shift, cellSize);

    for (size_t cellIdx = 0; cellIdx < cells.size(); cellIdx++) {
        roiCell.x = shift.x + cells[cellIdx].x * cellSize.width;
        roiCell.y = shift.y + cells[cellIdx].y * cellSize.height;
        if (eqHist)
            equalizeHistBGR(data(roiCell), cell);
        else
            data(roiCell).copyTo(cell);
        CV_Assert(cell.isContinuous());
        ofsDump.write(cell.ptr<char>(0), cellDataSize);
    }
}

} //namespace


int main(int argc, char** argv)
try {
    maps::cmdline::Parser parser;
    auto inpListParam = parser.string("input_list")\
        .required()\
        .help("path to the text file, every line consists of:\n\t<path to image file> <path to markup file>");

    auto annMapParam = parser.string("ann_map")\
        .required()\
        .help("path to the text file with mapping class name to the 8 bit number");

    auto annCoarseRadiusParam = parser.num("ann_coarse_radius")\
        .defaultValue(3)\
        .help("radius of circle for coarse markup rasterize");

    auto annWireThicknessParam = parser.num("ann_wire_thickness")\
        .defaultValue(1)\
        .help("thickness of lines for wire markup rasterize");

    auto annVerticesSizeParam = parser.num("ann_vertex_size")\
        .defaultValue(3)\
        .help("size of square for vertices markup rasterize");

    auto imgEqHistParam = parser.flag("img_eq_hist")\
        .defaultValue(false)\
        .help("apply equalization histogram to the image");

    auto checkNotEmptyParam = parser.flag("check_not_empty")\
        .defaultValue(false)\
        .help("save only not empty cells");

    auto refineMarkupParam = parser.string("refine_markup_classes")\
        .defaultValue("")\
        .help("names of classes separated by comma for refine markup. If parameter abcents - no markup refine");

    auto outWidthParam = parser.num("output_width")\
        .required()\
        .help("width of the output image");

    auto outHeightParam = parser.num("output_height")\
        .required()\
        .help("height of the output image");

    auto outImgFileParam = parser.string("output_img_dat")\
        .defaultValue("")
        .help("path to the output file with images (can be empty, then fill will not created)");

    auto outAnnFileParam = parser.string("output_ann_dat")\
        .required()\
        .help("path to the output file with annotations");

    parser.parse(argc, argv);

    if (0 >= outWidthParam || 0 >= outHeightParam) {
        ERROR() << "Output width and height must be great than 0";
    }

    std::list<StringPair> listFiles;
    loadFileList(inpListParam, listFiles);

    std::map<std::string, AnnClassInfo> annMap;
    loadAnnMap(annMapParam, annMap);

    std::ofstream ofsImg;
    if (!((std::string)outImgFileParam).empty())
        ofsImg.open(outImgFileParam, std::ofstream::binary);
    std::ofstream ofsAnn;
    if (!((std::string)outAnnFileParam).empty())
        ofsAnn.open(outAnnFileParam, std::ofstream::binary);

    std::vector<std::string> refineMarkupClasses;
    std::stringstream ss((std::string)refineMarkupParam);
    std::string item;
    while (std::getline(ss, item, ',')) {
        refineMarkupClasses.push_back(item);
    }

    cv::Size cellSize(outWidthParam, outHeightParam);
    for (const auto& item : listFiles) {
        INFO() << "image file: " << item.first;
        cv::Mat img = cv::imread(item.first);
        if (img.empty()) {
            ERROR() << "Unable to load image from file: " << item.first;
            continue;
        }
        if (img.cols < cellSize.width || img.rows < cellSize.height) {
            ERROR() << "Size of image from file: " << item.first
                    << "is too small for output size";
            continue;
        }

        std::vector<cv::Point> cells;
        if (ofsAnn.is_open()) {
            Markup markup;
            loadMarkup(item.second, markup);

            if (!refineMarkupClasses.empty())
                refineMarkup(img, markup, refineMarkupClasses);

            cv::Mat ann;
            drawAnnotation(img.size(), markup, annMap, ann, annCoarseRadiusParam, annVerticesSizeParam, annWireThicknessParam);

            dumpAnnotationCells(ann, cellSize, ofsAnn, cells, checkNotEmptyParam);
            if (0 == cells.size())
                continue;
        }

        if (ofsImg.is_open()) {
            INFO() << "equalize histogram " << (imgEqHistParam ? "enabled" : "disabled");
            dumpImageCells(img, cellSize, ofsImg, cells, imgEqHistParam);
        }
    }
    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    INFO() << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    INFO() << e.what();
    return EXIT_FAILURE;
}
catch (...) {
    INFO() << "Caught unknown exception";
    return EXIT_FAILURE;
}
