from __future__ import absolute_import

import json
import logging
import hashlib
import httplib
import datetime as dt

import requests

from .. import settings as conf
from .tools import Timer, no_semaphore_semaphore, if_true
from .enums import TaskState, ActionTypes
from . import http_client

from sandbox import common

import sandbox.common.types.misc as ctm
import sandbox.common.types.statistics as ctst

_log = logging.getLogger('celery.tasks')
_stats = common.statistics.Signaler(
    common.statistics.ClientSignalHandler(token=conf.SANDBOX_TOKEN, url=conf.SANDBOX_API_URL),
    component=ctm.Component.STEP,
    enabled=True
)


def schedule_task(task):
    timer = Timer('Scheduling task %s' % task.id).start()
    if task.state == TaskState.SCHEDULED:
        _log.warning('Already scheduled task in schedule_task (%s)', task.id)
        return  # due to errors with broken connection to broker some tasks try to be scheduled again

    def _safe_format(sem):
        try:
            return sem.format(**task.event_params)
        except KeyError:
            return sem

    task.semaphores = map(lambda sem: _safe_format(sem), task.semaphores)

    schedule = {
        ActionTypes.SANDBOX_TASK: schedule_sandbox_task,
        ActionTypes.SANDBOX_SCHEDULER: schedule_sandbox_scheduler,
        ActionTypes.WEB_HOOK: schedule_web_hook
    }.get(task.action_config.action_type, schedule_sandbox_task)

    ex = None
    try:
        _log.info('Task %s is going to be scheduled', task.id)
        task.set_state(TaskState.SCHEDULING)
        if task.state == TaskState.REJECTED:
            return  # TODO: prettify
        task.save()
        schedule(task)
        if task.action_config.action_type == ActionTypes.WEB_HOOK:
            return
    except common.rest.Client.HTTPError as ex:
        if ex.status in (httplib.BAD_REQUEST, httplib.FORBIDDEN):
            task.task_failed_due_to_params()
            _log.warning('Task %s returned %s: %s', task.id, ex.status, ex)
        else:
            _log.exception('Failed to schedule task %s', task.id)
        task.set_state(TaskState.FAILED)
        task.save()
    except Exception as ex:
        task.set_state(TaskState.FAILED)
        task.save()
        _log.exception('Failed to schedule task %s', task.id)
    else:
        task.set_state(TaskState.SCHEDULED)
        task.time_scheduled = dt.datetime.now()
        task.save()
        _log.info('Task %s has been scheduled', task.id)
    finally:
        if ex is not None:
            _stats.push(dict(
                type=ctst.SignalType.EXCEPTION_STATISTICS,
                timestamp=dt.datetime.utcnow(),
                exc_type=type(ex).__name__,
                client_id=common.config.Registry().this.id,
                component=ctm.Component.STEP
            ))
        timer.mark()


def _create_sandbox_task(sandbox_api, task, task_data):
    sandbox_task_id = task.sandbox_task_id
    if not sandbox_task_id:
        _log.info('Going to create task with params %s', task_data)
        result = sandbox_api.task(task_data)
        sandbox_task_id = result['id']
        _log.info('Sandbox task #%s with params %s has been created, STEP task %s', sandbox_task_id, task_data, task.id)
        if type(task).objects(id=task.id, sandbox_task_id=None).update_one(set__sandbox_task_id=sandbox_task_id):
            task.sandbox_task_id = sandbox_task_id
            task.save()
        else:
            task.reload()
            sandbox_task_id = task.sandbox_task_id
            _log.warning('Sandbox task for STEP task %s has already been created with id #%s', task.id, sandbox_task_id)
    return sandbox_task_id


def _schedule_in_statinfra(task):
    action_id = task.custom_fields.get('action_id') or task.task_params['action_id']
    if not action_id:
        raise Exception('No action_id in task %s' % task.id)
    if task.custom_params.get('dyn_sem'):
        r = requests.get(
            '{infra_host}/api/v1/actions/semaphores'.format(infra_host=conf.STATINFRA_API_HOST),
            params=dict(
                action_id=action_id,
                **task.event_params
            )
        )
        task.semaphores = r.json()['result']

    task.semaphores = task.semaphores or no_semaphore_semaphore(action_id)

    prio = int(task.custom_fields.get('prio', 110))
    params = dict(task.event_params)

    step_events = [e.to_dict() for e in task.events]
    event_deps_case = task.custom_params.get('case')

    if event_deps_case != 'default':
        params['__event_deps_case'] = event_deps_case

    unique_key = hashlib.md5('{}-{}'.format(action_id, json.dumps(params, sort_keys=True))).hexdigest()

    action_record = requests.get(
        '{}/api/v1/action_registry/{}'.format(
            conf.STATINFRA_API_HOST,
            action_id
        ),
        timeout=10,
    ).json()

    max_cont_failures = action_record.get('max_cont_failures', -1)
    failure_interval_factor = action_record.get('failure_interval_factor', 1)

    def get_cf(name, value):
        return {'name': name, 'value': value}

    custom_fields = [
        get_cf('action_id', action_id),
        get_cf('event_params', params),
        get_cf('prio', prio),
        get_cf('step_events', json.dumps(step_events)),
        get_cf('event_deps_case', event_deps_case),
        get_cf('max_cont_failures', max_cont_failures),
        get_cf('failure_interval_factor', failure_interval_factor),
    ]

    task_data = {
        'type': task.task_type,
        'custom_fields': custom_fields,
        'priority': task.task_params['priority'],
        'owner': task.task_params.get('owner', 'STATINFRA'),
        'uniqueness': {
            'key': unique_key,
            'excluded_statuses': ['BREAK', 'FINISH', 'EXECUTE'],
        },
        'requirements': {
            'semaphores': {
                'acquires': [
                    {'name': sem_name, 'capacity': 1}
                    for sem_name in task.semaphores
                ]
            }
        },
        'description': json.dumps({
            'action_id': action_id,
            'params': params,
        }, separators=(',', ':'), sort_keys=True)
    }

    sandbox_api = http_client.Sandbox()

    sandbox_task_id = _create_sandbox_task(sandbox_api, task, task_data)

    result = sandbox_api.batch.tasks.start.update([sandbox_task_id])[0]
    if result['status'] != 'SUCCESS':
        (_log.warning if result['status'] == 'WARNING' else _log.error)(
            'Starting status for task #%s: %s', result['id'], result['message']
        )


