#!/usr/bin/env python2.7
import pymysql
import logging
import sys
import time
from datetime import datetime, timedelta
from pymysql.err import MySQLError

MAX_TRANSACTION_REPEAT_COUNT = 5
TRANSACTION_REPEAT_SLEEP = 1


class ConfigError(Exception):
    pass


class CHHorizonError(Exception):
    pass


class ScriptHourError(Exception):
    pass


class CountersBase(object):
    def __init__(self, **kwargs):
        self.params = kwargs
        self.db_conn = None
        self.object = self.params['settings']['object']
        objects = self.object.split('_')
        self.object_primary = objects[0]
        if len(objects) > 1:
            self.object_secondary = objects[1]
        else:
            self.object_secondary = None

    def get_db_conn(self):
        if self.db_conn is None or not self.db_conn.open:
            self.db_conn = pymysql.connect(**self.get_db_config())
        else:
            # all changes of which we are not sure we need to commit - we want to roll back
            self.db_conn.rollback()
        return self.db_conn

    def close_db_conn(self):
        if self.db_conn is not None and self.db_conn.open:
            self.db_conn.close()

    def get_db_config(self):
        if 'mysql' in self.params.keys():
            return {
                'host': self.params['mysql']['host'],
                'user': self.params['mysql']['user'],
                'passwd': self.params['mysql']['password'],
                'db': self.params['mysql']['db'],
                'autocommit': False,
            }
        else:
            raise ConfigError

    def get_db_hourly_table(self):
        return '{}_hourly_counters'.format(self.object)

    def get_db_daily_table(self):
        return '{}_daily_counters'.format(self.object)

    def get_db_total_table(self):
        return '{}_total_counters'.format(self.object)

    def get_id_fields(self):
        if self.object_secondary is None:
            return ['id']
        else:
            return ['{}_id'.format(self.object_primary), '{}_id'.format(self.object_secondary)]

    def get_hourly_delete_day_query(self, day):
        query = """
            DELETE FROM {hourly_table}
            WHERE time >= TIMESTAMP('{date}') AND time < TIMESTAMP(DATE_ADD('{date}', INTERVAL 1 DAY))
        """.format(hourly_table=self.get_db_hourly_table(), date=day)
        return query

    def get_nullify_hourly_hour_query(self, hour):
        query = """
            UPDATE {hourly_table}
            SET impressions = 0, clicks = 0, flags = 'flag_ch_updated'
            WHERE time ='{hour}'
        """.format(hourly_table=self.get_db_hourly_table(), hour=hour)
        return query

    def get_nullify_daily_day_query(self, day):
        query = """
            UPDATE {daily_table}
            SET impressions = 0, clicks = 0
            WHERE day = '{day}'
        """.format(daily_table=self.get_db_daily_table(), day=day)
        return query

    def get_hourly_mark_merged_day_query(self, day):
        query = """
            UPDATE {hourly_table}
            SET impressions = 0, clicks = 0, flags = CONCAT_WS(',', flags, 'flag_merged')
            WHERE time >= TIMESTAMP('{date}') AND time < TIMESTAMP(DATE_ADD('{date}', INTERVAL 1 DAY))
        """.format(hourly_table=self.get_db_hourly_table(), date=day)
        return query

    def get_remove_old_data_from_hourly_table_query(self):
        query = """
            DELETE FROM {hourly_table}
            WHERE FIND_IN_SET('flag_merged', flags)
            AND last_update < DATE_SUB(now(), INTERVAL 1 DAY)
        """.format(hourly_table=self.get_db_hourly_table())
        return query

    def get_daily_mark_merged_day_query(self, day):
        query = """
            UPDATE {daily_table}
            SET impressions = 0, clicks = 0, flags = CONCAT_WS(',', flags, 'flag_merged')
            WHERE day = '{day}'
        """.format(daily_table=self.get_db_daily_table(), day=day)
        return query

    def get_remove_old_data_from_daily_table_query(self):
        query = """
            DELETE FROM {daily_table}
            WHERE FIND_IN_SET('flag_merged', flags)
            AND last_update < DATE_SUB(now(), INTERVAL 1 DAY)
        """.format(daily_table=self.get_db_daily_table())
        return query

    def get_daily_delete_query(self, day):
        query = """
            DELETE FROM {daily_table}
            WHERE day = '{day}'
        """.format(daily_table=self.get_db_daily_table(), day=day)
        return query

    def make_update_data_from_ch(self, data, update_hour):
        columns = list()
        if self.object_secondary is not None:
            primary_id = '{}_id'.format(self.object_primary)
        else:
            primary_id = 'id'
        columns.append(primary_id)
        if self.object_secondary is not None:
            columns.append('{}_id'.format(self.object_secondary))
        if update_hour:
            columns.extend(['time', 'flags'])
        else:
            columns.extend(['day'])
        columns.extend(['impressions', 'clicks'])
        values = list()
        for row in data:
            res = list()
            res.append(row['{}_id'.format(self.object_primary)])
            if self.object_secondary is not None:
                res.append(row['{}_id'.format(self.object_secondary)])
            if update_hour:
                res.append("'{}'".format(row['date_hour']))
                res.append("'{}'".format('flag_ch_updated'))
            else:
                res.append("'{}'".format(row['date']))
            res.append(row['impressions_total'])
            res.append(row['clicks_total'])
            s = '({})'.format(','.join(str(r) for r in res))
            values.append(s)
        return {'columns': columns, 'values': values}

    def ch_rtd_request(self, method, **kwargs):
        import pyjsonrpc
        client = pyjsonrpc.HttpClient(self.params['rtd']['host'])
        task_id = client.call(method, **kwargs)
        time_run = 0
        state = ''
        while state != 'SUCCESS' and time_run < self.params['rtd']['max_wait_result']:
            state = client.get_task_state(task_id)
            if state == 'FAILURE':
                logging.error(client.get_task_result(task_id))
                raise Exception
            time.sleep(self.params['rtd']['repeat_check'])
            time_run += self.params['rtd']['repeat_check']
        result = dict()
        # celery sets status 'SUCCESS' before result data is actually available, so we check result on non-empty
        while (len(result) == 0 or type(result) is dict and len(result['fields']) == 0) \
                and time_run < self.params['rtd']['max_wait_result']:
            result = client.call("get_task_result", task_id)
            time.sleep(self.params['rtd']['repeat_check'])
            time_run += self.params['rtd']['repeat_check']
        return result

    def get_ch_data(self, day, hour=None):
        where = {
            "date": day,
            "flag_virtual": [0, 1],
        }
        select = [
            "impressions_total",
            "clicks_total",
            "{}_id".format(self.object_primary),
        ]
        if hour is not None:
            where["hour"] = hour
            select.append("date_hour")
        else:
            select.append("date")
        if self.object_secondary is not None:
            select.append("{}_id".format(self.object_secondary))
        return self.ch_rtd_request('get_counters', where=where, select=select)

    def get_event_horizon(self):
        ch_event_horizon = self.ch_rtd_request('event_horizon')
        logging.info('event_horizon is {}'.format(ch_event_horizon))
        event_horizon = datetime.strptime(ch_event_horizon, '%Y-%m-%d %H:%M:%S')
        if datetime.now() < event_horizon - timedelta(minutes=5):
            # possible wrong event_horizon value
            raise ScriptHourError
        return event_horizon

    @staticmethod
    def set_logger():
        logging.getLogger().setLevel(logging.DEBUG)
        handler = logging.StreamHandler(sys.stdout)
        handler.setLevel(logging.DEBUG)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logging.getLogger().addHandler(handler)
        logging.info('Starting script...')

    @staticmethod
    def parse_args(*args):
        import argparse
        import json
        parser = argparse.ArgumentParser()
        for arg in args:
            parser.add_argument("--{}".format(arg))
        parsed_args = parser.parse_args()
        params = {}
        for arg in args:
            params[arg] = json.loads(getattr(parsed_args, arg))
        return params

    @staticmethod
    def transaction_repeat_decorator(sql_update_func):
        def wrapper(*args, **kwargs):
            count = 1
            committed = False
            while not committed:
                try:
                    sql_update_func(*args, **kwargs)
                    committed = True
                except MySQLError as e:
                    logging.warning('Got error {!r}, errno is {}'.format(e, e.args[0]))
                    count += 1
                    time.sleep(TRANSACTION_REPEAT_SLEEP)
                    if count > MAX_TRANSACTION_REPEAT_COUNT:
                        raise
        return wrapper
