import re
from datetime import datetime
from collections import namedtuple
from dateutil import tz

import sqlalchemy as sa


FULL_VALIDATION_MIN_CHECKS_COUNT = 65

RKUB_AOI_ID = 2065384494  # old id 1567907042
RKUB_VALIDATION_MIN_CHECKS_COUNT = 20

UKRAINE_AOI_ID = 1488418547

FRANCE_AOI_ID = 1748690958
FRANCE_VALIDATION_MIN_CHECKS_COUNT = 20


Task = namedtuple("Task", "id created log")
Stage = namedtuple("Stage", "name start last_task_start end is_completed task_ids")


def json_object(stages):
    DATE_FORMAT = '%Y-%m-%dT%H:%M:%S%z'
    return {
        'stages': [
            {
                'name': s.name,
                'start': s.start.strftime(DATE_FORMAT),
                'lastTaskStart': s.last_task_start.strftime(DATE_FORMAT),
                'end': s.end.strftime(DATE_FORMAT),
                'isCompleted': s.is_completed,
                'taskIds': s.task_ids,
            }
            for s in stages
        ]}


def get_time(log):
    match = re.search('\\[(.+)\\]', log)
    if match:
        time_str = match.group(1)
        tzinfo = tz.gettz('Europe/Moscow')
        return datetime.strptime(time_str, '%d/%m/%Y %H:%M:%S').replace(tzinfo=tzinfo)


def calc_start_time(log, start_marker='Starting task'):
    if len(log) > 0 and start_marker in log[0]:
        return get_time(log[0])


def calc_completion_time(log, completion_marker='Task finished'):
    if len(log) > 0 and completion_marker in log[-1]:
        return get_time(log[-1])


def calc_first_stage(name, tasks):
    if not tasks:
        return Stage(name, None, None, None, False, [])

    task_ids = []
    start_time = tasks[0].created
    last_start_time = None
    for task in tasks:
        task_ids.append(task.id)
        log = task.log.splitlines()
        worker_start_time = calc_start_time(log)
        last_start_time = task.created if worker_start_time is None else worker_start_time
        compl_time = calc_completion_time(log)
        if compl_time:
            return Stage(name,
                         start_time,
                         last_start_time,
                         compl_time,
                         True,
                         task_ids)
    return Stage(name,
                 start_time,
                 last_start_time,
                 datetime.now(tz=tz.tzutc()),
                 False,
                 task_ids)


def calc_repeated_stages(name, tasks):
    tasks_lists = [[]]
    for task in tasks:
        tasks_lists[-1].append(task)
        if 'Task finished' in task.log:
            tasks_lists.append([])
    if len(tasks_lists[-1]) == 0:
        tasks_lists.pop()
    return [calc_first_stage(name, tl) for tl in tasks_lists]


def collect_tasks(result):
    return sorted(
        [Task(r[0], r[1], r[2]) for r in result],
        key=lambda t: t.created)


def get_create_stable_stages(session, branch_id):
    result = session.execute(
        sa.select([sa.text('t.id'), sa.text('t.created'), sa.text('log')])
        .select_from(sa.text('service.task t join service.vrevisions_refresh_task using (id)'))
        .select_from(sa.text('revision.branch b'))
        .where(sa.text("action = 'create-stable'"))
        .where(sa.text('t.created < b.created'))
        .order_by(sa.text('t.created desc'))
        .limit(1)
        .where(sa.literal_column('b.id').label('branch_id') == branch_id))

    return calc_repeated_stages('create_stable', collect_tasks(result))


def get_apply_shadow_attributes(session, branch_id):
    result = session.execute(
        sa.select([sa.column('id'), sa.column('created'), sa.column('log')])
        .select_from(sa.text('service.task join service.apply_shadow_attributes_task using (id)'))
        .where(sa.sql.column('branch_id') == branch_id))

    return calc_repeated_stages('apply_shadow_attributes', collect_tasks(result))


