'''
Sandbox task tracing library implementation details: context managers.

Symbols defined here are not part of public interface.
Do not use them directly or rely on specifics of their behaviour.
'''

from __future__ import absolute_import, division, print_function

import itertools
import logging
import threading
import traceback

import sandbox.sdk2

from sandbox.projects.yabs.sandbox_task_tracing.defaults import DEFAULTS
from sandbox.projects.yabs.sandbox_task_tracing.info import (
    exception_info,
    resource_info,
    stack_info,
    task_info,
)
from sandbox.projects.yabs.sandbox_task_tracing.util import (
    coalesce,
    frozendict,
    microseconds_from_utc_iso,
    pairwise,
    time_microseconds,
)

from sandbox.projects.yabs.sandbox_task_tracing.impl.buffer import add_trace_record, clear_trace_record_queue, TraceWritingThread


logger = logging.getLogger(__name__)


def new_record_id(__counter=itertools.count(start=1), __lock=threading.Lock()):
    with __lock:  # not really needed in CPython, but who knows
        return next(__counter)


INITIALIZATION_LOCK = threading.RLock()
TRACE_WRITING_THREAD = None
ITERATION_RECORD_ID = new_record_id()


def flush_trace(timeout=None, ignore_errors=None):
    with INITIALIZATION_LOCK:
        if not TRACE_WRITING_THREAD:
            return
        try:
            TRACE_WRITING_THREAD.flush(timeout=timeout)
        except:
            logger.exception('Flush error')
            if not coalesce(ignore_errors, DEFAULTS['entry_point_spec']['ignore_exit_errors']):
                raise


class UsageError(Exception):
    pass


class EmptyContextManager(object):

    def __enter__(self):
        return self

    def __exit__(self, exception_type, exception_value, exception_traceback):
        pass


class EntryPoint(object):

    def __init__(
        self,
        task,
        writer_factory,
        spec=frozendict(),
    ):
        self.task = task
        self.writer_factory = writer_factory
        self.spec = frozendict(spec)
        self.iteration_requirements = None
        self.trace_current_audit_record = EmptyContextManager()
        self.iteration_started = None

    def __write_earlier_audit_records(self, audit_records):
        # Filter out non-status-change audit records (e.g. manual priority changes)
        audit_records = list(filter(lambda audit_record: audit_record['status'], audit_records))

        if not audit_records or audit_records[-1]['status'] != 'EXECUTING':
            logger.warning('audit_records=%s', audit_records)
            raise UsageError('Last audit status is not EXECUTING (perhaps the entry point is not on_execute)')

        earlier_audit_record_pairs = list(itertools.takewhile(
            lambda audit_record_pair: audit_record_pair[1]['status'] != 'EXECUTING',
            pairwise(reversed(audit_records)),
        ))
        for next_audit_record, audit_record in reversed(earlier_audit_record_pairs):
            info = dict(record=audit_record)
            semaphores = self.iteration_requirements['semaphores']
            if audit_record['status'] == 'ENQUEUED' and semaphores:
                info.update(semaphores=semaphores)
            add_trace_record(
                record=new_record_id(),
                parent=ITERATION_RECORD_ID,
                started=microseconds_from_utc_iso(audit_record['time']),
                finished=microseconds_from_utc_iso(next_audit_record['time']),
                type='audit',
                info=info,
            )

        earliest_record = earlier_audit_record_pairs[-1][1] if earlier_audit_record_pairs else audit_records[-1]
        self.trace_current_audit_record = TraceCurrentAuditRecord(audit_records[-1])
        self.iteration_started = microseconds_from_utc_iso(earliest_record['time'])

    # called under lock
    def __initialize(self):
        global TRACE_WRITING_THREAD
        if TRACE_WRITING_THREAD:
            logger.warning('Tracing has already been initialized')
            return

        task_full_info = task_info(self.task, format_='full')

        if self.task.id != task_full_info['id']:
            raise UsageError('Entry point called for task object {}, but current task is {}'.format(
                self.task.id,
                task_full_info['id'],
            ))

        self.iteration_requirements = task_full_info['_raw']['requirements']

        thread = TraceWritingThread(self.writer_factory, dict(
            task_id=self.task.id,
            iteration_id=sandbox.sdk2.Task.current.agentr.iteration,
            task_info=task_full_info,
        ))
        thread.start()
        TRACE_WRITING_THREAD = thread

        # If we fail here, the writer would still be initialized,
        # but some audit records might be lost.
        # There is no easy way to avoid this without risking duplication.
        self.__write_earlier_audit_records(audit_records=task_full_info['audit'])

    def __enter__(self):
        try:
            with INITIALIZATION_LOCK:
                self.__initialize()
        except:
            logger.exception('Error initializing trace writer')
            if not (DEFAULTS['entry_point_spec'] + self.spec)['ignore_initialization_errors']:
                raise
        return self

    def __exit__(self, exception_type, exception_value, exception_traceback):
        global TRACE_WRITING_THREAD
        try:
            if self.iteration_started is not None:
                add_trace_record(
                    record=ITERATION_RECORD_ID,
                    parent=0,
                    started=self.iteration_started,
                    finished=time_microseconds(),
                    type='iteration',
                    info=dict(requirements=self.iteration_requirements),
                )
            try:
                with INITIALIZATION_LOCK:
                    if TRACE_WRITING_THREAD:
                        try:
                            TRACE_WRITING_THREAD.flush(timeout=self.spec.get('flush_timeout_seconds'))
                        finally:
                            if (DEFAULTS['entry_point_spec'] + self.spec)['finish_on_exit']:
                                TRACE_WRITING_THREAD.finish(timeout=self.spec.get('finish_timeout_seconds'))
                                TRACE_WRITING_THREAD = None
            finally:
                if (DEFAULTS['entry_point_spec'] + self.spec)['clear_queue_on_exit']:
                    clear_trace_record_queue()
        except:
            logger.exception('Error leaving entry point')
            if not (DEFAULTS['entry_point_spec'] + self.spec)['ignore_exit_errors']:
                raise


