#include "tool.h"
#include "types.h"
#include "request.h"

#include <maps/libs/common/include/exception.h>
#include <maps/libs/json/include/builder.h>
#include <maps/libs/log8/include/log8.h>

#include <fstream>

namespace maps {
namespace mrc {
namespace toloka {

namespace {

const std::string FIELD_ANSWER = "answer";
const std::string FIELD_BBOX = "bbox";
const std::string FIELD_CLASSIFIED_SIGNS = "classified_signs";
const std::string FIELD_SIGN_ID = "sign_id";
const std::string FIELD_SIGNS = "signs";
const std::string FIELD_SOURCE = "source";


// Return pair of iterators [first, last) representing the longest equal range
template <typename Iterator>
std::pair<Iterator, Iterator> longestEqualRange(Iterator first, Iterator last)
{
    size_t maxSize = 0;
    auto maxFirst = last;
    auto maxLast = last;

    while (first != last) {
        auto pair = std::equal_range(first, last, *first);
        size_t size = std::distance(pair.first, pair.second);
        if (size > maxSize) {
            maxSize = size;
            maxFirst = pair.first;
            maxLast = pair.second;
        }
        first = pair.second;
    }
    return {maxFirst, maxLast};
}


/**
 * Merge results for a single task received from multiple users
 * If there is no answer chosen by the majority of users,
 * the resulting answer is set to NotRecognized
 */
TaskOutput mergeSingleTaskResults(const AssignmentResults& results,
                                  size_t taskIdx,
                                  UserIdToStat& userIdToStat)
{
    std::vector<TaskOutput> outputs;
    for (const auto& res : results) {
        outputs.push_back(res.outputs[taskIdx]);
    }

    std::sort(outputs.begin(), outputs.end());
    const auto pair = longestEqualRange(outputs.begin(), outputs.end());
    if ((size_t)std::distance(pair.first, pair.second) <= outputs.size() / 2) {
        return TaskOutput{TaskAnswer::NotRecognized, ""};
    }
    TaskOutput majorityOutput = *pair.first;

    for (const auto& res : results) {
        if (!userIdToStat.count(res.userId)) {
            userIdToStat.emplace(res.userId, UserStat(res));
        }
        auto& stat = userIdToStat.at(res.userId);
        const auto& userOutput = res.outputs[taskIdx];

        if (userOutput == majorityOutput) {
            ++stat.correctCount;
        } else {
            // ad hoc handling of some special cases
            if (majorityOutput.answer == TaskAnswer::NotRecognized) {
                if (userOutput.answer == TaskAnswer::Ok ||
                    userOutput.answer == TaskAnswer::NotSign ||
                    userOutput.answer == TaskAnswer::NotClassified) {
                    ++stat.correctCount;
                }
            } else if (majorityOutput.answer == TaskAnswer::NotSign &&
                       userOutput.answer == TaskAnswer::NotRecognized) {
                ++stat.correctCount;
            }
        }
    }
    return majorityOutput;
}


/**
 * Merge single task suite results received from multiple users
 */
TaskSuiteResult mergeTaskSuiteResults(const AssignmentResults& assignmentResults)
{
    TaskSuiteResult tsResult;

    REQUIRE(!assignmentResults.empty(), "Empty task suite results");
    const size_t tasksCount = assignmentResults[0].inputs.size();

    for (size_t i = 0; i < tasksCount; ++i) {
        auto input = assignmentResults[0].inputs[i];
        auto output =  mergeSingleTaskResults(assignmentResults, i, tsResult.userIdToStat);
        DEBUG() << "Task " << i << " sign ID: " << output.signId;
        tsResult.taskResults.push_back(TaskResult{std::move(input), std::move(output)});
    }

    return tsResult;
}


void toJson(const Bbox& bbox, json::ObjectBuilder b)
{
    b[FIELD_BBOX] = [&](json::ArrayBuilder b) {
        b << [&](json::ArrayBuilder b) {
            b << round(bbox.minX()) << round(bbox.minY());
        };
        b << [&](json::ArrayBuilder b) {
            b << round(bbox.maxX()) << round(bbox.maxY());
        };
    };
}

} // anonymous namespace


std::vector<TaskSuiteResult> mergeTasksResults(
        const IdToTaskSuite& idToTaskSuite,
        const TaskSuiteIdToResults& idToAssignmentResults)
{
    std::vector<TaskSuiteResult> tsResults;
    tsResults.reserve(idToTaskSuite.size());

    for (auto& idAndResults : idToAssignmentResults) {
        const auto& id = idAndResults.first;
        const auto& assignmentResults = idAndResults.second;

        auto taskSuiteItr = idToTaskSuite.find(id);
        if (taskSuiteItr == idToTaskSuite.end()) {
            // Skip results for unknown task suite
            WARN() << "Skip unknown task suite: " << id;
            continue;
        }

        // This can only happen if running the tool on an unfinished pool,
        // which we normally should NOT do.
        if (assignmentResults.size() < taskSuiteItr->second.overlap) {
            INFO() << "Results not ready for task suite " << id;
            continue;
        }
        INFO() << "Merge results for task suite " << id;
        tsResults.push_back(mergeTaskSuiteResults(assignmentResults));
    }
    return tsResults;
}


std::vector<TaskSuiteResult>
loadAndMergeTasksResults(const io::TolokaClient& tolokaClient,
                         const std::string& poolId)
{
    auto idToTaskSuite = loadTaskSuites(tolokaClient, poolId);
    auto idToTaskSuiteResults = loadTaskSuitesResults(tolokaClient, poolId);

    return mergeTasksResults(idToTaskSuite, idToTaskSuiteResults);
}


void evaluateAssignments(const io::TolokaClient& tolokaClient,
                         const TaskSuiteResult& tsResult,
                         bool dryRun)
{
    constexpr double THRESHOLD_PRECISION = 0.8;

    for (const auto& userAndStat : tsResult.userIdToStat) {
        const auto& stat = userAndStat.second;
        if (stat.assignmentStatus != io::AssignmentStatus::Submitted) {
            continue;
        }
        auto precision = (double)stat.correctCount / stat.tasksCount;

        if (precision < THRESHOLD_PRECISION) {
            INFO() << "REJECT assignment " << stat.assignmentId
                   << " from user " << stat.userId << ", precision " << precision;
            if (!dryRun) {
                tolokaClient.rejectAssignment(stat.assignmentId,
                    "Некоторые знаки классифицированы неверно");
            }
        } else {
            INFO() << "ACCEPT assignment " << stat.assignmentId
                   << " from user " << stat.userId;
            if (!dryRun) {
                tolokaClient.acceptAssignment(stat.assignmentId);
            }
        }
    }
}


void writeTaskResults(const std::vector<TaskSuiteResult>& tsResults,
                      const std::string& outputFile)
{
    std::unordered_map<std::string, std::vector<TaskResult>> resultMap;

    for (const auto& tsResult : tsResults) {
        for (const auto& taskResult : tsResult.taskResults) {
            resultMap[taskResult.input.source].push_back(taskResult);
        }
    }

    std::ofstream file(outputFile);
    json::Builder builder(file);

    builder << [&](json::ObjectBuilder b) {
        b[FIELD_CLASSIFIED_SIGNS] << [&](json::ArrayBuilder b) {
            for (const auto& item : resultMap) {
                b << [&](json::ObjectBuilder b) {
                    b[FIELD_SOURCE] = item.first;
                    b[FIELD_SIGNS] = [&](json::ArrayBuilder b) {
                        for (const auto& sign : item.second) {
                            b << [&](json::ObjectBuilder b) {
                                toJson(sign.input.bbox, b);
                                b[FIELD_ANSWER] = toString(sign.output.answer);
                                b[FIELD_SIGN_ID] = sign.output.signId;
                            };
                        }
                    };
                };
            }
        };
    };
}

} // toloka
} // mrc
} // maps
