# -*- coding: utf-8 -*-

from __future__ import unicode_literals

import logging

from passport.backend.core import Undefined
from passport.backend.core.db.runner import get_id_from_query_result
from passport.backend.core.db.utils import insert_with_on_duplicate_key_update
from passport.backend.social.common.db.execute import execute
from passport.backend.social.common.serialize import (
    DatabaseQuery,
    register_database_serializer,
    ValueSerializationError,
)
from passport.backend.utils.common import remove_none_values
from sqlalchemy import (
    and_ as sql_and,
    or_ as sql_or,
    select as sql_select,
    union as sql_union,
)


logger = logging.getLogger(__name__)


def _type_attributes_to_name_attributes(eav_conf, entity_id, attrs):
    unknown_attr_types = []
    for attr_type in attrs.keys():
        attr_name = eav_conf.get_name_from_type(attr_type)
        if attr_name:
            attrs[attr_name] = attrs[attr_type]
        else:
            unknown_attr_types.append(attr_type)
        del attrs[attr_type]
    if unknown_attr_types:
        formatted_unknown_attr_types = ', '.join(map(str, unknown_attr_types))
        logger.debug(
            'Unknown attributes on entity_id=%s found: %s' % (
                entity_id,
                formatted_unknown_attr_types,
            ),
        )


def _EavDatabaseReaderLoaderMethod(eav_conf, columns_for_search):
    selector = EavSelector(eav_conf, columns_for_search)

    def method(reader, ids):
        return selector(reader.database, ids)
    # Добавим атрибуты для облегчения отладки
    method.eav_conf = eav_conf
    method.columns_for_search = columns_for_search
    method.selector = selector
    return method


class EavConfiguration(object):
    eav_attributes = dict()
    eav_table = None

    eav_index_attributes = dict()
    eav_index_table = None

    index_attributes = list()
    index_table = None

    def __init__(self, eav_attributes=None, eav_table=None,
                 eav_index_attributes=None, eav_index_table=None,
                 index_attributes=None, index_table=None):
        self.eav_attributes = eav_attributes or dict()
        self.eav_table = eav_table
        self.eav_index_attributes = eav_index_attributes or dict()
        self.eav_index_table = eav_index_table
        self.index_attributes = index_attributes or list()
        self.index_table = index_table

        self._name_to_type = dict(eav_attributes)
        self._name_to_type.update(eav_index_attributes)
        self._type_to_name = dict((v, k) for k, v in self._name_to_type.items())

    def get_name_from_type(self, attr_type):
        return self._type_to_name.get(attr_type)

    def get_type_from_name(self, attr_name):
        return self._name_to_type.get(attr_name)

    @property
    def primary_key(self):
        assert len(self.index_table.primary_key) == 1
        column = next(iter(self.index_table.primary_key.columns))
        return column.name


