'''
Sandbox task tracing library implementation details: trace records buffer.

Symbols defined here are not part of public interface.
Do not use them directly or rely on specifics of their behaviour.
'''

from __future__ import absolute_import, division, print_function

import itertools
import logging
from six.moves import queue
import sys
import threading
import time

from sandbox.projects.yabs.sandbox_task_tracing.defaults import DEFAULTS
from sandbox.projects.yabs.sandbox_task_tracing.info import jsonified
from sandbox.projects.yabs.sandbox_task_tracing.util import coalesce
from sandbox.projects.yabs.sandbox_task_tracing.exceptions import BufferTimeout

from sandbox.projects.yabs.sandbox_task_tracing.writers.record import TraceRecord


logger = logging.getLogger(__name__)


TRACE_RECORD_QUEUE = queue.Queue()


def add_trace_record(record, parent, started, finished, type, info):
    '''
    Creates and queues new trace record.

    Parameters coincide with those of TraceRecord constructor.
    '''
    info = jsonified(info)
    TRACE_RECORD_QUEUE.put_nowait(TraceRecord(**locals()))


def clear_trace_record_queue():
    '''
    Clears the trace record queue.
    '''
    while not TRACE_RECORD_QUEUE.empty():
        TRACE_RECORD_QUEUE.get_nowait()


class TraceWritingThread(threading.Thread):
    '''
    Thread writing queued trace records in background.
    '''

    FLUSH_TOKEN = object()
    FINISH_TOKEN = object()

    def __init__(self, writer_factory, set_task_kwargs):
        super(TraceWritingThread, self).__init__(name=type(self).__name__)
        self.daemon = True
        self.queue = TRACE_RECORD_QUEUE
        self.writer_factory = writer_factory
        self.set_task_kwargs = jsonified(set_task_kwargs)
        self.exc_info = (None, None, None)

    def run(self):
        try:
            trace_writer = self.writer_factory()
            trace_writer.set_task(**self.set_task_kwargs)
            while True:
                record = self.queue.get()
                if record not in (self.FLUSH_TOKEN, self.FINISH_TOKEN):
                    records = [record]
                    while True:
                        try:
                            record = self.queue.get_nowait()
                        except queue.Empty:
                            break
                        if record in (self.FLUSH_TOKEN, self.FINISH_TOKEN):
                            break
                        records.append(record)
                    trace_writer.write_records(records)
                if record in (self.FLUSH_TOKEN, self.FINISH_TOKEN):
                    trace_writer.flush()
                if record is self.FINISH_TOKEN:
                    break
        except:
            self.exc_info = sys.exc_info()
            logger.exception('Error in trace writing thread')

    def flush(self, timeout=None, __poll_period_seconds=0.1):
        timeout = coalesce(timeout, DEFAULTS['entry_point_spec']['flush_timeout_seconds']) / __poll_period_seconds
        # Queue.join does not have timeout argument, so have to be creative
        self.queue.put_nowait(self.FLUSH_TOKEN)  # triggers special processing: next item
        self.queue.put_nowait(self.FLUSH_TOKEN)  # will not be read until all items are processed
        for iteration_index in itertools.count():
            if not self.is_alive():
                if self.exc_info[1]:
                    raise self.exc_info[1]
                if not self.queue.empty():
                    raise BufferTimeout()
                break
            if self.queue.empty():
                break
            if iteration_index >= timeout:
                raise BufferTimeout()
            time.sleep(__poll_period_seconds)

    def finish(self, timeout=None, __poll_period_seconds=0.1):
        timeout = coalesce(timeout, DEFAULTS['entry_point_spec']['finish_timeout_seconds']) / __poll_period_seconds
        # Queue.join does not have timeout argument, so have to be creative
        self.queue.put_nowait(self.FINISH_TOKEN)  # triggers special processing: thread finishes
        for iteration_index in itertools.count():
            if not self.is_alive():
                break
            if iteration_index >= timeout:
                raise BufferTimeout()
            time.sleep(__poll_period_seconds)
