#include "parser.h"

#include <yandex/maps/wiki/filters/exception.h>
#include <yandex/maps/wiki/revision/common.h>
#include <yandex/maps/wiki/common/string_utils.h>

#include <cstring>
#include <sstream>
#include <string>

namespace maps {
namespace wiki {
namespace filters {

namespace rf = revision::filters;

namespace {

const std::string ID = "id";

const std::string COMMA = ",";
const std::string LEFT_PAREN = "(";
const std::string RIGHT_PAREN = ")";

const std::string EQUAL = "=";
const std::string IN = "in";
const std::string LIKE = "%";

const std::string NOT_EQUAL = "<>";

const std::string LESS = "<";
const std::string LESS_EQUAL = "<=";
const std::string GREATER_EQUAL = ">=";
const std::string GREATER = ">";

const std::string AND = "and";
const std::string OR = "or";
const std::string NOT = "not";

const std::string SERVICE_ATTR_PREFIX = "s::";
const std::string SERVICE_ATTRS = "service_attrs";
const std::string DOMAIN_ATTRS = "domain_attrs";

enum class LogicOp { And, Or };

struct CompareOp {
    const std::string str;
    const rf::BigintAttrFilterExpr::Operation revisionFilterOp;
};

const std::vector<CompareOp> COMPARE_OPS = {
    { LESS, rf::BigintAttrFilterExpr::Operation::Less },
    { LESS_EQUAL, rf::BigintAttrFilterExpr::Operation::LessEqual },
    { GREATER, rf::BigintAttrFilterExpr::Operation::Greater },
    { GREATER_EQUAL, rf::BigintAttrFilterExpr::Operation::GreaterEqual },
};

bool isNullValue(const std::string& value)
{
    return value.empty();
}

bool isNullValue(const std::vector<std::string>& values)
{
    return values.size() == 1 && values.front().empty();
}

bool isServiceAttr(const std::string& name)
{
    return name.starts_with(SERVICE_ATTR_PREFIX);
}

const std::string&
viewColName(const std::string& name)
{
    return isServiceAttr(name)
        ? SERVICE_ATTRS
        : DOMAIN_ATTRS;
}

std::string viewColKeyName(const std::string& name)
{
    return isServiceAttr(name)
        ? name.substr(SERVICE_ATTR_PREFIX.length())
        : name;
}

class ObjectIdClauseNode : public Node
{
public:
    ObjectIdClauseNode(const std::vector<std::string>& values)
    {
        values_.reserve(values.size());
        for (const auto& value : values) {
            if (isNullValue(value)) {
                throw ParseError() << "unexpected empty string";
            }
            values_.push_back(std::stoul(value));
        }
    }

    std::string viewFilterClause(pqxx::transaction_base& /*txn*/, const std::string& viewAlias) const override
    {
        if (values_.size() == 1) {
            return "(" + viewAlias + "id = " + std::to_string(values_.front()) + ")";
        }
        return "(" + viewAlias + "id IN (" + common::join(values_, ',') + "))";
    }

    rf::ProxyFilterExpr revisionFilter() const override
    {
        return rf::ObjRevAttr::objectId().in(values_);
    }

    bool hasServiceAttrsClause() const override
    {
        return false;
    }

private:
    std::vector<revision::DBID> values_;
};

class EqualsClauseNode : public Node
{
public:
    EqualsClauseNode(std::string name, std::vector<std::string> values)
        : name_(std::move(name))
        , values_(std::move(values))
    { }

    std::string viewFilterClause(pqxx::transaction_base& txn, const std::string& viewAlias) const override
    {
        const auto& col = viewAlias + viewColName(name_);
        const auto& key = txn.quote(viewColKeyName(name_));
        if (isNullValue(values_)) {
            return "(NOT " + col + " ? " + key + ")";
        }

        if (values_.size() == 1) {
            return "(" + col + " ? " + key +
                " AND " + col + "->" + key + " = " + txn.quote(values_.front()) + ")";
        }

        std::stringstream ss;
        ss << "(" + col + " ? " << key
            << " AND " + col + "->" << key << " IN (" <<
                common::join(
                    values_,
                    [&](const std::string& value) { return txn.quote(value); },
                    ',')
            << "))";
        return ss.str();
    }

    bool hasServiceAttrsClause() const override
    {
        return isServiceAttr(name_);
    }

    rf::ProxyFilterExpr revisionFilter() const override
    {
        REQUIRE(!hasServiceAttrsClause(),
            "Shouldn't be called for service attr filter.");
        if (isNullValue(values_)) {
            return !rf::Attr(name_).defined();
        }
        return rf::Attr(name_).defined() && rf::Attr(name_).in(values_);
    }

private:
    std::string name_;
    std::vector<std::string> values_;
};

class LikeClauseNode : public Node
{
public:
    LikeClauseNode(std::string name, std::string value)
        : name_(std::move(name))
        , value_(std::move(value))
    { }

