#include "verification.h"
#include "../include/verification_policy.h"
#include "common.h"
#include "passage.h"

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/verified_detection_pair_match_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/verified_detection_missing_on_frame_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/for_each_batch.h>

#include <util/generic/iterator_range.h>

#include <optional>
#include <set>

namespace maps::mrc::eye {

namespace {


VerificationAction
evalVerificationAction(const VerificationRules& rules,
                       const privacy::GeoIdProvider& geoidProvider,
                       const geolib3::Point2& geodeticPos,
                       const db::eye::Object& object)
{
    VerificationAction result{};
    db::TIds geoIds = geoidProvider.load(geodeticPos);
    for (auto geoId : geoIds)
    {
        for (const auto& rule : rules) {
            if (rule.geoId == geoId && rule.objectType == object.type()
                    && rule.objectPredicate(object)) {
                result = merge(result, rule.verificationAction);
            }
        }
    }
    return result;
}

db::IdTo<VerificationAction>
makeObjectIdToVerificationActionMap(
    const VerificationRules& verificationRules,
    const privacy::GeoIdProvider& geoidProvider,
    const DetectionStore& detectionStore,
    const db::eye::Objects& objects
)
{
    db::IdTo<VerificationAction> result;
    for (const auto& object: objects) {
        const auto pos = detectionStore.locationByDetectionId(object.primaryDetectionId()).geodeticPos();
        auto verificationAction = evalVerificationAction(
            verificationRules,
            geoidProvider,
            pos,
            object);

        result.emplace(object.id(), std::move(verificationAction));
    }
    return result;
}

db::IdTo<const db::eye::Object&>
makeDetectionIdToObjectMap(
    const db::IdTo<db::TIdSet>& clusters,
    const db::eye::Objects& objects
)
{
    db::IdTo<const db::eye::Object&> result;
    for (const auto& object : objects) {
        auto clusterIt = clusters.find(object.primaryDetectionId());
        if (clusterIt == clusters.end()) {
            continue;
        }
        result.emplace(clusterIt->first, object);
        for (auto id : clusterIt->second) {
            result.emplace(id, object);
        }
    }
    return result;
}


db::eye::VerificationSource
selectVerificationSource(db::FeaturePrivacy privacy)
{
    if (privacy == db::FeaturePrivacy::Public) {
        return db::eye::VerificationSource::Toloka;
    } else {
        return db::eye::VerificationSource::Yang;
    }
}

void deleteExistingRequests(
    pqxx::transaction_base& txn,
    db::eye::VerifiedDetectionMissingOnFrames& requests)
{
    static constexpr size_t BATCH_SIZE = 1000;
    using table = db::eye::table::VerifiedDetectionMissingOnFrame;
    using RequestKeyInfo = std::tuple<db::eye::VerificationSource, db::TId, db::TId>;
    std::set<RequestKeyInfo> existingRequestsSet;

    auto makeKey = [](const db::eye::VerifiedDetectionMissingOnFrame& request) {
        return std::make_tuple(
            request.source(), request.detectionId(), request.frameId());
    };

    common::forEachBatch(requests, BATCH_SIZE,
        [&](auto beginIt, auto endIt) {
            sql_chemistry::FiltersCollection filter{sql_chemistry::op::Logical::Or};
            for (const auto& request : MakeIteratorRange(beginIt, endIt)) {
                filter.add(table::source == request.source() &&
                    table::detectionId == request.detectionId() &&
                    table::frameId == request.frameId());
            }

            const auto existingRequests =
                db::eye::VerifiedDetectionMissingOnFrameGateway{txn}.load(filter);

            for (const auto& existingRequest : existingRequests) {
                existingRequestsSet.insert(makeKey(existingRequest));
            }
        }
    );

    requests.erase(
        std::remove_if(
            requests.begin(),
            requests.end(),
            [&](const auto& request) {
                return existingRequestsSet.count(makeKey(request));
            }
        ),
        requests.end()
    );
}

db::eye::VerifiedDetectionPairMatch
makeVerifiedDetectionPairMatch(
    const DetectionStore& detectionStore,
    db::TId detectionId1,
    db::TId detectionId2)
{
    auto privacy1 = detectionStore.privacyByDetectionId(detectionId1).type();
    auto privacy2 = detectionStore.privacyByDetectionId(detectionId2).type();
    auto privacy = db::selectStricterPrivacy(privacy1, privacy2);
    return db::eye::VerifiedDetectionPairMatch{
        selectVerificationSource(privacy),
        detectionId1,
        detectionId2};
}


db::eye::VerifiedDetectionPairMatches createDetectionPairMatchingRequests(
    pqxx::transaction_base& txn,
    const db::IdTo<VerificationAction>& objectIdToVerificationAction,
    const db::IdTo<const db::eye::Object&>& detectionIdToObject,
    const DetectionStore& detectionStore,
    const VerificationRequestsIndex& existingRequests,
    const std::vector<MatchedObjects>& matchedObjectsVec
)
{
    VerificationRequestsIndex newRequests;
    db::eye::VerifiedDetectionPairMatches result;

    auto createVerificationRequestIfNeeded =
        [&](db::TId detectionId1, db::TId detectionId2)
        {
            bool requestExists = existingRequests.contains(detectionId1, detectionId2) ||
                newRequests.contains(detectionId1, detectionId2);

            if (!requestExists) {
                result.push_back(makeVerifiedDetectionPairMatch(
                    detectionStore, detectionId1, detectionId2));
                newRequests.set(detectionId1, detectionId2, std::nullopt);
            }
        };

    for (const auto& matchedObjects : matchedObjectsVec) {
        const db::TId detectionId1 = matchedObjects.objectPassageIndx1.first;
        const db::TId detectionId2 = matchedObjects.objectPassageIndx2.first;
        if (detectionId1 == detectionId2) {
            continue;
        }

        const auto& object1 = detectionIdToObject.at(detectionId1);
        const auto& object2 = detectionIdToObject.at(detectionId2);

        const auto action1 = objectIdToVerificationAction.at(object1.id());
        const auto action2 = objectIdToVerificationAction.at(object2.id());
        const auto mergedAction = merge(action1, action2);

        if (object1.id() == object2.id() && action1.verifyObjectDetections)
        {
            if (detectionId1 != object1.primaryDetectionId()) {
                createVerificationRequestIfNeeded(detectionId1, object1.primaryDetectionId());
            }
            if (detectionId2 != object1.primaryDetectionId()) {
                createVerificationRequestIfNeeded(detectionId2, object1.primaryDetectionId());
            }
        } else if (object1.id() != object2.id() && mergedAction.verifyObjectDuplication)
        {
            createVerificationRequestIfNeeded(object1.primaryDetectionId(), object2.primaryDetectionId());
        }
    }
    db::eye::VerifiedDetectionPairMatchGateway{txn}.insertx(result);
    return result;
}

db::eye::VerifiedDetectionMissingOnFrames createDetectionMissingOnFrameRequests(
    pqxx::transaction_base& txn,
    const DetectionStore& detectionStore,
    const db::IdTo<VerificationAction>& objectIdToVerificationAction,
    const db::IdTo<const db::eye::Object&>& detectionIdToObject,
    const db::IdTo<db::TIds>& detectionIdToFrameIds)
{
    db::eye::VerifiedDetectionMissingOnFrames result;

    for (const auto& [detectionId, frameIds]: detectionIdToFrameIds) {
        const auto& object = detectionIdToObject.at(detectionId);
        const auto& verificationAction =
            objectIdToVerificationAction.at(object.id());
        if (!verificationAction.verifyObjectMissingness) {
            continue;
        }
        const auto verificationSource = selectVerificationSource(
            detectionStore.privacyByDetectionId(detectionId).type());
        for (auto frameId: frameIds) {
            result.emplace_back(verificationSource, detectionId, frameId);
        }
    }
    deleteExistingRequests(txn, result);

    db::eye::VerifiedDetectionMissingOnFrameGateway{txn}.insertx(result);
    return result;
}


} // namespace

void VerificationRequestsIndex::set(db::TId id1, db::TId id2, std::optional<bool> value)
{
    detectionPairToResult_[std::make_pair(id1, id2)] = value;
    detectionPairToResult_[std::make_pair(id2, id1)] = std::move(value);
}

bool VerificationRequestsIndex::contains(db::TId id1, db::TId id2) const
{
    return detectionPairToResult_.contains(std::make_pair(id1, id2));
}

const std::optional<bool>& VerificationRequestsIndex::get(db::TId id1, db::TId id2) const
{
    return detectionPairToResult_.at(std::make_pair(id1, id2));
}

std::set<db::TId> VerificationRequestsIndex::referencedDetections() const
{
    std::set<db::TId> result;
    for (const auto& [pair, _] : detectionPairToResult_) {
        result.insert(pair.first);
        result.insert(pair.second);
    }
    return result;
}

std::vector<std::pair<db::TId, db::TId>> VerificationRequestsIndex::referencedPairs() const
{
    std::vector<std::pair<db::TId, db::TId>> result;
    result.reserve(detectionPairToResult_.size());
    for (const auto& [pair, _] : detectionPairToResult_) {
        result.push_back(pair);
    }
    return result;
}


VerificationRequestsIndex loadVerifiedMatches(
    pqxx::transaction_base& txn,
    const db::TIds& detectionIds
)
{
    constexpr size_t BATCH_SIZE = 1000;
    VerificationRequestsIndex result;
    db::eye::VerifiedDetectionPairMatchGateway gtw{txn};
    common::forEachBatch(detectionIds, BATCH_SIZE,
        [&](auto startIt, auto endIt) {
            const db::TIds detectionIds{startIt, endIt};
            const auto verifications = gtw.load(
                (db::eye::table::VerifiedDetectionPairMatch::detectionId1.in(detectionIds) ||
                 db::eye::table::VerifiedDetectionPairMatch::detectionId2.in(detectionIds))
            );

            for (const auto& verification : verifications) {
                result.set(verification.detectionId1(), verification.detectionId2(),
                    verification.approved());
            }
        }
    );

    return result;
}

void createVerificationRequests(
    pqxx::transaction_base& txn,
    const VerificationRules& verificationRules,
    const privacy::GeoIdProvider& geoidProvider,
    const DetectionStore& detectionStore,
    const VerificationRequestsIndex& existingRequests,
    const std::vector<MatchedObjects>& matchedObjectsVec,
    const db::IdTo<db::TIdSet>& clusters,
    const db::eye::Objects& objects,
    const db::IdTo<db::TIds>& missingDetectionIdToFrameIds
)
{
    const auto objectIdToVerificationAction = makeObjectIdToVerificationActionMap(
        verificationRules, geoidProvider, detectionStore, objects);

    const auto detectionIdToObject = makeDetectionIdToObjectMap(clusters, objects);

    const auto matchRequests = createDetectionPairMatchingRequests(
        txn, objectIdToVerificationAction, detectionIdToObject, detectionStore,
        existingRequests, matchedObjectsVec);
    INFO() << "Generated " << matchRequests.size() << " requests for matching detection pairs";

    const auto checkDetectionMissingRequests =
        createDetectionMissingOnFrameRequests(
            txn,
            detectionStore,
            objectIdToVerificationAction,
            detectionIdToObject,
            missingDetectionIdToFrameIds);

    INFO() << "Generated " << checkDetectionMissingRequests.size()
           << " requests for checking detection missingness";
}

} // namespace maps::mrc::eye
