#include "tool.h"
#include "types.h"
#include "request.h"

#include <maps/libs/common/include/exception.h>
#include <maps/libs/geolib/include/intersection.h>
#include <maps/libs/geolib/include/contains.h>
#include <maps/libs/json/include/builder.h>
#include <maps/libs/log8/include/log8.h>

#include <boost/range/algorithm_ext/erase.hpp>

#include <cmath>
#include <fstream>

namespace maps {
namespace mrc {
namespace toloka {

namespace {

const std::string FIELD_BBOX = "bbox";
const std::string FIELD_DETECTION_TASK_RESULTS = "detection_task_results";
const std::string FIELD_SOLUTIONS = "solutions";
const std::string FIELD_SOURCE = "source";

// UserID and rectangle
using UserRectangle = std::pair<std::string, Rectangle>;
using RectanglesGroup = std::vector<UserRectangle>;

constexpr size_t MAX_SIDE_SIZE_THRESHOLD_PX = 15;
constexpr size_t MIN_SIDE_SIZE_THRESHOLD_PX = 10;

bool isSignificant(const Rectangle& r)
{
    const auto maxSide = std::max(r.width(), r.height());
    const auto minSide = std::min(r.width(), r.height());
    return maxSide >= MAX_SIDE_SIZE_THRESHOLD_PX &&
           minSide >= MIN_SIDE_SIZE_THRESHOLD_PX;
}

/**
 * Match given rectangle with a group of rectangles.
 * Returns true if \p rect overlaps with every rectangle in the \p group
 * with relative difference by each coordinate not more than 25%.
 */
bool rectangleMatchesGroup(const Rectangle& rect,
                           const RectanglesGroup& group)
{
    constexpr int THRESH_DIFF_PIXELS = 8;
    constexpr double MAX_RELATIVE_COORD_DIFF = 0.3;

    for (const auto& userAndRect : group) {
        const auto& other = userAndRect.second;
        auto intersection = geolib3::intersection(rect, other);
        if (!intersection) {
            return false;
        }

        double width = std::max(rect.maxX() - rect.minX(),
                                other.maxX() - other.minX());
        double height = std::max(rect.maxY() - rect.minY(),
                                 other.maxY() - other.minY());

        double dxMin = abs(rect.minX() - other.minX());
        double dxMax = abs(rect.maxX() - other.maxX());
        double dyMin = abs(rect.minY() - other.minY());
        double dyMax = abs(rect.maxY() - other.maxY());

        if ((dxMin > THRESH_DIFF_PIXELS && dxMin / width > MAX_RELATIVE_COORD_DIFF) ||
            (dxMax > THRESH_DIFF_PIXELS && dxMax / width > MAX_RELATIVE_COORD_DIFF) ||
            (dyMin > THRESH_DIFF_PIXELS && dyMin / height > MAX_RELATIVE_COORD_DIFF) ||
            (dyMax > THRESH_DIFF_PIXELS && dyMax / height > MAX_RELATIVE_COORD_DIFF)
        ){
            return false;
        }
    }
    return true;
}


/**
 * Tells if the given group contains a rectangle, selected by the given user.
 */
bool groupHasRectangleFromUser(const RectanglesGroup& group,
                               const std::string& uid)
{
    return std::find_if(
            group.begin(),
            group.end(),
            [&](const UserRectangle& pair) { return pair.first == uid; })
        != group.end();
}


void retainGroupsWithMajority(std::vector<RectanglesGroup>& groups,
                              size_t majority)
{
    boost::range::remove_erase_if(groups, [&](const RectanglesGroup& group) {
        return group.size() < majority;
    });
}


/**
 * Collapse group of rectangles into a single rectangle.
 * Each coordinate (left, right, top, bottom) of the resulting rectangle
 * is calculated as the median of the corresponding coordinates
 * of all rectangles in the group.
 */
Rectangle mergeRectanglesGroup(const RectanglesGroup& group)
{
    REQUIRE(!group.empty(), "Empty rectangles group");

    auto median = [](std::vector<double>& vec) -> double {
        std::nth_element(vec.begin(), vec.begin() + vec.size()/2, vec.end());
        return vec[vec.size() / 2];
    };

    size_t size = group.size();
    std::vector<double> minX(size), maxX(size), minY(size), maxY(size);
    for (size_t i = 0; i < size; ++i) {
        minX[i] = group[i].second.minX();
        maxX[i] = group[i].second.maxX();
        minY[i] = group[i].second.minY();
        maxY[i] = group[i].second.maxY();
    }

    return Rectangle(geolib3::Point2(median(minX), median(minY)),
                     geolib3::Point2(median(maxX), median(maxY)));
}

/**
 * Merge groups of rectangles by collapsing each group into one rectangle.
 * If a group contains less than \p minRectanglesInGroup rectangles,
 * the group is 'unconfirmed' and no rectangle is generated for this group.
 */
Rectangles mergeRectanglesGroups(const std::vector<RectanglesGroup>& groups)
{
    Rectangles result;
    result.reserve(groups.size());
    for (const auto& group : groups) {
        result.push_back(mergeRectanglesGroup(group));
    }
    return result;
}


void updateUsersStatistics(UserIdToStat& userIdToStat,
                           const std::vector<RectanglesGroup>& rectangleGroups)
{
    for (const auto& group : rectangleGroups) {
        auto merged = mergeRectanglesGroup(group);
        if (isSignificant(merged)) {
            for (const auto& userAndRect : group) {
                ++userIdToStat.at(userAndRect.first).truePositives;
                if (!isSignificant(userAndRect.second)) {
                    ++userIdToStat.at(userAndRect.first).submittedSignsCount;
                }
            }
        }
    }
}


/**
 * Merge Toloka results for a single task received from multiple users
 */
TaskOutput mergeSingleTaskResults(const AssignmentResults& results,
                                  size_t taskIdx,
                                  UserIdToStat& userIdToStat)
{
    std::vector<RectanglesGroup> rectangleGroups;

    // Group rectangles marked by different users according to intersections
    for (const auto& result : results) {
        REQUIRE(result.outputs.size() > taskIdx, "No result for task #" << taskIdx);
        const auto& uid = result.userId;
        const auto& rectangles = result.outputs[taskIdx].rectangles;

        if (!userIdToStat.count(uid)) {
            userIdToStat.emplace(uid, UserStat(uid, result.assignmentId, result.status));
        }
        userIdToStat.at(uid).submittedSignsCount +=
            std::count_if(rectangles.begin(), rectangles.end(), isSignificant);

        Rectangles unmatched;

        // Try to intersect every new rectangle with every group
        for (const auto& rect : rectangles) {
            bool matched = false;
            for (auto& group : rectangleGroups) {
                if (rectangleMatchesGroup(rect, group) &&
                        !groupHasRectangleFromUser(group, uid)) {
                    group.push_back({uid, rect});
                    matched = true;
                    break;
                }
            }
            // Create new group if rectangle did not match any existing group
            if (!matched) {
                rectangleGroups.push_back({{uid, rect}});
            }
        }
    }
    retainGroupsWithMajority(rectangleGroups, results.size() / 2 + 1);
    updateUsersStatistics(userIdToStat, rectangleGroups);
    return TaskOutput{mergeRectanglesGroups(rectangleGroups)};
}


/**
 * Merge Toloka task suite results received from multiple users
 */
TaskSuiteResult mergeTaskSuiteResults(const AssignmentResults& assignmentResults)
{
    TaskSuiteResult tsResult;

    REQUIRE(!assignmentResults.empty(), "Empty task suite results");
    size_t tasksCount = assignmentResults[0].inputs.size();

    for (size_t i = 0; i < tasksCount; ++i) {
        TaskResult taskResult;
        taskResult.input = assignmentResults[0].inputs[i];
        taskResult.output =
                mergeSingleTaskResults(assignmentResults, i, tsResult.userIdToStat);
        DEBUG() << "Task " << i << " rectangles: " << taskResult.output.rectangles;
        tsResult.detectedSignsCount += std::count_if(
                taskResult.output.rectangles.begin(),
                taskResult.output.rectangles.end(),
                isSignificant);
        tsResult.taskResults.push_back(std::move(taskResult));
    }

    return tsResult;
}


void toJson(const Rectangles& rectangles, json::ArrayBuilder b)
{
    for (const auto& r : rectangles) {
        b << [&](json::ObjectBuilder b) {
            b[FIELD_BBOX] << [&](json::ArrayBuilder b) {
                b << [&](json::ArrayBuilder b) {
                    b << round(r.minX()) << round(r.minY());
                };
                b << [&](json::ArrayBuilder b) {
                    b << round(r.maxX()) << round(r.maxY());
                };
            };
        };
    }
}

} // anonymous namespace


std::vector<TaskSuiteResult> mergeTasksResults(
        const IdToTaskSuite& idToTaskSuite,
        const TaskSuiteIdToResults& idToAssignmentResults)
{
    std::vector<TaskSuiteResult> tsResults;

    for (const auto& idAndResults : idToAssignmentResults) {
        const auto& id = idAndResults.first;
        const auto& assignmentResults = idAndResults.second;

        auto taskSuiteItr = idToTaskSuite.find(id);
        if (taskSuiteItr == idToTaskSuite.end()) {
            // Skip results for unknown task suite
            WARN() << "Skip unknown task suite: " << id;
            continue;
        }

        if (assignmentResults.size() < taskSuiteItr->second.overlap) {
            // Skip task suite if not enough results are ready
            continue;
        }

        INFO() << "Merging results for task suite " << id;
        auto tsResult = mergeTaskSuiteResults(assignmentResults);
        tsResults.push_back(std::move(tsResult));
    }
    return tsResults;
}


std::vector<TaskSuiteResult>
loadAndMergeTasksResults(const io::TolokaClient& tolokaClient,
                         const std::string& poolId)
{
    auto idToTaskSuite = loadTaskSuites(tolokaClient, poolId);
    auto idToTaskSuiteResults = loadTaskSuitesResults(tolokaClient, poolId);

    return mergeTasksResults(idToTaskSuite, idToTaskSuiteResults);
}


void evaluateAssignments(const io::TolokaClient& tolokaClient,
                         const TaskSuiteResult& tsResult,
                         bool dryRun)
{
    constexpr double THRESHOLD_RECALL = 0.65;
    constexpr double THRESHOLD_PRECISION = 0.6;
    // Penalty for one submitted sign when there are no detected signs
    constexpr double EXTRA_SIGN_PENALTY = 0.2;

    for (const auto& userAndStat : tsResult.userIdToStat) {
        const auto& stat = userAndStat.second;

        double recall = 0, precision = 0;
        if (!tsResult.detectedSignsCount) {
            if (!stat.submittedSignsCount) {
                // No signs detected and no signs submitted by user
                recall = 1;
                precision = 1;
            } else {
                // No signs detected, but some signs submitted by user
                recall = 1;
                precision = std::max(1 - stat.submittedSignsCount * EXTRA_SIGN_PENALTY, 0.0);
            }
        } else {
            if (!stat.submittedSignsCount) {
                // Some signs detected, but no signs submitted by user
                recall = 0;
                precision = 1;
            } else {
                // Some signs detected and some signs submitted by user
                recall = (double)stat.truePositives / tsResult.detectedSignsCount;
                precision = (double)stat.truePositives / stat.submittedSignsCount;
            }
        }
        DEBUG() << "Assignment " << stat.assignmentId
                << ": recall = " << recall << ", precision = " << precision;

        if (recall < THRESHOLD_RECALL) {
            INFO() << "REJECT assignment " << stat.assignmentId
                   << " from user " << stat.uid << ". Low recall: " << recall;
            if (!dryRun) {
                tolokaClient.rejectAssignment(stat.assignmentId,
                    "Вы разметили не все знаки, которые есть на фотографиях "
                    "либо нарушили правила разметки, описанные в инструкции");
            }
        } else if (precision < THRESHOLD_PRECISION) {
            INFO() << "REJECT assignment " << stat.assignmentId
                   << " from user " << stat.uid << ". Low precision: " << precision;
            if (!dryRun) {
                tolokaClient.rejectAssignment(stat.assignmentId,
                    "Вы ошибочно обвели области, в которых нет знаков"
                    "либо нарушили правила разметки, описанные в инструкции");
            }
        } else {
            INFO() << "ACCEPT assignment " << stat.assignmentId
                   << " from user " << stat.uid;
            if (!dryRun) {
                tolokaClient.acceptAssignment(stat.assignmentId);
            }
        }
    }
}


void writeTaskResults(const std::vector<TaskSuiteResult>& tsResults,
                      const std::string& outputFile)
{
    std::ofstream file(outputFile);
    json::Builder builder(file);

    builder << [&](json::ObjectBuilder b) {
        b[FIELD_DETECTION_TASK_RESULTS] << [&](json::ArrayBuilder b) {
            for (const auto& suite : tsResults) {
                for (const auto& res : suite.taskResults) {
                    b << [&](json::ObjectBuilder b) {
                        b[FIELD_SOURCE] = res.input.source;
                        b[FIELD_SOLUTIONS] << [&](json::ArrayBuilder b) {
                            toJson(res.output.rectangles, b);
                        };
                    };
                }
            }
        };
    };
}

} // toloka
} // mrc
} // maps
