#include "../include/verify_detection_pair_match.h"
#include "../include/metadata.h"
#include "eye/verification_source.h"
#include "eye/verified_detection_pair_match.h"
#include "toloka/platform.h"

#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/collection.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/pg_locks.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/object_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/verified_detection_pair_match_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/frame_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/recognition_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/toloka_manager/include/detection_pair_match.h>

#include <maps/wikimap/mapspro/services/mrc/libs/toloka_manager/include/toloka_manager.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/toloka/task_gateway.h>

#include <maps/wikimap/mapspro/services/mrc/eye/lib/common/include/id.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/store.h>

#include <maps/libs/log8/include/log8.h>
#include <maps/libs/sql_chemistry/include/exists.h>

#include <util/generic/iterator_range.h>

#include <pqxx/pqxx>

#include <iterator>
#include <set>
#include <list>

namespace maps::mrc::eye {

namespace {

struct Batch {
    db::TId beginTxnId;
    db::TId endTxnId;

    db::TIds matchIds;
};


db::TIds collectDetectionIds(const db::eye::VerifiedDetectionPairMatches& matches) {
    db::TIdSet detectionIds;
    for (const auto& match : matches) {
        detectionIds.insert(match.detectionId1());
        detectionIds.insert(match.detectionId2());
    }

    return {detectionIds.begin(), detectionIds.end()};
}

db::eye::VerifiedDetectionPairMatches getUnprocessedMatches(
    pqxx::transaction_base& txn,
    const db::TIds& matchIds)
{
    return db::eye::VerifiedDetectionPairMatchGateway(txn).load(
        db::eye::table::VerifiedDetectionPairMatch::id.in(matchIds)
        && db::eye::table::VerifiedDetectionPairMatch::approved.isNull()
    );
}


db::IdTo<db::TId> loadMatchIdToTolokaTaskIdMap(
    pqxx::transaction_base& txn,
    const db::eye::VerifiedDetectionPairMatches& matches)
{
    db::TIds matchIds = collectIds(matches);

    auto matchToTolokaItems
        = db::eye::VerifiedDetectionPairMatchToTolokaTaskGateway(txn).load(
            db::eye::table::VerifiedDetectionPairMatchToTolokaTask::verifiedTaskId.in(matchIds)
        );

    db::IdTo<db::TId> matchIdToTolokaTaskId;
    for (const auto& item : matchToTolokaItems) {
        matchIdToTolokaTaskId[item.verifiedTaskId()] = item.tolokaTaskId();
    }

    return matchIdToTolokaTaskId;
}

toloka::DetectionPairMatchInput makeNewTaskInput(
    const DetectionPairMatchVerifierConfig& config,
    const DetectionStore& store,
    const db::eye::VerifiedDetectionPairMatch& match)
{
    const auto frame1 = store.frameByDetectionId(match.detectionId1());
    const auto privacy1 = store.privacyByFrameId(frame1.id()).type();
    const auto detection1 = store.detectionById(match.detectionId1());
    const auto frame2 = store.frameByDetectionId(match.detectionId2());
    const auto privacy2 = store.privacyByFrameId(frame2.id()).type();
    const auto detection2 = store.detectionById(match.detectionId2());

    return {
        config.frameUrlResolver->image(frame1, privacy1),
        common::transformByImageOrientation(
            detection1.box(), frame1.originalSize(), frame1.orientation()
        ),
        config.frameUrlResolver->image(frame2, privacy2),
        common::transformByImageOrientation(
            detection2.box(), frame2.originalSize(), frame2.orientation()
        )
    };
}


db::toloka::Platform evalTolokaPlatform(db::eye::VerificationSource source)
{
    switch(source) {
        case db::eye::VerificationSource::Toloka: return db::toloka::Platform::Toloka;
        case db::eye::VerificationSource::Yang: return db::toloka::Platform::Yang;
        default: REQUIRE(false, "Unsupported source " << source);
    }
}

size_t createNewTasks(
    pqxx::transaction_base& txn,
    const DetectionPairMatchVerifierConfig& config,
    const DetectionStore& store,
    db::eye::VerificationSource source,
    const db::eye::VerifiedDetectionPairMatches& matches,
    const db::IdTo<db::TId>& matchIdToTolokaTaskId)
{
    toloka::DetectionPairMatchInputs newTaskInputs;
    std::unordered_map<size_t, db::TId> newTaskIndxToMatchId;
    for (const auto& match : matches) {
        if (matchIdToTolokaTaskId.count(match.id()) != 0) {
            continue;
        }

        newTaskIndxToMatchId[newTaskInputs.size()] = match.id();
        newTaskInputs.push_back(makeNewTaskInput(config, store, match));
    }

    const auto tolokaPlatform = evalTolokaPlatform(source);

    auto tolokaTasks = toloka::createTasks<toloka::DetectionPairMatchTask>(
        txn, tolokaPlatform, newTaskInputs
    );

    db::eye::VerifiedDetectionPairMatchesToTolokaTasks matchesToTolokaTasks;
    for (const auto& [taskIndx, matchId] : newTaskIndxToMatchId) {
        matchesToTolokaTasks.emplace_back(matchId, tolokaTasks[taskIndx].id());
    }

    db::eye::VerifiedDetectionPairMatchToTolokaTaskGateway(txn).insert(matchesToTolokaTasks);

    return tolokaTasks.size();
}

size_t checkCompletedTasks(
    pqxx::transaction_base& txn,
    const db::eye::VerifiedDetectionPairMatches& matches,
    const db::IdTo<db::TId>& matchIdToTolokaTaskId)
{
    db::TIds tolokaTaskIds;
    for (const auto& [matchId, tolokaTaskId] : matchIdToTolokaTaskId) {
        tolokaTaskIds.push_back(tolokaTaskId);
    }
    db::IdTo<db::eye::VerifiedDetectionPairMatch> matchById = byId(matches);

    auto tolokaTaskById = byId(db::toloka::TaskGateway(txn).loadByIds(tolokaTaskIds));

    db::eye::VerifiedDetectionPairMatches completedMatches;
    for (const auto& match : matches) {
        if (matchIdToTolokaTaskId.count(match.id()) == 0) {
            continue;
        }

        const auto& tolokaTask = tolokaTaskById.at(matchIdToTolokaTaskId.at(match.id()));

        if (!tolokaTask.outputValues().has_value()) {
            continue;
        }

        const auto output = toloka::parseJson<toloka::DetectionPairMatchOutput>(
            json::Value::fromString(tolokaTask.outputValues().value())
        );

        db::eye::VerifiedDetectionPairMatch completedMatch = match;

        if (output.result() == toloka::DetectionPairMatchResult::Yes) {
            completedMatch.setApproved(true);
        } else {
            completedMatch.setApproved(false);
        }

        completedMatches.push_back(completedMatch);
    }

    db::eye::VerifiedDetectionPairMatchGateway(txn).updatex(completedMatches);

    return completedMatches.size();
}

size_t processMatchesInToloka(
    pqxx::transaction_base& txn,
    const DetectionPairMatchVerifierConfig& config,
    const DetectionStore& store,
    db::eye::VerificationSource source,
    db::eye::VerifiedDetectionPairMatches matches)
{
    size_t updatesNumber = 0;

    auto matchIdToTolokaTaskId = loadMatchIdToTolokaTaskIdMap(txn, matches);

    updatesNumber += createNewTasks(txn, config, store, source, matches, matchIdToTolokaTaskId);
    updatesNumber += checkCompletedTasks(txn, matches, matchIdToTolokaTaskId);

    return updatesNumber;
}

db::TIds loadMatchTxnIdsBatch(
    pqxx::transaction_base& txn,
    db::TId beginTxnId,
    size_t limit)
{
    return db::eye::VerifiedDetectionPairMatchGateway(txn).loadTxnIds(
        db::eye::table::VerifiedDetectionPairMatch::txnId >= beginTxnId,
        sql_chemistry::limit(limit)
            .orderBy(db::eye::table::VerifiedDetectionPairMatch::txnId)
    );
}

db::TIds loadTaskTxnIdsBatch(
    pqxx::transaction_base& txn,
    db::TId beginTxnId,
    size_t limit)
{
    return db::toloka::TaskGateway(txn).loadTxnIds(
        db::toloka::table::Task::txnId >= beginTxnId,
        sql_chemistry::limit(limit)
            .orderBy(db::toloka::table::Task::txnId)
    );
}

db::TIds loadMatchIds(
    pqxx::transaction_base& txn,
    db::TId beginTxnId,
    db::TId endTxnId)
{
    db::TIds matchIds = db::eye::VerifiedDetectionPairMatchGateway(txn).loadIds(
        db::eye::table::VerifiedDetectionPairMatch::txnId >= beginTxnId
        && db::eye::table::VerifiedDetectionPairMatch::txnId < endTxnId
    );

    db::TIds taskIds = db::toloka::TaskGateway(txn).loadIds(
        db::toloka::table::Task::txnId >= beginTxnId
        && db::toloka::table::Task::txnId < endTxnId
    );
    auto matchesToTasks = db::eye::VerifiedDetectionPairMatchToTolokaTaskGateway(txn).load(
        db::eye::table::VerifiedDetectionPairMatchToTolokaTask::tolokaTaskId.in(taskIds)
    );
    for (const auto& matchToTask : matchesToTasks) {
        matchIds.push_back(matchToTask.verifiedTaskId());
    }

    return matchIds;
}

Batch loadBatch(
    pqxx::transaction_base& txn,
    db::TId beginTxnId,
    size_t limit)
{
    Batch batch;
    batch.beginTxnId = beginTxnId;
    auto matchTxnIds = loadMatchTxnIdsBatch(txn, beginTxnId, limit);
    auto taskTxnIds = loadTaskTxnIdsBatch(txn, beginTxnId, limit);

    db::TIds txnIds = std::move(matchTxnIds);
    txnIds.insert(txnIds.end(), taskTxnIds.begin(), taskTxnIds.end());

    std::sort(txnIds.begin(), txnIds.end());

    if (txnIds.empty()) {
        batch.endTxnId = beginTxnId;
        return batch;
    }

    if (txnIds.size() > limit) {
        auto it = txnIds.begin();
        std::advance(it, limit - 1);
        batch.endTxnId = *it + 1;
    } else {
        batch.endTxnId = txnIds.back() + 1;
    }

    batch.matchIds = loadMatchIds(txn, batch.beginTxnId, batch.endTxnId);

    return batch;
}

std::map<db::eye::VerificationSource, db::eye::VerifiedDetectionPairMatches>
groupByVerificationSource(db::eye::VerifiedDetectionPairMatches matches)
{
    std::map<db::eye::VerificationSource, db::eye::VerifiedDetectionPairMatches>
        result;

    for (auto& match : matches) {
        auto source = match.source();
        result[source].push_back(std::move(match));
    }

    return result;
}

} // namespace

size_t DetectionPairMatchVerifier::processMatches(
    pqxx::transaction_base& txn,
    const db::TIds& matchIds)
{
    auto matches = getUnprocessedMatches(txn, matchIds);

    DetectionStore store;
    store.extendByDetectionIds(
        txn,
        collectDetectionIds(matches)
    );


    const auto sourceToMatchesMap = groupByVerificationSource(std::move(matches));

    size_t updatesNumber = 0;

    for (const auto& [source, matches] : sourceToMatchesMap) {
        updatesNumber += processMatchesInToloka(
            txn,
            config_,
            store,
            source,
            matches
        );
    }



    return updatesNumber;
}

void DetectionPairMatchVerifier::processBatch(const db::TIds& matchIds) {
    auto lock = lockIfNeed();

    auto writeTxn = getMasterWriteTxn(*(config_.mrc.pool));
    processMatches(*writeTxn, matchIds);

    commitIfNeed(*writeTxn);
}

bool DetectionPairMatchVerifier::processBatchInLoopMode(size_t batchSize) {
    const auto lock = lockIfNeed();

    Batch batch;
    {
        auto readTxn = getSlaveTxn();
        auto metadata = verifyDetectionPairMatchMetadata(*readTxn);

        batch = loadBatch(*readTxn, metadata.getTxnId(), batchSize);
        INFO() << "Batch match size " << batch.matchIds.size()
               << " [" << batch.beginTxnId << ", " << batch.endTxnId << ")";
    }

    auto writeTxn = getMasterWriteTxn(*(config_.mrc.pool));
    const size_t updatesNumber = processMatches(*writeTxn, batch.matchIds);

    auto metadata = verifyDetectionPairMatchMetadata(*writeTxn);
    metadata.updateTime();
    metadata.updateTxnId(batch.endTxnId);

    commitIfNeed(*writeTxn);

    return updatesNumber > 0;
}

} // namespace maps::mrc::eye
