#include "fixture.h"

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/toloka/mds_file_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/toloka/task.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/toloka/task_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/toloka/task_type_info_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/toloka/toloka_task_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/toloka/toloka_task_suite_gateway.h>

#include <library/cpp/testing/gtest/gtest.h>
#include <maps/libs/introspection/include/comparison.h>
#include <maps/libs/introspection/include/stream_output.h>
#include <maps/libs/log8/include/log8.h>
#include <yandex/maps/wiki/common/date_time.h>
#include <yandex/maps/wiki/common/misc.h>

#include <algorithm>
#include <chrono>
#include <set>
#include <tuple>
#include <utility>
#include <vector>

using maps::chrono::parseSqlDateTime;

namespace maps::mrc::db::toloka::tests {
using namespace ::testing;

using introspection::operator<<;
using introspection::operator==;
using introspection::operator!=;

namespace {

struct TaskTypeInfoParams {
    Platform platform;
    TaskType type;
    TId activePoolId;
    size_t activePoolSize;
};

struct TaskParams {
    Platform platform;
    TId id;
    TaskType type;
    TaskStatus status;
    std::string inputValues;
    std::optional<std::string> outputValues;
    int overlap;
    chrono::TimePoint createdAt;
    OptionalTimePoint postedAt;
    OptionalTimePoint solvedAt;
    std::optional<std::string> knownSolutions;
    std::optional<std::string> messageOnUnknownSolution;
};

struct TolokaTaskParams {
    Platform platform;
    TId taskSuiteId;
    int taskIndex;
    TId taskId;
    std::optional<std::string> tolokaId;
};

struct MdsFileParams {
    TId id;
    TId taskId;
    std::string mdsGroupId;
    std::string mdsPath;
};

struct TolokaTaskSuiteParams {
    Platform platform;
    TId id;
    std::string tolokaId;
    TId tolokaPoolId;
    int overlap;
    chrono::TimePoint createdAt;
    OptionalTimePoint solvedAt;
};

bool isEqual(const TaskTypeInfo& ti, const TaskTypeInfoParams& tip) {
    return ti.type() == tip.type &&
        ti.platform() == tip.platform &&
        ti.activePoolId() == tip.activePoolId &&
        ti.activePoolSize() == tip.activePoolSize;
}

bool isEqual(const Task& t, const TaskParams& tp) {
    return t.id() == tp.id &&
        t.platform() == tp.platform &&
        t.type() == tp.type &&
        t.status() == tp.status &&
        t.inputValues() == tp.inputValues &&
        t.outputValues() == tp.outputValues &&
        t.overlap() == tp.overlap &&
        t.createdAt() == tp.createdAt &&
        t.postedAt() == tp.postedAt &&
        t.solvedAt() == tp.solvedAt;
}

bool isEqual(const TolokaTask& t, const TolokaTaskParams& tp) {
    return t.taskSuiteId() == tp.taskSuiteId &&
        t.platform() == tp.platform &&
        t.taskIndex() == tp.taskIndex &&
        t.taskId() == tp.taskId &&
        t.tolokaId() == tp.tolokaId;
}

bool isEqual(const TolokaTaskSuite& ts, const TolokaTaskSuiteParams& tsp) {
    return ts.id() == tsp.id &&
        ts.platform() == tsp.platform &&
        ts.tolokaId() == tsp.tolokaId &&
        ts.tolokaPoolId() == tsp.tolokaPoolId &&
        ts.overlap() == tsp.overlap &&
        ts.createdAt() == tsp.createdAt &&
        ts.solvedAt() == tsp.solvedAt;
}

bool isEqual(const MdsFile& mf, const MdsFileParams& mfp) {
    return mf.id() == mfp.id &&
        mf.taskId() == mfp.taskId &&
        mf.mdsGroupId() == mfp.mdsGroupId &&
        mf.mdsPath() == mfp.mdsPath;
}

template <typename T>
void sortByIds(std::vector<T>& vec) {
    std::sort(vec.begin(), vec.end(),
        [](const T& lhs, const T& rhs) { return lhs.id() < rhs.id(); });
}

} // namespace

using Fixture = maps::mrc::db::tests::Fixture;

TEST_F(Fixture, toloka_tests_test_task_type_info) {
    {
        auto txn = txnHandle();
        TaskTypeInfoGateway gtw{*txn};
        auto info1 = gtw.loadById(Platform::Toloka, TaskType::ImageQualityClassification);
        auto info2 = gtw.loadById(Platform::Toloka, TaskType::TrafficLightDetection);

        EXPECT_TRUE(isEqual(info1, TaskTypeInfoParams{
            Platform::Toloka,
            TaskType::ImageQualityClassification, 0, 0}));
        EXPECT_TRUE(isEqual(info2, TaskTypeInfoParams{
            Platform::Toloka,
            TaskType::TrafficLightDetection, 0, 0}));

        gtw.update(info1.setActivePoolId(1).setActivePoolSize(10));
        gtw.update(info2.setActivePoolId(2).setActivePoolSize(20));
        txn->commit();
    }

    {
        auto txn = txnHandle();
        TaskTypeInfoGateway gtw{*txn};
        auto info1 = gtw.loadById(Platform::Toloka, TaskType::ImageQualityClassification);
        auto info2 = gtw.loadById(Platform::Toloka, TaskType::TrafficLightDetection);

        EXPECT_TRUE(isEqual(info1, TaskTypeInfoParams{
            Platform::Toloka,
            TaskType::ImageQualityClassification, 1, 10}));
        EXPECT_TRUE(isEqual(info2, TaskTypeInfoParams{
            Platform::Toloka,
            TaskType::TrafficLightDetection, 2, 20}));
    }
}

TEST_F(Fixture, toloka_tests_test_task) {
    const auto CREATED_AT = parseSqlDateTime("2017-01-01 00:00:00+03");
    const auto POSTED_AT = parseSqlDateTime("2017-01-02 00:00:00+03");
    const auto SOLVED_AT = parseSqlDateTime("2017-01-03 00:00:00+03");
    const std::string KNOWN_SOLUTIONS = R"(
[
{
    "output_values":{
        "colour":"black"
    },
    "correctness_weight":0.7
},
{
     "output_values":{
        "colour":"gray"
     },
     "correctness_weight":0.95
  }
]
)";
    const std::string MESSAGE_ON_UNKNOWN_SOLUTION = "Слон серый";

