#include "data_source.h"
#include "category_revisions_loader.h"
#include "async_results_queue.h"
#include <maps/wikimap/mapspro/libs/validator/common/exception.h>
#include <maps/wikimap/mapspro/libs/validator/common/utils.h>
#include <maps/wikimap/mapspro/libs/validator/content/inmemory_objects_collection-inl.h>

#include <yandex/maps/wiki/validator/common.h>
#include <yandex/maps/wiki/validator/result.h>
#include <yandex/maps/wiki/common/batch.h>
#include <yandex/maps/wiki/common/retry_duration.h>
#include <yandex/maps/wiki/revision/filters.h>
#include <yandex/maps/wiki/revision/revisionsgateway.h>
#include <yandex/maps/wiki/threadutils/threadpool.h>
#include <yandex/maps/wiki/configs/editor/config_holder.h>
#include <yandex/maps/wiki/configs/editor/categories.h>

namespace maps {
namespace wiki {
namespace validator {

namespace rf = revision::filters;

namespace {

const ObjectIdSet EMPTY_OBJECT_ID_SET;

const size_t REVISION_LOAD_BATCH_SIZE = 10000;

struct LoadedBatch
{
    std::vector<ObjectDatum> objectData;
    ObjectIdToCommitId objectCommitIds;
    ObjectIdSet deletedRelatedObjectIds;
};

std::vector<ObjectDatum> loadAllRelations(
        pqxx::transaction_base& txn,
        const DBGateway& dbGateway,
        const Revisions& revisions)
{
    std::vector<TId> objectIds;
    for (const auto& revision : revisions) {
        objectIds.push_back(revision.id().objectId());
    }

    auto mastersMap = dbGateway.mastersOf(objectIds, txn);
    auto slavesMap = dbGateway.slavesOf(objectIds, txn);

    std::vector<ObjectDatum> ret;
    for (const auto& revision : revisions) {
        TId oid = revision.id().objectId();
        ret.push_back(ObjectDatum{revision, mastersMap[oid], slavesMap[oid]});
    }

    return ret;
}

struct CommitIdInfo {
    ObjectIdToCommitId objectIdToCommitId;
    ObjectIdSet deletedRelatedObjectIds;
};

CommitIdInfo loadObjectCommitIds(
        pqxx::transaction_base& txn,
        const DBGateway& dbGateway,
        const std::vector<ObjectDatum>& objectData)
{
    ObjectIdToCommitId objectIdToCommitId;

    ObjectIdSet relatedObjectIds;
    for (const auto& datum : objectData) {
        const RevisionID& id = datum.revision.id();
        objectIdToCommitId[id.objectId()] = id.commitId();

        for (const Relation& relation: datum.masters) {
            relatedObjectIds.insert(relation.other);
        }
        for (const Relation& relation: datum.slaves) {
            relatedObjectIds.insert(relation.other);
        }
    }

    const std::vector<RevisionID> relatedRevisionIds =
        dbGateway.revisionIdsByObjectIds(relatedObjectIds, txn);
    for (const RevisionID& revisionId: relatedRevisionIds) {
        objectIdToCommitId[revisionId.objectId()] = revisionId.commitId();
    }

    ObjectIdSet deletedRelatedObjectIds;
    const RevisionIds revisionIds = dbGateway.revisionIdsByObjectIds(
        relatedObjectIds, rf::ObjRevAttr::isDeleted() && rf::ObjRevAttr::isNotRelation(), txn
    );
    for (const auto& id: revisionIds) {
        deletedRelatedObjectIds.insert(id.objectId());
    }

    return {
        std::move(objectIdToCommitId),
        std::move(deletedRelatedObjectIds)
    };
}

LoadedBatch loadBatch(
        const RevisionIds& revisionIds,
        bool hasGeom,
        DataSource& dataSource)
{
    return common::retryDuration([&] {
        auto txn = dataSource.getTransaction();
        const DBGateway& dbGateway = dataSource.dbGateway();
        const auto revisions = dbGateway.revisionsByRevisionIds(revisionIds, hasGeom, *txn);
        auto objectData = loadAllRelations(*txn, dbGateway, revisions);
        auto commitIdInfo = loadObjectCommitIds(*txn, dbGateway, objectData);

        return LoadedBatch{
            std::move(objectData),
            std::move(commitIdInfo.objectIdToCommitId),
            std::move(commitIdInfo.deletedRelatedObjectIds)
        };
    });
}

CategoryIdToObjectIds distributeObjectIdsByCategory(const ObjectIdSet& objectIds, DBGateway& gateway)
{
    // Protection from zero commit id in the tests
    if (!gateway.commitId()) {
        return {};
    }

    if (objectIds.empty()) {
        return {};
    }

    INFO() << "Start distributing object ids by category";

    CategoryIdToObjectIds result;

    common::applyBatchOp<ObjectIdSet>(
        objectIds,
        REVISION_LOAD_BATCH_SIZE,
        [&](const ObjectIdSet& objectIds) {
            auto revisions = common::retryDuration([&] {
                auto txn = gateway.getTransaction();
                auto revIds = gateway.revisionIdsByObjectIds(objectIds, *txn);

                // has geom
                auto result = gateway.revisionsByRevisionIds(revIds, true, *txn);

                // has no geom
                result.splice(
                    result.end(),
                    gateway.revisionsByRevisionIds(revIds, false, *txn));
                return result;
            });

            for (const auto& rev : revisions) {
                auto categoryId = extractCategoryId(*rev.data().attributes);
                result[categoryId].insert(rev.id().objectId());
            }
        });

    INFO() << "Finish distributing object ids by category";

    for (const auto& [categoryId, objectIds] : result) {
        INFO() << "Selected " << objectIds.size() << " objects of category " << categoryId;
    }

    return result;
}

} // namespace

DataSource::DataSource(
        const ValidatorConfig& validatorConfig,
        CheckCardinality checkCardinality,
        pgpool3::Pool& pgPool,
        DBID branchId,
        DBID commitId,
        AreaOfInterest aoi,
        const ObjectIdSet& objectIds)
    : validatorConfig_(validatorConfig)
    , checkCardinality_(checkCardinality)
    , dbGateway_(pgPool, branchId, commitId)
    , aoi_(std::move(aoi))
    , selectedObjectIdsByCategory_(distributeObjectIdsByCategory(objectIds, dbGateway_))
    , importantRegions_(dbGateway_, aoi_)
    , collections_(collectionsTupleToCollectionsMap(collectionsTuple_))
{
}

std::list<Message>
DataSource::load(const TCategoryId& category, LoaderType loaderType, ThreadPool& dbReadWorkers)
{
    Timer timer;
    INFO() << "loading category " << category << "...";

    AsyncResultsQueue<RevisionIds> revIdsQueue;
    dbReadWorkers.push(
        revIdsQueue.task(
            [&, this]() {
                const auto& revIdsLoader = validatorConfig_.categoryRevisionIdsLoader(category, loaderType);
                return revIdsLoader(*this);
            }));
    RevisionIds revisionIds = revIdsQueue.popSome().front().value();

    INFO() << revisionIds.size() << " revision ids for category " << category
           << " loaded in " << timer.elapsed().count() << "s.";

    ObjectsCollectionBase& collection = collections_.at(category);
    std::list<Message> baseCheckMessages;

    AsyncResultsQueue<LoadedBatch> loadedQueue;
    int batches = 0;
    common::applyBatchOp<RevisionIds>(
        revisionIds,
        REVISION_LOAD_BATCH_SIZE,
        [&](const RevisionIds& revisionIdsBatch) {
            dbReadWorkers.push(
                loadedQueue.task(
                    [&, this, revisionIdsBatch]() {
                        return loadBatch(
                            revisionIdsBatch,
                            categoryGeomType(category) != GeomType::None,
                            *this);
                    }));
            ++batches;
        });

    RevisionIds().swap(revisionIds);

    const configs::editor::Category* configCategory =
        checkCardinality_ == CheckCardinality::Yes
        ? configCategory = &validatorConfig_.editorConfig().categories()[category]
        : nullptr;

    while (batches > 0) {
        for (auto& result: loadedQueue.popSome()) {
            collection.addRevisions(
                std::move(result.value().objectData),
                std::move(result.value().objectCommitIds),
                std::move(result.value().deletedRelatedObjectIds),
                aoi_,
                configCategory,
                importantRegions_,
                &baseCheckMessages);
            --batches;
        }
    }
    ASSERT(batches == 0);

    auto size = collection.finalize();

    INFO() << size << " objects of"
           << " category " << category
           << " loaded in " << timer.elapsed().count() << "s.";

    return baseCheckMessages;
}

void
DataSource::unload(const TCategoryId& category)
{
    auto collection = collections_.find(category);
    REQUIRE(collection != collections_.end(),
            "collection for category: " << category << " not found");
    ObjectsCollectionBase& collectionBase = collection->second;
    collectionBase.unload();
}

const ObjectsCollectionBase&
DataSource::collection(const TCategoryId& category) const
{
    auto collection = collections_.find(category);
    REQUIRE(collection != collections_.end(),
            "collection for category: " << category << " not found");
    return collection->second;
}

const std::set<TCategoryId>&
DataSource::dependencies(const TCategoryId& category, LoaderType loaderType) const
{
    return validatorConfig_.loaderDependencies(category, loaderType);
}

RevisionID DataSource::objectRevisionId(
        TId objectId, const std::vector<TCategoryId>& possibleCategoryIds) const
{
    for (const auto& categoryId : possibleCategoryIds) {
        auto collectionIt = collections_.find(categoryId);
        if (collectionIt != collections_.end()) {
            const ObjectsCollectionBase& coll = collectionIt->second;
            auto revId = coll.revisionId(objectId);
            if (revId) {
                return *revId;
            }
        }
    }
    throw maps::RuntimeError() << "Object id: " << objectId << " not found";
}

const ObjectIdSet& DataSource::selectedObjectIds(const TCategoryId& categoryId) const
{
    auto it = selectedObjectIdsByCategory_.find(categoryId);
    if (it != selectedObjectIdsByCategory_.end()) {
        return it->second;
    }
    return EMPTY_OBJECT_ID_SET;
}

} // namespace validator
} // namespace wiki
} // namespace maps
