#include <yandex/maps/wiki/diffalert/revision/aoi_diff_loader.h>
#include "aoi_diff_context_impl.h"
#include "extractors.h"
#include "helpers.h"
#include "split_finder.h"

#include <maps/libs/geolib/include/common.h>
#include <yandex/maps/wiki/common/batch.h>

namespace maps {
namespace wiki {
namespace diffalert {

namespace rev = revision;
namespace rf = rev::filters;

class AoiDiffLoader::Impl
{
public:
    Impl(
            EditorConfig config,
            Geom aoi,
            const revision::Snapshot& fromSnapshot,
            const revision::Snapshot& toSnapshot,
            revision::DBID fromBranchId,
            revision::DBID toBranchId,
            pqxx::transaction_base& txn)
        : config_(std::move(config))
        , aoi_(std::move(aoi))
        , fromSnapshot_(fromSnapshot)
        , toSnapshot_(toSnapshot)
        , fromBranchId_(fromBranchId)
        , toBranchId_(toBranchId)
        , txn_(txn)
        , defaultMinLinearObjectIntersectionRatio_(1.0 - geolib3::EPS)
    {
        ASSERT(!aoi_.isNull());
    }

    const EditorConfig& config() const { return config_; }

    double defaultMinLinearObjectIntersectionRatio() const
    {
        return defaultMinLinearObjectIntersectionRatio_;
    }
    void setDefaultMinLinearObjectIntersectionRatio(double ratio)
    {
        defaultMinLinearObjectIntersectionRatio_ = ratio;
    }

    double minLinearObjectIntersectionRatio(const std::string& categoryId) const
    {
        auto it = minLinearObjectIntersectionRatio_.find(categoryId);
        return it != minLinearObjectIntersectionRatio_.end()
            ? it->second
            : defaultMinLinearObjectIntersectionRatio_;
    }
    void setMinLinearObjectIntersectionRatio(const std::string& categoryId, double ratio)
    {
        minLinearObjectIntersectionRatio_[categoryId] = ratio;
    }

    std::map<revision::DBID, AoiDiffData> loadDiffData(
            const std::set<std::string>& categoryIds,
            SplitPolicy splitPolicy) const;

private:
    enum { BATCH_SIZE = 1000 };
    enum SnapshotTime { From, To };

    std::map<revision::DBID, FullyLoadedObjectData> loadObjects(
            SnapshotTime snapshotTime,
            const std::set<std::string>& categoryIds,
            const std::map<rev::DBID, AoiDiffData>& geomPartsDiffs) const;

    double intersectionLength(const Geom& lineString) const
    {
        Geom intersection(aoi_->intersection(lineString.geosGeometryPtr()));
        return intersection->getLength();
    }

