# -*- coding: utf-8 -*-
import mock
from passport.backend.core.test.test_utils import single_entrant_patch
from passport.backend.core.ydb.faker.stubs import FakeResultSet


@single_entrant_patch
class FakeYdbKeyValue(object):
    def __init__(self):
        self._tables = dict()

        def _fake_kikimr_get(*args, **kwargs):
            return self._fake_kikimr_get(*args, **kwargs)

        def _fake_kikimr_set(*args, **kwargs):
            return self._fake_kikimr_set(*args, **kwargs)

        def _fake_kikimr_delete(*args, **kwargs):
            return self._fake_kikimr_delete(*args, **kwargs)

        self.__patches = [
            mock.patch(
                'passport.backend.core.ydb.ydb.YdbKeyValue._kikimr_get',
                _fake_kikimr_get,
            ),
            mock.patch(
                'passport.backend.core.ydb.ydb.YdbKeyValue._kikimr_set',
                _fake_kikimr_set,
            ),
            mock.patch(
                'passport.backend.core.ydb.ydb.YdbKeyValue._kikimr_delete',
                _fake_kikimr_delete,
            ),
        ]

        self._side_effect = None

    def start(self):
        for patch in self.__patches:
            patch.start()

    def stop(self):
        for patch in reversed(self.__patches):
            patch.stop()

    def set_response_side_effect(self, side_effect):
        self._side_effect = mock.Mock(side_effect=side_effect)

    def _fake_kikimr_get(self, ydb_key_value, session, keys, prepared_read_statement):
        if self._side_effect is not None:
            self._side_effect()

        table_address = self._build_table_address(ydb_key_value)
        result = list()
        if table_address not in self._tables:
            return result
        table = self._tables[table_address]
        standard = self._build_row_key(keys)
        for row_key in table:
            if self._row_key_matches(standard, row_key):
                row = {ydb_key_value._value_column: table[row_key]}
                result.append(row)
        return [FakeResultSet(result)]

    def _fake_kikimr_set(self, ydb_key_value, session, keys, value):
        if self._side_effect is not None:
            self._side_effect()

        table_address = self._build_table_address(ydb_key_value)
        if table_address in self._tables:
            table = self._tables[table_address]
        else:
            table = self._tables[table_address] = dict()
        row_key = self._build_row_key(keys)
        table[row_key] = value

    def _fake_kikimr_delete(self, ydb_key_value, session, keys, prepared_delete_statement):
        if self._side_effect is not None:
            self._side_effect()

        table_address = self._build_table_address(ydb_key_value)
        if table_address not in self._tables:
            return
        table = self._tables[table_address]
        row_key = self._build_row_key(keys)
        table.pop(row_key, None)

        table = self._tables[table_address]
        standard = self._build_row_key(keys)
        for row_key in list(table.keys()):
            if self._row_key_matches(standard, row_key):
                table.pop(row_key, None)

    def _build_table_address(self, ydb_key_value):
        return (
            ydb_key_value._ydb._endpoint,
            ydb_key_value._ydb._database,
            ydb_key_value._table_name,
        )

    def _build_row_key(self, keys):
        return tuple((k, keys[k]) for k in sorted(keys))

    def _row_key_matches(self, standard, row_key):
        row_key = dict(row_key)
        for field_name, field_value in standard:
            if field_name not in row_key or row_key[field_name] != field_value:
                return False
        return True
