import logging
import time

from infra.ya_salt.lib.components import component
from infra.ya_salt.lib import pbutil
from infra.ya_salt.lib.packages import stateutil
from infra.ya_salt.lib import saltutil
from infra.ya_salt.lib.packages import action
from infra.ya_salt.lib.components import yasm
from infra.ya_salt.lib import constants

log = logging.getLogger('salt-component')


def extract_yasm_from_lo(lo, spec):
    for state in lo:
        for act in stateutil.lo_to_action(state):
            if not isinstance(act, action.Install):
                continue
            if act.name == yasm.Yasm.AGENT_PACKAGE:
                # Extract agent version and replace salt state with noop
                # We cannot simply remove it because of probable dependencies
                spec.yasm.agent_version = act.version
                saltutil.mutate_lo_state_to_nop(state)
                break


class SaltComponentsOrly(object):
    def __init__(self, orly_check_fun):
        self._f = orly_check_fun
        self._triggered = None
        self.message = None

    def check_orly(self):
        if not self._triggered:
            self.message = self._f()
            self._triggered = True
        return self.message


class SaltComponent(object):
    def __init__(self, selector, orly, compiled_a, compiled_b, diff):
        self.selector = selector
        self._a = compiled_a
        self._b = compiled_b
        self._diff = diff
        self._orly = orly

    def has_diff(self):
        return bool(self._diff)

    def _get_allowed_component(self):
        if self.has_diff():
            log.info('Checking orly...')
            err = self._orly.check_orly()
            if err is not None:
                log.info('Orly check failed for {}: {}'.format(self.selector, err))
                return CompiledSaltComponent.noop_component(self._b)
            else:
                log.info('Orly allowed apply {}'.format(self.selector))
                return self._b
        else:
            return self._b

    def apply(self, ctx):
        result = self._get_allowed_component().apply(ctx)
        if self.has_diff() and result.applied():
            err = result.persist_lo()
            if err:
                log.error('Cannot persist lo state for {}: {}'.format(result.selector, err))

        return result

    def get_packages_actions(self):
        return self._b.get_packages_actions()


class Compiler(component.Compiler):
    def __init__(self, repo, salt_status, orly, spec_b, hctl):
        self.spec_b = spec_b
        self.repo = repo
        self.status = salt_status
        self.hctl = hctl
        self.orly = orly
        self._salt_orly = SaltComponentsOrly(self._get_orly_check_fun())

    def _get_orly_check_fun(self):
        if self.orly is None:
            def check_orly():
                return None
        else:
            if self.spec_b.need_initial_setup:
                rule_id = constants.ORLY_SALT_RULE_INITIAL
            elif self.spec_b.env_type == 'prestable':
                rule_id = constants.ORLY_SALT_RULE_PRESTABLE
            else:
                rule_id = constants.ORLY_SALT_RULE

            def check_orly():
                return self.orly.start_operation(rule_id)

        return check_orly

    @staticmethod
    def _validate_high(selector, hi):
        foreign_components = []
        component_name = selector.get_name()
        prefix = 'components.{}.'.format(component_name)
        cleaner_prefix = 'components.{}-cleaner.'.format(component_name)
        for state_id, state in hi.states().items():
            sls = state['__sls__']
            if not (sls.startswith(prefix)
                    or sls.startswith(cleaner_prefix)
                    or sls == selector.get_match()
                    or sls == 'packages'):
                foreign_components.append((state_id, sls))
        if foreign_components:
            return 'foreign components found for component {}: {}'.format(
                selector,
                ', '.join(['{}: {}'.format(s, state_id) for state_id, s in foreign_components])
            )
        else:
            return None

    @staticmethod
    def _get_version_from_lo(selector, lo):
        for state in lo.states():
            sls = state['__sls__']
            parts = sls.split('components.{}.'.format(selector.get_name()), 1)
            if len(parts) == 2:
                return parts[1]
        return None

    @staticmethod
    def _is_hostctl_component(lo):
        for state in lo.states():
            if state['state'] == 'hostctl' and state['fun'] == 'manage':
                return True
        return False

    @staticmethod
    def _validate_hostctl_lo(lo):
        if len(lo.states()) != 1:
            return 'hostctl component must have exactly and only one hostctl.manage state'
        state = lo.states()[0]
        if not state.get('name') or not state.get('contents'):
            return 'hostctl.manage should have "name" and "contents" parameters defined and not empty'
        return None

    def _compile_a(self, selector, component_status):
        lo, err = self.repo.lo_from_selector(selector)
        if err:
            log.warning('Cannot load previous lo_state: {}'.format(err))
            lo = self.repo.empty_lo()

        version = self._get_version_from_lo(selector, lo)
        if version:
            component_status.prev_version = version

        if self._is_hostctl_component(lo):
            return CompiledHostctlComponent(selector, self.repo, lo, component_status, version, self.hctl)
        else:
            return CompiledSaltComponent(selector, self.repo, lo, component_status, version)

    def _compile_b(self, selector, component_status):
        hi, err = self.repo.render_high_selector(selector)
        if err:
            return None, err

        err = self._validate_high(selector, hi)
        if err:
            return None, err

        lo, err = hi.render_lo()
        if err:
            return None, err

        version = self._get_version_from_lo(selector, lo)
        if version:
            component_status.new_version = version

        if self._is_hostctl_component(lo):
            err = self._validate_hostctl_lo(lo)
            if err:
                return None, 'failed to compile hostctl component({}): {}'.format(selector, err)
            return CompiledHostctlComponent(selector, self.repo, lo, component_status, version, self.hctl), None
        else:
            return CompiledSaltComponent(selector, self.repo, lo, component_status, version), None

    @staticmethod
    def _diff_compiled(a, b):
        diff = saltutil.get_lo_diff(a.lo.states(), b.lo.states())
        if diff is None:
            log.info("Execution plan for '{}' *NOT* changed! (revision: prev='{}', cur='{}').".format(
                b.status.name,
                a.lo.get_rev(),
                b.lo.get_rev(),
            ))
        else:
            log.warning("Execution plan for '{}' changed! (revision: prev='{}', cur='{}').".format(
                b.status.name,
                a.lo.get_rev(),
                b.lo.get_rev(),
            ))
            log.info(diff)
        return diff

    def _apply_hacks(self, b):
        # TODO: remove dirty hacks
        component_name = b.selector.get_name()
        if component_name == 'yasm':
            extract_yasm_from_lo(b.lo.states(), self.spec_b)

    def compile(self, selector):
        component_status = self.status.salt_components.add()
        component_status.name = selector.get_name()
        component_status.selector = selector.get_match()

        a = self._compile_a(selector, component_status)
        b, err = self._compile_b(selector, component_status)

        if err:
            pbutil.false_cond(component_status.initialized, err)
            return None, err
        else:
            pbutil.true_cond(component_status.initialized)

        self._apply_hacks(b)

        diff = self._diff_compiled(a, b)
        return SaltComponent(selector, self._salt_orly, a, b, diff), None


