# -*- coding: utf-8 -*-

import mock
from nose.tools import eq_
from passport.backend.core.test.test_utils import (
    iterdiff,
    single_entrant_patch,
)
from passport.backend.core.ydb.faker.stubs import FakeYdbCommit
import passport.backend.core.ydb_client as ydb


@single_entrant_patch
class FakeYdb(object):
    def __init__(self):
        self._executed_queries = list()

        self._transaction = mock.Mock(
            name='fake_ydb_session_trancation',
        )
        self._session_prepare = mock.Mock(
            name='fake_ydb_session_prepare',
        )
        self._session = mock.Mock(
            name='fake_ydb_session',
            spec=ydb.Session,
        )
        self._session.transaction = self._transaction
        self._session.prepare = self._session_prepare
        self._execute_mock = mock.Mock(name='execute_sideffect')

        self._session_prepare.side_effect = self.fake_session_prepare
        self._transaction.return_value.execute = self.fake_tx_execute
        self._transaction.return_value.__enter__ = self._transaction
        self._transaction.return_value.__exit__ = mock.Mock(name='fake_ydb_transaction_exit', return_value=None)
        # Это оберег, я много раз опечатывался в этом месте и искал, почему тесты не проходят
        self._transaction.execute.side_effect = Exception('session.transaction must be called before using execute')

        self.session_pool = mock.PropertyMock(
            name='fake_session_pool',
            spec=ydb.SessionPool,
        )
        self.session_pool.return_value.retry_operation_sync = self.fake_retry_operation_sync

        self.driver = mock.Mock(
            name='fake_ydb_driver',
            spec=ydb.pool.ConnectionPool,
        )
        self.driver.table_client = mock.Mock()

        self.patches = [
            mock.patch(
                target='passport.backend.core.ydb.ydb.ydb.Driver',
                return_value=self.driver,
            ),
            mock.patch(
                target='passport.backend.core.ydb.ydb.Ydb.session_pool',
                new_callable=self.session_pool,
            ),
        ]

    def fake_retry_operation_sync(self, f, retry_settings=None, *args, **kwargs):
        return f(self._session, *args, **kwargs)

    def start(self):
        for patch in self.patches:
            patch.start()

    def stop(self):
        for patch in reversed(self.patches):
            patch.stop()
        del self.patches

    def set_execute_side_effect(self, side_effect):
        self._execute_mock.side_effect = side_effect

    def set_execute_return_value(self, return_value):
        self._execute_mock.return_value = return_value

    @staticmethod
    def fake_session_prepare(query):
        return query

    def fake_tx_execute(self, query, parameters=None, commit_tx=False, session=None, transaction=None, settings=None):
        self._executed_queries.append({
            'query': {
                'query': query,
                'parameters': parameters,
                'commit_tx': commit_tx,
            },
            'session': session,
            'transaction': transaction,
        })
        retval = self._execute_mock()
        return retval

    def executed_queries(self):
        retval = []
        for query in self._executed_queries:
            retval.append(query['query'])
        return retval

    def assert_queries_executed(self, queries):
        expected = []
        for i, query in enumerate(queries):
            if query is FakeYdbCommit:
                expected[-1]['commit_tx'] = True
                continue
            expected.append(self.query_to_fake_query(
                query=query,
                commit_tx=i >= len(queries) - 1,
            ))
        iterdiff(eq_)(self.executed_queries(), expected)

    def assert_fake_query_equals(self, fake_query, query, commit_tx=False):
        query = self.query_to_fake_query(query, commit_tx)
        eq_(fake_query, query)

    def query_to_fake_query(self, query, commit_tx):
        return {
            'query': query.get_raw_statement(),
            'parameters': query.get_parameters(),
            'commit_tx': commit_tx,
        }
