import argparse
import sqlalchemy as sa
from dateutil import tz
from datetime import datetime
from lxml import etree

from six.moves.urllib.request import urlopen
from six.moves.urllib.parse import urljoin

from yandex.maps.wiki import config, db, release_stages as rs
from yandex.maps.wiki.tasks import states


def get_latest_stages(session, task_name):
    if task_name == 'create_stable':
        from_text = 'service.task t join service.vrevisions_refresh_task using (id)'
        where_text = "action = 'create-stable'"
    else:
        from_text = 'service.task t'
        where_text = "t.type = '%s'" % task_name

    result = session.execute(
        sa.select([sa.text('t.id'), sa.text('t.created'), sa.text('t.log')])
        .select_from(sa.text(from_text))
        .where(sa.text(where_text))
        .order_by(sa.text('t.id desc'))
        .limit(1))
    return rs.calc_repeated_stages(task_name, rs.collect_tasks(result))


def get_task_status(config, task_id):
    try:
        base_url = config.get_config().xml.get('/services/tasks/url')
        url = urljoin(base_url, '/tasks/' + str(task_id))
        data = urlopen(url, timeout=5).read()
        doc = etree.fromstring(data)
        ns = '{http://maps.yandex.ru/mapspro/tasks/1.x}'
        el = doc.find(ns + 'task')
        return el.attrib['status'].upper()
    except Exception as ex:
        return 'UNKNOWN: ' + str(ex)


def seconds(value):
    if value[-1] == 'h':
        return int(value[:-1]) * 3600
    if value[-1] == 'm':
        return int(value[:-1]) * 60
    return value


def get_monitoring_status(args, task_id, status, duration):
    if status.startswith('UNKNOWN'):
        level = 1  # warn
    elif status == states.SUCCESS:
        level = 0  # ok
    elif status == states.FAILURE:
        level = 2 if args.crit_failure else 1  # warn, crit
    elif status not in [states.PENDING, states.STARTED]:
        level = 0  # FROZEN, REVOKED ok
    elif duration < seconds(args.warn):
        level = 0  # PENDING, STARTED ok
    elif duration < seconds(args.crit):
        level = 1  # PENDING, STARTED warn
    else:
        level = 2  # PENDING, STARTED crit

    time = datetime.fromtimestamp(duration, tz.tzutc())
    return '{level};Task {task_name} {task_id} {status} ({time})'.format(
        level=level,
        task_name=args.task_name,
        task_id=task_id,
        status=status.lower(),
        time=time.strftime('%H:%M:%S'))


def find_longest_task(stages, task_name):
    max_duration = 0
    task_id = None
    for s in stages:
        if s.name == task_name:
            end = s.end if s.is_completed else datetime.now(tz.tzutc())
            last_task_duration = int((end - s.last_task_start).total_seconds())
            last_task_id = s.task_ids[-1]
            if last_task_duration > max_duration:
                max_duration = last_task_duration
                task_id = last_task_id
    return task_id, max_duration


def parse_args():
    arg_parser = argparse.ArgumentParser(
        description='Calculate latest release metric for monitoring')
    arg_parser.add_argument(
        '--config', type=str,
        help='full path to services.xml')
    arg_parser.add_argument(
        '--task-name', type=str, required=True,
        help='task name (example: create_stable)')
    arg_parser.add_argument(
        '--warn', type=str, required=True,
        help='warn duration threshold (<hours>h | <minutes>m | <seconds>)')
    arg_parser.add_argument(
        '--crit', type=str, required=True,
        help='crit duration threshold (<hours>h | <minutes>m | <seconds>)')
    arg_parser.add_argument(
        '--crit-failure', default=False, action='store_true',
        help='fire critical if task failure')
    return arg_parser.parse_args()


def load_stages(config, task_name):
    db.init_pool(['core'], 'core')
    with db.get_read_session('core') as session:
        if task_name in ['create_stable', 'prepare_stable_branch']:
            return get_latest_stages(session, task_name)
        branch_ids = rs.get_latest_branch_ids(session, limit=1)
        for branch_id in branch_ids:
            if branch_id > 1:
                return rs.get_all_stages(session, branch_id)


def main():
    args = parse_args()
    config.init_config(args.config)

    try:
        stages = load_stages(config, args.task_name)
    except Exception as ex:
        print('1;Monitoring broken:', ex)
        return

    if stages is None:
        print('0;Ok')
        return

    task_id, duration = find_longest_task(stages, args.task_name)
    if task_id is None:
        print('0;Ok')
        return

    status = get_task_status(config, task_id)
    print(get_monitoring_status(args, task_id, status, duration))


if __name__ == '__main__':
    main()
