from .fields import Field
from . import lookups

import copy
import six

from yt.wrapper.errors import YtTabletNotMounted


class QueryException(Exception):
    pass


# NOTE: Mostly adpted from django.db.models.query_utils.Q
# TODO: Remove unused params & methods.
class Node(object):
    """
    A single internal node in the tree graph. A Node should be viewed as a
    connection (the root) with the children being either leaf nodes or other
    Node instances.
    """
    default = 'DEFAULT'

    def __init__(self, children=None, connector=None, negated=False):
        """
        Constructs a new Node. If no connector is given, the default will be
        used.
        """
        self.children = children[:] if children else []
        self.connector = connector or self.default
        self.negated = negated

    @classmethod
    def _new_instance(cls, children=None, connector=None, negated=False):
        """
        This is called to create a new instance of this class when we need new
        Nodes (or subclasses) in the internal code in this class. Normally, it
        just shadows __init__(). However, subclasses with an __init__ signature
        that is not an extension of Node.__init__ might need to implement this
        method to allow a Node to create a new instance of them (if they have
        any extra setting up to do).
        """
        obj = Node(children, connector, negated)
        obj.__class__ = cls
        return obj

    def __str__(self):
        template = '(NOT (%s: %s))' if self.negated else '(%s: %s)'
        return template % (self.connector, ', '.join(str(c) for c in self.children))

    def __repr__(self):
        return str("<%s: %s>") % (self.__class__.__name__, self)

    def __deepcopy__(self, memodict):
        """
        Utility method used by copy.deepcopy().
        """
        obj = Node(connector=self.connector, negated=self.negated)
        obj.__class__ = self.__class__
        obj.children = copy.deepcopy(self.children, memodict)
        return obj

    def __len__(self):
        """
        The size of a node if the number of children it has.
        """
        return len(self.children)

    def __bool__(self):
        """
        For truth value testing.
        """
        return bool(self.children)

    def __nonzero__(self):
        return type(self).__bool__(self)

    def __contains__(self, other):
        """
        Returns True if 'other' is a direct child of this instance.
        """
        return other in self.children

    def add(self, data, conn_type, squash=True):
        """
        Combines this tree and the data represented by data using the
        connector conn_type. The combine is done by squashing the node other
        away if possible.

        This tree (self) will never be pushed to a child node of the
        combined tree, nor will the connector or negated properties change.

        The function returns a node which can be used in place of data
        regardless if the node other got squashed or not.

        If `squash` is False the data is prepared and added as a child to
        this tree without further logic.
        """
        if data in self.children:
            return data
        if not squash:
            self.children.append(data)
            return data
        if self.connector == conn_type:
            # We can reuse self.children to append or squash the node other.
            if (isinstance(data, Node) and not data.negated and
                    (data.connector == conn_type or len(data) == 1)):
                # We can squash the other node's children directly into this
                # node. We are just doing (AB)(CD) == (ABCD) here, with the
                # addition that if the length of the other node is 1 the
                # connector doesn't matter. However, for the len(self) == 1
                # case we don't want to do the squashing, as it would alter
                # self.connector.
                self.children.extend(data.children)
                return self
            else:
                # We could use perhaps additional logic here to see if some
                # children could be used for pushdown here.
                self.children.append(data)
                return data
        else:
            obj = self._new_instance(self.children, self.connector,
                                     self.negated)
            self.connector = conn_type
            self.children = [obj, data]
            return data

    def negate(self):
        """
        Negate the sense of the root connector.
        """
        self.negated = not self.negated


class Q(Node):
    """
    Encapsulates filters as objects that can then be combined logically (using
    `&` and `|`).
    """
    # Connection types
    AND = 'AND'
    OR = 'OR'
    default = AND

    def __init__(self, *args, **kwargs):
        super(Q, self).__init__(children=list(args) + list(kwargs.items()))

    def _combine(self, other, conn):
        if not isinstance(other, Q):
            raise TypeError(other)
        obj = type(self)()
        obj.connector = conn
        obj.add(self, conn)
        obj.add(other, conn)
        return obj

    def __or__(self, other):
        return self._combine(other, self.OR)

    def __and__(self, other):
        return self._combine(other, self.AND)

    def __invert__(self):
        obj = type(self)()
        obj.add(self, self.AND)
        obj.negate()
        return obj

    def as_sql(self, table):
        conn = u' {} '.format(self.connector)
        result = []
        for child in self.children:
            if hasattr(child, 'as_sql'):
                result.append(child.as_sql(table))
            else:
                result.append(self._resolve_lookup(child, table))

        sql_string = conn.join(result)
        if sql_string:
            if self.negated:
                sql_string = u'NOT ({})'.format(sql_string)
            elif len(result):
                sql_string = u'({})'.format(sql_string)
        return sql_string

    def _resolve_lookup(self, child, table):
        clause, value = child
        parts = clause.split(lookups.LOOKUP_SEP)
        if len(parts) == 1:
            lookup_cls = lookups.Exact
        else:
            lookup_cls = lookups.get_lookup(parts.pop())
        return lookup_cls(table, parts, value).resolve()


