# -*- coding: utf-8 -*-

from passport.backend.core.ydb.declarative.errors import ProgrammingError
from passport.backend.core.ydb.declarative.types import (
    init_type,
    pytype_to_data_type,
)
import six


class BaseElement(object):
    @property
    def objid(self):
        return id(self)


class RenderableElement(BaseElement):
    def _render(self, context, top=False):
        raise NotImplementedError()

    @staticmethod
    def _wrap_top(string, top):
        if top:
            return string
        else:
            return '(%s)' % string

    def _extract_columns(self):
        return set()


class SimpleRenderableElement(RenderableElement):
    def __init__(self, value):
        self.value = value

    def _render(self, context, top=False):
        return self.value

    def __str__(self):
        return str(self.value)


NULL = SimpleRenderableElement('NULL')


class BoundParameter(BaseElement):
    def __init__(self, data_type, value):
        self.data_type = init_type(data_type)
        self.value = value

    def _to_param(self):
        return (
            self.data_type.from_pyval(self.value),
            self.data_type.get_type_annotation(),
        )


class Expression(RenderableElement):
    def _extract_columns(self):
        raise NotImplementedError()

    def _render(self, context, top=False):
        raise NotImplementedError()

    @property
    def name(self):
        return 'expression'

    def from_pyval(self, value):
        data_type = pytype_to_data_type(type(value))
        try:
            db_val = data_type.from_pyval(value)
        except (ValueError, TypeError) as err:
            raise ProgrammingError(
                'Wrong data provided for %s value: %s' % (data_type.name, err),
            )
        return (
            db_val,
            data_type.get_type_annotation(),
        )

    def to_pyval(self, value):
        return value

    def _get_param_prefix(self):
        return 'param'

    def _wrap_param(self, value):
        if isinstance(value, (Expression, SimpleRenderableElement)):
            return value
        elif value is None:
            return NULL
        elif isinstance(value, BoundParameter):
            return ParametrizableValue(
                self._get_param_prefix(),
                *value._to_param()
            )
        elif isinstance(value, BaseElement):
            raise TypeError('Cannot use %r as parameter' % value)
        else:
            return ParametrizableValue(
                self._get_param_prefix(),
                *self.from_pyval(value)
            )

    def __eq__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '=')

    def __ne__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '!=')

    def __lt__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '<')

    def __le__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '<=')

    def __gt__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '>')

    def __ge__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '>=')

    def __add__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '+')

    def __sub__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '-')

    def __mul__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '*')

    def __div__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '/')

    __truediv__ = __div__

    def __mod__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '%')

    def __and__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, 'AND')

    def __or__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, 'OR')

    def __xor__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, 'XOR')

    def __invert__(self):
        return UnaryExpression(self, 'NOT')

    def __lshift__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '<<')

    def __rshift__(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '>>')

    def is_(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, 'IS')

    def is_not_(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, 'IS NOT')

    def in_(self, values_list):
        expression = InListExpression(values_list, self._wrap_param)
        return BinaryExpression(self, expression, 'IN')

    def not_in_(self, values_list):
        expression = InListExpression(values_list, self._wrap_param)
        return BinaryExpression(self, expression, 'NOT IN')

    def like(self, other, case_sensitive=True, escape=None):
        other = self._wrap_param(other)
        operator = 'LIKE' if case_sensitive else 'ILIKE'
        postfix = None if escape is None else 'ESCAPE \'%s\'' % escape
        return BinaryExpression(self, other, operator, postfix)

    def concat(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '||')

    def bin_and_(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '&')

    def bin_or_(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '|')

    def bin_xor_(self, other):
        other = self._wrap_param(other)
        return BinaryExpression(self, other, '^')

    def desc(self):
        return DescOrderBy(self)


class ParametrizableValue(Expression):
    def __init__(self, param_prefix, value, param_type):
        self.value = value
        self.param_type = param_type
        self.param_prefix = param_prefix

    def _render(self, context, top=False):
        return context.bind_parameter(
            self.param_prefix, self.value, self.param_type)

    def _extract_columns(self):
        return set()


class Function(Expression):
    base_type = None

    def __init__(self, function, *args):
        self.function = function
        self.args = list(self._wrap_param(arg) for arg in args)

    def _render(self, context, top=False):
        return '%s(%s)' % (
            self.function,
            ', '.join(arg._render(context) for arg in self.args),
        )

    def _get_param_prefix(self):
        return 'function'

    def _extract_columns(self):
        return six.moves.reduce(
            lambda x, y: x | y,
            (arg._extract_columns() for arg in self.args)
        )

    def to_pyval(self, value):
        if self.base_type is None:
            return super(Function, self).to_pyval(value)
        else:
            return self.base_type.to_pyval(value)

    def __str__(self):
        return 'Function(%s, %s)' % (
            self.function,
            ', '.join(map(str, self.args)),
        )


class UnaryExpression(Expression):
    def __init__(self, operand, operator):
        self.operand = operand
        self.operator = operator

    def _render(self, context, top=False):
        return self._wrap_top('%s %s', top) % (
            self.operator,
            self.operand._render(context),
        )

    def _extract_columns(self):
        if isinstance(self.operand, ColumnElement):
            return {self.operand}
        else:
            return set()

    def __str__(self):
        return 'UnaryExpression(%s, %s)' % (
            self.operator,
            self.operand,
        )