    void removeComplexObjectsOutside(
            std::map<rev::DBID, FullyLoadedObjectData>& objects,
            SnapshotTime snapshotTime,
            const std::map<revision::DBID, AoiDiffData>& geomPartsDiffs) const;


private:
    EditorConfig config_;
    Geom aoi_;
    revision::Snapshot fromSnapshot_;
    revision::Snapshot toSnapshot_;
    revision::DBID fromBranchId_;
    revision::DBID toBranchId_;
    pqxx::transaction_base& txn_;
    double defaultMinLinearObjectIntersectionRatio_;
    std::map<std::string, double> minLinearObjectIntersectionRatio_;
};

namespace {

void calcRelationDiffs(
    AoiDiffData& diff,
    const std::map<rev::DBID, AoiDiffData>& geomPartsDiffs,
    const EditorConfig& config)
{
    const Relations EMPTY_RELATIONS;

    const auto& oldMasterRelations =
        diff.oldObject ? diff.oldObject->masterRelations : EMPTY_RELATIONS;
    const auto& oldSlaveRelations =
        diff.oldObject ? diff.oldObject->slaveRelations : EMPTY_RELATIONS;
    const auto& oldTableAttrs = diff.oldObject ? diff.oldObject->tableAttrs : EMPTY_RELATIONS;

    const auto& newMasterRelations =
        diff.newObject ? diff.newObject->masterRelations : EMPTY_RELATIONS;
    const auto& newSlaveRelations =
        diff.newObject ? diff.newObject->slaveRelations : EMPTY_RELATIONS;
    const auto& newTableAttrs = diff.newObject ? diff.newObject->tableAttrs : EMPTY_RELATIONS;

    std::set_difference(
            oldMasterRelations.begin(), oldMasterRelations.end(),
            newMasterRelations.begin(), newMasterRelations.end(),
            std::inserter(diff.relationsDeleted, diff.relationsDeleted.end()));
    std::set_difference(
            newMasterRelations.begin(), newMasterRelations.end(),
            oldMasterRelations.begin(), oldMasterRelations.end(),
            std::inserter(diff.relationsAdded, diff.relationsAdded.end()));

    std::set_difference(
            oldSlaveRelations.begin(), oldSlaveRelations.end(),
            newSlaveRelations.begin(), newSlaveRelations.end(),
            std::inserter(diff.relationsDeleted, diff.relationsDeleted.end()));
    std::set_difference(
            newSlaveRelations.begin(), newSlaveRelations.end(),
            oldSlaveRelations.begin(), oldSlaveRelations.end(),
            std::inserter(diff.relationsAdded, diff.relationsAdded.end()));

    std::set_difference(
            oldTableAttrs.begin(), oldTableAttrs.end(),
            newTableAttrs.begin(), newTableAttrs.end(),
            std::inserter(diff.tableAttrsDeleted, diff.tableAttrsDeleted.end()));
    std::set_difference(
            newTableAttrs.begin(), newTableAttrs.end(),
            oldTableAttrs.begin(), oldTableAttrs.end(),
            std::inserter(diff.tableAttrsAdded, diff.tableAttrsAdded.end()));

    /// Filter out geom parts that lay outside the aoi
    if (geomPartsDiffs.empty()) {
        return;
    }

    const auto& categoryId = diff.oldObject
        ? diff.oldObject->categoryId
        : diff.newObject->categoryId;

    const auto& geomPartRoles = config.geomPartRoles(categoryId);

    auto clearOutsideGeomParts = [&](Relations& relations) {
        for (auto it = relations.begin(); it != relations.end();) {
            auto& rel = *it;
            if (rel.masterId == diff.objectId
                    && geomPartRoles.count(rel.role)
                    && !geomPartsDiffs.count(rel.slaveId)) {
                it = relations.erase(it);
            } else {
                ++it;
            }
        }
    };

    clearOutsideGeomParts(diff.relationsAdded);
    clearOutsideGeomParts(diff.relationsDeleted);
}

const double GEOMETRY_COMPARE_TOLERANCE = 0.001; // mercator meters

bool isGeomChanged(
        const AoiDiffData& diff,
        const std::map<rev::DBID, AoiDiffData>& geomPartsDiffs,
        const EditorConfig& config)
{
    const auto& oldObj = diff.oldObject;
    const auto& newObj = diff.newObject;

    if (!oldObj || !newObj) {
        return true;
    }
    if (!oldObj->geom.isNull() && !newObj->geom.isNull()) {
        if (diff.splitStatus != SplitStatus::None) {
            return false;
        }
        return !oldObj->geom.equal(newObj->geom, GEOMETRY_COMPARE_TOLERANCE);
    }

    const auto& geomPartRoles = config.geomPartRoles(oldObj->categoryId);

    for (const auto& rel: diff.relationsAdded) {
        if (rel.masterId == diff.objectId
                && geomPartRoles.count(rel.role)
                && geomPartsDiffs.count(rel.slaveId)) {
            return true;
        }
    }
    for (const auto& rel: diff.relationsDeleted) {
        if (rel.masterId == diff.objectId
                && geomPartRoles.count(rel.role)
                && geomPartsDiffs.count(rel.slaveId)) {
            return true;
        }
    }

    /// At this point we know that the set of geom-part objects inside AOI did not change.
    /// So now we test if the geometry of these parts has changed.
    for (const auto& rel: oldObj->slaveRelations) {
        if (geomPartRoles.count(rel.role)) {
            const auto& it = geomPartsDiffs.find(rel.slaveId);
            if (it != geomPartsDiffs.end() && it->second.geomChanged) {
                return true;
            }
        }
    }

    return false;
}

} // namespace

std::map<revision::DBID, AoiDiffData> AoiDiffLoader::Impl::loadDiffData(
        const std::set<std::string>& categoryIds,
        SplitPolicy splitPolicy) const
{
    if (categoryIds.empty()) {
        return {};
    }

    std::set<std::string> geomPartCats;
    for (const auto& cat: categoryIds) {
        auto geomParts = config_.geomPartCategories(cat);
        if (!geomParts.empty()) {
            geomPartCats.insert(geomParts.begin(), geomParts.end());
        }
    }
    auto geomPartsDiffs = loadDiffData(geomPartCats, splitPolicy);

    auto fromObjects = loadObjects(SnapshotTime::From, categoryIds, geomPartsDiffs);
    auto toObjects = loadObjects(SnapshotTime::To, categoryIds, geomPartsDiffs);

    std::map<revision::DBID, AoiDiffData> result;
    for (auto& pair: fromObjects) {
        result[pair.first].oldObject = std::move(pair.second);
        result[pair.first].objectId = pair.first;
    }
    for (auto& pair: toObjects) {
        result[pair.first].newObject = std::move(pair.second);
        result[pair.first].objectId = pair.first;
    }

    if (splitPolicy == SplitPolicy::Check) {
        ASSERT(fromBranchId_ == toBranchId_);
        SplitFinder finder(txn_, fromBranchId_, fromSnapshot_.maxCommitId(), toSnapshot_.maxCommitId(), result);
        finder.run();
    }

    for (auto& pair: result) {
        auto& diff = pair.second;

        calcRelationDiffs(diff, geomPartsDiffs, config_);

        diff.categoryChanged =
            (diff.oldObject && diff.newObject)
            ? (diff.oldObject->categoryId != diff.newObject->categoryId)
            : true;
        diff.attrsChanged =
            (diff.oldObject && diff.newObject)
            ? !areAttributesEqual(
                    diff.oldObject->attrs, diff.newObject->attrs,
                    diff.oldObject->categoryId, diff.newObject->categoryId)
            : true;
        diff.geomChanged = isGeomChanged(diff, geomPartsDiffs, config_);
    }

    return result;
}

std::map<rev::DBID, FullyLoadedObjectData> AoiDiffLoader::Impl::loadObjects(
        SnapshotTime snapshotTime,
        const std::set<std::string>& categoryIds,
        const std::map<rev::DBID, AoiDiffData>& geomPartsDiffs) const
{
    const auto& snapshot = snapshotTime == SnapshotTime::From ? fromSnapshot_ : toSnapshot_;

    std::map<rev::DBID, FullyLoadedObjectData> objects;

    auto loadObjectFromRevision = [&](const auto& rev) {
        if (rev.data().deleted) {
            return;
        }
        ASSERT(rev.data().attributes);

        auto categoryId = extractCategoryId(*rev.data().attributes);
        if (!categoryIds.count(categoryId)) {
            return;
        }

        Geom geom;
        if (rev.data().geometry) {
            geom = Geom(*rev.data().geometry);
            if (isLinear(geom)) {
                if (intersectionLength(geom) < minLinearObjectIntersectionRatio(categoryId) * geom->getLength()) {
                    return;
                }
            } else if (!aoi_->covers(geom.geosGeometryPtr())) {
                return;
            }
        }
        auto oid = rev.id().objectId();
        auto isFaceElement = config_.isFaceElementCategory(categoryId);
        auto isFaceJunction = config_.isFaceJunctionCategory(categoryId);
        objects.emplace(
            oid,
            FullyLoadedObjectData{
                oid,
                std::move(categoryId),
                *rev.data().attributes,
                std::move(geom),
                {}, {}, {},
                isFaceElement,
                isFaceJunction});
    };

    std::set<std::string> simpleCatAttrs;
    for (const auto& catId: categoryIds) {
        if (config_.geomPartCategories(catId).empty()) {
            simpleCatAttrs.insert("cat:" + catId);
        }
    }
    if (!simpleCatAttrs.empty()) {
        const auto* aoiBbox = aoi_->getEnvelopeInternal();
        ASSERT(aoiBbox);
        auto filter = rf::Attr::definedAny(simpleCatAttrs)
            && rf::ObjRevAttr::isNotDeleted()
            && rf::Geom::intersects(
                    aoiBbox->getMinX(), aoiBbox->getMinY(), aoiBbox->getMaxX(), aoiBbox->getMaxY());

        auto revisionIds = snapshot.revisionIdsByFilter(filter);
        common::applyBatchOp<decltype(revisionIds)>(
            revisionIds,
            BATCH_SIZE,
            [&](const auto& ids) {
                for (const auto& rev : snapshot.reader().loadRevisions(ids, rf::Geom::defined())) {
                    loadObjectFromRevision(rev);
                }
            }
        );
    }

    {
        std::set<rev::DBID> geomPartMasters;
        for (const auto& [_, diff]: geomPartsDiffs) {
            const auto& obj = snapshotTime == SnapshotTime::From ? diff.oldObject : diff.newObject;
            if (obj) {
                for (const auto& rel: obj->masterRelations) {
                    if (config_.geomPartMasterRoles(obj->categoryId).count(rel.role)) {
                        geomPartMasters.insert(rel.masterId);
                    }
                }
            }
        }
        if (!geomPartMasters.empty()) {
            common::applyBatchOp<std::set<rev::DBID>>(
                geomPartMasters,
                BATCH_SIZE,
                [&](const auto& ids) {
                    auto revisionsById = snapshot.objectRevisions(ids);
                    for (const auto& [key, rev]: revisionsById) {
                        loadObjectFromRevision(rev);
                    }
                }
            );
        }
    }

    auto getObjectIds = [&] {
        std::vector<rev::DBID> result;
        result.reserve(objects.size());
        for (const auto& [key, _]: objects) {
            result.emplace_back(key);
        }
        return result;
    };

    common::applyBatchOp<std::vector<rev::DBID>>(
        getObjectIds(),
        BATCH_SIZE,
        [&](const auto& ids) {
            for (const auto& rev: snapshot.loadSlaveRelations(ids)) {
                auto rel = extractRelation(rev);

                auto objectIt = objects.find(rel.masterId);
                REQUIRE(objectIt != objects.end(),
                    "loaded relation id " << rev.id() << " with unexpected master id " << rel.masterId);
                auto& object = objectIt->second;

                if (config_.tableAttrRoles(object.categoryId).count(rel.role)) {
                    object.tableAttrs.insert(std::move(rel));
                } else {
                    object.slaveRelations.insert(std::move(rel));
                }
            }
        }
    );

    removeComplexObjectsOutside(objects, snapshotTime, geomPartsDiffs);

    common::applyBatchOp<std::vector<rev::DBID>>(
        getObjectIds(),
        BATCH_SIZE,
        [&](const auto& ids) {
            for (const auto& rev: snapshot.loadMasterRelations(ids)) {
                auto rel = extractRelation(rev);

                auto objectIt = objects.find(rel.slaveId);
                REQUIRE(objectIt != objects.end(),
                    "loaded relation id " << rev.id() << " with unexpected slave id " << rel.slaveId);
                auto& object = objectIt->second;

                object.masterRelations.insert(std::move(rel));
            }
        }
    );
    return objects;
}

void AoiDiffLoader::Impl::removeComplexObjectsOutside(
        std::map<rev::DBID, FullyLoadedObjectData>& objects,
        SnapshotTime snapshotTime,
        const std::map<revision::DBID, AoiDiffData>& geomPartsDiffs) const
{
    const auto& snapshot = snapshotTime == SnapshotTime::From ? fromSnapshot_ : toSnapshot_;

    std::set<rev::DBID> objectsWithLinearParts;
    std::set<rev::DBID> partsOutsideToLoad;
    for (auto objectIt = objects.begin(); objectIt != objects.end(); ) {
        const auto& object = objectIt->second;

        bool hasLinearParts = false;
        std::set<rev::DBID> partsOutside;

        for (const auto& rel: object.slaveRelations) {
            if (!config_.geomPartRoles(object.categoryId).count(rel.role)) {
                continue;
            }

            auto partIt = geomPartsDiffs.find(rel.slaveId);
            if (partIt != geomPartsDiffs.end()) {
                const auto& part = SnapshotTime::From
                    ? partIt->second.oldObject
                    : partIt->second.newObject;
                if (part && isLinear(part->geom)) {
                    hasLinearParts = true;
                }
            } else {
                partsOutside.insert(rel.slaveId);
            }
        }

        if (hasLinearParts) {
            objectsWithLinearParts.insert(object.id);
        }

        if (!hasLinearParts || config_.isFaceCategory(object.categoryId)) {
            if (!partsOutside.empty()) {
                objectIt = objects.erase(objectIt);
            } else {
                ++objectIt;
            }
        } else {
            partsOutsideToLoad.insert(partsOutside.begin(), partsOutside.end());
            ++objectIt;
        }
    }

    std::map<rev::DBID, Geom> geomsOfPartsOutside;
    common::applyBatchOp<std::set<rev::DBID>>(
        partsOutsideToLoad,
        BATCH_SIZE,
        [&](const auto& ids) {
            for (const auto& [id, rev]: snapshot.objectRevisions(ids)) {
                if (rev.data().geometry) {
                    geomsOfPartsOutside.emplace(
                        id,
                        Geom(*rev.data().geometry));
                }
            }
        }
    );
    std::set<rev::DBID>().swap(partsOutsideToLoad);

    for (auto id: objectsWithLinearParts) {
        auto objectIt = objects.find(id);
        if (objectIt == objects.end()) {
            continue;
        }
        const auto& object = objectIt->second;

        bool hasNotLinearPartsOutside = false;
        double totalLength = 0.0;
        double insideLength = 0.0;

        for (const auto& rel: object.slaveRelations) {
            if (!config_.geomPartRoles(object.categoryId).count(rel.role)) {
                continue;
            }

            auto partIt = geomPartsDiffs.find(rel.slaveId);
            if (partIt != geomPartsDiffs.end()) {
                const auto& part = SnapshotTime::From
                    ? partIt->second.oldObject
                    : partIt->second.newObject;
                if (part && isLinear(part->geom)) {
                    totalLength += part->geom->getLength();
                    insideLength += intersectionLength(part->geom);
                }
            } else {
                auto outsideIt = geomsOfPartsOutside.find(rel.slaveId);
                if (outsideIt == geomsOfPartsOutside.end()) {
                    hasNotLinearPartsOutside = true;
                    break;
                }
                const auto& partGeom = outsideIt->second;
                if (!isLinear(partGeom)) {
                    hasNotLinearPartsOutside = true;
                    break;
                }

                totalLength += partGeom->getLength();
                insideLength += intersectionLength(partGeom);
            }
        }

        if (hasNotLinearPartsOutside
                || (config_.isFaceCategory(object.categoryId)
                    && insideLength < totalLength - geolib3::EPS)
                || insideLength < minLinearObjectIntersectionRatio(object.categoryId) * totalLength) {
            objects.erase(objectIt);
        }
    }
}

MOVABLE_PIMPL_DEFINITIONS(AoiDiffLoader)

AoiDiffLoader::AoiDiffLoader(
        EditorConfig config,
        Geom aoi,
        const revision::Snapshot& fromSnapshot,
        const revision::Snapshot& toSnapshot,
        revision::DBID fromBranchId,
        revision::DBID toBranchId,
        pqxx::transaction_base& txn)
    : impl_(new Impl(
        std::move(config),
        std::move(aoi),
        fromSnapshot,
        toSnapshot,
        fromBranchId,
        toBranchId,
        txn))
{}

std::vector<AoiDiffContext>
AoiDiffLoader::loadDiffContexts(const std::set<std::string>& categoryIds, SplitPolicy splitPolicy) const
{
    auto id2DiffData = impl_->loadDiffData(categoryIds, splitPolicy);
    std::vector<AoiDiffContext> result;
    for (auto& kv: id2DiffData) {
        auto& diffData = kv.second;
        auto diffContext = PImplFactory::create<AoiDiffContext>(std::move(diffData));
        if (diffContext.changed()) {
            result.push_back(std::move(diffContext));
        }
    }
    return result;
}

const EditorConfig& AoiDiffLoader::config() const
{
    return impl_->config();
}

double AoiDiffLoader::defaultMinLinearObjectIntersectionRatio() const
{
    return impl_->defaultMinLinearObjectIntersectionRatio();
}

void AoiDiffLoader::setDefaultMinLinearObjectIntersectionRatio(double ratio)
{
    impl_->setDefaultMinLinearObjectIntersectionRatio(ratio);
}

double AoiDiffLoader::minLinearObjectIntersectionRatio(const std::string& categoryId) const
{
    return impl_->minLinearObjectIntersectionRatio(categoryId);
}

void AoiDiffLoader::setMinLinearObjectIntersectionRatio(const std::string& categoryId, double ratio)
{
    impl_->setMinLinearObjectIntersectionRatio(categoryId, ratio);
}

} // namespace diffalert
} // namespace wiki
} // namespace maps
