"""Contains all logic for handling host rebooting via SSH."""

import logging

from walle.clients import ssh
from walle.expert.constants import HOST_BOOT_TIME
from walle.fsm_stages.common import (
    register_stage,
    generate_stage_handler,
    get_current_stage,
    commit_stage_changes,
    fail_current_stage,
    complete_parent_stage,
    fail_stage,
    get_parent_stage,
    complete_stage,
)
from walle.fsm_stages.power_post_util import STATUS_WAIT_POST_COMPLETE, wait_post_complete, goto_post_code_check
from walle.host_platforms.platform_manager import create_platform_for_host
from walle.stages import Stages

_STATUS_REBOOT = "reboot"
"""Host reboot is scheduled."""

_STATUS_WAITING_HOST_DOWN = "waiting-host-down"
"""Host waits shutdown"""

_STATUS_REBOOTING = "rebooting"
"""Host is currently rebooting."""

_CHECK_PERIOD = 5
"""Power on/off check period for ssh operations."""

_REBOOT_DELAY = 65
"""reboot via 'shutdown -r +1' so we need this delay to be at least 65 seconds"""

_POST_CHECK_DELAY = 90
"""Period before start polling POST code after we detected host gone down. 90 seconds for SSH reboot now."""

# This one is so big because it needs to wait until whole operation system shuts down and then boots up,
# not just hardware power on.
_REBOOT_TIMEOUT = 2 * HOST_BOOT_TIME
"""Host power state operation timeout for ssh operations."""

_REBOOT_TIMEOUT_WITH_POST_CHECK = 3 * HOST_BOOT_TIME
"""Host power state operation timeout for ssh operations for stages with POST code check"""

log = logging.getLogger(__name__)


def _reboot(host):
    stage = get_current_stage(host)
    try:
        with host.get_ssh_client() as client:
            boot_id = client.get_boot_id()
            client.issue_reboot_command()

        stage.set_temp_data("boot_id", boot_id)

        platform = create_platform_for_host(host)

        if stage.get_param("check_post_code", False) and platform.provides_post_code():
            # go to waiting-host-down only if platform supports POST code handling and stage wants POST checking
            return commit_stage_changes(host, status=_STATUS_WAITING_HOST_DOWN, check_after=_REBOOT_DELAY)

        return commit_stage_changes(host, status=_STATUS_REBOOTING, check_after=_REBOOT_DELAY)

    except ssh.SshError as error:
        message = "Failed to reboot host via ssh: {}"
        fail_current_stage(host, message.format(error))


def _wait_for_host_down(host):
    stage = get_current_stage(host)
    try:
        with host.get_ssh_client() as client:
            boot_id = client.get_boot_id()

        if boot_id != stage.get_temp_data("boot_id"):
            # woops... host is up already
            return complete_parent_stage(host, stage)
        else:
            return commit_stage_changes(host, check_after=_CHECK_PERIOD)
    except ssh.SshConnectionFailedError:
        # host gone down, we can start POST code analysis in one minute
        return goto_post_code_check(host, delay=_POST_CHECK_DELAY)


def _wait_for_reboot_to_complete(host):
    stage = get_current_stage(host)
    try:
        with host.get_ssh_client() as client:
            boot_id = client.get_boot_id()

    except ssh.SshConnectionFailedError:
        # Host is probably down
        return _check_timeout(host, stage, "The host didn't power on after given timeout")
    except ssh.SshError as error:
        message = "Failed to reboot host via ssh: {}"
        return fail_current_stage(host, message.format(error))

    if boot_id != stage.get_temp_data("boot_id"):
        if stage.get_param("check_post_code", False):
            composite_stage = get_parent_stage(host, stage)

            reboot_stage = get_parent_stage(host, composite_stage)
            return complete_stage(host, reboot_stage)

        else:
            return complete_parent_stage(host)

    return _check_timeout(host, stage, "Timeout while waiting for the host to actually reboot")


def _check_timeout(host, stage, error):
    check_post_code = stage.get_param("check_post_code", False)

    timeout = _REBOOT_TIMEOUT_WITH_POST_CHECK if check_post_code else _REBOOT_TIMEOUT

    if stage.timed_out(timeout):
        message = "Failed to reboot host via ssh: {}."

        if check_post_code:
            parent_stage = get_parent_stage(host, stage)
            return fail_stage(host, parent_stage, message.format(error))

        else:
            return fail_current_stage(host, message.format(error))

    else:
        return commit_stage_changes(host, check_after=_CHECK_PERIOD)


def _on_post_ok(host):
    commit_stage_changes(host, status=_STATUS_REBOOTING, check_after=_CHECK_PERIOD)


def wait_post_complete_ssh(host):
    wait_post_complete(host, _on_post_ok)


register_stage(
    Stages.SSH_REBOOT,
    generate_stage_handler(
        {
            _STATUS_REBOOT: _reboot,
            _STATUS_REBOOTING: _wait_for_reboot_to_complete,
            _STATUS_WAITING_HOST_DOWN: _wait_for_host_down,
            STATUS_WAIT_POST_COMPLETE: wait_post_complete_ssh,
        }
    ),
    initial_status=_STATUS_REBOOT,
)