class EavSelector(object):
    def __init__(self, eav_conf, columns_to_search):
        self.eav_conf = eav_conf
        self.index_table = eav_conf.index_table
        self.eav_table = eav_conf.eav_table
        self.eav_index_table = eav_conf.eav_index_table

        self.columns_to_search = list()
        for col_name in columns_to_search:
            col = self.index_table.c[col_name]
            self.columns_to_search.append(col)

        self._id_name = eav_conf.primary_key

    def __call__(self, db, ids):
        named_attrs = list()
        if ids:
            index_result = self._select_index(db, ids)
            entity_ids = [ir[self._id_name] for ir in index_result]
            if entity_ids:
                eav_result = self._select_eav(db, entity_ids)
                named_attrs = self._named_attrs_from_rows(index_result, eav_result)
        return named_attrs

    def _select_index(self, db, ids):
        index_query = self.index_query(ids)
        return execute(db, index_query).fetchall()

    def index_query(self, ids):
        if len(self.columns_to_search) < 2:
            column = self.columns_to_search[0]
            index_query = sql_select([self.index_table]).where(column.in_(ids))
        else:
            criteria = []
            for _id in ids:
                assert len(self.columns_to_search) == len(_id)
                and_criteria = [c == v for c, v in zip(self.columns_to_search, _id)]
                criteria.append(sql_and(*and_criteria))
            index_query = sql_select([self.index_table]).where(sql_or(*criteria))
        return index_query

    def _select_eav(self, db, ids):
        attrs_query = self.attrs_query(ids)
        return execute(db, attrs_query).fetchall()

    def attrs_query(self, ids):
        attrs_query = (
            sql_select([self.eav_table])
            .where(self.eav_table.c[self._id_name].in_(ids))
        )
        index_attrs_query = (
            sql_select([self.eav_index_table])
            .where(self.eav_index_table.c[self._id_name].in_(ids))
        )
        return sql_union(attrs_query, index_attrs_query)

    def _named_attrs_from_rows(self, index_rows, eav_rows):
        id_to_dict = dict()

        for index_row in index_rows:
            entity_id = index_row[self._id_name]
            id_to_dict[entity_id] = dict()

        for eav_row in eav_rows:
            entity_id = eav_row[self._id_name]
            if entity_id in id_to_dict:
                type_attrs = id_to_dict[entity_id]
                type_attrs[eav_row.type] = eav_row.value

        for entity_id, type_attrs in id_to_dict.iteritems():
            _type_attributes_to_name_attributes(self.eav_conf, entity_id, type_attrs)

        for index_row in index_rows:
            entity_id = index_row[self._id_name]
            name_attrs = id_to_dict[entity_id]

            for index_attr in self.eav_conf.index_attributes:
                value = index_row[index_attr]
                if value is not None:
                    if not isinstance(value, basestring):
                        value = str(value)
                    name_attrs[index_attr] = value
        return id_to_dict.values()


class _EavDatabaseReaderMetaClass(type):
    def __new__(cls, name, bases, attrs):
        eav_conf = attrs.get('eav_configuration')
        if eav_conf:
            for index in eav_conf.index_table.indexes:
                method_name = 'load_by_%ss' % index.name
                columns_for_search = [c.name for c in index.columns]
                method = _EavDatabaseReaderLoaderMethod(eav_conf, columns_for_search)
                method.__name__ = str(method_name)
                attrs[method_name] = method
        return type.__new__(cls, name, bases, attrs)


class EavDatabaseReader(object):
    __metaclass__ = _EavDatabaseReaderMetaClass

    eav_configuration = None

    def __init__(self, database):
        self.database = database