# This would be base class for all queries, similar to django QuerySet
class Query(object):
    """Class that encapsulates query to yt. Similar to django QuerySet

    Allows to build a query lazily, adding filters with filter/exclude, joins and limit.
    All methods return a copy of Query object with added clauses. Thus you can reuse the same
    base Query by expanding it with different filters

    .eval() itself returns a lazy generator from yt client.
    """

    def __init__(self, table, store, where=None, joins=None, limit=None, orders=None):
        self.table = table
        self.store = store
        self.where = where or Q()
        self._limit = limit
        self.joins = joins or []
        self.orders = orders or []
        self._join_aliases = {}
        for idx, join in enumerate(self.joins):
            self._join_aliases[join[0]] = u"j{}".format(idx)

    def _clone(self):
        c = self.__class__(table=self.table,
                           store=self.store,
                           where=self.where,
                           limit=self._limit,
                           joins=self.joins,
                           orders=self.orders)
        return c

    def add_q(self, q_object):
        self.where.add(q_object, Q.AND)

    def filter(self, *args, **kwargs):
        clone = self._clone()
        clone.add_q(Q(*args, **kwargs))
        return clone

    def exclude(self, *args, **kwargs):
        clone = self._clone()
        clone.add_q(~Q(*args, **kwargs))
        return clone

    def limit(self, limit=None):
        """Add a LIMIT clause to the query

        0 or None limit results removal of LIMIT
        """
        clone = self._clone()
        if not limit:
            limit = None
        else:
            try:
                limit = int(limit)
            except (TypeError, ValueError):
                raise QueryException("Can't use {!r} as LIMIT in YT query".format(limit))

        clone._limit = limit
        return clone

    def _infer_join_fields(self, other_table, on_left=None, on_right=None):
        def ensure_sequence(val):
            return (val, ) if isinstance(val, Field) else val

        if on_left and on_right:
            return ensure_sequence(on_left), ensure_sequence(on_right)

        if on_left is None and on_right is None:
            raise NotImplementedError("Automatic ON clause inferring is not implemented yet")

        clause_fields = ensure_sequence(on_left or on_right)
        left, right = [], []
        for field in clause_fields:
            if field.table == self.table:
                lf = field
                rf = other_table._fields.all.get(field.name)
            elif field.table == other_table:
                lf = self.table._fields.all.get(field.name)
                rf = field
            else:
                lf = rf = None
            if lf is None or rf is None:
                raise ValueError("Don't know how to join {} and {} on {}".format(
                    self.table, other_table, field))
            left.append(lf)
            right.append(rf)
        return left, right

    def join(self, table, on_left=None, on_right=None, use_left_join=False):
        """Add a join clause to the query

        on_left/on_right may be a single field or a tuple of fields to be passed to ON clause
        you may omit on_right. In this case field names will be inferred from on_left (or vice versa)
        """
        on_left, on_right = self._infer_join_fields(table, on_left, on_right)

        clone = self._clone()
        clone.joins.append(
            (table, on_left, on_right, use_left_join)
        )
        join_alias = u"j{}".format(len(clone._join_aliases))
        clone._join_aliases[table] = join_alias
        return clone

    def _expand_join_clause(self, fields):
        """Expand left/right clause, respecting table aliases"""
        names = []
        for field in fields:
            if isinstance(field, Field):
                alias = self._join_aliases.get(field.table, '')
                full_name = alias + '.' + field.name if alias else field.name
            elif isinstance(field, six.string_types):
                full_name = field
            else:
                raise TypeError("Don't know how to use {} in join".format(field))
            names.append(full_name)
        return ','.join(names)

    def order_by(self, *fields):
        clone = self._clone()
        clone.orders.extend(fields)
        return clone

    def as_sql(self, *fields):
        selector = ",".join(fields) if fields else '*'

        query = u"{} FROM [{}]".format(selector, self.table._table_path)

        if self.joins:
            clauses = []
            for join in self.joins:
                joined_table, on_left, on_right, use_left_join = join
                alias = self._join_aliases[joined_table]

                join_clause = u"{} [{}] AS {} ON ({}) = ({})".format(
                    'LEFT JOIN' if use_left_join else 'JOIN',
                    joined_table._table_path, alias,
                    self._expand_join_clause(on_left),
                    self._expand_join_clause(on_right),
                )
                clauses.append(join_clause)
            query += u" " + u" ".join(clauses)

        where = self.where.as_sql(self.table)
        if where.strip():
            query += u" WHERE {}".format(where)

        if self.orders:
            # TODO: Allow Fields; Inspect _join_aliases, to allow fields from other Tables
            # or other way to implement joins
            orders = (
                field_name[1:] + u" DESC" if field_name.startswith('-')
                else field_name
                for field_name in self.orders
            )
            query += u" ORDER BY {}".format(", ".join(orders))

        if self._limit:
            query += u" LIMIT {}".format(self._limit)
        return query.strip()

    def only(self, *fields):
        """Return only fields mentioned"""
        if not fields:
            raise ValueError("No fields specified for 'only' query on table {}".format(self.table))
        for field in fields:
            if field not in self.table._fields.all:
                raise ValueError("Can't find field '{}' on table {}, while executing only-clase".format(field, self.table))
        for row in self.store.select(self.as_sql(*fields)):
            yield row

    # TODO: add caching
    def eval(self, allow_join_without_index=False):
        sql = self.as_sql()
        try:
            objects = self.store.select(sql, allow_join_without_index, remount=False)
        except YtTabletNotMounted:
            self.table.mount_table()
            for table in self._join_aliases:
                table.mount_table()
            objects = self.store.select(sql, allow_join_without_index, remount=False)
        for row in objects:
            extra = {}
            # If there were joins row contains 'j0.foo', 'j1.bar', etc.
            # Will instantiate these objects and store in ._extra property
            for table, alias in self._join_aliases.items():
                prefix = alias + '.'
                sub_table_args = {}
                for key in list(row.keys()):
                    if key.startswith(prefix):
                        sub_table_args[key.replace(prefix, '')] = row.pop(key)
                extra[table.__name__] = table.deserialize(**sub_table_args)
            obj = self.table.deserialize(**row)
            obj._extra = extra
            yield obj

    def __iter__(self):
        return self.eval()

    def list(self):
        return list(self.eval())
