from sandbox import sdk2
from sandbox.projects.adfox.adfox_ui.util.duty_tools import DutyTools
from sandbox.projects.adfox.adfox_ui.util.date import current_mysql_datetime, mysql_string_to_datetime, now
import sandbox.common.types.task as ctt
import uuid
import logging

TASK_RESULT_OK = 'ok'
TASK_RESULT_FAIL = 'fail'
TASK_DURATION_TRACK = 'task-duration'


class Analyzable(object):
    class Context(sdk2.Task.Context):
        run_id = None
        metrics_release_id = None
        metrics_id = None
        metrics_rows = {}
        metrics_pauses = {}

    class Parameters(sdk2.Parameters):
        with sdk2.parameters.Group('Metrics settings', description='Metrics settings') as metrics_settings:
            send_metrics = sdk2.parameters.Bool(
                'Send metrics',
                description='Send metrics',
                default=True
            )
            metrics_id = sdk2.parameters.String(
                'Metrics ID',
                description='Unique identifier of the release. If empty takes parent\'s metrics_id if exists or generates'
            )
            metrics_release_id = sdk2.parameters.String(
                'Release ID',
                description='Symbolic representation of the release. If empty takes parent\'s metrics_id if exists or generates'
            )

    # call in task lifecycle method
    def metrics_on_execute(self):
        with self.memoize_stage.task_duration_metrics:
            self.track_task_start()

        if self.agentr.iteration > 1:
            self.track_task_resume()

    # call task lifecycle method
    def metrics_on_wait(self, prev_status, status):
        self.track_task_pause()

    # call in task lifecycle method
    def metrics_on_finish(self, prev_status, status):
        if status == ctt.Status.FAILURE:
            self.track_task_finish(TASK_RESULT_FAIL)
        elif status == ctt.Status.SUCCESS:
            self.track_task_finish()

    # call in task lifecycle method
    def metrics_on_failure(self):
        self.track_task_finish(TASK_RESULT_FAIL)

    # call in task lifecycle method
    def metrics_on_break(self, prev_status, status):
        self.track_task_finish(TASK_RESULT_FAIL)

    # call in task lifecycle method
    def metrics_on_timeout(self):
        self.track_task_finish(TASK_RESULT_FAIL)

    @property
    def metrics_pauses(self):
        if not self.Context.metrics_pauses:
            self.Context.metrics_pauses = {}
        return self.Context.metrics_pauses

    @property
    def metrics_rows(self):
        if not self.Context.metrics_rows:
            self.Context.metrics_rows = {}
        return self.Context.metrics_rows

    @property
    def run_id(self):
        if not self.Context.run_id:
            if self.parent and hasattr(self.parent, 'run_id'):
                self.Context.run_id = self.parent.run_id
            else:
                self.Context.run_id = str(uuid.uuid4())

        return self.Context.run_id

    @property
    def metrics_id(self):
        if not self.Context.metrics_id:
            if self.Parameters.metrics_id:
                self.Context.metrics_id = self.Parameters.metrics_id
            elif self.parent and hasattr(self.parent, 'metrics_id'):
                self.Context.metrics_id = self.parent.metrics_id
            else:
                self.Context.metrics_id = str(uuid.uuid4())

        return self.Context.metrics_id

    def get_metrics_release_id(self):
        return 'not defined'

    @property
    def metrics_release_id(self):
        if not self.Context.metrics_release_id:
            if self.Parameters.metrics_release_id:
                self.Context.metrics_release_id = self.Parameters.metrics_release_id
            elif self.parent and hasattr(self.parent, 'metrics_release_id'):
                self.Context.metrics_release_id = self.parent.metrics_release_id
            else:
                self.Context.metrics_release_id = self.get_metrics_release_id()

        return self.Context.metrics_release_id

    def track_start(self, metric_type):
        metric_type = self.get_current_task_track_code(metric_type)
        logging.info('Track start: {}'.format(metric_type))
        self.metrics_rows[metric_type] = {
            'task_id': self.id,
            'task_type': '{}'.format(self.type),
            'run_id': self.run_id,
            'external_id': self.metrics_id,
            'metric_type': metric_type,
            'processing_time': 0,
            'started_at': current_mysql_datetime()
        }

    def track_finish(self, metric_type, result=TASK_RESULT_OK):
        metric_type = self.get_current_task_track_code(metric_type)
        if result in [TASK_RESULT_OK, TASK_RESULT_FAIL] and metric_type in self.metrics_rows:
            logging.info('Track finish: {}'.format(metric_type))
            self.metrics_rows[metric_type].update({
                'result': result,
                'finished_at': current_mysql_datetime()
            })

            last_resume = mysql_string_to_datetime(
                self.metrics_pauses[metric_type]['resumed_at']
                if metric_type in self.metrics_pauses and 'resumed_at' in self.metrics_pauses[metric_type]
                else self.metrics_rows[metric_type]['started_at']
            )

            self.metrics_rows[metric_type]['processing_time'] += (now() - last_resume).seconds

    def track_pause(self, metric_type):
        metric_type = self.get_current_task_track_code(metric_type)
        if metric_type in self.metrics_rows:
            # if already paused
            if metric_type in self.metrics_pauses and 'resumed_at' not in self.metrics_pauses[metric_type]:
                logging.error('Cannot pause {}: already paused'.format(metric_type))
            else:
                logging.info('Track pause: {}'.format(metric_type))
                last_resume = mysql_string_to_datetime(
                    self.metrics_pauses[metric_type]['resumed_at'] if metric_type in self.metrics_pauses
                    else self.metrics_rows[metric_type]['started_at']
                )

                self.metrics_rows[metric_type]['processing_time'] += (now() - last_resume).seconds
                self.metrics_pauses[metric_type] = {
                    'paused_at': current_mysql_datetime()
                }

    def track_resume(self, metric_type):
        metric_type = self.get_current_task_track_code(metric_type)
        if metric_type in self.metrics_rows and metric_type in self.metrics_pauses:
            logging.info('Track resume: {}'.format(metric_type))
            self.metrics_pauses[metric_type]['resumed_at'] = current_mysql_datetime()
        else:
            logging.error('Cannot resume {}: already running'.format(metric_type))

    def get_current_task_name(self):
        return self.type

    def get_current_task_track_code(self, track_name):
        return '{}|{}'.format(self.get_current_task_name(), track_name)

    # if cannot be run in on_enqueue
    def init_metrics_on_execute(self):
        if not self.Context.metrics_id:
            self.Context.metrics_id = self.metrics_id
        if not self.Context.metrics_release_id:
            self.Context.metrics_release_id = self.metrics_release_id

    def on_task_enqueue(self):
        if not self.Parameters.metrics_id:
            self.Parameters.metrics_id = self.metrics_id

        if not self.Context.metrics_id:
            self.Context.metrics_id = self.Parameters.metrics_id

        if not self.Parameters.metrics_release_id:
            self.Parameters.metrics_release_id = self.metrics_release_id

        if not self.Context.metrics_release_id:
            self.Context.metrics_release_id = self.Parameters.metrics_release_id

    # task starts
    def track_task_start(self):
        self.track_start(TASK_DURATION_TRACK)

    # task is waiting, not doing anything
    def track_task_pause(self):
        self.track_pause(TASK_DURATION_TRACK)

    # task resumes
    def track_task_resume(self):
        self.track_resume(TASK_DURATION_TRACK)

    # task finishes
    def track_task_finish(self, result=TASK_RESULT_OK):
        self.track_finish(TASK_DURATION_TRACK, result)
        self.send_collected_metrics()

    def send_collected_metrics(self):
        if len(self.metrics_rows) > 0 and self.Parameters.send_metrics:
            for row in self.metrics_rows.values():
                row.update({
                    'release_id': self.metrics_release_id,
                })
            logging.info('Sending collected metrics')
            logging.info('Collected metrics: {}'.format(self.metrics_rows))
            duty_tools = DutyTools()
            duty_tools.send_common_metrics(self.metrics_rows.values())


class AnalyzableTask(sdk2.Task, Analyzable):
    class Parameters(Analyzable.Parameters):
        pass

    class Context(Analyzable.Context):
        pass

    def pseudo_on_enqueue(self):
        self.init_metrics_on_execute()

    def on_execute(self):
        self.metrics_on_execute()

    def on_enqueue(self):
        self.on_task_enqueue()

    def on_break(self, prev_status, status):
        self.metrics_on_break(prev_status, status)

    def on_finish(self, prev_status, status):
        self.metrics_on_finish(prev_status, status)

    def on_wait(self, prev_status, status):
        self.metrics_on_wait(prev_status, status)