class EavDatabaseSerializer(object):
    eav_configuration = None

    def serialize(self, old, new, difference=None):
        if old is None and new is not None:
            queries_gen = self._serialize_create(new)
        elif old is not None and new is None:
            queries_gen = self._serialize_delete(old)
        elif old is not None and new is not None:
            old_pk_db_value = old.get(self.eav_configuration.primary_key)
            new_pk_db_value = new.get(self.eav_configuration.primary_key)
            if (
                old_pk_db_value is not None and
                new_pk_db_value is not None and
                old_pk_db_value == new_pk_db_value
            ):
                queries_gen = self._serialize_change(old, new)
            else:
                raise NotImplementedError()  # pragma: no cover
        else:
            assert False  # pragma: no cover
        return queries_gen

    def on_create(self, attribute_dict):
        pass

    def on_delete(self, attribute_dict):
        pass

    def _serialize_create(self, attribute_dict):
        query_kwargs = dict()

        pk_db_value = attribute_dict.get(self.eav_configuration.primary_key)
        if pk_db_value is not None:
            query_kwargs[self.eav_configuration.primary_key] = pk_db_value

        for column_name in self.eav_configuration.index_attributes:
            if column_name != self.eav_configuration.primary_key:
                query_kwargs[column_name] = attribute_dict.get(column_name)

        pk_db_value_container = [pk_db_value]

        def _set_pk_to_attribute_dict(result):
            pk_db_value = get_id_from_query_result(result)
            assert not isinstance(pk_db_value, (list, tuple))
            pk_db_value_container[0] = pk_db_value
            attribute_dict[self.eav_configuration.primary_key] = pk_db_value
        yield (
            self._insert_to_index_query(query_kwargs),
            _set_pk_to_attribute_dict,
        )
        pk_db_value = pk_db_value_container[0]

        for attr_name in self.eav_configuration.eav_attributes:
            db_value = attribute_dict.get(attr_name)
            if db_value is not None:
                attr_type = self.eav_configuration.get_type_from_name(attr_name)
                if attr_type is None:
                    raise EavUnknownAttributeNameError(attr_name)
                yield self._insert_to_attribute_query(pk_db_value, attr_type, db_value)

        for attr_name in self.eav_configuration.eav_index_attributes:
            db_value = attribute_dict.get(attr_name)
            if db_value is not None:
                attr_type = self.eav_configuration.get_type_from_name(attr_name)
                if attr_type is None:
                    raise EavUnknownAttributeNameError(attr_name)
                yield self._insert_to_index_attribute_query(pk_db_value, attr_type, db_value)

        self.on_create(attribute_dict)

    def _serialize_change(self, old_attribute_dict, new_attribute_dict):
        query_kwargs = dict()

        for column_name in self.eav_configuration.index_attributes:
            old_db_value = old_attribute_dict.get(column_name)
            new_db_value = new_attribute_dict.get(column_name)
            if old_db_value != new_db_value:
                query_kwargs[column_name] = new_db_value

        pk_db_value = new_attribute_dict[self.eav_configuration.primary_key]

        if query_kwargs:
            yield self._update_index_query(pk_db_value, query_kwargs)

        for attr_name in self.eav_configuration.eav_attributes:
            old_db_value = old_attribute_dict.get(attr_name)
            new_db_value = new_attribute_dict.get(attr_name)
            if old_db_value != new_db_value:
                attr_type = self.eav_configuration.get_type_from_name(attr_name)
                if not attr_type:
                    raise EavUnknownAttributeNameError(attr_name)
                if new_db_value is not None:
                    yield self._update_attribute_query(pk_db_value, attr_type, new_db_value)
                else:
                    yield self._delete_one_attribute_query(pk_db_value, attr_type)

        for attr_name in self.eav_configuration.eav_index_attributes:
            old_db_value = old_attribute_dict.get(attr_name)
            new_db_value = new_attribute_dict.get(attr_name)
            if old_db_value != new_db_value:
                attr_type = self.eav_configuration.get_type_from_name(attr_name)
                if not attr_type:
                    raise EavUnknownAttributeNameError(attr_name)
                if new_db_value is not None:
                    yield self._update_index_attribute_query(pk_db_value, attr_type, new_db_value)
                else:
                    yield self._delete_one_index_attribute_query(pk_db_value, attr_type)

    def _serialize_delete(self, attribute_dict):
        pk_db_value = attribute_dict.get(self.eav_configuration.primary_key)
        if pk_db_value is not None:
            yield self._delete_index_query(pk_db_value)
            yield self._delete_all_attribute_query(pk_db_value)
            yield self._delete_all_index_attribute_query(pk_db_value)

        attribute_dict.clear()
        self.on_delete(attribute_dict)

    def _insert_to_index_query(self, values):
        index_table = self.eav_configuration.index_table
        return DatabaseQuery(index_table.insert().values(**values))

    def _update_index_query(self, pk, values):
        index_table = self.eav_configuration.index_table
        index_pk_column = index_table.c[self.eav_configuration.primary_key]
        return DatabaseQuery(
            index_table.update()
            .values(**values)
            .where(index_pk_column == pk),
        )

    def _delete_index_query(self, pk):
        index_table = self.eav_configuration.index_table
        index_pk_column = index_table.c[self.eav_configuration.primary_key]
        return DatabaseQuery(
            index_table.delete()
            .where(index_pk_column == pk),
        )

    def _insert_to_attribute_query(self, pk, attr_type, attr_value):
        pk_column_name = self.eav_configuration.primary_key
        values = {
            pk_column_name: pk,
            'type': attr_type,
            'value': attr_value,
        }
        return DatabaseQuery(
            insert_with_on_duplicate_key_update(
                self.eav_configuration.eav_table,
                ['value'],
            )
            .values(**values)
        )

    def _update_attribute_query(self, pk, attr_type, attr_value):
        return self._insert_to_attribute_query(pk, attr_type, attr_value)

    def _delete_one_attribute_query(self, pk, attr_type):
        table = self.eav_configuration.eav_table
        pk_column = table.c[self.eav_configuration.primary_key]
        return DatabaseQuery(
            table.delete()
            .where(
                sql_and(
                    pk_column == pk,
                    table.c.type.in_([attr_type]),
                ),
            )
        )

    def _delete_all_attribute_query(self, pk):
        table = self.eav_configuration.eav_table
        eav_pk_column = table.c[self.eav_configuration.primary_key]
        return DatabaseQuery(table.delete().where(eav_pk_column == pk))

    def _insert_to_index_attribute_query(self, pk, attr_type, attr_value):
        pk_column_name = self.eav_configuration.primary_key
        values = {
            pk_column_name: pk,
            'type': attr_type,
            'value': attr_value,
        }
        return DatabaseQuery(
            insert_with_on_duplicate_key_update(
                self.eav_configuration.eav_index_table,
                ['value'],
            )
            .values(**values)
        )

    def _update_index_attribute_query(self, pk, attr_type, attr_value):
        return self._insert_to_index_attribute_query(pk, attr_type, attr_value)

    def _delete_one_index_attribute_query(self, pk, attr_type):
        table = self.eav_configuration.eav_index_table
        pk_column = table.c[self.eav_configuration.primary_key]
        return DatabaseQuery(
            table.delete()
            .where(
                sql_and(
                    pk_column == pk,
                    table.c.type.in_([attr_type]),
                ),
            )
        )

    def _delete_all_index_attribute_query(self, pk):
        table = self.eav_configuration.eav_index_table
        eav_pk_column = table.c[self.eav_configuration.primary_key]
        return DatabaseQuery(table.delete().where(eav_pk_column == pk))


