# -*- coding: utf-8 -*-

from collections import defaultdict
from contextlib import contextmanager
import json

import mock
from nose.tools import (
    eq_,
    ok_,
)
from passport.backend.core.builders.blackbox.faker.blackbox import get_parsed_blackbox_response
from passport.backend.core.builders.blackbox.parsers import parse_blackbox_family_info_response
from passport.backend.core.db import schemas
from passport.backend.core.db.faker.db_utils import _eq_queries
from passport.backend.core.db.utils import (
    encode_params_for_db,
    insert_with_on_duplicate_key_append,
    insert_with_on_duplicate_key_increment,
    insert_with_on_duplicate_key_update,
    insert_with_on_duplicate_key_update_if_equals,
    with_ignore_prefix,
)
from passport.backend.core.dbmanager.manager import get_dbm
from passport.backend.core.dbmanager.sharder import (
    _Sharder,
    build_range_shard_function,
    get_sharder,
)
from passport.backend.core.dbmanager.transaction_manager import TransactionAdapter
from passport.backend.core.differ import diff
from passport.backend.core.eav_type_mapping import (
    alias_name_exists,
    attr_name_exists,
    ext_attr_name_exists,
)
from passport.backend.core.models.account import Account
from passport.backend.core.models.family import FamilyInfo
from passport.backend.core.processor import run_eav
from passport.backend.core.serializers.eav.base import EavSerializer
from passport.backend.core.test.test_utils import (
    iterdiff,
    single_entrant_patch,
)
from passport.backend.core.types.account.account import (
    KINOPOISK_UID_BOUNDARY,
    PDD_UID_BOUNDARY,
)
import six
from six import iteritems
from sqlalchemy.sql.expression import and_


# Тестовая база работает на основе моков метода execute engine-ов SQLAlchemy. Набор engine-ов
# после конфигурации может расширяться, это происходит при создании транзакции.
# Моки execute настраиваются в _setup_engine_execute_patch, сохраняют переданный запрос,
# а также имя БД (_save_query). Таким образом можно восстановить последовательность запросов во все базы
# и в конкретную базу.
#
# При работе с базой в тестовом коде для того, чтобы обращения не сохранялись,
# используется менеджер _pause_patches.


@contextmanager
def _pause_patches(patches):
    for patch in reversed(patches):
        patch.stop()
    try:
        yield
    finally:
        for patch in patches:
            patch.start()


class TrackingTransactionAdapter(TransactionAdapter):
    """
    Данный класс предназначен для отслеживания работы с транзакциями в тестах, подменяет собой
    исходный TransactionAdapter. Сохраняет обращения к BEGIN/COMMIT/ROLLBACK, обеспечивает отслеживание
    engine-ов транзакции, позволяет выполнять сайд-эффекты (отдельно для BEGIN/COMMIT/ROLLBACK).

    Сайд-эффекты выполняются на уровне методов класса.
    """
    fake_db_obj = None  # Объект типа FakeDB для связи с контекстом тестирования

    def _check_for_side_effect_and_save_query(self, dbname, query):
        # Т.к. класс наследует и подменяет методы TransactionAdapter, сайд-эффекты вызываем явно
        trx_side_effects = self.fake_db_obj._transaction_side_effects.get(dbname, {})
        if query in trx_side_effects:
            side_effect = trx_side_effects[query]
            mock.Mock(side_effect=side_effect)(query)
        # Сайд-эффект не задан, сохраним запрос
        self.fake_db_obj._save_query(dbname, query)

    def begin(self, engine):
        dbname = self.fake_db_obj.get_dbname(engine)
        self._check_for_side_effect_and_save_query(dbname, 'BEGIN')

        # Сайд-эффектов нет - начнем транзакцию
        super(TrackingTransactionAdapter, self).begin(engine)

        # Настраиваем патч на engine.execute транзакции
        new_engine = self.connection
        self.fake_db_obj._setup_engine_execute_patch(new_engine, dbname)

        # Если для БД выставлен глобальный сайд-эффект, настроим его для execute нового engine
        if dbname in self.fake_db_obj._db_execute_side_effects:
            new_engine.execute.side_effect = self.fake_db_obj._db_execute_side_effects[dbname]

    def commit_and_close(self):
        if self.is_started:
            dbname = self.fake_db_obj.get_dbname(self.connection)
            self._check_for_side_effect_and_save_query(dbname, 'COMMIT')
        return super(TrackingTransactionAdapter, self).commit_and_close()

    def rollback_and_close(self):
        if self.is_started:
            dbname = self.fake_db_obj.get_dbname(self.connection)
            try:
                self._check_for_side_effect_and_save_query(dbname, 'ROLLBACK')
            except:
                # Для корректности отрабатываем так же, как в базовом методе rollback_and_close
                self.connection.close()
                self.connection = None
                raise
        return super(TrackingTransactionAdapter, self).rollback_and_close()


