import logging
import sys
import traceback
from contextlib import contextmanager

from walle.errors import WalleError, UnknownFixableError
from walle.fsm_stages.common import NEXT_CHECK_ADVANCE_TIME
from walle.models import timestamp
from walle.scenario.errors import ErrorList, DEFAULT_ERROR_ADVANCE_PERIOD

log = logging.getLogger(__name__)


def get_current_stage_info(stage_info):
    if len(stage_info.stages) != 0:
        # Temporary fix for saved schedulers, earlier scheduler didn't have seq_num
        idx = stage_info.seq_num or 0
        return get_current_stage_info(stage_info.stages[idx])
    return stage_info


@contextmanager
def scenario_root_stage_error_handler(stage_info, scenario):
    active_stage_info = get_current_stage_info(stage_info)
    errors = dict()
    next_check = timestamp() + NEXT_CHECK_ADVANCE_TIME

    try:
        yield
    except ErrorList as error:
        errors, next_check = collect_errors_and_next_check(*error.errors)
    except WalleError as error:
        log.exception("Error processing scenario #%s:", scenario.scenario_id)
        exc_id, exc_context = collect_exc_context_for_scenario_root_stage(error, active_stage_info)
        stage_error = transform_to_stage_error(scenario, exc_id, exc_context, error)
        errors, next_check = collect_errors_and_next_check(stage_error)
    except Exception as error:
        log.exception("Error processing scenario #%s:", scenario.scenario_id)
        unknown_error = UnknownFixableError(error)
        exc_id, exc_context = collect_exc_context_for_scenario_root_stage(error, active_stage_info)
        stage_error = transform_to_stage_error(scenario, exc_id, exc_context, unknown_error)
        errors, next_check = collect_errors_and_next_check(stage_error)
    finally:
        store_errors(scenario, errors, next_check)


class HssErrorHandler:
    def __init__(self):
        self.errors = []

    @contextmanager
    def __call__(self, stage_info, scenario, host):
        active_stage_info = get_current_stage_info(stage_info)
        try:
            yield
        except ErrorList as error:
            self.errors.extend(error.errors)
        except WalleError as error:
            log.exception("Error processing scenario #%s host %s:", scenario.scenario_id, host.human_id())
            self.append_error(error, active_stage_info, host, scenario)
        except Exception as error:
            log.exception("Error processing scenario #%s host %s:", scenario.scenario_id, host.human_id())
            unknown_error = UnknownFixableError(error, message=error)
            self.append_error(unknown_error, active_stage_info, host, scenario)

    def append_error(self, error, active_stage_info, host, scenario):
        exc_id, exc_context = collect_exc_context_for_host_root_stage(error, active_stage_info, host)
        stage_error = transform_to_stage_error(scenario, exc_id, exc_context, error)
        self.errors.append(stage_error)

    def raise_exception(self):
        if self.errors:
            raise ErrorList(self.errors)


def transform_to_stage_error(scenario, exc_id, exc_context, error):
    stage_error = scenario.get_scenario_error(exc_id, error)
    stage_error.stacktrace = traceback.format_exception(*sys.exc_info())
    stage_error.repeats += 1
    stage_error.is_visible = is_visible(stage_error, error)
    stage_error.exc_info = exc_context
    stage_error.retry_period = error.retry_period
    return stage_error


def collect_errors_and_next_check(*errors):
    next_run_delay = DEFAULT_ERROR_ADVANCE_PERIOD

    updated_errors_dict = {}

    for error in errors:
        next_run_delay = min(next_run_delay, error.retry_period)
        updated_errors_dict[error.id] = error

    next_check = timestamp() + next_run_delay

    return updated_errors_dict, next_check


def store_errors(scenario, errors, next_check):
    scenario.errors = errors
    scenario.next_check_time = next_check


def is_visible(error, raw_error):
    if isinstance(raw_error, WalleError) and raw_error.is_recoverable() and error.repeats < raw_error.retry_limit:
        return False
    return True


def collect_exc_context_for_scenario_root_stage(exception, stage_info):
    exc_id = "{} <-> {}".format(stage_info.uid, str(exception))
    exc_id = exc_id.replace(".", "_")
    return exc_id, dict(
        timestamp=timestamp(),
        stage_uid=stage_info.uid,
        stage_name=stage_info.name,
        stage_action=stage_info.action_type,
        error_message=str(exception),
        exc_id=exc_id,
    )


def collect_exc_context_for_approve_stage_stage(exception, stage_info, error_seq_num):
    exc_id, exc_context = collect_exc_context_for_scenario_root_stage(exception, stage_info)
    exc_id += " <-> {}".format(error_seq_num)
    return exc_id, exc_context


def collect_exc_context_for_host_root_stage(exception, stage_info, host):
    exc_id = "{} <-> {} <-> {}".format(stage_info.uid, host.uuid, str(exception))
    exc_id = exc_id.replace(".", "_")
    return exc_id, dict(
        timestamp=timestamp(),
        stage_uid=stage_info.uid,
        stage_name=stage_info.name,
        stage_action=stage_info.action_type,
        error_message=str(exception),
        exc_id=exc_id,
        host_uuid=host.uuid,
        host_inv=host.inv,
        host_name=host.name,
    )


class GroupStageErrorHandler:
    def __init__(self):
        self.errors = []

    @contextmanager
    def __call__(self, stage_info, scenario, group_id):
        active_stage_info = get_current_stage_info(stage_info)
        try:
            yield
        except ErrorList as error:
            self.errors.extend(error.errors)
        except WalleError as error:
            log.exception("Error processing scenario #%s group id %s:", scenario.scenario_id, group_id)
            self.append_error(error, active_stage_info, group_id, scenario)
        except Exception as error:
            log.exception("Error processing scenario #%s group_id %s:", scenario.scenario_id, group_id)
            unknown_error = UnknownFixableError(error, message=error)
            self.append_error(unknown_error, active_stage_info, group_id, scenario)

    def append_error(self, error, active_stage_info, group_id, scenario):
        exc_id, exc_context = collect_exc_context_for_group_stage(error, active_stage_info, group_id)
        stage_error = transform_to_stage_error(scenario, exc_id, exc_context, error)
        self.errors.append(stage_error)

    def raise_exception(self):
        if self.errors:
            raise ErrorList(self.errors)


def collect_exc_context_for_group_stage(exception, stage_info, group_id):
    exc_id = "{} <-> {} <-> {}".format(stage_info.uid, group_id, str(exception))
    exc_id = exc_id.replace(".", "_")
    return exc_id, dict(
        timestamp=timestamp(),
        stage_uid=stage_info.uid,
        stage_name=stage_info.name,
        stage_action=stage_info.action_type,
        error_message=str(exception),
        group_id=group_id,
        exc_id=exc_id,
    )
