# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging

from retry import retry

import MySQLdb


log = logging.getLogger(__name__)


class DefaultReplicaChecker(object):
    def check_replica(self, host):
        return True


class RaspReplicaChecker(DefaultReplicaChecker):
    def __init__(self, user, password, database, max_replication_lag=60 * 5):
        self.user = user
        self.password = password
        self.database = database
        self.suspended_replicas = set()
        self.max_replication_lag = max_replication_lag

    def set_suspended_replicas(self, blacklist):
        self.suspended_replicas = set(blacklist)

    @retry(tries=2, delay=3)
    def get_connection(self, host):
        return MySQLdb.connect(host=host, user=self.user, passwd=self.password, db=self.database)

    def get_connection_checks(self):
        return [self.check_replication_lag]

    def check_replica(self, host):
        if host in self.suspended_replicas:
            return False

        conn = None
        try:
            conn = self.get_connection(host)
            for check in self.get_connection_checks():
                if not check(conn):
                    return False

            return True
        except Exception as ex:
            log.error("Failed to check replica {} health status: {}".format(host, ex))
            return False
        finally:
            if conn:
                conn.close()

    def check_replication_lag(self, conn):
        conn.query("""show slave status;""")

        result = conn.store_result()
        row = result.fetch_row()
        if len(row) == 0:
            # значит, что мы подключились к мастреру
            return True

        seconds_behind_master = int(row[0][32])
        if seconds_behind_master > self.max_replication_lag:
            log.error("Replication lag is to large. seconds_behind_master={}s, MAX_REPLICATION_LAG={}s"
                      .format(seconds_behind_master, self.max_replication_lag))

            return False

        return True


class RaspMainReplicaChecker(RaspReplicaChecker):
    MIN_TABLES_COUNT = 250

    def get_connection_checks(self):
        return super(RaspMainReplicaChecker, self).get_connection_checks() + [RaspMainReplicaChecker.check_tables]

    @staticmethod
    def check_tables(conn):
        conn.query("""show tables;""")

        result = conn.store_result()
        tables_count = int(result.num_rows())
        if tables_count < RaspMainReplicaChecker.MIN_TABLES_COUNT:
            log.error("Missing some tables. tables_count={}, MIN_TABLES_COUNT={}"
                      .format(tables_count, RaspMainReplicaChecker.MIN_TABLES_COUNT))

            return False

        return True