@single_entrant_patch
class FakeDB(object):
    # Одна metadata к одному имени базы
    schemes_metadata = (
        # Новая схема таблиц с шардами
        (schemas.central_metadata, 'passportdbcentral'),
        (schemas.shard_metadata, 'passportdbshard1'),
        (schemas.shard_metadata, 'passportdbshard2'),
    )

    default_sharding_config = dict.fromkeys(
        [
            'attributes',
            'extended_attributes',
            'password_history',
            'phone_operations',
            'phone_bindings',
            'phone_bindings_history',
            'tracks',
            'email_bindings',
            'account_deletion_operations',
            'passman_recovery_keys',
        ],
        {
            1: 'passportdbshard1',
            2: 'passportdbshard2',
        },
    )

    # Атрибуты обычных пользователей - в первый шард, ПДД и КП - во второй
    default_shard_ranges = [(1, 0), (2, KINOPOISK_UID_BOUNDARY), (2, PDD_UID_BOUNDARY)]

    # Настройки для движка sqlalchemy
    default_db_config = {
        'master': {
            'driver': 'sqlite',
            'database': ':memory:',
        },
        'master_with_low_timeout': {
            'driver': 'sqlite',
            'database': ':memory:',
            'has_low_timeout': True,
        },
    }

    def __init__(self, db_config=None, sharding_config=None):
        self.config = db_config or self.default_db_config
        self.sharding_config = sharding_config or self.default_sharding_config

        self._sharders_patch = mock.patch(
            'passport.backend.core.dbmanager.sharder._sharders',
            defaultdict(lambda: mock.Mock(wraps=_Sharder())),
        )
        self.shard_function = build_range_shard_function(self.default_shard_ranges)

    def start(self):
        TrackingTransactionAdapter.fake_db_obj = self
        self._transaction_adapter_patch = mock.patch(
            'passport.backend.core.dbmanager.transaction_manager.TransactionAdapter',
            TrackingTransactionAdapter,
        )
        self._transaction_adapter_patch.start()
        self._sharders_patch.start()
        self._execute_patches = []
        # Маппинги для отслеживания соответствий engine-ов и имен БД
        self._dbname_to_engines = {}
        self._engine_to_dbname = {}
        self._db_execute_side_effects = {}
        self._transaction_side_effects = {}
        self._total_query_order = []

        for metadata, dbname in self.schemes_metadata:
            dbm = get_dbm(dbname)
            dbm.configure(self.config)
            dbm_engines = [engine for router in (dbm._master, dbm._slave) if router for engine in router.engines]
            for engine in dbm_engines:
                metadata.create_all(engine)
                self._setup_engine_execute_patch(engine, dbname)

        for table_name, config in iteritems(self.sharding_config):
            sharder = get_sharder(table_name)
            sharder.configure(config, shard_function=self.shard_function)
            for key, dbname in iteritems(config):
                dbm = get_dbm(dbname)
                if not dbm.is_configured():
                    dbm.configure(self.config)

    def stop(self):
        for patch in reversed(self._execute_patches):
            patch.stop()
        for metadata, dbname in self.schemes_metadata:
            dbm = get_dbm(dbname)
            dbm_engines = [engine for router in (dbm._master, dbm._slave) if router for engine in router.engines]
            for engine in dbm_engines:
                metadata.drop_all(engine)
        self._sharders_patch.stop()
        self._transaction_adapter_patch.stop()
        TrackingTransactionAdapter.fake_db_obj = None

    def _setup_engine_execute_patch(self, engine, dbname):
        execute = engine.execute
        assert not isinstance(execute, mock.Mock)

        def patched_execute(query, *args, **kwargs):
            self._save_query(dbname, query)
            return execute(query, *args, **kwargs)

        patch = mock.patch.object(engine, 'execute', mock.Mock(wraps=patched_execute))
        patch.start()

        self._execute_patches.append(patch)
        self._dbname_to_engines.setdefault(dbname, set()).add(engine)
        self._engine_to_dbname[engine] = dbname

    def _save_query(self, dbname, query):
        # Note: при необходимости можно начать хранить и соответствующий engine
        self._total_query_order.append((dbname, query))

    def get_dbname(self, engine):
        return self._engine_to_dbname[engine]

    def _wrap_side_effect(self, db, side_effect):
        """
        Приходится делать такое дублирование кода из библиотеки mock, т.к. хотим знать
        порядок всех запросов ко всем БД.
        Обычный мок на execute не помог бы восстановить порядок, т.к. engine-ов много, а также
        есть BEGIN/ROLLBACK/COMMIT.
        """
        def _get_wrapper(effect):
            if isinstance(effect, (list, tuple)):
                effect = iter(effect)

            def _wrapper(query, *args, **kwargs):
                if not isinstance(effect, (type, Exception)):
                    result = next(effect)
                    if isinstance(result, Exception) or isinstance(result, type) and issubclass(result, Exception):
                        self._save_query(db, query)
                        raise result
                    return result
                self._save_query(db, query)
                raise effect

            return _wrapper

        if side_effect is not None:
            return _get_wrapper(side_effect)

    def set_side_effect_for_db(self, db, side_effect):
        """
        Установка побочного эффекта для запросов в данную БД (за исключением BEGIN/COMMIT/ROLLBACK).
        Если передается список эффектов - он общий на все engine-ы (включая engine транзакции).
        """
        side_effect = self._wrap_side_effect(db, side_effect)
        # Необходимо настроить эффект для текущих engine-ов
        for engine in self._dbname_to_engines[db]:
            engine.execute.side_effect = side_effect
        # А также записать для случая создания транзакций
        self._db_execute_side_effects[db] = side_effect

    def set_side_effect_for_transaction(self, db, begin=None, commit=None, rollback=None):
        """
        Установка побочных эффектов на вызовы BEGIN/COMMIT/ROLLBACK.
        """
        self._transaction_side_effects[db] = {
            'BEGIN': self._wrap_side_effect(db, begin),
            'COMMIT': self._wrap_side_effect(db, commit),
            'ROLLBACK': self._wrap_side_effect(db, rollback),
        }

    def assert_executed_queries_equal(self, expected_queries, db=None, row_count=None):
        actual_db_queries = [query for (dbname, query) in self._total_query_order if db is None or dbname == db]
        _eq_queries(
            actual_db_queries,
            [query.to_query() if hasattr(query, 'to_query') else query for query in expected_queries],
            row_count=row_count,
        )

    def assert_transaction_queries_equal(self, db, expected_queries):
        actual_queries = [
            query
            for (dbname, query) in self._total_query_order
            if dbname == db and query in {'COMMIT', 'ROLLBACK', 'BEGIN'}
        ]
        eq_(actual_queries, expected_queries)

    def reset_mocks(self, *dbs):
        """
        Сбрасывает моки, очищает историю запросов для заданных БД.
        """
        dbs = dbs or [
            'passportdbcentral',
            'passportdbshard1',
            'passportdbshard2',
        ]
        for dbname in dbs:
            for engine in self._dbname_to_engines[dbname]:
                engine.execute.reset_mock()

        self._total_query_order = [k for k in self._total_query_order if k[0] not in dbs]

    def _serialize_to_eav(self, instance, old_instance=None, reset_mocks=True):
        run_eav(old_instance, instance, diff(old_instance, instance))
        if reset_mocks:
            self.reset_mocks()

    def serialize(self, data, build_model=None):
        """
        Сохраняет в БД данные
        """
        if build_model is None:
            build_model = modelizer.userinfo
        elif isinstance(build_model, six.string_types):
            build_model = getattr(modelizer, build_model)

        model_object = build_model(data)
        self._serialize_to_eav(model_object)

    def serialize_sessionid(self, data):
        self.serialize(data, modelizer.sessionid)

    def get_table_and_db(self, table_name, dbname):
        for metadata, name in self.schemes_metadata:
            if dbname == name:
                return metadata.tables[table_name], dbname
        raise KeyError('No such table %s in db %s' % (table_name, dbname))

    def insert(self, table_name, db=None, master=True, **kwargs):
        table, db = self.get_table_and_db(table_name, db)
        query = table.insert().values(**encode_params_for_db(kwargs))
        with self.no_recording():
            get_dbm(db).get_engine(force_master=master).execute(query)

    def no_recording(self):
        return _pause_patches(self._execute_patches)

    def _query_eav(self, table, query, attr, entity_type=None):
        """Поиск по имени атрибута в таблицах attributes & aliases"""
        if table.name == 'attributes' and attr_name_exists(attr):
            eav_type = EavSerializer.attr_name_to_type(attr)
            query = query.where(table.c.type == eav_type)

        elif table.name == 'extended_attributes' and ext_attr_name_exists(entity_type, attr):

            eav_type = EavSerializer.ext_attr_name_to_type(entity_type, attr)
            query = query.where(and_(table.c.type == eav_type, table.c.entity_type == entity_type))

        elif table.name in ['aliases', 'removed_aliases'] and alias_name_exists(attr):
            eav_type = EavSerializer.alias_name_to_type(attr)
            query = query.where(table.c.type == eav_type)

        return query

    def select(self, table_name, field=None, db=None, limit=None, master=True, **kwargs):
        """
        Выходные параметры

          Список кортежей-табличных-строк (ещё у кортежей есть lookup как у
          словарей)

        Поясняющие примеры использования входных параметров table_name, kwargs, field и entity_type

          Для краткости опускаются параметры db, limit и master

          * Общий

            Вызов select(table_name, crit1=val1, crit2=val2) равносилен

            SELECT * FROM table_name WHERE crit1 = val1 AND crit2 = value2

          * Таблица attributes и параметр field

            Вызов select('attributes', field=attr_name) равносилен

            SELECT * FROM attributes WHERE type = attr_type

          * Таблицы aliases и removed_aliases, а также параметр field

            Вызов select('aliases', field=alias_name) равносилен

            SELECT * FROM aliases WHERE type = alias_type

          * Таблица extended_attributes и параметры field и entity_type

            Вызов select('extended_attributes', field=entity_attr_name, entity_type=entity_type) равносилен

            SELECT * FROM extended_attributes WHERE entity_type = entity_type AND type = entity_attr_type

          Все специальные способы можно комбинировать с общим
        """
        table, db = self.get_table_and_db(table_name, db)
        query = table.select()
        query = self._query_eav(table, query, field, kwargs.get('entity_type'))
        for key, val in iteritems(encode_params_for_db(kwargs)):
            query = query.where(table.c[key] == val)

        if limit:
            query = query.limit(limit)

        with self.no_recording():
            return get_dbm(db).get_engine(force_master=master).execute(query).fetchall()

    def get(self, table_name, field=None, **kwargs):
        result = self.select(table_name, field=field, limit=1, **kwargs)
        if not result:
            return None

        result = result[0]

        if field is not None:
            # Если указано имя атрибута или алиаса, вернем его Значение
            if table_name == 'attributes' and attr_name_exists(field):
                value = result['value']
            elif table_name == 'extended_attributes' and ext_attr_name_exists(kwargs['entity_type'], field):
                value = result['value']
            elif table_name in ['aliases', 'removed_aliases'] and alias_name_exists(field):
                value = result['value']
            else:
                value = result[field]

            if six.PY3 and isinstance(value, six.binary_type):
                return value.decode('utf8')

            return value

        return result

    def query_count(self, db):
        return sum(engine.execute.call_count for engine in self._dbname_to_engines[db])

    def check(self, table_name, field_name, expected_value, **kwargs):
        """Проверим что в БД записано ожидаемое значение в указанной таблице и столбце"""
        value = self.get(table_name, field_name, **kwargs)
        if six.PY3 and isinstance(value, six.binary_type):
            value = value.decode('utf-8')

        eq_(
            value,
            expected_value,
            'Field %r in %r equals %r, but should be %r' % (
                field_name,
                table_name,
                value,
                expected_value,
            ),
        )

    def check_line(self, table_name, expected_data, **kwargs):
        data = self.get(table_name, **kwargs)
        if not data:
            raise AssertionError('No DB records found')

        data = dict(data)
        result = {}
        for key, value in iteritems(data):
            if isinstance(value, six.binary_type):
                value = value.decode('utf-8')
            result[key] = value

        iterdiff(eq_)(expected_data, result)

    def check_missing(self, table_name, attr=None, **kwargs):
        """Проверим отсутствие записи в БД"""
        value = self.get(table_name, field=attr, **kwargs)
        ok_(
            value is None,
            '%s%r found, but should be missing.' % ('%r=' % attr if attr else '', value),
        )

    def check_db_attr(self, uid, attr_name, value, db='passportdbshard1'):
        self.check(
            'attributes',
            attr_name,
            value,
            db=db,
            uid=uid,
        )

    def check_db_attr_missing(self, uid, attr_name):
        self.check_missing(
            'attributes',
            attr_name,
            db='passportdbshard1',
            uid=uid,
        )

    def check_db_ext_attr(self, uid, entity_type, entity_id, field_name, value):
        self.check(
            'extended_attributes',
            field_name,
            value,
            entity_type=entity_type,
            entity_id=entity_id,
            uid=uid,
            db='passportdbshard1',
        )

    def check_db_ext_attr_missing(self, uid, entity_type, entity_id, field_name):
        self.check_missing(
            'extended_attributes',
            field_name,
            entity_type=entity_type,
            entity_id=entity_id,
            uid=uid,
            db='passportdbshard1',
        )

    def check_table_contents(self, table_name, db_name, expected, master=True):
        table, db = self.get_table_and_db(table_name, db_name)
        query = table.select()
        column_names = [c.name for c in table.columns]
        with self.no_recording():
            records = get_dbm(db).get_engine(force_master=master).execute(query).fetchall()
        actual = [
            dict(zip(column_names, record))
            for record in records
        ]
        eq_(
            len(expected),
            len(actual),
            'Mismatched number of records: %s exists, but %s expected\n'
            'actual:\n%s\n\nexpected:\n%s' % (
                len(actual),
                len(expected),
                actual,
                expected
            ),
        )
        for actual_record, expected_record in zip(actual, expected):
            decoded_actual_record = {
                key: value.decode('utf8') if six.PY3 and isinstance(value, six.binary_type) else value
                for key, value in expected_record.items()
            }
            iterdiff(eq_)(decoded_actual_record, expected_record)

    def check_query_counts(self, central=0, shard=0):
        eq_(self.query_count('passportdbcentral'), central)
        eq_(self.query_count('passportdbshard1'), shard)