def get_exports(session, branch_id):
    result = session.execute(
        sa.select([sa.column('id'), sa.column('created'), sa.column('log')])
        .select_from(sa.text('service.task join service.export_task using (id)'))
        .where(sa.text("subset = 'domain'"))
        .where(sa.sql.column('branch_id') == branch_id))

    return calc_repeated_stages('export', collect_tasks(result))


def get_diffalerts(session, branch_id):
    result = session.execute(
        sa.select([sa.column('id'), sa.column('created'), sa.column('log')])
        .select_from(sa.text('service.task t join service.diffalert_task using (id)'))
        .where(sa.sql.column('new_branch_id') == branch_id))

    return calc_repeated_stages('diffalert', collect_tasks(result))


def get_prepare_stable_branch(session, branch_id):
    result = session.execute(
        sa.select([sa.column('id'), sa.column('created'), sa.column('log')])
        .select_from(sa.text('service.task join service.prepare_stable_branch_working_data ON id=task_id'))
        .where(sa.sql.column('branch_id') == branch_id))

    return calc_repeated_stages('prepare_stable_branch', collect_tasks(result))


def get_validation_export(session, branch_id):
    result = session.execute(
        sa.select([sa.column('id'), sa.column('created'), sa.column('log')])
        .select_from(sa.text('service.task join service.validation_export_task using (id)'))
        .where(sa.sql.column('branch_id') == branch_id))

    return calc_repeated_stages('validation_export', collect_tasks(result))


def get_validation_tasks(session, branch_id, aoi_id, min_checks_count=0):
    query = sa.select([sa.column('id'), sa.column('created'), sa.column('log')]) \
        .select_from(sa.text('service.task join service.validation_task using (id)')) \
        .where(sa.sql.column('branch_id') == branch_id) \
        .where(sa.literal_column('array_length(checks, 1)').label('checks_count')
               >= min_checks_count)

    if aoi_id:
        query = query \
            .where(sa.text('aoi_ids is not null')) \
            .where(sa.literal_column('array_length(aoi_ids, 1)').label('aois_count') == 1) \
            .where(sa.literal_column('aoi_ids[1]').label('first_aoi_id') == aoi_id)
    else:
        query = query \
            .where(sa.text('aoi_ids is null'))

    result = session.execute(query)
    return collect_tasks(result)


def get_ukr_validations(session, branch_id):
    return calc_repeated_stages(
        'validation_ukr', get_validation_tasks(session, branch_id,
                                               aoi_id=UKRAINE_AOI_ID))


def get_rkub_validations(session, branch_id):
    return calc_repeated_stages(
        'validation_rkub', get_validation_tasks(
            session, branch_id,
            aoi_id=RKUB_AOI_ID,
            min_checks_count=RKUB_VALIDATION_MIN_CHECKS_COUNT))


def get_france_validations(session, branch_id):
    return calc_repeated_stages(
        'validation_france', get_validation_tasks(
            session, branch_id,
            aoi_id=FRANCE_AOI_ID,
            min_checks_count=FRANCE_VALIDATION_MIN_CHECKS_COUNT))


def get_full_validations(session, branch_id):
    return calc_repeated_stages(
        'validation_full', get_validation_tasks(
            session, branch_id, aoi_id=None,
            min_checks_count=FULL_VALIDATION_MIN_CHECKS_COUNT))


def get_all_stages(session, branch_id):
    all_stages = []
    for stage_getter in [
            get_create_stable_stages,
            get_apply_shadow_attributes,
            get_exports,
            get_diffalerts,
            get_prepare_stable_branch,
            get_validation_export,
            get_full_validations,
            get_rkub_validations,
            get_france_validations,
            get_ukr_validations]:
        all_stages.extend(stage_getter(session, branch_id))
    return all_stages


def get_latest_branch_ids(session, limit):
    result = session.execute(
        sa.select([sa.column('id')])
        .select_from(sa.text('revision.branch'))
        .order_by(sa.text('id desc'))
        .limit(limit))
    return [r[0] for r in result]
