# -*- coding: utf-8 -*-

import time
import random
from datetime import datetime, timedelta

import sandbox.projects.common.yabs.lbyt_reader.utils as lbyt_utils
from sandbox.projects.common.yabs.lbyt_reader.sandbox_task import YabsLbYtSandboxTask
from sandbox.projects.common.yabs.lbyt_reader.job_yt_client import MANAGER_STATE, get_task_offsets, DataProxy
from sandbox.projects.common.yabs.lbyt_reader.config import (
    LOGTYPES, MANAGER_JOBS, READER_JOBS,
    PARTITIONS_OVERLAP, LB_MERGE_GAP_COEFFICIENT,
)

EXTERNAL_OFFSET_LIMITATION = "offset_limitation"


class TaskState:
    def __init__(self, data_proxy):
        self.data_proxy = data_proxy
        self.cluster_state = {}


class YabsLbYtTaskManager(YabsLbYtSandboxTask):
    type = 'YABS_LBYT_TASK_MANAGER'
    _lbyt_jobs = MANAGER_JOBS

    cores = 1
    required_ram = 8072
    execution_space = 4096

    @staticmethod
    def is_task_ready(task_state):
        return any(state != 'sync' for state in task_state.cluster_state.values())

    def need_work(self):
        for yt_cluster in self.yt_clusters:
            yt_cluster.read_control_state()
            if yt_cluster.control_state.get(MANAGER_STATE):
                return True
        return False

    def do_job_work(self):
        lbyt_utils.do_for_clusters(self.yt_clusters, 'read_tasks')

        successful = 0
        for yt_cluster in self.yt_clusters:
            if yt_cluster.tasks is not None:
                successful += 1
        yt_quorum = len(self.yt_clusters) / 2 + 1
        self.logger.info(
            'Current yt clusters quorum state: %d of %d answered, %d needed',
            successful, len(self.yt_clusters), yt_quorum,
        )
        if successful < yt_quorum:
            self.logger.error('No yt clusters quorum, can not proceed')
            return False, 0
        if successful < len(self.yt_clusters):
            self.logger.warn('Not all yt clusters available')

        tasks = self.merge_tasks_info()
        skip_older_than = self.get_skip_older_than()
        for task_num in sorted(tasks, reverse=True):
            task_time_mark = max(tasks[task_num])
            if datetime.utcfromtimestamp(task_time_mark / 1000) < skip_older_than:
                continue
            all_clusters_synced = True
            task_state = tasks[task_num][task_time_mark]
            for yt_cluster in self.yt_clusters:
                if yt_cluster.tasks is None:
                    # To avoid rewriting all tasks when cluster is gone
                    continue
                if yt_cluster.name in task_state.cluster_state:
                    continue
                all_clusters_synced = False
                yt_cluster.add_task({
                    'num': task_num,
                    'time_mark': task_time_mark,
                    'state': 'sync',
                    'data': task_state.data_proxy.request(),
                })
            if all_clusters_synced:
                # We want to sync only unsynced top
                break
        if not DataProxy.load_requested():
            return False, 0

        new_task = None
        if not tasks:
            new_task = self.make_new_reader_task(1, {})
        else:
            last_task_num = max(tasks)
            last_task_time_mark = max(tasks[last_task_num])
            last_task_info = tasks[last_task_num][last_task_time_mark]
            if self.is_task_ready(last_task_info):
                # Last task is ready (or even done) somewhere, so it's time to make a new one.
                last_task_info.data_proxy.request()
                if not DataProxy.load_requested():
                    return False, 0
                offsets = get_task_offsets(last_task_info.data_proxy.get(), for_next=True)
                new_task = self.make_new_reader_task(last_task_num + 1, offsets)
            else:
                self.logger.warn('There is a task in "sync" state, can not make a new one')
                new_task = False
        limited_share = 0
        if new_task:
            limited_share = new_task.pop('limited_share')
            for yt_cluster in self.yt_clusters:
                if yt_cluster.tasks is None:
                    # To avoid creating a gap (see comment about rewriting above)
                    continue
                yt_cluster.add_task(new_task)

        self.logger.info('Prepared data: %s', ', '.join(
            '{} tasks for {}'.format(yt_cluster.prepared_tasks, yt_cluster.name)
            for yt_cluster in self.yt_clusters
        ))
        lbyt_utils.do_for_clusters(self.yt_clusters, 'flush_tasks')

        if yt_quorum > 1:
            lbyt_utils.do_for_clusters(self.yt_clusters, 'read_tasks')
            tasks = self.merge_tasks_info()
            for task_num in tasks:
                task_time_mark = max(tasks[task_num])
                task_state = tasks[task_num][task_time_mark]
                if (
                    len(task_state.cluster_state) < yt_quorum
                    and not self.is_task_ready(task_state)
                ):
                    # Task is not ready yet
                    continue
                for yt_cluster in self.yt_clusters:
                    if task_state.cluster_state.get(yt_cluster.name) == 'sync':
                        yt_cluster.mark_tasks([task_num], 'ready')
            self.logger.info('Quorum checker marked as ready: %s', ', '.join(
                '{} tasks at {}'.format(yt_cluster.prepared_tasks, yt_cluster.name)
                for yt_cluster in self.yt_clusters
            ))
            lbyt_utils.do_for_clusters(self.yt_clusters, 'flush_tasks')

        if new_task is not None and not new_task:
            # An error occurred while making task
            return False, 0

        run_freq = min(LOGTYPES[logtype]['frequency'] for logtype in self.config['logtypes'])
        if not limited_share:
            next_task_timestamp = self.start_timestamp + run_freq
            if self.config['use_time_binding']:
                next_task_timestamp -= next_task_timestamp % run_freq
            return True, next_task_timestamp

        # Empiric formula to make required reaction profile
        good_share_coef = max(1.0 - limited_share, 0.0) ** 5.0
        next_period = good_share_coef * run_freq
        self.logger.info(
            'Share of limited readings %.3f, run period reduced to %.2f seconds',
            limited_share, next_period,
        )
        return True, self.start_timestamp + next_period

    def merge_tasks_info(self):
        tasks = {}
        for yt_cluster in self.yt_clusters:
            if yt_cluster.tasks is None:
                continue
            for task_info in yt_cluster.tasks:
                num = task_info.num
                if num not in tasks:
                    tasks[num] = {}
                time_mark = task_info.time_mark
                if time_mark not in tasks[num]:
                    tasks[num][time_mark] = TaskState(task_info.data_proxy)
                tasks[num][time_mark].cluster_state[yt_cluster.name] = task_info.state
        return tasks

    def get_skip_older_than(self):
        min_keep_time = min(
            config['keep_completed_tasks']
            for config in READER_JOBS.itervalues()
            if config['job'] == self.job
        )
        return datetime.utcnow() - timedelta(hours=max(1, min_keep_time - 1))

    def make_new_reader_task(self, num, first_offsets_to_read):
        self.logger.info('Making new task #%d', num)

        lbyt_utils.do_for_clusters(self.lb_clusters, 'read_offsets', self.config['logtypes'])
        lb_info = self.merge_lb_info()

        if not lb_info:
            self.logger.error('No offsets info available, can not make a new task')
            return False
        self.logger.info('Got offsets for logtypes: %s', ', '.join(lb_info))

        new_task_data = {}
        overall_chunks = 0
        average_limited_share = 0
        has_new_logtype = False
        for logtype in self.config['logtypes']:
            if logtype not in first_offsets_to_read:
                has_new_logtype = True  # force make task to record current logtype offsets
            logtype_chunks, logtype_limited_share = self.fill_task_data(
                new_task_data, logtype, first_offsets_to_read, lb_info,
            )
            average_limited_share += logtype_limited_share
            overall_chunks += logtype_chunks
        average_limited_share /= float(max(len(self.config['logtypes']), 1))
        if overall_chunks == 0 and not has_new_logtype:
            self.logger.warn('Nothing to read, no need to make a new task')
            return None
        time_mark = int(1000 * time.time())
        self.logger.info(
            'Task #%d generated with time mark %d, overall chunks to read: %d',
            num, time_mark, overall_chunks,
        )
        return {
            'num': num,
            'time_mark': time_mark,
            'state': 'sync' if len(self.config['yt']) > 1 else 'ready',
            'data': new_task_data,
            'limited_share': average_limited_share,
        }

    def _get_limitations_dict(self):
        if len(self.yt_clusters) == 1:
            return self.yt_clusters[0].control_state.get(EXTERNAL_OFFSET_LIMITATION, {})
        return {}

    def fill_task_data(self, task_data, logtype, first_offsets_to_read, lb_info):
        if logtype in first_offsets_to_read:
            topparts_to_process = first_offsets_to_read[logtype]
        else:
            self.logger.warn('Not found logtype %s in last task info (new logtype?)', logtype)
            topparts_to_process = []
        topparts_to_process = list(set(topparts_to_process).union(lb_info.get(logtype, [])))
        random.shuffle(topparts_to_process)  # read evenly when big bulk of data is available

        chunks_to_read = 0
        chunks_accepted = 0
        chunks_available = 0
        limited_share = 0
        task_data[logtype] = []
        partitions_count = sum(
            1 for toppart, toppart_info in lb_info.get(logtype, {}).iteritems()
            if toppart_info.get('accepted') != first_offsets_to_read.get(logtype, {}).get(toppart)
        )
        max_read = LOGTYPES[logtype]['max_offsets_per_task']
        max_read_part = max_read * PARTITIONS_OVERLAP / max(partitions_count, 1)
        limitations_dict = self._get_limitations_dict()
        for toppart in topparts_to_process:
            if logtype not in first_offsets_to_read:
                offset = lb_info[logtype][toppart]['accepted']  # read from current position
            else:
                offset = first_offsets_to_read[logtype].get(toppart)
                if offset is None:
                    self.logger.warn(
                        'Not found partition %s for logtype %s in last task info (new partition?)',
                        toppart, logtype,
                    )
                    offset = 0  # read from the beginning
            available_until = lb_info.get(logtype, {}).get(toppart, {}).get('available', 0)
            limitation = limitations_dict.get(logtype, {}).get(toppart)
            accepted_until = lb_info.get(logtype, {}).get(toppart, {}).get('accepted', 0)
            if limitation is not None and limitation < accepted_until:
                accepted_until = limitation
            # Read not less than zero and not more than any of current restrictions.
            limit = max(0, min(accepted_until - offset, max_read - chunks_to_read, max_read_part))
            chunks_available += max(0, available_until - offset)
            chunks_accepted += max(0, accepted_until - offset)
            chunks_to_read += limit
            topic, partition = toppart.split(':')
            task_data[logtype].append({
                'topic': topic,
                'partition': int(partition),
                'offset': offset,
                'limit': limit,
            })
            if limit == max_read_part:
                limited_share += 1
        task_data[logtype].sort(key=(lambda part: (part['topic'], part['partition'])))

        rel_diff = 0
        if chunks_to_read > 0:
            rel_diff = float(chunks_accepted - chunks_to_read) / chunks_to_read

        rel_unread = 0
        if chunks_to_read > 0:
            rel_unread = float(chunks_available - chunks_to_read) / chunks_to_read

        self.add_metric(logtype + '.available', chunks_available)
        self.add_metric(logtype + '.accepted', chunks_accepted)
        self.add_metric(logtype + '.read', chunks_to_read)
        self.add_metric(logtype + '.relative_diff', rel_diff)
        self.add_metric(logtype + '.relative_unread', rel_unread)
        if chunks_to_read == max_read:
            limited_share = 1
        else:
            limited_share /= float(max(partitions_count, 1))
        return chunks_to_read, limited_share

    def merge_lb_info(self):
        lb_info = {}
        # Merge info from all LB clusters to one structure with list of offsets
        for lb_cluster in self.lb_clusters:
            for logtype, logtype_info in lb_cluster.offsets.iteritems():
                topparts = lb_info.setdefault(logtype, {})
                for toppart, offset in logtype_info.iteritems():
                    offsets = topparts.setdefault(toppart, [])
                    offsets.append(offset)
        # Select the best candidate from each list
        for logtype, topparts in lb_info.iteritems():
            # Empiric formula to select not very far replica
            max_gap = (
                LOGTYPES[logtype]['max_offsets_per_task']
                * LB_MERGE_GAP_COEFFICIENT
                * PARTITIONS_OVERLAP
                / max(len(topparts), 1)
            )
            for toppart in topparts:
                known_offsets = sorted(topparts[toppart])
                topparts[toppart] = {'available': known_offsets.pop()}
                if not known_offsets:
                    # No more info so accepting what we only know
                    topparts[toppart]['accepted'] = topparts[toppart]['available']
                else:
                    # Wow, there is some more info! Then the best value is minimum of the
                    # regular replicas. So let find the least value which is not very far
                    # (that is, not from the DC that have just returned after downtime).
                    min_acceptable = known_offsets[-1] - max_gap
                    self.logger.debug(
                        'For partition %s (logtype %s) we know %s (gap is %d so %d is acceptable)',
                        toppart, logtype, known_offsets, max_gap, min_acceptable,
                    )
                    topparts[toppart]['accepted'] = min(
                        candidate
                        for candidate in known_offsets
                        if candidate >= min_acceptable
                    )
                self.logger.debug(
                    'Accepted %d of %d for partition %s (logtype %s)',
                    topparts[toppart]['accepted'],
                    topparts[toppart]['available'],
                    toppart, logtype,
                )
        return lb_info


__Task__ = YabsLbYtTaskManager