    // Create tasks
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};

        Task task1(Platform::Toloka);
        task1.setType(TaskType::ImageQualityClassification)
            .setStatus(TaskStatus::New)
            .setInputValues("input-values-1")
            .setOverlap(3)
            .setCreatedAt(CREATED_AT);

        Task task2(Platform::Toloka);
        task2.setType(TaskType::ImageQualityClassification)
            .setStatus(TaskStatus::InProgress)
            .setInputValues("input-values-2")
            .setOverlap(3)
            .setCreatedAt(CREATED_AT)
            .setPostedAt(POSTED_AT)
            .setKnownSolutions(KNOWN_SOLUTIONS)
            .setMessageOnUnknownSolution(MESSAGE_ON_UNKNOWN_SOLUTION);

        Task task3(Platform::Toloka);
        task3.setType(TaskType::TrafficLightDetection)
            .setStatus(TaskStatus::Finished)
            .setInputValues("input-values-3")
            .setOutputValues("output-values-3")
            .setOverlap(3)
            .setCreatedAt(CREATED_AT)
            .setPostedAt(POSTED_AT)
            .setSolvedAt(SOLVED_AT);

        Tasks tasks{task1, task2, task3};
        gtw.insertx(tasks);
        txn->commit();
    }

    // Load task ids and tasks
    TId taskId;
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};
        auto detectionIds = gtw.loadIdsByType(Platform::Toloka, TaskType::ImageQualityClassification);
        EXPECT_EQ(detectionIds.size(), 2u);

        auto classificationIds
            = gtw.loadIdsByType(Platform::Toloka, TaskType::TrafficLightDetection);
        EXPECT_EQ(classificationIds.size(), 1u);

        auto detectionNewIds
            = gtw.loadIdsByTypeStatus(Platform::Toloka, TaskType::ImageQualityClassification, TaskStatus::New);
        EXPECT_EQ(detectionNewIds.size(), 1u);
        taskId = detectionNewIds[0];

        auto tasks = gtw.loadByIds(detectionIds);

        EXPECT_EQ(tasks.size(), 2u);
        std::sort(tasks.begin(), tasks.end(),
            [](const Task& lhs, const Task& rhs) {
                return lhs.inputValues() < rhs.inputValues();
            });

        EXPECT_TRUE(isEqual(tasks[0], TaskParams{
            Platform::Toloka,
            tasks[0].id(), TaskType::ImageQualityClassification, TaskStatus::New,
            "input-values-1", std::nullopt, 3,
            CREATED_AT, std::nullopt, std::nullopt, std::nullopt, std::nullopt}));

        EXPECT_TRUE(isEqual(tasks[1], TaskParams{
            Platform::Toloka,
            tasks[1].id(), TaskType::ImageQualityClassification, TaskStatus::InProgress,
            "input-values-2", std::nullopt, 3,
            CREATED_AT, POSTED_AT, std::nullopt,
            KNOWN_SOLUTIONS, MESSAGE_ON_UNKNOWN_SOLUTION}));
    }

    // Update task
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};

        auto task = gtw.loadById(taskId);

        EXPECT_TRUE(isEqual(task, TaskParams{
            Platform::Toloka,
            taskId, TaskType::ImageQualityClassification, TaskStatus::New,
            "input-values-1", std::nullopt, 3,
            CREATED_AT, std::nullopt, std::nullopt, std::nullopt, std::nullopt}));

        task.setStatus(TaskStatus::InProgress)
            .setPostedAt(POSTED_AT)
            .setOverlap(5);

        gtw.updatex(task);
        txn->commit();
    }

    // Update task again
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};
        auto task = gtw.loadById(taskId);
        task.setStatus(TaskStatus::Finished).setSolvedAt(SOLVED_AT);
        gtw.updatex(task);
        txn->commit();
    }

    // Check that task was successfully updated in db
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};
        auto task = gtw.loadById(taskId);
        EXPECT_TRUE(isEqual(task, TaskParams{
            Platform::Toloka,
            taskId, TaskType::ImageQualityClassification, TaskStatus::Finished,
            "input-values-1", std::nullopt, 5,
            CREATED_AT, POSTED_AT, SOLVED_AT, std::nullopt, std::nullopt}));
    }


    // Update multiple tasks
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};
        auto ids = gtw.loadIdsByType(Platform::Toloka, TaskType::ImageQualityClassification);
        auto tasks = gtw.loadByIds(ids);
        sortByIds(tasks);
        EXPECT_EQ(tasks.size(), 2u);
        tasks[0].setOverlap(5).setStatus(TaskStatus::Free);
        tasks[1].setOverlap(5);
        gtw.updatex(tasks);
        txn->commit();
    }

    // Check the updates
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};
        auto ids = gtw.loadIdsByType(Platform::Toloka, TaskType::ImageQualityClassification);
        auto tasks = gtw.loadByIds(ids);
        sortByIds(tasks);

        EXPECT_EQ(tasks[0].overlap(), 5);
        EXPECT_EQ(tasks[0].status(), TaskStatus::Free);
        EXPECT_EQ(tasks[1].overlap(), 5);
    }

    // Delete sign-detection tasks
    TIds taskIdsToDelete;
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};

        taskIdsToDelete = gtw.loadIdsByType(Platform::Toloka, TaskType::ImageQualityClassification);
        EXPECT_EQ(taskIdsToDelete.size(), 2u);

        gtw.removeByIds(taskIdsToDelete);
        txn->commit();
    }

    // Check that tasks were deleted from db
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};
        auto taskIds = gtw.loadIdsByType(Platform::Toloka, TaskType::ImageQualityClassification);
        EXPECT_TRUE(taskIds.empty());
    }

    // Try to delete already deleted tasks again
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};

        EXPECT_EQ(gtw.removeByIds(taskIdsToDelete), 0u);

        // sign-classification have not been deleted:
        auto taskIds = gtw.loadIdsByType(Platform::Toloka, TaskType::TrafficLightDetection);
        EXPECT_EQ(taskIds.size(), 1u);

        gtw.removeById(taskIds[0]);
        txn->commit();
    }
}

