from __future__ import absolute_import

import logging
import requests
import time
from datetime import datetime, timedelta
from uuid import uuid4
from multiprocessing.pool import ThreadPool

from celery import shared_task
from mongoengine import Q

from .models import Event, Task
from .tools import round_time, is_for_scale, create_event_params_list, split_to_chunks, with_timer, generate_stats
from .enums import EventState, TaskState, EventStoryState
from .scheduling import schedule_task
from .event_processing import process_event_wrapper
from .. import settings as conf


_log = logging.getLogger('celery.tasks')


@shared_task
def generate_clock_events():
    clock_events = Event.objects(
        Q(name='clock') & Q(parameters={'key': 'scale', 'value': '1m'})
    ).order_by('-_id').limit(1)

    last_timestamp = clock_events.first().get_param('timestamp') if clock_events else None

    new_events = []
    now = datetime.now()
    now -= timedelta(
        # minutes=now.minute % 5,
        seconds=now.second,
        microseconds=now.microsecond
    )

    timestamps = [now]
    if last_timestamp is not None:
        if now.strftime(conf.CLOCK_EVENT_DATETIME_FORMAT) == last_timestamp:
            _log.warning('Attempt to clone clock events detected')
            return
        curr_timestamp = datetime.strptime(last_timestamp, conf.CLOCK_EVENT_DATETIME_FORMAT) + timedelta(minutes=1)
        while curr_timestamp < now:
            timestamps.append(curr_timestamp)
            curr_timestamp += timedelta(minutes=1)

    for tstamp in sorted(timestamps):
        rounded = round_time(tstamp, '1m')

        for scale in conf.CLOCK_EVENTS_SCALES:
            if is_for_scale(rounded, scale):
                _log.info('Generating %s %s', scale, rounded)
                new_events.append(Event(
                    name='clock',
                    parameters=create_event_params_list({
                        'scale': scale,
                        'timestamp': rounded.strftime(conf.CLOCK_EVENT_DATETIME_FORMAT),
                        'weekday': rounded.weekday()
                    })
                ))
                new_events.append(Event(
                    name='clock_periodic',
                    parameters=create_event_params_list({
                        'scale': scale,
                        'timestamp': rounded.strftime('0000-00-00 %H:%M:%S'),
                        'date': rounded.strftime('%Y-%m-%d'),
                        'weekday': rounded.weekday()
                    })
                ))

        if is_for_scale(rounded, '1d') and tstamp.weekday() == 0:  # monday
            new_events.append(Event(
                name='clock_periodic',
                parameters=create_event_params_list({
                    'scale': '1w',
                    'timestamp': rounded.strftime('0000-00-00 %H:%M:%S'),
                    'date': rounded.strftime('%Y-%m-%d'),
                    'weekday': rounded.weekday()
                })
            ))
            new_events.append(Event(
                name='clock',
                parameters=create_event_params_list({
                    'scale': '1w',
                    'timestamp': rounded.strftime(conf.CLOCK_EVENT_DATETIME_FORMAT),
                    'weekday': rounded.weekday()
                })
            ))

        if is_for_scale(rounded, '1d') and tstamp.day == 1:  # first day of month
            new_events.append(Event(
                name='clock_periodic',
                parameters=create_event_params_list({
                    'scale': '1month',
                    'timestamp': rounded.strftime('0000-00-00 %H:%M:%S'),
                    'date': rounded.strftime('%Y-%m-%d'),
                    'weekday': rounded.weekday()
                })
            ))
            new_events.append(Event(
                name='clock',
                parameters=create_event_params_list({
                    'scale': '1month',
                    'timestamp': rounded.strftime(conf.CLOCK_EVENT_DATETIME_FORMAT),
                    'weekday': rounded.weekday()
                })
            ))

    Event.objects.insert(new_events)


@shared_task
def get_new_events():
    guid = uuid4()

    count = Event.objects(state=EventState.NEW).update(state=EventState.PROCESSING, processor_guid=guid)
    if count != 0:
        _log.info('Starting process_new_events %s', str(guid))
        process_new_events.delay(guid)


@shared_task
def process_new_events(guid):
    _log.info(' process_new_events %s has started', str(guid))
    events_ids = [x.id for x in Event.objects(processor_guid=guid).only('id').all()]
    _log.info('Processing events %s with guid %s', events_ids, guid)
    chunks = split_to_chunks(events_ids, conf.EVENTS_NUMBER_IN_CHUNK)

    for chunk in chunks:
        process_events.delay(chunk)


@shared_task
def process_events(events_ids):
    if not events_ids:
        return

    tasks_ids = []
    for event_id in events_ids:
        tasks_ids += process_event_wrapper(event_id) or []

    if len(tasks_ids) == 0:
        return

    splitted_tasks_ids = split_to_chunks(tasks_ids, conf.TASKS_NUMBER_IN_CHUNK)
    if len(splitted_tasks_ids) > 1:
        for chunk_with_ids in splitted_tasks_ids:
            schedule_tasks.delay(chunk_with_ids)
    else:
        schedule_tasks(tasks_ids)


@shared_task
@with_timer('Scheduling tasks chunk {}')
def schedule_tasks(tasks_ids):
    pool = ThreadPool(min(conf.TASKS_NUMBER_IN_CHUNK, 20))
    tasks = [Task.objects.get(id=task_id) for task_id in tasks_ids]
    pool.map(schedule_task, tasks)