    std::string viewFilterClause(pqxx::transaction_base& txn, const std::string& viewAlias) const override
    {
        const auto& col = viewAlias + viewColName(name_);
        const auto& key = txn.quote(viewColKeyName(name_));
        return "(" + col + " ? " + key +
            " AND " + col + "->" + key + " LIKE " + txn.quote("%" + value_ + "%") + ")";
    }

    bool hasServiceAttrsClause() const override
    {
        return isServiceAttr(name_);
    }

    rf::ProxyFilterExpr revisionFilter() const override
    {
        REQUIRE(!hasServiceAttrsClause(),
            "Shouldn't be called for service attr filter.");
        return rf::Attr(name_).defined() && rf::Attr(name_).like("%" + value_ + "%");
    }

private:
    std::string name_;
    std::string value_;
};


class CompareClauseNode : public Node
{
public:
    CompareClauseNode(std::string name, CompareOp op, int64_t value)
        : name_(std::move(name))
        , op_(std::move(op))
        , value_(value)
    { }

    std::string viewFilterClause(pqxx::transaction_base& txn, const std::string& viewAlias) const override
    {
        std::stringstream ss;
        const auto& col = viewAlias + viewColName(name_);
        const auto& key = txn.quote(viewColKeyName(name_));
        ss << "(" + col + " ? " << key
           << " AND (" + col + "->" << key << ")::bigint "
           << op_.str << " " << value_
           << ")";
        return ss.str();
    }

    bool hasServiceAttrsClause() const override
    {
        return isServiceAttr(name_);
    }

    rf::ProxyFilterExpr revisionFilter() const override
    {
        REQUIRE(!hasServiceAttrsClause(),
            "Shouldn't be called for service attr filter.");
        return rf::Attr(name_).defined()
            && rf::BigintAttrFilterExpr(op_.revisionFilterOp, name_, value_);
    }

private:
    std::string name_;
    CompareOp op_;
    int64_t value_;
};

class LogicOpNode : public Node
{
public:
    LogicOpNode(LogicOp op, std::vector<NodePtr> children)
        : op_(op), children_(std::move(children))
    { }

    std::string viewFilterClause(pqxx::transaction_base& txn, const std::string& viewAlias) const override
    {
        return "(" +
            common::join(
                children_,
                [&](const NodePtr& child) { return child->viewFilterClause(txn, viewAlias); },
                op_ == LogicOp::And ? " AND " : " OR ") +
            ")";
    }

    bool hasServiceAttrsClause() const override
    {
        return std::any_of(children_.begin(),  children_.end(),
            [](const auto& node) {
                return node->hasServiceAttrsClause();
            });
    }

    rf::ProxyFilterExpr revisionFilter() const override
    {
        REQUIRE(!hasServiceAttrsClause(),
            "Shouldn't be called for service attr filter.");
        rf::ProxyFilterExpr ret = rf::False();
        bool first = true;
        for (const auto& child : children_) {
            if (first) {
                first = false;
                ret = child->revisionFilter();
            } else if (op_ == LogicOp::And) {
                ret &= child->revisionFilter();
            } else {
                ret |= child->revisionFilter();
            }
        }
        return ret;
    }

private:
    LogicOp op_;
    std::vector<NodePtr> children_;
};

class NegationNode : public Node
{
public:
    NegationNode(NodePtr child)
        : child_(std::move(child))
    { }

    std::string viewFilterClause(pqxx::transaction_base& txn, const std::string& viewAlias) const override
    {
        return "(NOT " + child_->viewFilterClause(txn, viewAlias) + ")";
    }

    bool hasServiceAttrsClause() const override
    {
        return child_->hasServiceAttrsClause();
    }

