#include "filter_helpers.h"
#include "sql_strings.h"

#include <maps/libs/common/include/exception.h>

#include <algorithm>
#include <memory>
#include <stack>

namespace maps::wiki::revision::filters {

namespace {

const std::string STR_OBJECT_REVISION_MASTER_OBJECT_ID = \
    sql::alias::OBJECT_REVISION + "." + sql::col::MASTER_OBJECT_ID;

const std::string STR_OBJECT_REVISION_SLAVE_OBJECT_ID = \
    sql::alias::OBJECT_REVISION + "." + sql::col::SLAVE_OBJECT_ID;

template <class TargetFilterClass>
TargetFilterClass tryCastFilter(const FilterExpr* expr)
{
    auto res = dynamic_cast<TargetFilterClass>(expr);
    if (res) {
        return res;
    }

    auto proxyRefPtr = dynamic_cast<const ProxyReferenceFilterExpr*>(expr);
    if (proxyRefPtr) {
        return tryCastFilter<TargetFilterClass>(&(proxyRefPtr->filter()));
    }

    auto proxyPtr = dynamic_cast<const ProxyFilterExpr*>(expr);
    if (proxyPtr) {
        return tryCastFilter<TargetFilterClass>(proxyPtr->filter().get());
    }

    return nullptr;
}

db::Partitions evalConcernedTablesForMasterObjectId(
    TableAttrFilterExpr::Operation op,
    const std::vector<DBID>& masterIds)
{
    switch (op) {
    case TableAttrFilterExpr::Operation::In:
        if (std::all_of(masterIds.begin(), masterIds.end(),
                        [](DBID id) { return id > 0; }))
        {
            return db::MASTER_POSSESSORS;
        } else if (std::all_of(masterIds.begin(), masterIds.end(),
                    [](DBID id) { return id == 0; }))
        {
            return db::PARTITIONS_WITHOUT_SLAVES;
        }
        return db::ALL_PARTITIONS;

    case TableAttrFilterExpr::Operation::NotNull:
    case TableAttrFilterExpr::Operation::IsNull:
        throw RuntimeError("Comparison with NULL is prohibited");

    case TableAttrFilterExpr::Operation::IsZero:
        return db::PARTITIONS_WITHOUT_SLAVES;

    case TableAttrFilterExpr::Operation::NotZero:
        return db::MASTER_POSSESSORS;

    case TableAttrFilterExpr::Operation::Equal:
        if (masterIds.at(0) == 0) {
            return db::PARTITIONS_WITHOUT_SLAVES;
        }
        return db::MASTER_POSSESSORS;

    case TableAttrFilterExpr::Operation::NotEqual:
        if (masterIds.at(0) == 0) {
            return db::MASTER_POSSESSORS;
        }
        return db::ALL_PARTITIONS;

    case TableAttrFilterExpr::Operation::Greater:
        return db::MASTER_POSSESSORS;

    case TableAttrFilterExpr::Operation::GreaterEqual:
        if (masterIds.at(0) > 0) {
            return db::MASTER_POSSESSORS;
        }
        return db::ALL_PARTITIONS;

    case TableAttrFilterExpr::Operation::Less:
        if (masterIds.at(0) <= 1) {
            return db::PARTITIONS_WITHOUT_SLAVES;
        }
        return db::ALL_PARTITIONS;

    case TableAttrFilterExpr::Operation::LessEqual:
        if (masterIds.at(0) <= 0) {
            return db::PARTITIONS_WITHOUT_SLAVES;
        }
        return db::ALL_PARTITIONS;

    default:
        return db::ALL_PARTITIONS;
    }
}

db::Partitions evalConcernedTablesForSlaveObjectId(
    TableAttrFilterExpr::Operation op,
    const std::vector<DBID>& slaveIds)
{
    switch (op) {
    case TableAttrFilterExpr::Operation::In:
        if (std::all_of(slaveIds.begin(), slaveIds.end(),
                [](DBID id) { return id > 0; }))
        {
            return {db::Partition::Relation};
        } else if (std::all_of(slaveIds.begin(), slaveIds.end(),
                    [](DBID id) { return id == 0; }))
        {
            return db::PARTITIONS_WITHOUT_SLAVES;
        }
        return db::ALL_PARTITIONS;

    case TableAttrFilterExpr::Operation::NotNull:
    case TableAttrFilterExpr::Operation::IsNull:
        throw RuntimeError("Comparison with NULL is prohibited");

    case TableAttrFilterExpr::Operation::IsZero:
        return db::PARTITIONS_WITHOUT_SLAVES;

    case TableAttrFilterExpr::Operation::NotZero:
        return {db::Partition::Relation};

    case TableAttrFilterExpr::Operation::Equal:
        if (slaveIds.at(0) == 0) {
            return db::PARTITIONS_WITHOUT_SLAVES;
        }
        return {db::Partition::Relation};

    case TableAttrFilterExpr::Operation::NotEqual:
        if (slaveIds.at(0) == 0) {
            return {db::Partition::Relation};
        }
        return db::ALL_PARTITIONS;

    case TableAttrFilterExpr::Operation::Greater:
        return {db::Partition::Relation};

    case TableAttrFilterExpr::Operation::GreaterEqual:
        if (slaveIds.at(0) > 0) {
            return {db::Partition::Relation};
        }
        return db::ALL_PARTITIONS;

    case TableAttrFilterExpr::Operation::Less:
        if (slaveIds.at(0) <= 1) {
            return db::PARTITIONS_WITHOUT_SLAVES;
        }
        return db::ALL_PARTITIONS;

    case TableAttrFilterExpr::Operation::LessEqual:
        if (slaveIds.at(0) <= 0) {
            return db::PARTITIONS_WITHOUT_SLAVES;
        }
        return db::ALL_PARTITIONS;

    default:
        return db::ALL_PARTITIONS;
    }
}


db::Partitions evalConcernedTables(const FilterExpr* expr);

db::Partitions evalConcernedTablesValue(const FilterExpr* expr)
{
    auto tableFilterPtr = tryCastFilter<const TableAttrFilterExpr*>(expr);

    if (!tableFilterPtr) {
        if (tryCastFilter<const False*>(expr)) {
            return {};
        }
        if (tryCastFilter<const GeomFilterExpr*>(expr)) {
            return {db::Partition::RegularWithGeometry};
        }
        auto notFilter = tryCastFilter<const NegativeFilterExpr*>(expr);
        if (!notFilter) {
            return db::ALL_PARTITIONS;
        }

        auto parts = evalConcernedTables(notFilter->expr().get());
        if (parts.size() == db::ALL_PARTITIONS.size()) {
            return parts;
        }

        db::Partitions result;
        std::set_difference(
            db::ALL_PARTITIONS.begin(), db::ALL_PARTITIONS.end(),
            parts.begin(), parts.end(),
            std::inserter(result, result.begin())
        );
        return result;
    }

    if (tableFilterPtr->name() == STR_OBJECT_REVISION_MASTER_OBJECT_ID)
    {
        return evalConcernedTablesForMasterObjectId(tableFilterPtr->op(),
                                                    tableFilterPtr->values());
    }
    if (tableFilterPtr->name() == STR_OBJECT_REVISION_SLAVE_OBJECT_ID)
    {
        return evalConcernedTablesForSlaveObjectId(tableFilterPtr->op(),
                                                   tableFilterPtr->values());
    }

    return db::ALL_PARTITIONS;
}

db::Partitions evalConcernedTables(const FilterExpr* expr)
{
    std::map<const FilterExpr*, db::Partitions> values;
    std::stack<const FilterExpr*> needEval;
    needEval.push(expr);

    while (!needEval.empty()) {
        const FilterExpr* testExpr = needEval.top();
        auto binExprPtr = tryCastFilter<const BinaryFilterExpr*>(testExpr);
        if (!binExprPtr) {
            values[testExpr] = evalConcernedTablesValue(testExpr);
            needEval.pop();
            continue;
        }

        bool needCombine = true;
        const FilterExpr* leftExpr = binExprPtr->leftExpr().get();
        if (!values.count(leftExpr)) {
            needEval.push(leftExpr);
            needCombine = false;
        }
        const FilterExpr* rightExpr = binExprPtr->rightExpr().get();
        if (!values.count(rightExpr)) {
            needEval.push(rightExpr);
            needCombine = false;
        }
        if (!needCombine) {
            continue;
        }

        const db::Partitions& left = values.at(leftExpr);
        const db::Partitions& right = values.at(rightExpr);
        db::Partitions result;
        if (binExprPtr->op() == BinaryFilterExpr::Or) {
            std::set_union(
                left.begin(), left.end(),
                right.begin(), right.end(),
                std::inserter(result, result.begin())
            );
        } else {
            std::set_intersection(
                left.begin(), left.end(),
                right.begin(), right.end(),
                std::inserter(result, result.begin())
            );
        }
        values.erase(leftExpr);
        values.erase(rightExpr);
        values[testExpr] = result;
        needEval.pop();
    }

    auto it = values.find(expr);
    ASSERT(it != values.end());
    return it->second;
}

} //anonymous namespace

/// @brief In the range of pairs [\p beg, \p end) finds the first pair
/// where pair.first equals to \p value
/// @return Found element iterator
template <class PairItr, class FirstType>
PairItr findByFirst(const PairItr beg, const PairItr end, const FirstType& value)
{
    return std::find_if(beg, end, [&](decltype(*beg)& entry) {
        return entry.first == value;
    });
}

ClassifiedFilters classifyFilters(const FilterExpr& expr)
{
    ClassifiedFilters filters;
    std::vector<const FilterExpr*> orParts;

    std::stack<const FilterExpr*> stack;
    stack.push(&expr);
    while (!stack.empty()) {
        auto exprPtr = stack.top();
        stack.pop();

        auto binExprPtr = tryCastFilter<const BinaryFilterExpr*>(exprPtr);
        if (binExprPtr == nullptr || binExprPtr->op() != BinaryFilterExpr::Or) {
            orParts.push_back(exprPtr);
        } else {
            stack.push(binExprPtr->rightExpr().get());
            stack.push(binExprPtr->leftExpr().get());
        }
    }

    for (const auto& part: orParts) {
        db::Partitions concernedTables = evalConcernedTables(part);
        if (concernedTables.empty()) {
            continue;
        }

        auto it = findByFirst(
            filters.begin(),
            filters.end(),
            concernedTables
        );
        if (it == filters.end()) {
            filters.push_back({
                concernedTables,
                std::make_shared<ProxyReferenceFilterExpr>(*part)
            });
        } else {
            it->second = std::make_shared<BinaryFilterExpr>(
                BinaryFilterExpr::Operation::Or,
                it->second,
                std::make_shared<ProxyReferenceFilterExpr>(*part)
            );
        }
    }
    return filters;
}

} // namespace maps::wiki::revision::filters

