# -*- coding: utf-8 -*-

from enum import IntEnum
import json
import time

from passport.backend.core.lazy_loader import LazyLoader
from passport.backend.utils.common import classproperty
import passport.backend.utils.p3 as p3
from passport.backend.vault.api.db import get_db
from passport.backend.vault.api.errors import (
    AccessError,
    NonexistentEntityError,
)
from passport.backend.vault.api.utils import ulid
from passport.backend.vault.api.utils.json import (
    JsonSerializable,
    JsonSerializator,
)
import six
from sqlalchemy import (
    func,
    literal,
)
import sqlalchemy.dialects.mysql.types as mysql_types
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.inspection import inspect
from sqlalchemy.sql.expression import ClauseElement
import sqlalchemy.types as types


db = get_db()


class ReprMixin(object):
    __repr_attrs__ = []

    @property
    def _id_str(self):
        ids = inspect(self).identity
        if ids:
            return '-'.join([str(x) for x in ids]) if len(ids) > 1 else str(ids[0])
        else:
            return 'None'  # pragma: no cover

    @property
    def _repr_attrs_str(self):
        values = []
        single = len(self.__repr_attrs__) == 1
        for key in self.__repr_attrs__:
            if not hasattr(self, key):  # pragma: no cover
                raise KeyError('{} has incorrect attribute "{}" in '
                               '__repr__attrs__'.format(self.__class__, key))
            value = getattr(self, key)
            wrap_in_quote = isinstance(value, six.string_types)
            value = p3.encode_value(value)

            if wrap_in_quote:
                value = '"{}"'.format(value)
            values.append(value if single else '{}:{}'.format(key, value))

        return ', '.join(values)

    def __repr__(self):
        # get id like '#123'
        id_str = ('#' + self._id_str) if self._id_str else ''
        # join class name, id and repr_attrs
        return '<{} {}{}>'.format(
            self.__class__.__name__,
            id_str,
            ' ' + self._repr_attrs_str if self._repr_attrs_str else '',
        )


class IdentityMixin(object):
    @property
    def pk_fileds(self):
        return [pk.name for pk in self.__table__.primary_key]

    @property
    def identity(self):
        return tuple([getattr(self, f) for f in self.pk_fileds])


class ModelSerializator(JsonSerializator):
    def __init__(self, max_serialization_depth=1):
        super(ModelSerializator, self).__init__(
            max_serialization_depth=max_serialization_depth,
            custom_type_processors={
                db.Model: self._process_model,
            },
        )

    def _process_model(self, model, depth, exclude=None, include=None, *args, **kwargs):
        if depth > self._max_serialization_depth:  # pragma: no cover
            return
        if exclude is None:
            exclude = []
        if include is None:
            include = []

        columns = set(model.__table__.columns.keys())

        serialization_columns = list(getattr(model, 'default_serialization_columns', []))
        serialization_columns.extend(include)
        serialization_columns = set(serialization_columns)

        serialization_pycolumns = list(getattr(model, 'default_serialization_pycolumns', []))
        serialization_pycolumns.extend(include)
        serialization_pycolumns = set(serialization_pycolumns)

        result_dict = {c: getattr(model, c) for c in serialization_columns if c in columns and c not in exclude}

        # В serialization_pycolumns можно указать функцию или поле/свойство
        result_dict.update({
            c: (getattr(model, c)() if callable(getattr(model, c, None)) else getattr(model, c, None))
            for c in serialization_pycolumns
            if c not in exclude
        })

        for rel in inspect(model.__class__).relationships.keys():
            if serialization_columns and rel not in serialization_columns:
                continue
            attr = getattr(model, rel)
            result_dict[rel] = attr
        result_dict = self._process_dict(result_dict, depth, include=include, exclude=exclude, *args, **kwargs)

        remove_columns_equals_zero = getattr(model, 'remove_columns_equals_zero', [])
        for k in remove_columns_equals_zero:
            if k in result_dict and result_dict[k] == 0:
                del result_dict[k]

        return result_dict


class ModelSerializable(JsonSerializable):
    serializator = ModelSerializator


class BaseModel(db.Model, ModelSerializable, ReprMixin, IdentityMixin):
    __abstract__ = True
    __repr__ = ReprMixin.__repr__
    raise_access_error = False

    @classmethod
    def get_by_id(cls, uuid):
        o = cls.query.get(uuid)
        if not o:
            if cls.raise_access_error:
                raise AccessError()
            else:
                raise NonexistentEntityError(cls, uuid)
        return o

    @classproperty
    def config(cls):
        return LazyLoader.get_instance('config')

    @classmethod
    def get_count(cls, force=False):
        """
        Достает количество записей в таблице для модели.
        Оптимизированный вариант count(*) для статистики.
        """
        engine = get_db().engine
        if not force and engine.dialect.name == 'mysql':  # pragma: no cover
            # Получаем количество записей в таблице из статистики
            result = engine.execute('''
                SELECT table_rows
                  FROM information_schema.tables
                 WHERE table_schema = '{}' and table_name = '{}'
            '''.format(
                cls.config['database']['schema'],
                cls.__tablename__,
            ))
            count = 0
            for row in result:
                count = row[0] or 0
            return count
        else:
            return get_db().session.query(func.count(1)).select_from(cls).scalar()