TEST_F(Fixture, toloka_tests_test_toloka_task_suite) {
    const auto CREATED_AT = parseSqlDateTime("2017-01-01 00:00:00+03");
    const auto SOLVED_AT = parseSqlDateTime("2017-01-03 00:00:00+03");

    // Create task suites
    {
        auto txn = txnHandle();
        TolokaTaskSuiteGateway gtw{*txn};

        TolokaTaskSuite taskSuite1(Platform::Toloka);
        taskSuite1.setTolokaId("toloka-task-suite-1")
            .setTolokaPoolId(1)
            .setOverlap(3)
            .setCreatedAt(CREATED_AT);

        TolokaTaskSuite taskSuite2(Platform::Toloka);
        taskSuite2.setTolokaId("toloka-task-suite-2")
            .setTolokaPoolId(2)
            .setOverlap(3)
            .setCreatedAt(CREATED_AT);

        gtw.insert(taskSuite1);
        gtw.insert(taskSuite2);
        txn->commit();
    }

    // Load task suite ids and task suites
    TId taskSuiteId;
    {
        auto txn = txnHandle();
        TolokaTaskSuiteGateway gtw{*txn};
        auto taskSuiteIds = gtw.loadIds();
        EXPECT_EQ(taskSuiteIds.size(), 2u);

        auto taskSuites = gtw.loadByIds(taskSuiteIds);
        EXPECT_EQ(taskSuites.size(), 2u);
        std::sort(taskSuites.begin(), taskSuites.end(),
            [](const TolokaTaskSuite& lhs, const TolokaTaskSuite& rhs) {
                return lhs.tolokaId() < rhs.tolokaId();
            });

        taskSuiteId = taskSuites[0].id();

        EXPECT_TRUE(isEqual(taskSuites[0], TolokaTaskSuiteParams{
            Platform::Toloka,
            taskSuites[0].id(), "toloka-task-suite-1", 1, 3, CREATED_AT, std::nullopt}));

        EXPECT_TRUE(isEqual(taskSuites[1], TolokaTaskSuiteParams{
            Platform::Toloka,
            taskSuites[1].id(), "toloka-task-suite-2", 2, 3, CREATED_AT, std::nullopt}));
    }

    // Update task suite
    {
        auto txn = txnHandle();
        TolokaTaskSuiteGateway gtw{*txn};

        auto taskSuite = gtw.loadById(taskSuiteId);
        taskSuite.setSolvedAt(SOLVED_AT);
        gtw.update(taskSuite);
        txn->commit();
    }

    // Check that task suite was successfully updated in db
    {
        auto txn = txnHandle();
        TolokaTaskSuiteGateway gtw{*txn};
        auto solvedTaskSuiteIds = gtw.loadIds(table::TolokaTaskSuite::solvedAt.isNotNull());
        EXPECT_EQ(solvedTaskSuiteIds.size(), 1u);

        auto taskSuite = gtw.loadById(solvedTaskSuiteIds[0]);
        EXPECT_TRUE(isEqual(taskSuite, TolokaTaskSuiteParams{
            Platform::Toloka,
            taskSuiteId, "toloka-task-suite-1", 1, 3, CREATED_AT, SOLVED_AT}));
    }

    // Delete task suites
    {
        auto txn = txnHandle();
        TolokaTaskSuiteGateway gtw{*txn};
        auto taskSuiteIds = gtw.loadIds();

        gtw.removeByIds(taskSuiteIds);
        txn->commit();
    }

    // Check that task suites were deleted from db
    {
        auto txn = txnHandle();
        TolokaTaskSuiteGateway gtw{*txn};
        auto taskSuiteIds = gtw.loadIds();
        EXPECT_TRUE(taskSuiteIds.empty());
    }
}