@if_true(conf.REAL_SCHEDULING)
def schedule_sandbox_task(task):
    _log.info('Scheduling %s %s %s %s', task.task_type, task.task_params, task.event_params, task.semaphores)
    if task.task_type in ('STATINFRA_TASK', 'STATINFRA_TASK_BETA'):
        _schedule_in_statinfra(task)
        return

    sandbox_api = http_client.Sandbox()

    task_params = dict(task.task_params)
    task_params.update({
        'owner': task_params.get('owner', conf.DEFAULT_GROUP),
        'priority': task_params.get('priority', conf.DEFAULT_PRIORITY)
    })

    semaphores = {sem_name: 1 for sem_name in task.semaphores}
    for sem_dict in task_params.get('requirements', {}).get('semaphores', {}).get('acquires', []):
        semaphores[sem_dict['name']] = sem_dict['capacity']

    if semaphores:
        if 'requirements' not in task_params:
            task_params['requirements'] = {}

        if 'semaphores' not in task_params['requirements']:
            task_params['requirements']['semaphores'] = {'acquires': []}

        task_params['requirements']['semaphores']['acquires'] = [
            {'name': k, 'capacity': v} for k, v in semaphores
        ]

    input_parameters = dict(event_params=task.event_params, **{k: v for k, v in task.custom_fields.items()})
    input_parameters['step_events'] = [e.to_dict() for e in task.events]
    task_data = dict(
        type=task.task_type,
        custom_fields=[dict(name=k, value=v) for k, v in input_parameters.iteritems()],
        **task_params
    )

    sandbox_task_id = _create_sandbox_task(sandbox_api, task, task_data)

    result = sandbox_api.batch.tasks.start.update([sandbox_task_id])[0]
    if result['status'] != 'SUCCESS':
        (_log.warning if result['status'] == 'WARNING' else _log.error)(
            'Starting status for task #%s: %s', result['id'], result['message']
        )


@if_true(conf.REAL_SCHEDULING)
def schedule_sandbox_scheduler(task):
    _log.info('Scheduling scheduler %s %s %s', task.scheduler_id, task.event_params, task.semaphores)
    sandbox_api = http_client.Sandbox()
    result = sandbox_api.task(scheduler_id=task.scheduler_id, context=dict(event_params=task.event_params))
    sandbox_task_id = result['id']
    result = sandbox_api.batch.tasks.start.update([sandbox_task_id])[0]
    if result['status'] != 'SUCCESS':
        (_log.warning if result['status'] == 'WARNING' else _log.error)(
            'Starting status for task #%s: %s', result['id'], result['message']
        )
    task.sandbox_task_id = sandbox_task_id
    task.save()


@if_true(conf.REAL_SCHEDULING)
def schedule_web_hook(task):
    task_id = str(task.pk)
    _log.info('Scheduling task %s: web hook %s with parameters %s', task_id, task.url, task.event_params)
    session = requests.Session()
    session.headers['X-Request-Id'] = task_id
    _log.info('Calling web hook %s from task %s', task.url, task_id)
    if not task.first_attempt:
        task.first_attempt = dt.datetime.now()
    task.attempts += 1
    try:
        data = (
            json.loads(task.body)
            if task.body else
            dict(events=[e.to_dict() for e in task.events], params=task.event_params)
        )
        response = session.request('POST', task.url, timeout=task.response_timeout or 1, json=data)
        response.raise_for_status()
    except requests.RequestException as ex:
        if not task.response_timeout or isinstance(ex, requests.HTTPError) and ex.response.status_code < 500:
            _log.info('Web hook %s called from task %s with result: %s', task.url, task_id, ex)
            task.time_scheduled = dt.datetime.now()
            task.set_state(TaskState.SCHEDULED)
        else:
            _log.error('Error in web hook %s: %s', task_id, ex)
            total_time = (dt.datetime.now() - task.first_attempt).total_seconds()
            if total_time >= task.retry_timeout or task.attempts >= task.retry_attempts:
                _log.warning(
                    'Disabling task %s after %s attempt(s) and total time %ss',
                    task_id, task.attempts, total_time
                )
                task.enabled = False
            else:
                _log.warning('Task %s will be retried later', task_id)
            task.set_state(TaskState.FAILED)
    else:
        _log.info('Web hook %s successfully called with result: %s', task_id, response.text)
        task.time_scheduled = dt.datetime.now()
        task.set_state(TaskState.SCHEDULED)
    task.save()