class StringyJSON(types.TypeDecorator):
    """Stores and retrieves JSON as TEXT."""
    impl = types.TEXT

    def process_bind_param(self, value, dialect):
        if value is not None:
            value = json.dumps(value)
        return value

    def process_result_value(self, value, dialect):
        if value is not None:
            value = json.loads(value)
        return value


class UUIDType(types.TypeDecorator):
    impl = types.CHAR(26)
    python_type = ulid.ULID
    prefix = None
    ignore_prefix = False

    def __init__(self, prefix=None, ignore_prefix=False, default=None, *args, **kwargs):
        super(UUIDType, self).__init__()
        self.prefix = prefix or self.prefix
        self.ignore_prefix = ignore_prefix or self.ignore_prefix
        self.default = default

    def process_bind_param(self, value, dialect):
        if value is None:
            return value or self.default

        if not isinstance(value, ulid.ULID):
            value = ulid.ULID(
                value,
                prefix=self.prefix,
                ignore_prefix=self.ignore_prefix,
            )

        return value.str(skip_prefix=True)

    def process_result_value(self, value, dialect):
        if value is None or value == self.default:
            return value

        return ulid.ULID(
            value,
            prefix=self.prefix,
            ignore_prefix=self.ignore_prefix,
        )

    @classmethod
    def create_ulid(cls, value=None):
        value = value or ulid.create_ulid()
        return ulid.ULID(
            value,
            prefix=cls.prefix,
            ignore_prefix=cls.ignore_prefix,
        )


# Пишем наследников для TypeDecorator вместо короткого with_variant,
# чтобы Алембик сгенерировал нормальный вызов поля в миграции.

class MagicInteger(types.TypeDecorator):
    impl = types.Integer

    def load_dialect_impl(self, dialect):
        if dialect.name == 'mysql':
            return dialect.type_descriptor(mysql_types.INTEGER(unsigned=True))
        else:
            return dialect.type_descriptor(self.impl)


class MagicBigInteger(types.TypeDecorator):
    impl = types.BigInteger

    def load_dialect_impl(self, dialect):
        if dialect.name == 'mysql':
            return dialect.type_descriptor(mysql_types.BIGINT(unsigned=True))
        else:
            return dialect.type_descriptor(self.impl)


class Timestamp(types.TypeDecorator):
    """
    MySQL добавляет для TIMESTAMP'ов кучу магии в DDL. Чтобы не делать в моделях неочевидное,
    используем NUMERIC(15, 3) вместо TIMESTAMP'a.
    Для SQLite переходим на Float, чтобы подавить предупреждения на тестах.

    Про проблему в доке Алхимии:
    http://docs.sqlalchemy.org/en/latest/dialects/mysql.html#timestamp-columns-and-null

    Перешли на числа в тикете PASSP-19883
    """
    scale = 3
    precision = 15
    impl = types.Numeric(precision=precision, scale=scale)

    def __init__(self, current_timestamp=False):
        super(Timestamp, self).__init__()
        self.current_timestamp = current_timestamp

    def load_dialect_impl(self, dialect):
        if dialect.name == 'sqlite':
            return dialect.type_descriptor(types.Float())
        else:
            return dialect.type_descriptor(self.impl)

    def process_bind_param(self, value, dialect):
        if value is None:
            return round(time.time(), self.scale) if self.current_timestamp else None
        return round(value, self.scale)


class MagicJSON(types.TypeDecorator):
    impl = types.JSON

    def load_dialect_impl(self, dialect):
        if dialect.name == 'sqlite':
            return dialect.type_descriptor(StringyJSON())
        else:
            return dialect.type_descriptor(self.impl)


class MagicBLOB(types.TypeDecorator):
    impl = types.BLOB

    def load_dialect_impl(self, dialect):
        if dialect.name == 'mysql':
            return dialect.type_descriptor(mysql_types.MEDIUMBLOB())
        else:
            return dialect.type_descriptor(self.impl)


class UpdatableMixin(object):
    updated_at = db.Column(Timestamp, nullable=False)
    updated_by = db.Column(MagicBigInteger, nullable=False)

    def touch(self, updated_by, updated_at=None):
        self.updated_at = updated_at if updated_at else time.time()
        self.updated_by = updated_by


class State(IntEnum):
    normal = 0
    hidden = 1


class TokenState(IntEnum):
    normal = 0
    revoked = 1


class ExternalRecordState(IntEnum):
    normal = 0
    inactive = 1


class MySQLMatch(ClauseElement):
    BOOLEAN_MODE = 'IN BOOLEAN MODE'
    NATURAL_MODE = 'IN NATURAL LANGUAGE MODE'
    QUERY_EXPANSION_MODE = 'WITH QUERY EXPANSION'

    def __init__(self, columns, value, mode=None):
        self.columns = columns
        self.value = literal(value)
        self.mode = mode or ''


@compiles(MySQLMatch)
def _mysql_match(element, compiler, **kwargs):
    return 'MATCH ({}) AGAINST ({} {})'.format(
        ', '.join(compiler.process(c, **kwargs) for c in element.columns),
        compiler.process(element.value),
        element.mode,
    )