class CompiledSaltComponent(component.CompiledComponent):
    def __init__(self, selector, repo, lo, component_status, version, noop=False):
        self.selector = selector
        self.repo = repo
        self.status = component_status
        self.lo = lo
        self.version = version
        self._noop = noop

    @staticmethod
    def _component_empty(lo):
        if lo.states():
            return False
        else:
            return True

    def get_packages_actions(self):
        return stateutil.extract_lo_package_actions(self.lo.states())

    def apply(self, ctx):
        results = {}
        if self.version:
            self.status.applied_version = self.version

        empty = self._component_empty(self.lo)
        if not empty:
            results, err = self.repo.get_executor().execute(ctx, self.lo)
            if err is not None:
                log.error('Salt component {} apply failed: {}'.format(self.status.name, err))
                pbutil.set_condition(self.status.applied, 'False', err)
                return AppliedSaltComponent(self.selector, self.repo, self.lo, self.status, results, self._noop)

        pbutil.set_condition(self.status.applied, 'True')
        return AppliedSaltComponent(self.selector, self.repo, self.lo, self.status, results, self._noop)

    @classmethod
    def noop_component(cls, component):
        empty_lo = component.repo.empty_lo()
        return cls(component.selector, component.repo, empty_lo, component.status, None, True)


class CompiledHostctlComponent(component.CompiledComponent):
    def __init__(self, selector, repo, lo, component_status, version, hctl, noop=False):
        self.selector = selector
        self.repo = repo
        self.status = component_status
        self.lo = lo
        self.version = version
        self.hctl = hctl
        self._noop = noop
        self.status.skip_reporting = True

    def get_packages_actions(self):
        # TODO: parse hostctl spec for packages?
        return tuple(), tuple()

    def _wrap_ret(self, ret):
        return {
            'hostctl_|-{0}_|-{0}_|-manage'.format(self.selector.get_match()): ret
        }

    def apply(self, ctx):
        if self.version:
            self.status.applied_version = self.version
        start_time = time.time()
        ret = {
            'name': 'manage',
            'changes': {},
            'result': False,
            'comment': '',
            'started': int(start_time),
            'start_time': int(start_time),
            '__run_num__': 0,
        }
        if self.lo.states():
            state = self.lo.states()[0]
            unit_content = state['contents']
            unit_name = state['name']
            unit_lo = state
            ret['__sls__'] = unit_lo['__sls__']
            ret['__id__'] = unit_lo['__id__']
            # use hostctl has own overrides, so use default repo path
            err = self.hctl.manage_inline(unit_name, constants.LOCAL_REPO_CURRENT, unit_content)
            ret['comment'] = 'hostctl manage unit {}: {}'.format(unit_name, err or 'success')
            ret['duration'] = time.time() - start_time
            ret['result'] = err is None
            if err is None:
                ctx.ok(self.selector)
            else:
                ctx.fail(self.selector, err)
        else:
            ret['comment'] = 'noop hoctctl component'
            ret['duration'] = time.time() - start_time
            ret['result'] = True
        pbutil.set_condition(self.status.applied, 'True')
        return AppliedHostctlComponent(self.selector, self.repo, self.lo, self.status, self._wrap_ret(ret), self._noop)

    @classmethod
    def noop_component(cls, component):
        empty_lo = component.repo.empty_lo()
        return cls(component.selector, component.repo, empty_lo, component.status, None, True)


class AppliedSaltComponent(component.AppliedComponent):
    def __init__(self, selector, repo, lo, status, results, noop=False):
        self.selector = selector
        self.repo = repo
        self.lo = lo
        self.status = status
        self.results = results
        self._noop = noop

    def is_noop(self):
        return self._noop

    def persist_lo(self):
        return self.repo.persist_lo_for_selector(self.lo, self.selector)

    def process_results(self):
        pbutil.update_status_from_result(self.results, self.status.salt)

    def applied(self):
        return self.status.applied.status == 'True'


class AppliedHostctlComponent(component.AppliedComponent):
    def __init__(self, selector, repo, lo, status, results, noop=False):
        self.selector = selector
        self.repo = repo
        self.lo = lo
        self.status = status
        self.results = results
        self._noop = noop

    def is_noop(self):
        return self._noop

    def persist_lo(self):
        return self.repo.persist_lo_for_selector(self.lo, self.selector)

    def process_results(self):
        pbutil.update_status_from_result(self.results, self.status.salt)

    def applied(self):
        return self.status.applied.status == 'True'
