# coding=utf-8

import logging
from sandbox import sdk2
from sandbox.sandboxsdk import environments
from datetime import datetime


class NoMbiMappingTasksParking(sdk2.Task):
    DATE_IN_FUTURE = datetime(3000, 1, 1)
    MAX_ATTEMPTS_NUMBER = 14

    TASK_TYPE_PATTERN = "'CHANGE%'"
    TASK_FAIL_REASON_PATTERN = "'%Feed id mapping not found%'"

    class Parameters(sdk2.Task.Parameters):
        warehouses_for_monitoring = sdk2.parameters.String('Warehouses for monitoring: '
                                                           'String in format warehouseId[;warehouseId]', required=True)
        ss_database_user_vault_key = sdk2.parameters.String('SS database user vault key', required=True)
        ss_database_password_vault_key = sdk2.parameters.String('SS database password vault key', required=True)
        ss_database_hosts = sdk2.parameters.String('SS database hosts: '
                                                   'String in format host[;host]', required=True)
        ss_database_name = sdk2.parameters.String('SS database name', required=True)

    class Requirements(sdk2.Requirements):
        disk_space = 1024 * 5
        environments = (environments.PipEnvironment('psycopg2-binary'),)

    def on_execute(self):
        import psycopg2
        from psycopg2.extras import RealDictCursor
        conn = None
        cursor = None
        try:
            user = sdk2.Vault.data(self.Parameters.ss_database_user_vault_key)
            password = sdk2.Vault.data(self.Parameters.ss_database_password_vault_key)
            hosts = self.Parameters.ss_database_hosts.split(";")
            logging.info('Connecting to SS database...')

            for host in hosts:
                conn = psycopg2.connect(host=host, port="6432",
                                        database=self.Parameters.ss_database_name,
                                        user=user, password=password, cursor_factory=RealDictCursor)
                cursor = conn.cursor()
                if self.is_master(cursor) is False:
                    logging.info('Connected to SS db on master')
                    sql = self.get_failed_tasks_sql()
                    cursor.execute(sql)
                    raw_result = cursor.fetchall()
                    self.process_results(raw_result, cursor, conn)
                    return
        except Exception as error:
            logging.error(error)
        finally:
            if cursor is not None:
                cursor.close()
            if conn is not None:
                conn.close()
            logging.info('SS database connection closed.')

    def process_results(self, raw_result, cursor, conn):
        if len(raw_result) <= 0:
            logging.info('No tasks found for parking')
            return
        logging.info('Fetched tasks for parking')
        ids_to_park = []
        for row in raw_result:
            ids_to_park.append((row['id'],))

        if len(ids_to_park) > 0:
            self.park_tasks(cursor, conn, ids_to_park)

    def park_tasks(self, cursor, conn, ids):
        sql = self.get_updated_time_sql(self.DATE_IN_FUTURE, self.MAX_ATTEMPTS_NUMBER)
        cursor.executemany(sql, ids)
        conn.commit()
        logging.info('Parked ' + str(len(ids)) + ' tasks')

    def get_failed_tasks_sql(self):
        return "SELECT id FROM execution_queue" + \
               " WHERE attempt_number >= " + str(self.MAX_ATTEMPTS_NUMBER) + \
               " AND execute_after < '" + str(self.DATE_IN_FUTURE) + "'" + \
               " AND TYPE LIKE " + self.TASK_TYPE_PATTERN + \
               " AND (" + self.warehouse_ids_to_sql() + ")" + \
               " AND fail_reason LIKE " + self.TASK_FAIL_REASON_PATTERN + ";"

    def warehouse_ids_to_sql(self):
        sql_expression = ''
        warehouses_for_monitoring = self.Parameters.warehouses_for_monitoring.split(";")
        for warehouse_id in warehouses_for_monitoring:
            sql_expression += "uuid like '%-" + warehouse_id + "' OR "
        return sql_expression[:(len(sql_expression) - 4)]

    @staticmethod
    def get_updated_time_sql(date, attempt_number):
        return "UPDATE execution_queue" + \
               " SET execute_after = '" + str(date) + "', attempt_number = " + str(attempt_number) + \
               " WHERE id = %s"

    @staticmethod
    def is_master(cursor):
        cursor.execute("SELECT pg_is_in_recovery();")
        return cursor.fetchall()[0]['pg_is_in_recovery']