class BinaryExpression(Expression):
    def __init__(self, left, right, operator, postfix=None):
        self.left = left
        self.right = right
        self.operator = operator
        self.postfix = postfix

    def _render(self, context, top=False):
        return self._wrap_top('%s %s %s%s', top) % (
            self.left._render(context),
            self.operator,
            self.right._render(context),
            '' if self.postfix is None else ' %s' % self.postfix,
        )

    def _extract_columns(self):
        return self.left._extract_columns() | self.right._extract_columns()

    def __str__(self):
        return 'BinaryExpression(%s, %s, %s)' % (
            self.operator,
            self.left,
            self.right,
        )


class InListExpression(RenderableElement):
    def __init__(self, values, value_renderer):
        self.values = values
        self.value_renderer = value_renderer

    def _render(self, context, top=False):
        return '(%s)' % ', '.join(
            self.value_renderer(value)._render(context)
            for value in self.values
        )

    def __str__(self):
        return 'InListExpression(%s)' % (', '.join(self.values))


class OperationsList(RenderableElement):
    def __init__(self, values, operator):
        self.values = values
        self.operator = operator

    def _render(self, context, top=False):
        return self._wrap_top('%s', top) % (' %s ' % self.operator).join(
            value._render(context) for value in self.values
        )

    def __str__(self):
        return 'OperationsList(%s, %s)' % (self.operator, ', '.join(self.values))


def and_(*operations):
    return OperationsList(operations, 'AND')


def or_(*operations):
    return OperationsList(operations, 'OR')


def xor_(*operations):
    return OperationsList(operations, 'XOR')


class DescOrderBy(RenderableElement):
    def __init__(self, expression):
        self.expression = expression

    def _render(self, context, top=False):
        return '%s DESC' % self.expression._render(context)

    def __str__(self):
        return 'Desc(%s)' % self.expression

    def _extract_columns(self):
        return self.expression._extract_columns()


class FromElement(RenderableElement):
    def _extract_columns(self):
        raise NotImplementedError()

    def _render(self, context, top=False):
        raise NotImplementedError()

    @property
    def name(self):
        raise NotImplementedError()


class Table(FromElement):
    class Columns(object):
        def __init__(self, columns):
            self.columns = columns

        def __getattr__(self, item):
            try:
                return self.columns[item]
            except KeyError:
                raise AttributeError('Unknown column %s' % item)

    def __init__(self, name, *columns):
        self._name = name
        self.primary_key = {}
        self.first_primary_key = None
        self.first_primary_key_name = None
        self.column_dict = {}
        self.c = Table.Columns(self.column_dict)
        self.column_list = list(columns)

        for column in columns:
            column._attach_to_table(self)
            self.column_dict[column.name] = column
            if column.primary_key:
                if self.first_primary_key is None:
                    self.first_primary_key = column
                    self.first_primary_key_name = column.name
                self.primary_key[column.name] = column

    def _render(self, context, top=False):
        return self.name

    @property
    def name(self):
        return self._name

    def _extract_columns(self):
        return self.column_list

    def __str__(self):
        return 'Table(%s)' % self.name


class ColumnElement(Expression):
    @property
    def name(self):
        raise NotImplementedError()

    @property
    def table(self):
        raise NotImplementedError()

    def _render(self, context, top=False):
        raise NotImplementedError()

    def from_pyval(self, value):
        raise NotImplementedError()

    def to_pyval(self, value):
        raise NotImplementedError()

    def _wrap_opposite_value(self, value):
        raise NotImplementedError()

    def _extract_columns(self):
        return {self}

    def __hash__(self):
        return object.__hash__(self)


class Label(BaseElement):
    def __init__(self, element, alias):
        self.element = element
        self.name = alias

    def __str__(self):
        return 'Label(%s, %s)' % (self.name, self.element)


class Column(ColumnElement):
    def __init__(self, name, data_type, primary_key=False):
        self._name = name
        self.data_type = init_type(data_type)
        self.primary_key = primary_key
        self._table = None

    @property
    def name(self):
        return self._name

    @property
    def table(self):
        return self._table

    def _attach_to_table(self, table):
        self._table = table

    def from_pyval(self, value):
        try:
            db_val = self.data_type.from_pyval(value)
        except (TypeError, ValueError) as err:
            raise ProgrammingError(
                'Wrong data provided for column %s: %s' % (self.name, err),
            )
        return (
            db_val,
            self.data_type.get_type_annotation(),
        )

    def to_pyval(self, value):
        return self.data_type.to_pyval(value)

    def _get_param_prefix(self):
        return self.name

    def _render(self, context, top=False):
        if self.table is None:
            raise ProgrammingError('Unable to render unbound column')
        table_alias = context.get_table_alias(self.table)
        return '%s%s' % (
            '%s.' % table_alias if table_alias is not None else '',
            self.name,
        )

    def label(self, alias):
        return Label(self, alias)

    def __str__(self):
        return 'Column(%s%s)' % (
            '%s.' % self.table.name if self.table else '',
            self.name,
        )

    def _wrap_opposite_value(self, value):
        return self.data_type._wrap_opposite_value(value)
