import logging

from lexer import lex
from expression import Dnf, Clause, Literal


def is_terminal(symbol):
    return symbol[0].upper() == symbol[0]


def _trace(msg, context, stream):
    logging.getLogger('parser').debug('%s %s from %s', msg, context, stream)


def _parse_nonterminal(symbol, stream):
    _trace('try nonterminal', symbol, stream)

    for rule in rules[symbol]:
        result = _apply_rule(rule, stream)
        if result:
            _trace('nonterminal', symbol, stream)
            return result
    return None, None


def _parse_terminal(symbol, stream):
    _trace('try terminal', symbol, stream)

    if stream and symbol == stream[0].type:
        _trace('terminal', stream[0], stream)
        return [stream[0]], stream[1:]

    return None, None


def _apply_rule(rule, stream):
    _trace('try rule', rule, stream)

    rule, user_fn = rule

    full_result = []
    for symbol in rule:
        if is_terminal(symbol):
            result, newstream = _parse_terminal(symbol, stream)
        else:
            result, newstream = _parse_nonterminal(symbol, stream)

        if not result:
            return None

        full_result += result
        stream = newstream

    _trace('applied rule', rule, stream)
    return user_fn(full_result), stream


def default(p):
    return p


def term_prefix(p):
    r = p[1]
    r.set_prefix(p[0].value)
    return [r]


def term_name(p):
    literal = Literal(p[0].value)
    clause = Clause({literal})
    return [Dnf({clause})]


rules = {
    's': [
        (('expr', '$END'), lambda p: [p[0]]),
    ],
    'expr': [
        (('factor', 'MINUS', 'expr'), lambda p: [p[0] - p[2]]),
        (('factor', 'expr'), lambda p: [p[0] + p[1]]),
        (('factor',), lambda p: [p[0]]),
    ],
    'factor': [
        (('term', 'factor_tail'), lambda p: [p[0] * p[1]]),
        (('term',), default),
    ],
    'factor_tail': [
        (('TIMES', 'term', 'factor_tail'), lambda p: [p[1] * p[2]]),
        (('TIMES', 'term'), lambda p: [p[1]]),
    ],
    'term': [
        (('PREFIX', 'term'), term_prefix),
        (('NAME',), term_name),
        (('LBRACKET', 'expr', 'RBRACKET'), lambda p: [p[1]]),
    ],
}


def parse(expr):
    tokens = list(lex(expr))

    r = _parse_nonterminal('s', tokens)
    if not r or not r[0]:
        return None
    return r[0][0]
