from __future__ import unicode_literals

import calendar
import logging
from collections import defaultdict, namedtuple

import time
from datetime import datetime, timedelta
from six.moves.urllib_parse import urljoin, urlencode

import requests
import six
from infra.watchdog.src.lib.metrics import Counter

UNPROCESSED_TERMINAL_TASKGROUPS_QS = urlencode({'statuses': 'REJECTED,CANCELLED,FAILED,DONE',
                                                'nanny_processed': 'false'})

Task = namedtuple('Task', ('status', 'completion_status', 'dependencies'))
Snapshot = namedtuple('Snapshot', ('service_id', 'id', 'target_state'))


def get_value(struct, path, default=None):
    if struct is None or not path:
        return default
    path = path.split('/')
    current = struct
    for i, chunk in enumerate(path):
        try:
            if i >= len(path) - 1:
                return current.get(chunk, default)
            else:
                current = current.get(chunk, {})
        except AttributeError:
            return default


class AlemateClient(object):
    def __init__(self, nanny_url, token):
        self.session = requests.Session()
        self.session.headers['Authorization'] = 'OAuth {}'.format(token)
        self.url = nanny_url
        self.log = logging.getLogger('watchdog-alemate')
        self._taskgroups_url = urljoin(self.url, '/v1/alemate/task_groups/')
        self._children_url_template = urljoin(self.url, '/v1/alemate/task_groups/{}/children/')
        self._workers_url = urljoin(self.url, '/v1/alemate/status/workers/')
        self._taskgroup_status_url = urljoin(self.url, '/v1/alemate/task_groups/{}/status/')

    def get_unprocessed_taskgroups_count(self):
        resp = self.session.get(url=urljoin(self.url, '/v1/alemate/task_groups/count/'),
                                params={'filter': UNPROCESSED_TERMINAL_TASKGROUPS_QS})
        resp.raise_for_status()
        return int(resp.json()['count'])

    def get_taskgroup_status(self, taskgroup_id):
        resp = self.session.get(url=self._taskgroup_status_url.format(taskgroup_id))
        resp.raise_for_status()
        return resp.json()['status']

    @staticmethod
    def _get_expected_snapshot_state(task):
        processor_options = task.get('processorOptions', {})
        if processor_options.get('type') != 'SetSnapshotTargetState':
            return None
        status = get_value(task, 'dispatcherTaskInfo/metaTask/state/state')
        if status != 'WAITING':
            return None
        opts = processor_options.get('options', {})
        return Snapshot(service_id=opts.get('service_id'),
                        id=opts.get('snapshot_id'),
                        target_state=opts.get('target_state')
                        )

    @staticmethod
    def _get_task_info(task):
        return Task(status=get_value(task, 'schedulerOptions/status'),
                    completion_status=get_value(task, 'schedulerOptions/state/state'),
                    dependencies=get_value(task, 'schedulerOptions/dependencies', default=[]),
                    )

    def _get_taskgroups_tasks_info(self, skip, limit, reqid, query):
        resp = self.session.get(url=self._taskgroups_url,
                                headers={'X-Req-Id': reqid},
                                params={
                                    'filter': query,
                                    'skip': skip,
                                    'limit': limit, }, )
        resp.raise_for_status()
        taskgroups = set(tg['id'] for tg in resp.json())
        expected_snapshot_statuses = {}
        enqueued_tasks = Counter()
        for tg in taskgroups:
            children = {}
            children_resp = self.session.get(url=self._children_url_template.format(tg))
            resp.raise_for_status()
            for task in children_resp.json():
                children[task.get('id')] = self._get_task_info(task)
                snapshot = self._get_expected_snapshot_state(task)
                if snapshot is not None and snapshot.id is not None:
                    expected_snapshot_statuses.setdefault(snapshot.service_id, {})
                    expected_snapshot_statuses[snapshot.service_id].setdefault(snapshot.id, defaultdict(int))
                    expected_snapshot_statuses[snapshot.service_id][snapshot.id][snapshot.target_state] += 1
            for task_id, task_info in six.iteritems(children):
                if task_info.status in ('NEW', 'ENQUEUED') and task_info.completion_status != 'DONE':
                    if not task_info.dependencies:
                        enqueued_tasks.inc()
                    elif all('DONE' in (children[dep_id].completion_status, children[dep_id].status)
                             for dep_id in task_info.dependencies):
                        enqueued_tasks.inc()
        return expected_snapshot_statuses, enqueued_tasks

    @staticmethod
    def calculate_worker_params(total_tasks, pool_size):
        if not total_tasks or not pool_size:
            return 0, []
        workers_count = min(total_tasks, pool_size)
        per_worker_limit = total_tasks // workers_count
        if total_tasks % workers_count != 0:
            per_worker_limit += 1
        per_worker_limit = min(per_worker_limit, 100)
        actual_workers = total_tasks // per_worker_limit
        if total_tasks % per_worker_limit != 0:
            actual_workers += 1
        worker_skips = [per_worker_limit * i for i in range(actual_workers)]
        return per_worker_limit, worker_skips

    @staticmethod
    def _period_one_week():
        return calendar.timegm((datetime.now() - timedelta(days=7)).timetuple())

    def collect_tasks_info(self, pool):
        reqid = str(hash(time.time()))
        running_taskgroups_query = urlencode({'statuses': 'NEW,MERGED,COMMITTED',
                                             'creation_timestamp_gte': self._period_one_week()})
        resp = self.session.get(url=self._taskgroups_url,
                                headers={'X-Req-Id': reqid},
                                params={'filter': running_taskgroups_query, 'limit': 1, }, )
        total_tasks = int(resp.headers.get('X-Total-Items', 0))
        per_worker_limit, worker_skips = self.calculate_worker_params(total_tasks, pool.size)
        imap = pool.imap_unordered(
            lambda skip: self._get_taskgroups_tasks_info(skip, per_worker_limit, reqid, running_taskgroups_query),
            worker_skips)
        total_expected_snapshot_statuses = {}
        total_enqueued_tasks = Counter()
        for expected_snapshot_statuses, enqueued_tasks in imap:
            if expected_snapshot_statuses:
                for service_id, snapshots in six.iteritems(expected_snapshot_statuses):
                    for snapshot_id, target_states in six.iteritems(snapshots):
                        for target_state, count in six.iteritems(target_states):
                            total_expected_snapshot_statuses.setdefault(service_id, {})
                            total_expected_snapshot_statuses[service_id].setdefault(snapshot_id, defaultdict(int))
                            total_expected_snapshot_statuses[service_id][snapshot_id][target_state] += count
            if enqueued_tasks:
                total_enqueued_tasks.inc(enqueued_tasks.get())
        return total_expected_snapshot_statuses, total_enqueued_tasks
