import argparse
import json
import logging
import sys

import pymysql

from merge_report_tables import merge_report_tables_avg_day, merge_report_tables_count

DATE_FROM = '2009-01-01'
DATE_TO = '2016-04-01'


class ConfigError(Exception):
    pass


class TableCheckError(Exception):
    pass


class ReportsMonitoring:
    def __init__(self, **kwargs):
        self.params = kwargs
        self.conn = None
        self.merge_table_names = None

    def proceed(self):
        self.checkAvgDay()
        self.checkCount()

        logging.info('Everything is OK!')

    def checkAvgDay(self):
        for table_name in merge_report_tables_avg_day:
            logging.info("Checking table {table_name} avg day...".format(table_name=table_name))
            self.check_avg_day(table_name)
            logging.info("Table is ok!")

    def checkCount(self):
        for table_name in merge_report_tables_count:
            logging.info("Checking table {table_name} count...".format(table_name=table_name))
            self.check_count(table_name)
            logging.info("Table is ok!")

    def check_avg_day(self, table_name):
        expected_avg = merge_report_tables_avg_day[table_name]['expectedAvgDay']
        query = self.get_avg_day_sql(table_name)
        got_avg = self.get_first_db_result(query)
        if got_avg == expected_avg:
            return True

        logging.info("Table avg day check failed!")
        logging.info(
            "Expected avg: {expected_avg}, got: {got_avg}!".format(expected_avg=expected_avg, got_avg=got_avg)
        )
        raise TableCheckError

    def check_count(self, table_name):
        expected_count = merge_report_tables_count[table_name]['expectedCount']
        query = self.get_count_sql(table_name)
        got_count = self.get_first_db_result(query)
        if got_count == expected_count:
            return True

        logging.info("Table count check failed!")
        logging.info(
            "Expected count: {expected_count}, got: {got_count}!".format(
                expected_count=expected_count,
                got_count=got_count
            )
        )
        raise TableCheckError

    def get_first_db_result(self, query):
        conn = self.get_db_conn()
        cur = conn.cursor()
        cur.execute(query)
        res = cur.fetchall()
        cur.close()
        return res[0][0]

    def get_avg_day_sql(self, table_name):
        return """
            SELECT SQL_NO_CACHE
                avg(day)
            FROM `{table_name}`
            WHERE
                day >= '{date_from}'
                AND day < '{date_to}'
        """.format(table_name=table_name, date_from=DATE_FROM, date_to=DATE_TO)

    def get_count_sql(self, table_name):
        return """
            SELECT SQL_NO_CACHE
                COUNT(1)
            FROM `{table_name}`
        """.format(table_name=table_name)

    def get_db_conn(self):
        if self.conn is None or not self.conn.open:
            self.conn = pymysql.connect(**self.get_db_config())
        return self.conn

    def get_db_config(self):
        if 'mysql' in self.params.keys():
            return {
                'host': self.params['mysql']['host'],
                'user': self.params['mysql']['user'],
                'passwd': self.params['mysql']['password'],
                'db': self.params['mysql']['db'],
                'autocommit': False,
            }
        else:
            raise ConfigError


def main():
    logging.getLogger().setLevel(logging.DEBUG)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logging.getLogger().addHandler(handler)
    logging.info('Starting script...')

    parser = argparse.ArgumentParser(description='Report monitoring')
    parser.add_argument('--mysql', dest='mysql',
                        help="Mysql settings", required=True)
    args = parser.parse_args()
    mysql = json.loads(args.mysql)
    ReportsMonitoring(mysql=mysql).proceed()


if __name__ == '__main__':
    main()