@shared_task
def recover_events():
    # process all failed and (new && processing)
    edge_datetime = datetime.now() - timedelta(seconds=conf.EVENT_TIMEOUT_TO_RECOVER)

    events = Event.objects(
        Q(state=EventState.FAILED) |
        (Q(state=EventState.PROCESSING) & Q(time_created__lt=edge_datetime))
    ).limit(conf.EVENTS_NUMBER_TO_RECOVER).only('id').order_by('time_created').all()
    ids = [x.id for x in events]
    if ids:
        _log.warning('Going to recover events %s', ids)
        process_events(ids)


@shared_task
def recover_tasks():
    # schedule all failed and (new && scheduling)
    edge_datetime = datetime.now() - timedelta(seconds=conf.TASK_TIMEOUT_TO_RECOVER)
    recovering_edge_datetime = datetime.now() - timedelta(seconds=conf.TASK_TIMEOUT_TO_RECOVER_FROM_RECOVER)

    guid = uuid4()

    tasks_count = Task.objects(
        (
            Q(state=TaskState.FAILED) |
            (Q(state__in=[TaskState.NEW, TaskState.SCHEDULING]) & Q(time_created__lt=edge_datetime)) |
            (Q(state=TaskState.RECOVERING) & Q(time_scheduling_started__lt=recovering_edge_datetime))
        ) &
        Q(enabled=True)
    ).limit(conf.TASKS_NUMBER_TO_RECOVER).update(
        state=TaskState.RECOVERING, processor_guid=guid, time_scheduling_started=datetime.now())

    if tasks_count == 0:
        return

    tasks = Task.objects(state=TaskState.RECOVERING, processor_guid=guid).only('id').all()

    tasks_ids = [x.id for x in tasks]
    if tasks_ids:
        _log.warning('Going to recover tasks %s', tasks_ids)
        schedule_tasks(tasks_ids)


@shared_task
def delete_old_events():
    edge_datetime = datetime.now() - timedelta(days=conf.OLD_EVENTS_TIMEOUT)
    old_events = Event.objects.filter(
        state=EventState.PROCESSED, time_created__lt=edge_datetime
    )
    if not old_events:
        return
    id_str = ', '.join(str(e.id) for e in old_events)
    _log.warning('Removing events %s', id_str)
    old_events.delete()
    _log.warning('Removing events %s has been done', id_str)


@shared_task
def delete_old_tasks():
    edge_datetime = datetime.now() - timedelta(days=conf.OLD_TASKS_TIMEOUT)
    old_tasks = Task.objects.filter(
        state=TaskState.SCHEDULED, time_scheduled__lt=edge_datetime
    )
    if not old_tasks:
        return
    id_str = ', '.join(str(e.id) for e in old_tasks)
    _log.warning('Removing tasks %s', id_str)
    old_tasks.delete()
    _log.warning('Removing tasks %s has been done', id_str)


@shared_task
def send_stat_to_solomon():
    if not conf.SEND_STAT_TO_SOLOMON:
        return
    now = round_time(datetime.now(), '1m')
    ts = time.mktime(now.timetuple())
    start_time = datetime.now() - timedelta(minutes=1)

    try:
        stats = generate_stats(start_time, now)
    except Exception as e:
        _log.warning('Failed to generate stats for solomon: %s', e)
        return

    resp = None
    try:
        resp = requests.post(
            conf.SOLOMON_PUSH_URL,
            json={
                'commonLabels': {
                    'project': conf.SOLOMON_PROJECT,
                    'cluster': conf.SOLOMON_CLUSTER,
                    'service': conf.SOLOMON_SERVICE,
                },
                'sensors': [
                    {
                        'labels': {'sensor': sensor_name},
                        'ts': ts,
                        'value': sensor_value
                    } for sensor_name, sensor_value in stats.iteritems()
                ]
            }
        )
        resp.raise_for_status()
    except Exception as e:
        reason = ''
        if hasattr(resp, 'content'):
            reason = getattr(resp, 'content')
        _log.warning('Failed to send stat to solomon: %s %s', str(e), reason)


@shared_task
def travel_through_the_event(
        event_id,
        event_story=None,
        reject=False,
        external=True
):
    if external:
        _log.info('event_story: starting for event %s. Rejecting: %s', event_id, reject)
    _log.info('event_story: visiting event %s', event_id)

    if external and event_story:
        event_story.state = EventStoryState.RUNNING
        event_story.save()
    event = Event.objects.get(id=event_id)

    if event_story and not external:
        event_story.produced_events.append(event)

    if reject:
        event.state = EventState.REJECTED
        event.save()

    produced_tasks = Task.objects(events__contains=event_id).all()
    for task in produced_tasks:
        answer = None
        if reject:
            task.state = TaskState.REJECTED
            task.save()
            _log.info('event_story: gonna stop task %s', task.id)
            try:
                answer = task.stop_if_running()
            except Exception as e:
                _log.exception('event_story: failed to stop task %s', task.id)
                answer = str(e)

        if event_story:
            event_story.produced_tasks.append(task)

        if reject and event_story:
            event_story.task_stopping_info[str(task.id)] = answer

        # TODO: different behavior for different source_type
        if task.sandbox_task_id:
            produced_events = Event.objects(source__task_id=task.sandbox_task_id).all()
            for ev in produced_events:
                travel_through_the_event(
                    ev.id,
                    event_story=event_story,
                    reject=reject,
                    external=False
                )

    if event_story and external:
        event_story.state = EventStoryState.DONE

    if event_story:
        event_story.save()

    if external:
        _log.info('Done: event_story for event %s. Rejecting: %s', event_id, reject)
