from contextlib import closing
from datetime import timedelta
from numbers import Number
from statistics import mean, stdev

from airflow.hooks.postgres_hook import PostgresHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults

from piper.alerting.rollbar import LogLevel, rollbar_report_message


class BasePostgresDataValidator(BaseOperator):
    """
    Abstract base class to load and validate data from a Postgres connection.
    Operators derived from this class must extend the validate method and the query properties.
    """

    # Query properties to be derived by subclasses.
    conn_id = ''       # connection identifier.
    query = ''         # query SQL string.
    query_params = {}  # query parameters.
    block_dag_if_invalid = False  # fail task and dag if true, or send a warning to rollbar if false.

    template_fields = ('query', 'query_params')
    template_ext = ('.sql', '.hql')
    ui_color = '#d9bcf5'

    @apply_defaults
    def __init__(self, *args, **kwargs):
        super(BasePostgresDataValidator, self).__init__(*args, **kwargs)

        self.rows = []    # List of dictionaries to be validated (loaded from DB).
        self.errors = []  # List of errors that is checked after calling the validate method.
        self.row_id_field = ''  # Identifies the failing row in error messages, e.g. 'id', 'game' or 'extension_id'.

        # Properties to execute the query are taken from the subclass
        self.conn_id = self.__class__.conn_id
        self.query = self.__class__.query
        self.query_params = self.__class__.query_params
        self.block_dag_if_invalid = self.__class__.block_dag_if_invalid

    def execute(self, context):
        db_hook = PostgresHook.get_hook(self.conn_id)

        # Load rows from db
        with closing(db_hook.get_conn()) as conn:
            with closing(conn.cursor()) as cur:
                cur.execute(self.query, self.query_params)
                colnames = [desc[0] for desc in cur.description]
                results = cur.fetchall()
                raw_query = cur.query.decode("utf-8")  # query that was executed, including params

        self.rows = [to_dict(row_tupl, colnames) for row_tupl in results]  # convert tuples to dictionaries
        self.errors = []

        self.validate()

        if self.errors:
            self.log_info("Query:\n{}\n", raw_query)
            msg = "{validator} {date}:\n{errors}\n".format(
                validator=self.__class__.__name__,
                date=context['dag_run'].execution_date.strftime('%Y-%m-%d'),
                errors='\n'.join(self.errors))

            if self.block_dag_if_invalid:
                raise AssertionError(msg)  # Raised errors will fail the task and block the DAG

            rollbar_report_message(context, msg, LogLevel.WARNING)  # Non blocking

    # Main method to derive when creating a validator. Add errors to fail the validation.
    def validate(self):
        raise NotImplementedError()

    def add_error(self, text, *args):
        self.errors.append(text.format(*args))

    def log_info(self, text, *args):
        self.log.info(text.format(*args))

    # Validation Helpers
    # ------------------

    def validate_non_negative(self, field):
        if not self.row_id_field:
            self.add_error("Please set self.row_id_field before calling validate_non_negative.")
            return

        for row in self.rows:
            if field not in row:
                self.add_error("Field {} is not in row: {}.", field, row.keys())
                return

            val = row[field]
            if not isinstance(val, Number) or val < 0:
                self.add_error("Field {} is {}, expected non-negative number. In {}: {}.",
                               field, val, self.row_id_field, row[self.row_id_field])

    def validate_greather_or_equal(self, field_big, field_small):
        if not self.row_id_field:
            self.add_error("Please set self.row_id_field before calling validate_greather_or_equal.")
            return

        for row in self.rows:
            if field_big not in row or field_small not in row:
                self.add_error("Field {} or {} is not in row: {}.", field_big, field_small, row.keys())
                return

            val_big = row[field_big]
            val_small = row[field_small]
            if not isinstance(val_big, Number) or not isinstance(val_small, Number) or val_big < val_small:
                self.add_error("Field {} must be >= {}, but {} is not >= {}. In {}: {}.",
                               field_big, field_small, val_big, val_small, self.row_id_field, row[self.row_id_field])

    # Make sure that all rows are ordered by date (DESC) and include every day.
    def validate_rows_ordered_by_continuous_days(self, date_field):
        length = len(self.rows)
        index = 0
        while index < length - 1:
            curr_row = self.rows[index]
            prev_row = self.rows[index + 1]
            if date_field not in curr_row:
                self.add_error("Field {} is not in row: {}.", date_field, curr_row.keys())
                return

            curr_date = curr_row[date_field]
            prev_date = prev_row[date_field]
            if not hasattr(curr_date, 'strftime'):  # strftime works both on date and datetime
                self.add_error("Field {} is {}, expected to be a date or datetime.", date_field, curr_date)
                return

            DATE_FMT = "%Y-%m-%d"
            curr_date_str = curr_date.strftime(DATE_FMT)
            prev_date_str = prev_date.strftime(DATE_FMT)
            curr_date_before_str = (curr_date - timedelta(days=1)).strftime(DATE_FMT)

            if curr_date_before_str != prev_date_str:
                self.add_error("Unordered or missing day: {} expected consecutive days but found {} => {}.",
                               date_field, curr_date_str, prev_date_str)
                return

            index += 1

    # Check if a field value in the first row (today) is above the lower boundary of an interval derived from the
    # moving average and standard deviation of the previous rows.
    def validate_moving_avg_bottom(self, field):
        curr_row, *prev_rows = self.rows
        if field not in curr_row:
            self.add_error("Field {} is not in row: {}.", field, curr_row.keys())
            return

        curr_val = curr_row[field]
        prev_vals = [row[field] for row in prev_rows]
        if not isinstance(curr_val, Number):
            self.add_error("Field {} is {}, expected to be a Number.")
            return

        # To calculate a reasonable interval, we assume that values can be approximated to normal distributions,
        # then we can follow the 68–95–99.7 rule (https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule)
        # and say a new value has a 99.85% probability >= mean - 3 standard deviations.
        avg = mean(prev_vals)   # moving average
        sig = stdev(prev_vals)  # sample standard deviation

        # We only check for bottom value and ignore top, because it turns out the top values are quite unpredictable
        # on weekends, events and new feature releases.
        min_val = avg - 3*sig
        if curr_val < min_val:
            self.add_error("Field {} is {:,.0f}, expected to be above {:,.0f} (moving-avg: {:,.0f} - 3*sigma).",
                           field, curr_val, min_val, avg)


# Given a tuple like (1, 'foo') and a list of keys like ['id', 'name'],
# return a dict like {'id': 1, 'name': 'foo'}.
def to_dict(tupl, keys):
    return dict((key, tupl[i]) for i, key in enumerate(keys))