    rf::ProxyFilterExpr revisionFilter() const override
    {
        return !child_->revisionFilter();
    }

private:
    NodePtr child_;
};

} // namespace

Parser::Parser(const std::string& text)
    : tokens_(tokenize(text))
    , curToken_(tokens_.begin())
{
    ASSERT(!tokens_.empty());
    ASSERT(tokens_.back().type() == Token::Type::End);
}

NodePtr Parser::parse()
{
    NodePtr ret = parseOrExpr();
    if (!match(Token::Type::End)) {
        throw ParseError()
            << "not fully parsed at position: "
            << curToken_->startPos();
    }
    return ret;
}

NodePtr Parser::parseOrExpr()
{
    std::vector<NodePtr> children;
    children.push_back(parseAndExpr());

    while (match(Token::Type::Operator, OR)) {
        advance();
        children.push_back(parseAndExpr());
    }

    return children.size() == 1
        ? std::move(children[0])
        : NodePtr(new LogicOpNode(LogicOp::Or, std::move(children)));
}

NodePtr Parser::parseAndExpr()
{
    std::vector<NodePtr> children;
    children.push_back(parseNotExpr());

    while (match(Token::Type::Operator, AND)) {
        advance();
        children.push_back(parseNotExpr());
    }

    return children.size() == 1
        ? std::move(children[0])
        : NodePtr(new LogicOpNode(LogicOp::And, std::move(children)));
}

NodePtr Parser::parseNotExpr()
{
    if (match(Token::Type::Operator, NOT)) {
        advance();
        return NodePtr(new NegationNode(parseTerm()));
    }
    return parseTerm();
}

NodePtr Parser::parseTerm()
{
    if (match(Token::Type::Punctuation, LEFT_PAREN)) {
        advance();
        NodePtr ret = parseOrExpr();
        consume(Token::Type::Punctuation, RIGHT_PAREN);
        return ret;
    }
    return parseClause();
}

NodePtr Parser::parseClause()
{
    expect(Token::Type::Identifier);
    auto attrId = curToken_->str();
    advance();

    auto parseEqualsValues = [this](boost::optional<Token::Type> expectedValueType)
    {
        std::vector<std::string> result;
        if (match(Token::Type::Operator, EQUAL)) {
            advance();
            result.push_back(parseVal(expectedValueType));
        } else if (match(Token::Type::Operator, IN)) {
            advance();
            result = parseValues(expectedValueType);
        } else {
            throwUnexpected();
        }
        return result;
    };

    if (attrId == ID) {
        return NodePtr(
                new ObjectIdClauseNode(parseEqualsValues(Token::Type::IntLiteral)));
    }

    expect(Token::Type::Operator);

    auto compareOp = std::find_if(
            COMPARE_OPS.begin(), COMPARE_OPS.end(),
            [this](const CompareOp& op) { return curToken_->equals(op.str); });
    if (compareOp != COMPARE_OPS.end()) {
        advance();
        int64_t value;
        try {
            value = std::stoll(parseVal(Token::Type::IntLiteral));
        } catch (const std::out_of_range& e) {
            throw ParseError() << "int out of range: " << curToken_->str();
        }

        return NodePtr(
                new CompareClauseNode(std::move(attrId), *compareOp, value));
    }

    if (match(Token::Type::Operator, LIKE)) {
        advance();
        return NodePtr(
                new LikeClauseNode(std::move(attrId), parseVal(Token::Type::StringLiteral)));
    }

    if (match(Token::Type::Operator, NOT_EQUAL)) {
        advance();
        return NodePtr(
            new NegationNode(
                NodePtr(
                    new EqualsClauseNode(std::move(attrId), {parseVal(boost::none)}))));
    } else {
        return NodePtr(
                new EqualsClauseNode(std::move(attrId), parseEqualsValues(boost::none)));
    }
}

std::vector<std::string> Parser::parseValues(
        boost::optional<Token::Type> expectedValueType)
{
    std::vector<std::string> ret;

    consume(Token::Type::Punctuation, LEFT_PAREN);
    if (!expectedValueType) {
        expectedValueType = curToken_->type();
    }
    ret.push_back(parseVal(expectedValueType));
    if (isNullValue(ret.back())) {
        throwUnexpected();
    }

    while (!match(Token::Type::Punctuation, RIGHT_PAREN)) {
        consume(Token::Type::Punctuation, COMMA);
        ret.push_back(parseVal(expectedValueType));
        if (isNullValue(ret.back())) {
            throwUnexpected();
        }
    }

    advance();

    return ret;
}

std::string Parser::parseVal(boost::optional<Token::Type> expectedValueType)
{
    if (expectedValueType) {
        expect(*expectedValueType);
    }

    std::string val;
    if (match(Token::Type::IntLiteral)) {
        val = curToken_->str();
    } else if (match(Token::Type::StringLiteral)) {
        auto withQuotes = curToken_->str();
        ASSERT(withQuotes.length() >= 2);
        val = withQuotes.substr(1, withQuotes.length() - 2);
    } else {
        throwUnexpected();
    }

    advance();
    return val;
}

bool Parser::match(Token::Type type)
{
    return type == curToken_->type();
}

bool Parser::match(Token::Type type, const std::string& str)
{
    return match(type) && curToken_->equals(str);
}

void Parser::throwUnexpected()
{
    if (match(Token::Type::End)) {
        throw ParseError() << "unexpected end of input";
    }

    throw ParseError()
        << "unexpected " << curToken_->type()
        << " at position: " << curToken_->startPos();
}

void Parser::advance()
{
    if (match(Token::Type::End)) {
        throwUnexpected();
    }
    ++curToken_;
}

void Parser::consume(Token::Type type, const std::string& str)
{
    expect(type, str);
    advance();
}

} // namespace filters
} // namespace wiki
} // namespace maps