@single_entrant_patch
class IdGeneratorFaker(object):
    def __init__(self):
        self._mock = mock.Mock(name='id_generator')
        self._patch = mock.patch(
            'passport.backend.core.db.runner.runner._run_incr_id_query',
            self._mock,
        )

    def start(self):
        self._patch.start()

    def stop(self):
        self._patch.stop()

    def set_list(self, ids):
        self._mock.side_effect = ids

    @property
    def call_count(self):
        return self._mock.call_count


class FakeTransaction(object):
    def begin(self):
        return 'BEGIN'

    def commit(self):
        return 'COMMIT'


class BaseModelizer(object):
    """
    Класс для преобразования чего-нибудь в объект модели
    """
    def __call__(self, data):
        """
        Этот метод возвращать объект модели описываемый в data
        """
        raise NotImplementedError()


class AccountBlackboxResponseModelizer(BaseModelizer):
    def __init__(self, blackbox_method):
        self._blackbox_method = blackbox_method

    def __call__(self, data):
        return Account().parse(get_parsed_blackbox_response(self._blackbox_method, data))


class FamilyInfoBlackboxResponseModelizer(BaseModelizer):
    def __call__(self, data):
        data = json.loads(data)
        family_id = list(data['family'].keys())[0]
        data = parse_blackbox_family_info_response(family_id, data)
        return FamilyInfo().parse(data)