TEST_F(Fixture, toloka_tests_test_toloka_task) {
    const auto CREATED_AT = parseSqlDateTime("2017-01-01 00:00:00+03");
    const auto POSTED_AT = parseSqlDateTime("2017-01-02 00:00:00+03");

    TId taskId1, taskId2, taskSuiteId;
    const std::string tolokaId1 = "toloka_1";
    const std::string tolokaId2 = "toloka_2";
    const std::string tolokaId3 = "toloka_3";

    // Preparation: create tasks and task suites
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};

        Task task1(Platform::Toloka);
        task1.setType(TaskType::ImageQualityClassification)
            .setStatus(TaskStatus::InProgress)
            .setInputValues("input-values-1")
            .setOverlap(3)
            .setCreatedAt(CREATED_AT)
            .setCreatedAt(POSTED_AT);

        Task task2(Platform::Toloka);
        task2.setType(TaskType::ImageQualityClassification)
            .setStatus(TaskStatus::InProgress)
            .setInputValues("input-values-2")
            .setOverlap(3)
            .setCreatedAt(CREATED_AT)
            .setPostedAt(POSTED_AT);

        gtw.insertx(task1);
        gtw.insertx(task2);

        TolokaTaskSuite taskSuite(Platform::Toloka);
        taskSuite.setTolokaId("toloka-task-suite-1")
            .setTolokaPoolId(1)
            .setOverlap(3)
            .setCreatedAt(CREATED_AT);

        TolokaTaskSuiteGateway{*txn}.insert(taskSuite);
        txn->commit();

        taskId1 = task1.id();
        taskId2 = task2.id();
        taskSuiteId = taskSuite.id();
    }

    // Create toloka tasks
    {
        auto txn = txnHandle();
        TolokaTaskGateway gtw{*txn};

        TolokaTasks tolokaTasks{
            TolokaTask(Platform::Toloka, taskSuiteId, 1, taskId1, tolokaId1),
            TolokaTask(Platform::Toloka, taskSuiteId, 2, taskId2, tolokaId2),
            TolokaTask(Platform::Toloka, taskSuiteId, 3, taskId2, tolokaId3)
        };

        gtw.insert(tolokaTasks);
        txn->commit();
    }

    // Load toloka tasks
    {
        auto txn = txnHandle();
        TolokaTaskGateway gtw{*txn};

        auto tolokaTasks = gtw.loadByTaskIds(TIds{taskId1, taskId2});

        EXPECT_EQ(tolokaTasks.size(), 3u);
        std::sort(tolokaTasks.begin(), tolokaTasks.end(),
            [](const TolokaTask& lhs, const TolokaTask& rhs) {
                return lhs.taskIndex() < rhs.taskIndex();
            });

        EXPECT_TRUE(isEqual(tolokaTasks[0],
            TolokaTaskParams{Platform::Toloka, taskSuiteId, 1, taskId1, tolokaId1}));
        EXPECT_TRUE(isEqual(tolokaTasks[1],
            TolokaTaskParams{Platform::Toloka, taskSuiteId, 2, taskId2, tolokaId2}));
        EXPECT_TRUE(isEqual(tolokaTasks[2],
            TolokaTaskParams{Platform::Toloka, taskSuiteId, 3, taskId2, tolokaId3}));
    }

    // Delete toloka tasks
    {
        auto txn = txnHandle();
        TolokaTaskGateway gtw{*txn};

        gtw.removeByTaskIds(TIds{taskId1, taskId2});
        txn->commit();
    }

    // Check that toloka tasks were deleted from db
    {
        auto txn = txnHandle();
        TolokaTaskGateway gtw{*txn};
        auto tolokaTasks = gtw.loadByTaskIds(TIds{taskId1, taskId2});
        EXPECT_TRUE(tolokaTasks.empty());
    }
}

