from __future__ import absolute_import, division, print_function

from sandbox.projects.yabs.sandbox_task_tracing.util import frozendict

from sandbox.projects.yabs.sandbox_task_tracing.writers.abstract import AbstractTraceWriter, CallSequenceError


class DictTraceWriter(AbstractTraceWriter):
    '''
    Trace writer storing trace records in provided dict.

    Primarily for tests.

    :ivar int task_id: task id set by `set_task`
    :ivar int iteration_id: iteration id set by `set_tasks`
    :ivar dict[int, TraceRecord] ids_record: records stored so far, keyed by record id
    '''

    def __init__(self, ids_record):
        '''
        Initializes new dict trace writer instance.
        '''
        self.task_id = None
        self.iteration_id = None
        self.ids_record = ids_record

    def set_task(self, task_id, iteration_id, task_info=frozendict()):
        if self.task_id is not None or self.iteration_id is not None:
            raise CallSequenceError('Method `set_task` should only be called once')

        assert task_id is not None
        assert iteration_id is not None

        self.task_id = task_id
        self.iteration_id = iteration_id
        self.task_info = task_info

    def write_records(self, records):
        if self.task_id is None or self.iteration_id is None:
            raise CallSequenceError('Method `set_task` must be called first')
        self.ids_record.update((record.record, record) for record in records)

    def flush(self, timeout=None):
        # this writer is synchronous
        pass