class TraceContext(object):

    __thread_local_storage = threading.local()

    def __init__(self, type_, info=frozendict(), stack_info_enabled=None, stack_info_spec=frozendict()):
        self.record = None
        self.parent = None
        self.started = None
        self.finished = None
        self.type = type_
        self.info = dict(info)  # make a shallow copy of info since we are going to change it

        self.stack_info_enabled = coalesce(stack_info_enabled, DEFAULTS['stack_info_enabled'])
        self.stack_info_spec = stack_info_spec

    def __enter__(self):
        if self.record is not None:
            raise UsageError('This context manager can only be entered once')

        if not hasattr(self.__thread_local_storage , 'stack'):
            self.__thread_local_storage.stack = [ITERATION_RECORD_ID]
        stack = self.__thread_local_storage.stack
        self.record = new_record_id()
        self.parent = stack[-1]
        stack.append(self.record)

        if self.stack_info_enabled:
            self.info.update(stack=stack_info(traceback.extract_stack(), self.stack_info_spec))

        self.started = time_microseconds()

        return self

    def __exit__(self, exception_type, exception_value, exception_traceback):
        self.finished = time_microseconds()

        stack = getattr(self.__thread_local_storage, 'stack', [])
        if stack[-2:] != [self.parent, self.record]:
            logger.exception('Current exception, if any:')
            raise UsageError('Invalid nesting: expected={}, found={}'.format([self.parent, self.record], stack[-2:]))
        stack.pop()

        if exception_type:
            self.info.update(exception=exception_info(
                exception_type,
                exception_value,
                exception_traceback,
                self.stack_info_enabled,
                self.stack_info_spec,
            ))

        add_trace_record(
            record=self.record,
            parent=self.parent,
            started=self.started,
            finished=self.finished,
            type=self.type,
            info=self.info,
        )


class TraceCurrentAuditRecord(TraceContext):

    def __init__(self, audit_record):
        super(TraceCurrentAuditRecord, self).__init__(
            'audit',
            info=dict(record=audit_record),
            stack_info_enabled=False,
        )

    def __enter__(self):
        result = super(TraceCurrentAuditRecord, self).__enter__()
        self.started = microseconds_from_utc_iso(self.info['record']['time'])
        return result


class TraceNewResources(TraceContext):

    def __init__(self, *args, **kwargs):
        super(TraceNewResources, self).__init__(*args, **kwargs)
        self.resources = []

    def append(self, value):
        self.resources.append(value)

    def extend(self, values):
        self.resources.extend(values)

    def __exit__(self, exception_type, exception_value, exception_traceback):
        try:
            self.info.update(resources=list(map(resource_info, self.resources)))
        finally:
            super(TraceNewResources, self).__exit__(exception_type, exception_value, exception_traceback)


class TraceNewTasks(TraceContext):

    def __init__(self, *args, **kwargs):
        super(TraceNewTasks, self).__init__(*args, **kwargs)
        self.tasks = []

    def append(self, value):
        self.tasks.append(value)

    def extend(self, values):
        self.tasks.extend(values)

    def __exit__(self, exception_type, exception_value, exception_traceback):
        try:
            self.info.update(tasks=list(map(task_info, self.tasks)))
        finally:
            super(TraceNewTasks, self).__exit__(exception_type, exception_value, exception_traceback)