class _Modelizer(object):
    family_info = FamilyInfoBlackboxResponseModelizer()
    oauth = AccountBlackboxResponseModelizer('oauth')
    sessionid = AccountBlackboxResponseModelizer('sessionid')
    userinfo = AccountBlackboxResponseModelizer('userinfo')


def attribute_table_insert():
    return schemas.attributes_table.insert()


def attribute_table_insert_on_duplicate_update_key():
    return insert_with_on_duplicate_key_update(schemas.attributes_table, ['value'])


def attribute_table_insert_on_duplicate_update_if_value_equals(expected_value, else_null=False):
    return insert_with_on_duplicate_key_update_if_equals(schemas.attributes_table, ['value'], 'value', expected_value, else_null=else_null)


def attribute_table_insert_on_duplicate_append_key():
    return insert_with_on_duplicate_key_append(schemas.attributes_table, ['value'])


def attribute_table_insert_on_duplicate_increment_key():
    return insert_with_on_duplicate_key_increment(schemas.attributes_table, 'value')


def extended_attribute_table_insert_on_duplicate_key():
    return insert_with_on_duplicate_key_update(schemas.extended_attributes_table, ['value'])


def extended_attribute_table_delete():
    return schemas.extended_attributes_table.delete()


def aliases_insert():
    return schemas.aliases_table.insert()


