# coding: utf-8

import json
import logging
import collections
import yt.wrapper as yt
import yt.yson as yson

from irt.bannerland.options import get_option
from bannerland.archive_workers.full_state import FSWorker
from irt.utils import get_duration

DEFAULT_DATE_FMT = get_option('bannerland_pocket_name_format')

timings_config = get_option('timings_config')
fields_renaming = {
    'tao_table_time': 'waiting_upload_and_merge_tao',
}


class CalculateTimingsWorker(FSWorker):
    def __init__(self, **kwargs):
        super(CalculateTimingsWorker, self).__init__(**kwargs)
        self.pocket = ''
        self.fs = ''

    def _get_source(self, source_type):
        if source_type == 'fs':
            return self.fs
        elif source_type == 'pocket':
            return self.pocket
        raise ValueError("Unknown source_type: {}".format(source_type))

    def _get_timings_from_attributes(self, steps_group):
        source_type = steps_group.get('source')
        source = self._get_source(source_type)
        attribute_prefix = steps_group.get('steps_prefix', '')
        timings = collections.OrderedDict()
        for step in steps_group['steps']:
            timings[step] = self.yt_client.get_attribute(source, attribute_prefix + step)
        return timings

    def do_work(self, fs_dir):
        logging.info('do_work at %s ...', fs_dir)
        self.fs = fs_dir
        self.pocket = self.yt_client.get_attribute(fs_dir, 'last_pocket')
        if not self.pocket:
            raise ValueError("Attribute last_pocket is not set!")
        timings = collections.OrderedDict()

        for group in timings_config['attributes']:
            step_name = group['name']
            timings[step_name] = self._get_timings_from_attributes(group)

        schema = yson.YsonList([
            {'name': 'OrderID', 'type': 'uint64', 'required': True},
            {'name': 'task_id', 'type': 'string', 'required': True},
            {'name': 'timings', 'type': 'any', 'required': False},
            {'name': 'durations', 'type': 'any', 'required': False},
        ])
        schema.attributes['strict'] = False  # не делаем strict для безболезненного удаления поле
        self.yt_client.run_map(
            TimingsMapper(timings),
            yt.TablePath(yt.ypath_join(self.pocket, 'tasks.final')),
            yt.TablePath(yt.ypath_join(self.pocket, 'tasks_timings'), attributes={'schema': schema}),
        )
        logging.info('last_pocket %s ...', self.pocket)


class TimingsMapper:
    def __init__(self, other_timings):
        self.other_timings = other_timings
        self.end_time = self._get_value_from_path(other_timings, timings_config['end_time'])

    def __call__(self, row):
        result = collections.OrderedDict()
        for steps_group in timings_config['table']:
            steps_group_type = steps_group.get('type')
            group_name = steps_group.get('name')
            aggregated = collections.OrderedDict()
            if steps_group_type == 'from_column':
                data = json.loads(row.get(steps_group['column'], '{}'))
                if 'key' in steps_group:
                    data = data.get(steps_group['key'], {})
                for step_name in steps_group['steps']:
                    aggregated[step_name] = data.get(step_name)
            elif steps_group_type == 'columns':
                for step_name in steps_group['steps']:
                    aggregated[step_name] = row.get(step_name)
            result[group_name] = aggregated

        result.update(self.other_timings)
        durations = self._process_durations(result)

        yield {
            'OrderID': row['OrderID'],
            'task_id': row['task_id'],
            'timings': result,
            'durations': durations
        }

    @staticmethod
    def _get_value_from_path(data, path):
        for step_key in path.split('/'):
            data = data.get(step_key)
        return data

    @staticmethod
    def _process_durations(data):
        first_step_time = ''
        steps_durations = collections.OrderedDict()
        last_step_end_time = ''
        for group_name in data.keys():
            steps = data[group_name]
            for step_name in steps.keys():
                step_timing = steps[step_name]
                duration = 0
                if isinstance(step_timing, dict):
                    duration = step_timing.get('duration', 0)
                    end_time = step_timing.get('end_time', '')
                    if not first_step_time:
                        first_step_time = step_timing.get('start_time', '')
                    if end_time:
                        last_step_end_time = end_time
                elif isinstance(step_timing, str):
                    if not first_step_time:
                        first_step_time = step_timing
                    if last_step_end_time:
                        duration = get_duration(last_step_end_time, step_timing, DEFAULT_DATE_FMT)
                        last_step_end_time = step_timing
                if group_name == 'make_banners':
                    if step_name == 'start_time':
                        step_name = 'waiting_for_merge_pocket_tao'
                    else:
                        step_name = ''
                if step_name in fields_renaming:
                    step_name = fields_renaming[step_name]
                full_step_name = '/'.join(filter(None, [group_name, step_name]))
                steps_durations[full_step_name] = duration
        steps_durations['total_time'] = get_duration(first_step_time, last_step_end_time, DEFAULT_DATE_FMT)
        return steps_durations
