# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

from threading import current_thread

from django.db import connections
from django.db.backends.base.base import DEFAULT_DB_ALIAS
from django.utils.module_loading import import_string

from travel.rasp.library.python.common23.db.backends.dbwrapper_base import DatabaseWrapper as BaseDatabaseWrapper
from travel.rasp.library.python.common23.db.switcher import switcher


def in_main_thread():
    return current_thread().name == 'MainThread'


class DatabaseWrapper(BaseDatabaseWrapper):
    """ MysqlWrapper, позволяющий проксировать некоторые вызовы другим врапперам. """

    def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS, *args, **kwargs):
        self._settings_dict = settings_dict
        self._alias = alias
        self._get_alias = import_string(self._settings_dict['ALIAS_GETTER'])

        # После переключения необходимо заново установить соединение.
        # Не вешаем сигнал для DatabaseWrapper, созданных в потоках,
        # т.к. свитч им не важен, а из-за регистрации в сигнале такие врапперы утекают.
        # Использовать weak=True нельзя, т.к. bound-method сожрется garbage-collector'ом
        if self._settings_dict.get('CLOSE_ON_SWITCH', True) and in_main_thread():
            switcher.db_switched.connect(self.on_db_switched, weak=False)

        super(DatabaseWrapper, self).__init__(settings_dict, alias, *args, **kwargs)

    def on_db_switched(self, sender, **kwargs):
        if self.connection:
            self.close()

    def get_actual_db_wrapper(self):
        db_alias = self._get_alias()
        return connections[db_alias]

    @property
    def settings_dict(self):
        return self.get_actual_db_wrapper().settings_dict

    @settings_dict.setter
    def settings_dict(self, v):
        """Перегружаем, чтобы MysqlDatabaseWrapper при попытке записать property settings_dict"""

    def get_new_connection(self, conn_params):
        return self.get_actual_db_wrapper().get_new_connection(conn_params)

    def get_connection_params(self):
        return self.get_actual_db_wrapper().get_connection_params()

    def get_db_name(self):
        return self.get_actual_db_wrapper().get_db_name()

    def create_cluster(self):
        pass

    def get_cluster(self):
        return self.get_actual_db_wrapper().get_cluster()

    def get_hosts(self):
        return self.get_actual_db_wrapper().get_hosts()

    def get_all_hosts(self):
        return self.get_actual_db_wrapper().get_all_hosts()

    def get_connection_to_host(self, host):
        return self.get_actual_db_wrapper().get_connection_to_host(host)
