#!/usr/bin/env python2.7
import pandas
import pymysql
import logging
import sys
import mysql_queries
import time
import warnings
from datetime import datetime, timedelta


class ConfigError(Exception):
    pass


class NewHourError(Exception):
    pass


class CHHorizonError(Exception):
    pass


class ScriptHourError(Exception):
    pass


class ReportsStatusHourlyError(Exception):
    pass


class HourlyUpdateCounters:
    def __init__(self, **kwargs):
        self.params = kwargs
        self.read_conn = None
        self.write_conn = None
        self.start_date_hour = None

    def get_db_read_conn(self):
        if self.read_conn is None or not self.read_conn.open:
            self.read_conn = pymysql.connect(**self.get_db_read_config())
        return self.read_conn

    def get_db_write_conn(self):
        if self.write_conn is None or not self.write_conn.open:
            self.write_conn = pymysql.connect(**self.get_db_write_config())
        return self.write_conn

    def close_db_conn(self):
        if self.read_conn is not None and self.read_conn.open:
            self.read_conn.close()
        if self.write_conn is not None and self.write_conn.open:
            self.write_conn.close()

    def get_db_read_config(self):
        if 'mysql' in self.params.keys():
            return {
                'host': self.params['mysql']['host'],
                'user': self.params['mysql']['user'],
                'passwd': self.params['mysql']['password'],
                'db': self.params['mysql']['reports_db'],
                'autocommit': False,
            }
        else:
            raise ConfigError

    def get_db_write_config(self):
        if 'mysql_write' in self.params.keys():
            return {
                'host': self.params['mysql_write']['host'],
                'user': self.params['mysql_write']['user'],
                'passwd': self.params['mysql_write']['password'],
                'autocommit': False,
            }
        else:
            raise ConfigError

    def get_main_db(self):
        return self.params['mysql_write']['main_db']

    def get_reports_db(self):
        return self.params['mysql']['reports_db']

    def get_curr_date_hour(self):
        query = "select date_format(now(),'%Y-%m-%d %H')"
        conn = self.get_db_write_conn()
        with conn as cursor:
            cursor.execute(query)
            result = cursor.fetchall()
        return result[0][0]

    def get_today(self):
        return self.start_date_hour.strftime("%Y-%m-%d")

    def get_hour(self):
        return self.start_date_hour.strftime("%H")

    def init_start_time(self):
        self.start_date_hour = datetime.strptime(self.get_curr_date_hour(), '%Y-%m-%d %H')

    def get_start_date_hour(self):
        return self.start_date_hour.strftime("%Y-%m-%d %H")

    def get_ch_prev_dates(self):
        date_diff = self.params['settings']['ch_read_back_days']
        dates = list()
        while date_diff > 0:
            date_add = self.start_date_hour - timedelta(days=date_diff)
            dates.append(date_add.strftime("%Y-%m-%d"))
            date_diff -= 1
        return dates

    def get_old_reports_date(self):
        reports_date = self.start_date_hour - timedelta(days=self.params['settings']['ch_read_back_days'])
        return reports_date.strftime("%Y-%m-%d")

    def check_reports_status_hourly(self):
        conn = self.get_db_read_conn()
        # db = self.get_main_db()
        db = 'adfox'
        conn.select_db(db)
        query = mysql_queries.queries.get('reports_status_hourly') \
            .format(ch_read_back_days=self.params['settings']['ch_read_back_days'])
        with conn as cursor:
            cursor.execute(query)
            result = cursor.fetchall()
        conn.close()
        if result[0][0] != 1:
            raise ReportsStatusHourlyError
        logging.info('reports_status_hourly status is ok')

    def validate_hour(self):
        mysql_hour = self.get_curr_date_hour()
        script_hour = datetime.now().strftime("%Y-%m-%d %H")
        if script_hour != mysql_hour:
            logging.error('script local date_hour differ from mysql date_hour')
            raise ScriptHourError
        event_horizon = self.ch_rtd_request('event_horizon')
        logging.info('event_horizon is {}'.format(event_horizon))
        if event_horizon < script_hour:
            logging.error('event_horizon is not reached')
            raise CHHorizonError

    def check_if_not_new_hour(self):
        return self.get_curr_date_hour() == self.get_start_date_hour()

    def make_update_query(self, data, tablename, begin):
        query_list = []
        for row in data:
            query = """
                UPDATE {tablename} SET impressions = impressions_hour + {impressions_total_today} + {impressions_total_prev},
                    clicks = clicks_hour + {clicks_total_today} + {clicks_total_prev},
                    impressions_today = impressions_hour + {impressions_total_today},
                    clicks_today = clicks_hour + {clicks_total_today}
                    WHERE id = {id}
                """.format(tablename=tablename, **row)
            query_list.append(query)
        operation = ";".join(query_list)
        conn = self.get_db_write_conn()
        conn.select_db(self.get_main_db())
        result = conn.cursor().execute(operation, multi=True)
        for res in result:
            pass
        if not self.check_if_not_new_hour():
            logging.error('new hour started, script will be terminated')
            raise NewHourError
        conn.commit()

    def make_ivodku_query(self, data, tablename, key_columns):
        columns = []
        columns.extend(key_columns)
        columns.extend(['impressions', 'clicks', 'impressions_today', 'clicks_today'])
        values = []
        for row in data:
            res = []
            for col in key_columns:
                res.append(row[col])
            res.append(row['impressions_total_today'] + row['impressions_total_prev'])
            res.append(row['clicks_total_today'] + row['clicks_total_prev'])
            res.append(row['impressions_total_today'])
            res.append(row['clicks_total_today'])
            s = '({})'.format(','.join(str(r) for r in res))
            values.append(s)
        query = """
            INSERT INTO {tablename}
            ({column})
            VALUES
            {values}
            ON DUPLICATE KEY UPDATE
                impressions = impressions_hour + VALUES(impressions),
                clicks = clicks_hour + VALUES(clicks),
                impressions_today = impressions_hour + VALUES(impressions_today),
                clicks_today = clicks_hour + VALUES(clicks_today)
        """.format(tablename=tablename, column=",".join(columns), values=",".join(values))
        conn = self.get_db_write_conn()
        conn.select_db(self.get_main_db())
        conn.cursor().execute(query)
        if not self.check_if_not_new_hour():
            logging.error('new hour started, script will be terminated')
            raise NewHourError
        conn.commit()
        logging.info('update completed for {} table'.format(tablename))

    def make_ivodku_offer_query(self, data):
        tablename = 'billing_offer_counters'
        key_columns = ['owner_id']
        columns = []
        columns.extend(key_columns)
        columns.extend(['loads', 'loads_today'])
        values = []
        for row in data:
            res = []
            for col in key_columns:
                res.append(row[col])
            res.append(row['loads_total_today'] + row['loads_total_prev'])
            res.append(row['loads_total_today'])
            s = '({})'.format(','.join(str(r) for r in res))
            values.append(s)
        query = """
            INSERT INTO {tablename}
            ({column})
            VALUES
            {values}
            ON DUPLICATE KEY UPDATE
                loads = loads_hour + VALUES(loads),
                loads_today = loads_hour + VALUES(loads_today)
        """.format(tablename=tablename, column=",".join(columns), values=",".join(values))
        conn = self.get_db_write_conn()
        conn.select_db(self.get_main_db())
        conn.cursor().execute(query)
        if not self.check_if_not_new_hour():
            logging.error('new hour started, script will be terminated')
            raise NewHourError
        conn.commit()
        conn.close()
        logging.info('update completed for {} table'.format(tablename))

    def get_db_report_data(self, date, report_name, ids):
        conn = self.get_db_read_conn()
        conn.select_db(self.get_reports_db())
        query = mysql_queries.queries.get(report_name) \
            .format(date_to=date, ids=",".join(str(i) for i in ids))
        res = pandas.read_sql(query, conn)
        logging.info('Got {} report data'.format(report_name))
        return res

    def ch_rtd_request(self, method, **kwargs):
        import pyjsonrpc

        client = pyjsonrpc.HttpClient(self.params['rtd']['host'])
        task_id = client.call(method, **kwargs)
        time_run = 0
        state = ''
        while state != 'SUCCESS' and time_run < self.params['rtd']['max_wait_result']:
            state = client.get_task_state(task_id)
            if state == 'FAILURE':
                logging.error(client.get_task_result(task_id))
                raise Exception
            time.sleep(self.params['rtd']['repeat_check'])
            time_run += self.params['rtd']['repeat_check']
        result = dict()
        # celery sets status 'SUCCESS' before result data is actually available, so we check result on non-empty
        while (len(result) == 0 or type(result) is dict and len(result['fields']) == 0) \
                and time_run < self.params['rtd']['max_wait_result']:
            result = client.call("get_task_result", task_id)
            time.sleep(self.params['rtd']['repeat_check'])
            time_run += self.params['rtd']['repeat_check']
        return result

    def get_ch_data(self, **kwargs):
        return self.ch_rtd_request('hourly_update_counters', **kwargs)

    """
    Get set of unique ids from 2 pandas df
    """
    @staticmethod
    def get_ids(data1, data2, field_name):
        ids = set(data1[field_name]) | set(data2[field_name])
        ids.discard(0)  # builder_basic doesn't filter zero id's
        return ids

    @staticmethod
    def make_single_object_data(column_name, data_prev, data_today, db_data):
        data_obj_prev = data_prev.groupby([column_name], as_index=False)['impressions_total', 'clicks_total'].sum()
        data_obj_prev = data_obj_prev.query('{} > 0'.format(column_name))
        data_prev_total = pandas.concat([data_obj_prev, db_data], sort=False) \
            .groupby([column_name], as_index=False)['impressions_total', 'clicks_total'].sum()
        data_obj_today = data_today.groupby([column_name], as_index=False)['impressions_total', 'clicks_total'].sum()
        data_obj_today = data_obj_today.query('{} > 0'.format(column_name))
        data = pandas.merge(data_prev_total, data_obj_today, on=[column_name],
                            how='outer', suffixes=['_prev', '_today'], copy=False)
        data = data.fillna(0)
        data[['impressions_total_prev', 'clicks_total_prev', 'impressions_total_today', 'clicks_total_today']] = \
            data[['impressions_total_prev', 'clicks_total_prev', 'impressions_total_today', 'clicks_total_today']] \
                .astype(int)
        data.rename(columns=lambda x: 'id' if x == column_name else x, inplace=True)
        logging.info('make data completed for {}'.format(column_name))
        return data.to_dict('records')

    @staticmethod
    def make_double_object_data(column_name, data_prev, data_today, db_data):
        data_obj_prev = data_prev.groupby(column_name, as_index=False)['impressions_total', 'clicks_total'].sum()
        # Keep zero id's if second column is section_id, filter zero id's for other cases (adnetwork)
        if column_name[1] == 'section_id':
            filter_query = '{} > 0'.format(column_name[0])
        else:
            filter_query = '{} > 0 and {} > 0'.format(column_name[0], column_name[1])
        data_obj_prev = data_obj_prev.query(filter_query)
        data_prev_total = pandas.concat([data_obj_prev, db_data], sort=False) \
            .groupby(column_name, as_index=False)['impressions_total', 'clicks_total'].sum()
        data_obj_today = data_today.groupby(column_name, as_index=False)['impressions_total', 'clicks_total'].sum()
        data_obj_today = data_obj_today.query(filter_query)
        data = pandas.merge(data_prev_total, data_obj_today, on=column_name,
                            how='outer', suffixes=['_prev', '_today'], copy=False)
        data = data.fillna(0)
        data[['impressions_total_prev', 'clicks_total_prev', 'impressions_total_today', 'clicks_total_today']] = \
            data[['impressions_total_prev', 'clicks_total_prev', 'impressions_total_today', 'clicks_total_today']] \
                .astype(int)
        data.rename(columns=lambda x: 'zone_id' if x == 'section_id' else x, inplace=True)
        logging.info('make data completed for {}_{}'.format(column_name[0], column_name[1]))
        return data.to_dict('records')

    @staticmethod
    def make_campaign_adnetwork_category_data(column_name, data_prev, data_today, db_data):
        data_obj_prev = data_prev.groupby(column_name, as_index=False)['impressions_total', 'clicks_total'].sum()
        filter_query = '{} > 0 and {} > 0 and {} > 0'.format(column_name[0], column_name[1], column_name[2])
        data_obj_prev = data_obj_prev.query(filter_query)
        data_prev_total = pandas.concat([data_obj_prev, db_data], sort=False) \
            .groupby(column_name, as_index=False)['impressions_total', 'clicks_total'].sum()
        data_obj_today = data_today.groupby(column_name, as_index=False)['impressions_total', 'clicks_total'].sum()
        data_obj_today = data_obj_today.query(filter_query)
        data = pandas.merge(data_prev_total, data_obj_today, on=column_name,
                            how='outer', suffixes=['_prev', '_today'], copy=False)
        data = data.fillna(0)
        data[['impressions_total_prev', 'clicks_total_prev', 'impressions_total_today', 'clicks_total_today']] = \
            data[['impressions_total_prev', 'clicks_total_prev', 'impressions_total_today', 'clicks_total_today']] \
                .astype(int)
        data.rename(columns=lambda x: 'category_id' if x == 'placecategory_id' else x, inplace=True)
        logging.info('make data completed for {}_{}_{}'.format(column_name[0], column_name[1], column_name[2]))
        return data.to_dict('records')

    @staticmethod
    def make_offer_object_data(data_prev, data_today, db_data, ids_str):
        column_name = 'owner_id'
        data_obj_prev = data_prev.query('is_turbo == 0 and flag_virtual == 0 and is_rsya == 0')
        data_obj_prev = data_obj_prev.groupby([column_name], as_index=False)['loads_total'].sum()
        data_obj_prev = data_obj_prev.query('{} > 0'.format(column_name))
        data_prev_total = pandas.concat([data_obj_prev, db_data], sort=False) \
            .groupby([column_name], as_index=False)['loads_total'].sum()
        data_obj_today = data_today.query('is_turbo == 0 and flag_virtual == 0 and is_rsya == 0')
        data_obj_today = data_obj_today.groupby([column_name], as_index=False)['loads_total'].sum()
        data_obj_today = data_obj_today.query('{} > 0'.format(column_name))
        data = pandas.merge(data_prev_total, data_obj_today, on=[column_name],
                            how='outer', suffixes=['_prev', '_today'], copy=False)
        data = data.fillna(0)
        data = data.query('owner_id in ({})'.format(ids_str))
        data[['loads_total_prev', 'loads_total_today']] = data[['loads_total_prev', 'loads_total_today']].astype(int)
        logging.info('make data completed for {}'.format('offer clients'))
        return data.to_dict('records')

    def proceed(self):
        # ignore sql warnings like "1364, Field 'name' doesn't have a default value"
        warnings.simplefilter("ignore", category=pymysql.Warning)
        logging.getLogger().setLevel(logging.DEBUG)
        handler = logging.StreamHandler(sys.stdout)
        handler.setLevel(logging.DEBUG)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logging.getLogger().addHandler(handler)
        logging.info('Starting script...')
        self.init_start_time()
        self.get_ch_prev_dates()
        logging.info('Current hour is {}'.format(self.get_start_date_hour()))
        mydb = self.get_db_write_conn()
        query = mysql_queries.queries.get('offer_owners')
        cur = mydb.cursor()
        cur.execute(query)
        owners_str = ''
        for row in cur:
            owners_str = row[0]

        try:
            self.validate_hour()
            self.check_reports_status_hourly()
        except ReportsStatusHourlyError:
            return False
        except CHHorizonError:
            return False
        except ScriptHourError:
            return False
        finally:
            self.close_db_conn()
        try:
            res = self.get_ch_data(date=self.get_ch_prev_dates())
            data_prev = pandas.DataFrame(res['table'], columns=res['fields'])\
                .query('impressions_total > 0 or clicks_total > 0')
            columns_prev = res['fields']
            logging.info('Got ch prev data')
            date = self.get_today()
            to_hour = int(self.get_hour())
            # if to_hour == 0 no data is taken for today from ch, so we make empty df
            if to_hour > 0:
                res = self.get_ch_data(date=date, to_hour=to_hour)
                data_today = pandas.DataFrame(res['table'], columns=res['fields']) \
                    .query('impressions_total > 0 or clicks_total > 0')
            else:
                data_today = pandas.DataFrame([], columns=columns_prev)
            logging.info('Got ch today data')
            owner_ids = self.get_ids(data_prev, data_today, 'owner_id')
            supercampaign_ids = self.get_ids(data_prev, data_today, 'supercampaign_id')
            campaign_ids = self.get_ids(data_prev, data_today, 'campaign_id')
            banner_ids = self.get_ids(data_prev, data_today, 'banner_id')
            date = self.get_old_reports_date()
            if len(owners_str):
                db_offer_owner_data = self.get_db_report_data(date=date, report_name='merge_report_offer_owner',
                                                              ids=owner_ids)
                data_offer_owner_data = self.make_offer_object_data(data_prev=data_prev, data_today=data_today,
                                                                    db_data=db_offer_owner_data, ids_str=owners_str)
                self.make_ivodku_offer_query(data=data_offer_owner_data)

            data_sc_db = self.get_db_report_data(date=date, report_name='supercampaign', ids=supercampaign_ids)
            data_sc = self.make_single_object_data(column_name='supercampaign_id', data_prev=data_prev,
                                                   db_data=data_sc_db, data_today=data_today)
            self.make_ivodku_query(data=data_sc, tablename='supercampaign', key_columns=['id'])
            data_camp_db = self.get_db_report_data(date=date, report_name='campaign', ids=campaign_ids)
            data_cam = self.make_single_object_data(column_name='campaign_id', data_prev=data_prev,
                                                    db_data=data_camp_db, data_today=data_today)
            self.make_ivodku_query(data=data_cam, tablename='campaign', key_columns=['id'])

            data_banner_db = self.get_db_report_data(date=date, report_name='banner', ids=banner_ids)
            data_banner = self.make_single_object_data(column_name='banner_id', data_prev=data_prev,
                                                       db_data=data_banner_db, data_today=data_today)
            self.make_ivodku_query(data=data_banner, tablename='banner', key_columns=['id'])

            db_camp_place_data = self.get_db_report_data(date=date, report_name='campaign_place', ids=campaign_ids)
            data_camp_site_data = self.make_double_object_data(column_name=['campaign_id', 'site_id'],
                                                               data_prev=data_prev, data_today=data_today,
                                                               db_data=db_camp_place_data)
            self.make_ivodku_query(data=data_camp_site_data, tablename='campaign_site',
                                   key_columns=['campaign_id', 'site_id'])
            data_camp_section_data = self.make_double_object_data(column_name=['campaign_id', 'section_id'],
                                                                  data_prev=data_prev, data_today=data_today,
                                                                  db_data=db_camp_place_data)
            self.make_ivodku_query(data=data_camp_section_data, tablename='campaign_zone',
                                   key_columns=['campaign_id', 'zone_id'])
            data_camp_place_data = self.make_double_object_data(column_name=['campaign_id', 'place_id'],
                                                                data_prev=data_prev, data_today=data_today,
                                                                db_data=db_camp_place_data)
            self.make_ivodku_query(data=data_camp_place_data, tablename='campaign_place',
                                   key_columns=['campaign_id', 'place_id'])

            db_camp_adnetwork_site_data = self.get_db_report_data(date=date, report_name='campaign_adnetwork_site',
                                                                  ids=campaign_ids)
            data_camp_adnetwork_data = self.make_double_object_data(column_name=['campaign_id', 'adnetwork_id'],
                                                                    data_prev=data_prev, data_today=data_today,
                                                                    db_data=db_camp_adnetwork_site_data)
            self.make_ivodku_query(data=data_camp_adnetwork_data, tablename='campaign_adnetwork',
                                   key_columns=['campaign_id', 'adnetwork_id'])
            db_camp_adnetwork_category_data = self.get_db_report_data(date=date,
                                                                      report_name='campaign_adnetwork_category',
                                                                      ids=campaign_ids)
            data_camp_adnetwork_category_data = self.make_campaign_adnetwork_category_data(
                column_name=['campaign_id', 'adnetwork_id', 'placecategory_id'],
                data_prev=data_prev, data_today=data_today,
                db_data=db_camp_adnetwork_category_data)
            self.make_ivodku_query(data=data_camp_adnetwork_category_data, tablename='campaign_adnetwork_category',
                                   key_columns=['campaign_id', 'adnetwork_id', 'category_id'])
            data_camp_adnetwork_site_data = self.make_campaign_adnetwork_category_data(
                column_name=['campaign_id', 'site_id', 'adnetwork_id'],
                data_prev=data_prev, data_today=data_today,
                db_data=db_camp_adnetwork_site_data)
            self.make_ivodku_query(data=data_camp_adnetwork_site_data, tablename='campaign_adnetwork_site',
                                   key_columns=['campaign_id', 'site_id', 'adnetwork_id'])
            logging.error('Script has been successfully completed. ')
            return True
        except NewHourError:
            # New hour started, so just terminating script, this is not task error.
            return False
        finally:
            self.close_db_conn()


if __name__ == "__main__":
    if len(sys.argv) == 1:
        # execute from command line with no args: let's use config file for settings
        import config as conf
        mysql = conf.MYSQL
        mysql_write = conf.MYSQL_WRITE
        rtd = conf.RTD
        settings = conf.SETTINGS

    else:
        # args not empty: let's take settings from args as json.dumps dicts
        import argparse
        import json
        parser = argparse.ArgumentParser()
        parser.add_argument("--mysql")
        parser.add_argument("--mysql_write")
        parser.add_argument("--rtd")
        parser.add_argument("--settings")
        args = parser.parse_args()
        mysql = json.loads(args.mysql)
        mysql_write = json.loads(args.mysql_write)
        rtd = json.loads(args.rtd)
        settings = json.loads(args.settings)
    HourlyUpdateCounters(mysql=mysql, mysql_write=mysql_write, rtd=rtd, settings=settings).proceed()
