"""Contains all logic for handling host rebooting via SSH."""

import logging

from walle.clients import ssh
from walle.expert.constants import HOST_BOOT_TIME
from walle.fsm_stages.common import (
    register_stage,
    generate_stage_handler,
    get_current_stage,
    commit_stage_changes,
    fail_current_stage,
    complete_stage,
)
from walle.stages import Stages

_STATUS_KEXEC_REBOOT = "reboot"
"""Host reboot is scheduled."""

_STATUS_KEXEC_REBOOTING = "rebooting"
"""Host is currently rebooting."""

_CHECK_PERIOD = 5
"""Power on/off check period for ssh operations."""

_REBOOT_DELAY = 35
"""reboot via 'sudo -n /usr/sbin/wall-e.kexec-helper --reboot' so we need this delay to be at least 35 seconds"""

# This one is so big because it needs to wait until whole operation system shuts down and then boots up,
# not just hardware power on.
_KEXEC_REBOOT_TIMEOUT = 2 * HOST_BOOT_TIME
"""Host power state operation timeout for ssh operations."""

log = logging.getLogger(__name__)


def _kexec_reboot(host):
    stage = get_current_stage(host)
    try:
        with host.get_ssh_client() as client:
            boot_id = client.get_boot_id()
            client.issue_kexec_reboot_command()
        stage.set_temp_data("boot_id", boot_id)
        return commit_stage_changes(host, status=_STATUS_KEXEC_REBOOTING, check_after=_REBOOT_DELAY)

    except ssh.SshError as error:
        message = "Failed to reboot host via kexec: {}"
        fail_current_stage(host, message.format(error))


def _wait_for_reboot_to_complete(host):
    stage = get_current_stage(host)
    try:
        with host.get_ssh_client() as client:
            boot_id = client.get_boot_id()

    except ssh.SshConnectionFailedError:
        # Host is probably down
        return _check_timeout(host, stage, "The host didn't kexec on after given timeout")
    except ssh.SshError as error:
        message = "Failed to reboot host via kexec: {}"
        return fail_current_stage(host, message.format(error))

    if boot_id != stage.get_temp_data("boot_id"):
        return complete_stage(host, stage)

    return _check_timeout(host, stage, "Timeout while waiting for the host to actually reboot via kexec")


def _check_timeout(host, stage, error):
    timeout = _KEXEC_REBOOT_TIMEOUT

    if stage.timed_out(timeout):
        message = "Failed to reboot host via kexec: {}."
        return fail_current_stage(host, message.format(error))
    else:
        return commit_stage_changes(host, check_after=_CHECK_PERIOD)


register_stage(
    Stages.KEXEC_REBOOT,
    generate_stage_handler(
        {_STATUS_KEXEC_REBOOT: _kexec_reboot, _STATUS_KEXEC_REBOOTING: _wait_for_reboot_to_complete}
    ),
    initial_status=_STATUS_KEXEC_REBOOT,
)