TEST_F(Fixture, toloka_tests_test_mds_file) {
    const auto CREATED_AT = parseSqlDateTime("2017-01-01 00:00:00+03");
    const auto POSTED_AT = parseSqlDateTime("2017-01-02 00:00:00+03");

    const std::string MDS_GROUP_1 = "mds-group-1";
    const std::string MDS_GROUP_2 = "mds-group-2";
    const std::string MDS_GROUP_3 = "mds-group-3";

    const std::string MDS_PATH_1 = "path/1";
    const std::string MDS_PATH_2 = "path/2";
    const std::string MDS_PATH_3 = "path/3";

    TId taskId1, taskId2;
    // Preparation: create tasks and task suites
    {
        auto txn = txnHandle();
        TaskGateway gtw{*txn};

        Task task1(Platform::Toloka);
        task1.setType(TaskType::ImageQualityClassification)
            .setStatus(TaskStatus::InProgress)
            .setInputValues("input-values-1")
            .setOverlap(3)
            .setCreatedAt(CREATED_AT)
            .setCreatedAt(POSTED_AT);

        Task task2(Platform::Toloka);
        task2.setType(TaskType::ImageQualityClassification)
            .setStatus(TaskStatus::InProgress)
            .setInputValues("input-values-2")
            .setOverlap(3)
            .setCreatedAt(CREATED_AT)
            .setPostedAt(POSTED_AT);

        gtw.insertx(task1);
        gtw.insertx(task2);
        txn->commit();

        taskId1 = task1.id();
        taskId2 = task2.id();
    }

    // Create mds files
    {
        auto txn = txnHandle();
        MdsFileGateway gtw{*txn};

        MdsFiles mdsFiles{
            MdsFile(taskId1, MDS_GROUP_1, MDS_PATH_1),
            MdsFile(taskId1, MDS_GROUP_2, MDS_PATH_2),
            MdsFile(taskId2, MDS_GROUP_3, MDS_PATH_3)
        };

        gtw.insert(mdsFiles);
        txn->commit();
    }

    // Load mds files
    {
        auto txn = txnHandle();
        MdsFileGateway gtw{*txn};

        auto mdsFiles = gtw.loadByTaskIds(TIds{taskId1, taskId2});

        EXPECT_EQ(mdsFiles.size(), 3u);
        std::sort(mdsFiles.begin(), mdsFiles.end(),
            [](const MdsFile& lhs, const MdsFile& rhs) {
                return lhs.mdsGroupId() < rhs.mdsGroupId();
            });

        EXPECT_TRUE(isEqual(mdsFiles[0],
            MdsFileParams{mdsFiles[0].id(),
                taskId1,
                MDS_GROUP_1,
                MDS_PATH_1}));

        EXPECT_TRUE(isEqual(mdsFiles[1],
            MdsFileParams{mdsFiles[1].id(),
                taskId1,
                MDS_GROUP_2,
                MDS_PATH_2}));
        EXPECT_TRUE(isEqual(mdsFiles[2],
            MdsFileParams{mdsFiles[2].id(),
                taskId2,
                MDS_GROUP_3,
                MDS_PATH_3}));
    }

    // Delete mds files
    {
        auto txn = txnHandle();
        MdsFileGateway gtw{*txn};

        gtw.removeByTaskIds(TIds{taskId1, taskId2});
        txn->commit();
    }

    // Check that mds files were deleted from db
    {
        auto txn = txnHandle();
        MdsFileGateway gtw{*txn};
        auto mdsFiles = gtw.loadByTaskIds(TIds{taskId1, taskId2});
        EXPECT_TRUE(mdsFiles.empty());
    }
}

TEST_F(Fixture, toloka_tests_statuses)
{
    std::vector<Task> tasks;
    for (int status = static_cast<int>(TaskStatus::Min);
        status <= static_cast<int>(TaskStatus::Max); ++status)
    {
        tasks.push_back(Task(Platform::Toloka)
            .setType(TaskType::ImageQualityClassification)
            .setStatus(static_cast<TaskStatus>(status)));
    }
    auto txn = txnHandle();
    TaskGateway gtw{*txn};
    gtw.insertx(tasks);
}

} // namespace maps::mrc::db::toloka::tests
