#!/usr/bin/env python2.7
# -*- coding: utf-8 -*-
import datetime
import logging
import os
import re
import sys

import yaml

parent_dir_path = '/'.join(os.path.abspath(__file__).split('/')[:-2])
sys.path = [parent_dir_path] + sys.path

from common.juggler import send_juggler, ok, warn, crit
from datetime import timedelta, datetime
from optparse import OptionParser, Option
from yql.api.v1.client import YqlClient

from checks import parse_check_or_none


WEEKDAY_NAMES = ['MON', 'TUE', 'WED', 'THU', 'FRI', 'SAT', 'SUN']

SERVICE_NAME = 'ytLogJobs'


class TaskStat(object):
    def __init__(self, task_name, exec_fail_cnt, job_fail_cnt, job_success_cnt):
        self.task_name = task_name
        self.exec_fail_cnt = int(exec_fail_cnt)
        self.job_fail_cnt = int(job_fail_cnt)
        self.job_success_cnt = int(job_success_cnt)

    @classmethod
    def from_yql_row(cls, row):
        task_name, exec_fail_cnt, job_fail_cnt, job_success_cnt, _, _ = row
        return cls(task_name, exec_fail_cnt, job_fail_cnt, job_success_cnt)

    @property
    def job_total_cnt(self):
        return self.job_fail_cnt + self.job_success_cnt

    @property
    def job_success_ratio(self):
        return float(self.job_success_cnt) / float(self.job_total_cnt or 1)


class TaskChecker(object):
    def __init__(self, task_name, field_checkers):
        self.task_name = task_name
        self.field_checkers = field_checkers

    @classmethod
    def from_dict(cls, task_name, raw_checkers):
        checkers = {fld: FieldChecker.from_dict_or_str(chk) for fld, chk in raw_checkers.iteritems()}
        return TaskChecker(task_name, checkers)

    def check(self, task):
        if task is None:
            return [crit('NO DATA') + (self.task_name, None)]

        return [self._check_field(task, fld, chk) for fld, chk in self.field_checkers.iteritems()]

    def _check_field(self, task, field_name, checker):
        result = checker.check(getattr(task, field_name, None))
        return result + (self.task_name, field_name)

    def __repr__(self):
        return '%s checker' % self.task_name


class FieldChecker(object):
    def __init__(self, crit_checker, warn_checker=None):
        self.crit_checker = crit_checker
        self.warn_checker = warn_checker

    @classmethod
    def from_dict_or_str(cls, dict_or_str):
        def get_source_data():
            day = WEEKDAY_NAMES[datetime.today().weekday()]
            if type(dict_or_str) is dict and day in dict_or_str:
                return dict_or_str.get(day)
            else:
                return dict_or_str

        def get_warn_and_crit(src):
            if type(src) is dict:
                return src.get('WARN'), src.get('CRIT')
            else:
                return None, src

        warn, crit = get_warn_and_crit(get_source_data())
        return FieldChecker(parse_check_or_none(crit), parse_check_or_none(warn))

    def check(self, value):
        if value is None:
            return crit('NO DATA')

        result = self.crit_checker.check(value)
        if result:
            return crit(result)

        result = self.warn_checker.check(value) if self.warn_checker else None
        if result:
            return warn(result)

        return ok()


def read_sibling_file(filename):
    current_dir_path = '/'.join(os.path.abspath(__file__).split('/')[:-1])
    with open(os.path.join(current_dir_path, filename), 'r') as f:
        return f.read()


def fetch_jobs_stats_rows():
    date_table = (datetime.today() - timedelta(days=1)).strftime('%Y-%m-%d')

    query = read_sibling_file('count_jobs.yql').replace('%date%', date_table)

    client = YqlClient()
    request = client.query(query, syntax_version=1)
    request.run()

    logger.info(query)

    rows = []
    for table in request.get_results():
        table.fetch_full_data()
        rows.extend([row for row in table.rows])

    return rows