def insert_ignore_into_removed_aliases(select):
    return with_ignore_prefix(
        schemas.removed_aliases_table.insert().from_select(
            schemas.removed_aliases_table.c.keys(),
            select,
        ),
    )


def uid_table_insert():
    return schemas.uid_table.insert()


def pdduid_table_insert():
    return schemas.pdduid_table.insert()


def suid_table_insert():
    return schemas.suid_table.insert()


def suid_table_delete():
    return schemas.suid_table.delete()


def pddsuid_table_insert():
    return schemas.pddsuid_table.insert()


def totp_secret_id_table_insert():
    return schemas.totp_secret_id_table.insert()


def get_transaction(*queries):
    return (transaction.begin(),) + queries + (transaction.commit(),)


def yakey_backup_insert():
    return schemas.yakey_backups_table.insert()


def yakey_backup_update(previous_update):
    return insert_with_on_duplicate_key_update_if_equals(
        schemas.yakey_backups_table,
        ['backup', 'device_name', 'updated'],
        'updated',
        previous_update,
    )


def passman_recovery_key_insert():
    return insert_with_on_duplicate_key_update(
        schemas.passman_recovery_keys_table,
        ['recovery_key'],
        binary_value_ids=[1, 2],
    )


modelizer = _Modelizer
transaction = FakeTransaction()