class EavModelDatabaseSerializerMetaClass(type):
    def __new__(cls, name, bases, attrs):
        db_serializer_cls = type.__new__(cls, name, bases, attrs)
        if db_serializer_cls.model is not None:
            model_name = db_serializer_cls.model.__name__
            register_database_serializer(model_name, db_serializer_cls)
        return db_serializer_cls


class EavModelDatabaseSerializer(object):
    __metaclass__ = EavModelDatabaseSerializerMetaClass

    model = None
    eav_model_converters = None
    database_serializer_class = None

    def __init__(self):
        self._db_serializer = self.database_serializer_class()

    def serialize(self, old, new, difference=None):
        if old:
            old_attribute_dict = self._build_attribute_dict(old)
        else:
            old_attribute_dict = None
        if new:
            new_attribute_dict = self._build_attribute_dict(new)
        else:
            new_attribute_dict = None

        def _set_pk_to_model(attribute_dict):
            primary_key = self._db_serializer.eav_configuration.primary_key
            new.parse({primary_key: attribute_dict[primary_key]})
        self._db_serializer.on_create = _set_pk_to_model

        def _clear_model(attribute_dict):
            for field_name in old._meta.fields.iterkeys():
                setattr(old, field_name, Undefined)
            old.parse(attribute_dict)
        self._db_serializer.on_delete = _clear_model

        return self._db_serializer.serialize(old_attribute_dict, new_attribute_dict)

    def _build_attribute_dict(self, model):
        result = dict()
        for converter in self.eav_model_converters:
            attr_dict = converter(model)
            if set(result.keys()) & set(attr_dict.keys()):
                raise NotImplementedError()
            result.update(attr_dict)
        return remove_none_values(result)


class _EavModelConverter(object):
    def __call__(self, model):
        raise NotImplementedError()  # pragma: no cover


class EavSimpleModelConverter(_EavModelConverter):
    def __init__(self, model_attr_name, model_value_to_db_value, db_attr_name=None):
        self._model_attr_name = model_attr_name
        self._db_attr_name = db_attr_name or model_attr_name
        self._model_value_to_db_value = model_value_to_db_value

    def __call__(self, model):
        model_value = getattr(model, self._model_attr_name)
        try:
            db_value = self._model_value_to_db_value(model_value)
        except ValueSerializationError:
            raise ValueSerializationError("Invalid value: %s = %r" % (self._model_attr_name, model_value))
        return {self._db_attr_name: db_value}


class EavError(Exception):
    pass


class EavUnknownAttributeNameError(EavError):
    """
    Неизвестный атрибут:
        - для атрибута неизвестен тип (числовой идентификатор)
    """