class CheckConfig:
    def __init__(self, default_checks, default_excludes, checks):
        self.default_checks = default_checks
        self.default_excludes = re.compile(u'^(' + u'|'.join(default_excludes) + u')$')
        self.specific_tasks = checks.keys()
        self.specific_checkers = [TaskChecker.from_dict(task, checks) for task, checks in checks.iteritems()]

    @classmethod
    def load(cls, path):
        checks = yaml.load(read_sibling_file(path))
        default = checks.get('default', {})
        return cls(default.get('checks', {}), default.get('excludes', []), checks.get('jobs'))

    def get_checkers(self, found_tasks):
        default_check_tasks = set(found_tasks) - set(self.specific_tasks)
        return self._cons_default_checkers(default_check_tasks) + self.specific_checkers

    def _cons_default_checkers(self, tasks):
        filtered_tasks = filter(lambda name: self._is_included(name), tasks)
        return [TaskChecker.from_dict(task, self.default_checks) for task in filtered_tasks]

    def _is_included(self, task_name):
        if not isinstance(task_name, basestring):
            logger.warn('Got non-string task name: ' + str(task_name))
            return False

        return not self.default_excludes.match(task_name)


def get_send_method(virtual_host, dry_run=False):
    def do_send(lvl, message):
        logger.info('%s, %s, %s' % (virtual_host, lvl, message))
        if not dry_run:
            send_juggler(virtual_host, SERVICE_NAME, lvl, message)
    return do_send


def check_stats(rows, checks_filename, send_status):
    config = CheckConfig.load(checks_filename)
    task_stats = {stat.task_name: stat for stat in [TaskStat(*row[:4]) for row in rows]}
    checkers = config.get_checkers(task_stats.keys())
    results = [checker.check(task_stats.get(checker.task_name, None)) for checker in checkers]
    flat_results = {(task, field): (lvl, result) for l in results for (lvl, result, task, field) in l}

    def filter_and_transform_check_results(filter_lvl):
        def cons_message(task, field, result):
            check_name = '%s.%s' % (task, field) if field else task
            return '%s: %s' % (check_name, result)

        return [cons_message(task, fld, res) for (task, fld), (lvl, res) in flat_results.iteritems()
                if lvl == filter_lvl]

    crit_results = filter_and_transform_check_results('CRIT')
    warn_results = filter_and_transform_check_results('WARN')

    if crit_results:
        aggr_level = 'CRIT'
    elif warn_results:
        aggr_level = 'CRIT'
    else:
        aggr_level = 'OK'

    aggr_results = []
    if crit_results:
        aggr_results = aggr_results + ['CRIT:'] + crit_results

    if warn_results:
        aggr_results = aggr_results + ['WARN:'] + warn_results

    send_status(aggr_level, '\n'.join(aggr_results))


def _load_test_data():
    from test_data import TEST_DATA
    return TEST_DATA


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger('bazinga-jobs-counts')

    opts_list = (
        Option(
            '-v', '--virtual-host',
            action='store',
            dest='virtual_host',
            type='string',
            default=None,
            help='Golem virtual host'
        ),
        Option(
            '-c', '--checks-filename',
            action='store',
            dest='checks_filename',
            type='string',
            default=None,
            help='Checks path'
        ),
        Option(
            '-t', '--test',
            action='store_true',
            dest='test',
            default=False,
            help='For testing'
        ),
    )

    opts_parser = OptionParser('Usage: ./%s -h {virtual_host} -c {checks_filename}' % os.path.basename(__file__),
                               option_list=opts_list)
    (opts, args) = opts_parser.parse_args()

    if not opts.virtual_host and not opts.checks_filename:
        opts_parser.print_usage()
        exit(1)

    send_method = get_send_method(opts.virtual_host, dry_run=opts.test)
    data = fetch_jobs_stats_rows() if not opts.test else _load_test_data()
    check_stats(data, opts.checks_filename, send_method)
