from collections import defaultdict
from itertools import chain
from .errors import HrSyntaxError

import six


def resolvePrefixes(a, b):
    if not a.has_prefix and not b.has_prefix:
        return False
    if not a.has_prefix:
        a.set_prefix()
    if not b.has_prefix:
        b.set_prefix()
    return True


def getNamesByPrefixes(literals):
    d = defaultdict(set)
    for l in literals:
        d[l.prefix].add(l.name)
    return d


class Literal(object):
    __slots__ = ('name', 'prefix')

    def __init__(self, name=None, prefix=None):
        self.name = name
        self.prefix = prefix

    def set_prefix(self, prefix=None):
        if self.prefix is not None:
            raise HrSyntaxError('Can''t add a prefix to a prefixed expression')

        self.prefix = prefix or 'h'

    def __str__(self):
        if self.prefix:
            return '{0}@{1}'.format(self.prefix, self.name)
        return self.name

    def __hash__(self):
        return hash((self.prefix, self.name))

    def __cmp__(self, other):
        return cmp((self.prefix, self.name), (other.prefix, other.name))  # noqa

    def __lt__(self, other):
        return (self.prefix or '', self.name or '') < (other.prefix or '', other.name or '')

    def __gt__(self, other):
        return (self.prefix or '', self.name or '') > (other.prefix or '', other.name or '')

    def __le__(self, other):
        return (self.prefix or '', self.name or '') <= (other.prefix or '', other.name or '')

    def __ge__(self, other):
        return (self.prefix or '', self.name or '') >= (other.prefix or '', other.name or '')

    def __eq__(self, other):
        return (self.prefix, self.name) == (other.prefix, other.name)

    def __ne__(self, other):
        return (self.prefix, self.name) != (other.prefix, other.name)


class Clause(object):
    """
    Simple Clause

    Literal & Literal & ... & Literal
    """
    __slots__ = ('positive_literals', 'negative_literals')

    def __init__(self, positive_literals=None, negative_literals=None):
        self.positive_literals = positive_literals or set()
        self.negative_literals = negative_literals or set()

    def __mul__(self, other):
        positive_literals = set()
        negative_literals = set()

        for l in self.positive_literals | other.positive_literals:
            positive_literals.add(Literal(l.name, l.prefix))

        for l in self.negative_literals | other.negative_literals:
            negative_literals.add(Literal(l.name, l.prefix))

        return Clause(positive_literals, negative_literals)

    def apply_demorgan_rule(self):
        clauses = set()

        for pl in self.positive_literals:
            clauses.add(Clause(None, set([Literal(pl.name, pl.prefix)])))

        for nl in self.negative_literals:
            clauses.add(Clause(set([Literal(nl.name, nl.prefix)]), None))

        return clauses

    def get_positive_names_by_prefixes(self):
        return getNamesByPrefixes(self.positive_literals)

    def get_negative_names_by_prefixes(self):
        return getNamesByPrefixes(self.negative_literals)

    def get_p_literals_by_prefix_set(self, prefixes):
        return [l for l in self.positive_literals if l.prefix in prefixes]

    def get_n_literals_by_prefix_set(self, prefixes):
        return [l for l in self.negative_literals if l.prefix in prefixes]

    def set_prefix(self, prefix=None):
        for pl in self.positive_literals:
            pl.set_prefix(prefix)

        for nl in self.negative_literals:
            nl.set_prefix(prefix)

    def __str__(self):
        return ' & '.join(chain(
            map(str, sorted(self.positive_literals)),
            map(lambda x: '~' + str(x), sorted(self.negative_literals))
        ))

    def __hash__(self):
        return hash(str(self))

    def __cmp__(self, other):
        return cmp(str(self), str(other))  # noqa

    def __lt__(self, other):
        return str(self) < str(other)

    def __gt__(self, other):
        return str(self) > str(other)

    def __le__(self, other):
        return str(self) <= str(other)

    def __ge__(self, other):
        return str(self) >= str(other)

    def __eq__(self, other):
        return str(self) == str(other)

    def __ne__(self, other):
        return str(self) != str(other)


class Dnf(object):
    """
    Disjunctive normal form

    Clause | Clause | ... | Clause
    """

    __slots__ = ('clauses', 'has_prefix')

    def __init__(self, clauses=None, has_prefix=False):
        self.clauses = clauses or set()
        self.has_prefix = has_prefix

    def __add__(self, other):
        has_prefix = resolvePrefixes(self, other)
        return Dnf(self.clauses | other.clauses, has_prefix)

    def __sub__(self, other):
        return self * (~other)

    def __invert__(self):
        dnfs = set()

        for c in self.clauses:
            dnfs.add(Dnf(c.apply_demorgan_rule(), self.has_prefix))

        return six.moves.reduce(lambda x, y: x * y, dnfs)

    def __mul__(self, other):
        has_prefix = resolvePrefixes(self, other)

        clauses = set()
        for c1 in self.clauses:
            for c2 in other.clauses:
                clauses.add(c1 * c2)

        return Dnf(clauses, has_prefix)

    def set_prefix(self, prefix=None):
        if self.has_prefix:
            raise HrSyntaxError('Can''t add a prefix to a prefixed expression')

        for c in self.clauses:
            c.set_prefix(prefix)

        self.has_prefix = True

    def __str__(self):
        return ' | '.join(map(str, sorted(self.clauses)))

    def __cmp__(self, other):
        return cmp(str(self), str(other))  # noqa

    def __lt__(self, other):
        return str(self) < str(other)

    def __gt__(self, other):
        return str(self) > str(other)

    def __le__(self, other):
        return str(self) <= str(other)

    def __ge__(self, other):
        return str(self) >= str(other)

    def __eq__(self, other):
        return str(self) == str(other)

    def __ne__(self, other):
        return str(self) != str(other)

    def __hash__(self):
        return id(self)


def transform_literals(dnf, f):
    errors = {}

    def transform_literal(l):
        try:
            return f(l)
        except Exception as e:
            errors[str(l)] = str(e)
            return l

    def transform_clause(clause):
        newpl = set()
        newnl = set()
        for l in clause.positive_literals:
            newpl.add(transform_literal(l))
        for l in clause.negative_literals:
            newnl.add(transform_literal(l))
        return Clause(newpl, newnl)

    newclauses = set()
    for clause in dnf.clauses:
        newclauses.add(transform_clause(clause))

    return (Dnf(newclauses, dnf.has_prefix), errors)
