import collections
import copy
import json
import logging
import os
import pprint
import socket
import time

import yaml

import salt.fileclient
import salt.state

from infra.ya_salt.lib import constants
from . import gencfg
from infra.ya_salt.lib import fileutil

log = logging.getLogger('saltutil')

# These are keywords passed to state module functions which are to be used
# by salt in this state module and not on the actual state module function
STATE_REQUISITE_KEYWORDS = frozenset([
    'onchanges',
    'onfail',
    'prereq',
    'prerequired',
    'watch',
    'require',
    'listen',
])

LO_STATES_PATH = 'lo_states'

DEFAULT_MINION_OPTS = {
    'interface': '0.0.0.0',
    'master': 'salt',
    'user': 'root',
    'id': '',
    'cachedir': '/var/cache/salt/minion',
    'grains_cache': False,
    'grains_deep_merge': False,
    'renderer': 'yaml_jinja',
    'failhard': False,
    'autoload_dynamic_modules': True,
    'environment': None,
    'pillarenv': None,
    'pillar_opts': False,
    # ``pillar_cache``, ``pillar_cache_ttl`` and ``pillar_cache_backend``
    # are not used on the minion but are unavoidably in the code path
    'pillar_cache': False,
    'pillar_cache_ttl': 3600,
    'pillar_cache_backend': 'disk',
    'state_top': 'top.sls',
    'state_top_saltenv': None,
    'startup_states': '',
    'sls_list': [],
    'top_file': '',
    'file_client': 'local',
    'use_master_when_local': False,
    'top_file_merging_strategy': 'merge',
    'env_order': [],
    'default_top': 'base',
    'fileserver_limit_traversal': False,
    'file_recv': False,
    'file_recv_max_size': 100,
    'file_ignore_regex': [],
    'file_ignore_glob': [],
    'hash_type': 'md5',
    'ipv6': False,
    'test': False,
    'rejected_retry': False,
    'loop_interval': 1,
    'verify_env': True,
    'grains': {},
    'permissive_pki_access': False,
    'restart_on_error': False,
    'cmd_safe': True,
    'sudo_user': '',
    'state_auto_order': True,
}


def load_minion_config(minion_id=None):
    """
    :param str|None minion_id: hostname to override (in tests)
    :return: dictionary with salt opts
    """
    opts = DEFAULT_MINION_OPTS.copy()
    opts['id'] = minion_id or socket.gethostname()
    # Disable systemd-run as it can result in leaking cgroups.
    # See: RTCSUPPORT-3293
    opts['systemd.scope'] = False
    # Set default env to search_runtime
    opts['environment'] = 'search_runtime'
    return opts, None


def configure_dummy_minion(saltenv, minion_id=None):
    """
    :param str|None minion_id: hostname to override (in tests)
    :return: dictionary with salt opts
    """
    opts = DEFAULT_MINION_OPTS.copy()
    opts['id'] = minion_id or socket.gethostname()
    # Disable systemd-run as it can result in leaking cgroups.
    # See: RTCSUPPORT-3293
    opts['systemd.scope'] = False
    opts['environment'] = saltenv
    return opts, None


def rewrite_relative_src(chunk):
    """
    Rewrites 'source' in low state if it uses relative salt URL.
    """
    src = chunk.get('source')
    if not src:
        return
    index = src.find('salt://./')
    if index != 0:
        return
    sls = chunk.get('__sls__')
    if not sls:
        raise Exception('no __sls__ in chunk: {}'.format(chunk))
    sls = sls.replace('.', '/')
    chunk['source'] = 'salt://{}/{}'.format(sls, src[len('salt://./'):])


def state_args(id_, state, high):
    """
    Returns a set of the arguments passed to the named state.
    """
    args = set()
    if id_ not in high:
        return args
    if state not in high[id_]:
        return args
    for item in high[id_][state]:
        if not isinstance(item, dict):
            continue
        if len(item) != 1:
            continue
        args.add(next(iter(item)))
    return args


def find_name(name, state, high):
    """
    Scans high data for the id referencing the given name
    and return a list of (IDs, state) tuples that match

    Note: if `state` is sls, then we are looking for all IDs that match the given SLS
    """
    ext_id = []
    if name in high:
        ext_id.append((name, state))
    # if we are requiring an entire SLS, then we need to add ourselves to everything in that SLS
    elif state == 'sls':
        for nid, item in high.iteritems():
            if item['__sls__'] == name:
                ext_id.append((nid, next(iter(item))))
    # otherwise we are requiring a single state, lets find it
    else:
        # We need to scan for the name
        for nid in high:
            if state in high[nid]:
                if isinstance(high[nid][state], list):
                    for arg in high[nid][state]:
                        if not isinstance(arg, dict):
                            continue
                        if len(arg) != 1:
                            continue
                        if arg[next(iter(arg))] == name:
                            ext_id.append((nid, state))
    return ext_id


def _get_gencfg_grains(repo, grain_fn=gencfg.gencfg):
    try:
        with repo.open_file(constants.GENCFG_REPO_PATH) as f:
            d = yaml.load(f, Loader=yaml.SafeLoader)
    except Exception as e:
        return None, 'failed to load gencfg tag from {}: {}'.format(constants.GENCFG_REPO_PATH, e)
    if not isinstance(d, dict):
        return None, 'gencfg-tag file is not a YAML mapping'
    tag = d.get('gencfg-tag')
    if not tag:
        return None, 'no gencfg-tag in {}'.format(constants.GENCFG_REPO_PATH)
    try:
        return grain_fn(tag), None
    except Exception as e:
        return None, 'failed to get gencfg grain: {}'.format(e)


class Selector(object):
    @staticmethod
    def _fmt_name(m):
        if m.startswith('deploy.'):
            parts = m.split('deploy.', 1)
        elif m.startswith('components.'):
            parts = m.split('.')
        else:
            return m
        if len(parts) == 1:
            return parts[0]
        else:
            return parts[1]

    @staticmethod
    def _parse_raw_selector(r):
        t_end = r.find(':')
        if t_end > 0:
            type_ = r[:t_end]
            r = r[t_end+1:]
        else:
            type_ = 'salt'
        o_end = r.find('/')
        if o_end > 0:
            origin = r[:o_end]
            match = r[o_end+1:]
        else:
            # origin describes component's originating repo
            # for salt is "." as all other repos stored under salt repo
            # /var/lib/ya-salt/repo/current - salt, origin = '.'
            # /var/lib/ya-salt/repo/current/core - core repo, origin = 'core'
            origin = '.'
            match = r
        return type_, origin, match

    def __init__(self, env, match):
        self._env = env
        self._type, self._origin, self._match = self._parse_raw_selector(match)
        self._name = self._fmt_name(self._match)

    def get_env(self):
        return self._env

    def get_match(self):
        return self._match

    def get_name(self):
        return self._name

    def get_origin(self):
        return self._origin

    def get_type(self):
        return self._type

    def __str__(self):
        return 'Selector(env: {}, type: {}, origin: {}, match: {})'.format(self._env, self._type, self._origin, self._match)

    def __repr__(self):
        return str(self)


class SaltRepo(object):
    @staticmethod
    def init_opts(opts, path, roots):
        opts = copy.deepcopy(opts)
        opts['file_client'] = 'local'
        v = []
        for p in path:
            for r in roots:
                v.append(os.path.join(p, r))
        file_roots = {
            opts['environment']: v,
        }
        # See salt.utils.jinja.SaltCacheLoader for details.
        # Otherwise it will try to use /var/cache directory.
        opts['file_roots'] = file_roots
        opts['pillar_roots'] = file_roots
        return opts

    @staticmethod
    def _add_gencfg_grains(opts, repo, gencfg_grains_fn=_get_gencfg_grains):
        g, err = gencfg_grains_fn(repo)
        if err is not None:
            return None, err
        opts['grains'].update(g)
        return opts, None

    def __init__(self, meta, config, path, roots=None, file_client_cls=salt.fileclient.LocalClient, inject_grains=True):
        """
        config - opts with grains, but without gencfg
        """
        if isinstance(path, basestring):
            path = [path]
        self._meta = meta
        self._path = path
        # default roots
        self._roots = roots if roots else [config['environment'], 'common']
        opts = self.init_opts(config, path, self._roots)
        self._file_client = file_client_cls(opts)
        self._hostname = config['id']
        self._env = config['environment']
        if inject_grains:
            self._opts, self._err = self._inject_grains(opts)
        else:
            self._opts, self._err = opts, None
        self._hs = salt.state.HighState(self._opts, file_client=self._file_client)

    def _inject_grains(self, opts):
        g_opts, err = self._add_gencfg_grains(opts, self)
        if err:
            log.error('Cannot load gencfg grains: {}'.format(err))
        if not g_opts:
            g_opts = opts
        return g_opts, None

    def open_file(self, repo_path):
        # Searches through paths, returns last error.
        for p in self._path[:-1]:
            try:
                return open(os.path.join(p, repo_path))
            except Exception:
                continue
        return open(os.path.join(self._path[-1], repo_path))

    def has_overrides(self):
        return len(self._path) > 1

    def get_path(self):
        return self._path[:]

    def get_hostname(self):
        return self._hostname

    def get_env(self):
        return self._env

    def get_meta(self):
        return self._meta

    def get_opts(self):
        return self._opts, self._err

    def _hs_render_highstate(self, *args, **kwargs):
        # cleanup highstate
        if self._hs.building_highstate:
            self._hs.building_highstate = {}
        return self._hs.render_highstate(*args, **kwargs)

    def get_executor(self):
        if self._err is not None:
            return None, self._err
        return Executor(self._opts)

    def list_selectors(self):
        if self._err is not None:
            return None, self._err
        try:
            top_sls = self._hs.get_top()
            top_matches = self._hs.top_matches(top_sls)
        except Exception as e:
            return None, str(e)

        if len(top_matches) > 1:
            return None, 'got multiple environments from top.sls: {}'.format(', '.join(top_matches.keys()))

        selectors = []
        for env, matches in top_matches.items():
            for match in matches:
                selectors.append(Selector(env, match))
        return selectors, None

    def render_high_selector(self, selector):
        if self._err is not None:
            return None, self._err
        try:
            d, errors = self._hs_render_highstate(matches={selector.get_env(): [selector.get_match()]})
        except Exception as e:
            return None, str(e)
        if errors:
            return None, '\n'.join(str(e) for e in errors)
        return High(d, rev=self._meta.commit_id, ts=self._meta.mtime.seconds), None

    def _format_lo_path(self, selector):
        lo_name = '{}.json'.format(selector.get_name())
        return os.path.join(constants.VAR_LIB, LO_STATES_PATH, lo_name)

    def lo_from_selector(self, selector):
        lo_path = self._format_lo_path(selector)
        if os.path.exists(lo_path):
            lo, err = Low.from_file(lo_path)
            if err:
                return None, err
            else:
                return lo, None
        else:
            return Low.empty(), None

    def persist_lo_for_selector(self, lo, selector):
        return lo.to_file(self._format_lo_path(selector))

    def purge_lo_for_selector(self, selector):
        return fileutil.unlink(self._format_lo_path(selector))

    def empty_lo(self):
        return Low.empty()


class High(object):
    def __init__(self, high_dict, rev='', ts=-1):
        self._high = high_dict
        self._rev = rev
        self._ts = ts

    def get_rev(self):
        return self._rev

    def get_ts(self):
        return self._ts

    def states(self):
        return self._high

    def verify_high(self, high):
        '''
        Verify that the high data is viable and follows the data structure
        '''
        errors = []
        if not isinstance(high, dict):
            errors.append('High data is not a dictionary and is invalid')
        reqs = {}
        for name, body in high.items():
            if name.startswith('__'):
                continue
            if not isinstance(name, basestring):
                errors.append(
                    'ID \'{0}\' in SLS \'{1}\' is not formed as a string, but '
                    'is a {2}'.format(
                        name,
                        body['__sls__'],
                        type(name).__name__
                    )
                )
            if not isinstance(body, dict):
                err = ('The type {0} in {1} is not formatted as a dictionary'
                       .format(name, body))
                errors.append(err)
                continue
            for state in body:
                if state.startswith('__'):
                    continue
                if not isinstance(body[state], list):
                    errors.append(
                        'State \'{0}\' in SLS \'{1}\' is '
                        'not formed as a list'.format(name, body['__sls__'])
                    )
                else:
                    fun = 0
                    if '.' in state:
                        fun += 1
                    for arg in body[state]:
                        if isinstance(arg, basestring):
                            fun += 1
                            if ' ' in arg.strip():
                                errors.append(('The function "{0}" in state '
                                               '"{1}" in SLS "{2}" has '
                                               'whitespace, a function with whitespace is '
                                               'not supported, perhaps this is an argument '
                                               'that is missing a ":"').format(
                                    arg,
                                    name,
                                    body['__sls__']))
                        elif isinstance(arg, dict):
                            # The arg is a dict, if the arg is require or
                            # watch, it must be a list.
                            #
                            # Add the requires to the reqs dict and check them
                            # all for recursive requisites.
                            argfirst = next(iter(arg))
                            if argfirst in ('require', 'watch', 'prereq', 'onchanges'):
                                if not isinstance(arg[argfirst], list):
                                    errors.append(('The {0}'
                                                   ' statement in state \'{1}\' in SLS \'{2}\' '
                                                   'needs to be formed as a list').format(
                                        argfirst,
                                        name,
                                        body['__sls__']
                                    ))
                                # It is a list, verify that the members of the
                                # list are all single key dicts.
                                else:
                                    reqs[name] = {'state': state}
                                    for req in arg[argfirst]:
                                        if isinstance(req, basestring):
                                            req = {'id': req}
                                        if not isinstance(req, dict):
                                            err = ('Requisite declaration {0}'
                                                   ' in SLS {1} is not formed as a'
                                                   ' single key dictionary').format(
                                                req,
                                                body['__sls__'])
                                            errors.append(err)
                                            continue
                                        req_key = next(iter(req))
                                        req_val = req[req_key]
                                        if '.' in req_key:
                                            errors.append((
                                                'Invalid requisite type \'{0}\' '
                                                'in state \'{1}\', in SLS '
                                                '\'{2}\'. Requisite types must '
                                                'not contain dots, did you '
                                                'mean \'{3}\'?'.format(
                                                    req_key,
                                                    name,
                                                    body['__sls__'],
                                                    req_key[:req_key.find('.')]
                                                )
                                            ))
                                        if req_val.__hash__ is None:
                                            errors.append((
                                                'Illegal requisite "{0}", '
                                                'is SLS {1}\n'
                                            ).format(
                                                str(req_val),
                                                body['__sls__']))
                                            continue

                                        # Check for global recursive requisites
                                        reqs[name][req_val] = req_key
                                        # I am going beyond 80 chars on
                                        # purpose, this is just too much
                                        # of a pain to deal with otherwise
                                        if req_val in reqs:
                                            if name in reqs[req_val]:
                                                if reqs[req_val][name] == state:
                                                    if reqs[req_val]['state'] == reqs[name][req_val]:
                                                        err = ('A recursive '
                                                               'requisite was found, SLS '
                                                               '"{0}" ID "{1}" ID "{2}"'
                                                               ).format(
                                                            body['__sls__'],
                                                            name,
                                                            req_val
                                                        )
                                                        errors.append(err)
                                # Make sure that there is only one key in the
                                # dict
                                if len(list(arg)) != 1:
                                    errors.append(('Multiple dictionaries '
                                                   'defined in argument of state \'{0}\' in SLS'
                                                   ' \'{1}\'').format(
                                        name,
                                        body['__sls__']))
                    if not fun:
                        if state == 'require' or state == 'watch':
                            continue
                        errors.append(('No function declared in state \'{0}\' in'
                                       ' SLS \'{1}\'').format(state, body['__sls__']))
                    elif fun > 1:
                        errors.append(
                            'Too many functions declared in state \'{0}\' in '
                            'SLS \'{1}\''.format(state, body['__sls__'])
                        )
        return errors

    @staticmethod
    def order_chunks(chunks):
        """
        Sort the chunk list verifying that the chunks follow the order
        specified in the order options.
        """
        cap = 1
        for chunk in chunks:
            if 'order' in chunk:
                if not isinstance(chunk['order'], int):
                    continue
                chunk_order = chunk['order']
                if chunk_order > cap - 1 and chunk_order > 0:
                    cap = chunk_order + 100
        for chunk in chunks:
            if 'order' not in chunk:
                chunk['order'] = cap
                continue
            if not isinstance(chunk['order'], (int, float)):
                if chunk['order'] == 'last':
                    chunk['order'] = cap + 1000000
                elif chunk['order'] == 'first':
                    chunk['order'] = 0
                else:
                    chunk['order'] = cap
            if 'name_order' in chunk:
                chunk['order'] = chunk['order'] + chunk.pop('name_order') / 10000.0
            if chunk['order'] < 0:
                chunk['order'] = cap + 1000000 + chunk['order']
        chunks.sort(key=lambda ch: (ch['order'], ch['state'], ch['name'], ch['fun']))
        return chunks

    def compile_high_data(self, high):
        """
        "Compile" the high data as it is retrieved from the CLI or YAML into
        the individual state executor structures
        """
        chunks = []
        for name, body in high.items():
            if name.startswith('__'):
                continue
            for state, run in body.items():
                funcs = set()
                names = set()
                if state.startswith('__'):
                    continue
                chunk = {'state': state,
                         'name': name}
                if '__sls__' in body:
                    chunk['__sls__'] = body['__sls__']
                if '__env__' in body:
                    chunk['__env__'] = body['__env__']
                chunk['__id__'] = name
                for arg in run:
                    if isinstance(arg, basestring):
                        funcs.add(arg)
                        continue
                    if isinstance(arg, dict):
                        for key, val in arg.items():
                            if key == 'names':
                                names.update(val)
                            elif key == 'state':
                                # Don't pass down a state override
                                continue
                            elif key == 'name' and not isinstance(val, basestring):
                                # Invalid name, fall back to ID
                                chunk[key] = name
                            else:
                                chunk[key] = val
                if names:
                    name_order = 1
                    for entry in names:
                        live = copy.deepcopy(chunk)
                        if isinstance(entry, dict):
                            low_name = next(iter(entry))
                            live['name'] = low_name
                            list(map(live.update, entry[low_name]))
                        else:
                            live['name'] = entry
                        live['name_order'] = name_order
                        name_order = name_order + 1
                        for fun in funcs:
                            live['fun'] = fun
                            chunks.append(live)
                else:
                    live = copy.deepcopy(chunk)
                    for fun in funcs:
                        live['fun'] = fun
                        chunks.append(live)
        chunks = self.order_chunks(chunks)
        return chunks

    def render_lo(self):
        """
        Renders lo state from high.
        :rtype: tuple[Low,None|str]
        """
        high = self._high
        # Verify that the high data is structurally sound
        errors = self.verify_high(high)
        if errors:
            return None, errors
        # Compile and verify the raw chunks
        chunks = self.compile_high_data(high)
        # If there are extensions in the highstate, process them and update
        # the low data chunks
        if errors:
            return None, errors
        for chunk in chunks:
            rewrite_relative_src(chunk)
        return Low(chunks, self._rev, self._ts), None


class Low(object):
    """
    :type _lo: list[dict]
    """

    @classmethod
    def empty(cls):
        return cls([], '', 0)

    @classmethod
    def from_bytes(cls, buf):
        d = json.loads(buf, object_pairs_hook=collections.OrderedDict)
        lo = d.get('lo')
        rev = d.get('rev', '')
        ts = d.get('ts', 0)
        return cls(lo, rev, ts)

    @classmethod
    def from_file(cls, path):
        try:
            with open(path, 'r') as f:
                try:
                    return cls.from_bytes(f.read(10 * 1024 * 1024)), None
                except ValueError as e:
                    return None, 'failed to load {}: {}'.format(path, e)
                except AttributeError as e:
                    return None, 'failed to load {}: {}'.format(path, e)
        except EnvironmentError as e:
            return None, 'failed to load {}: {}'.format(path, e)

    def __init__(self, lo, rev, ts):
        self._lo = lo
        self._rev = rev
        self._ts = ts

    def to_dict(self):
        return {
            'lo': self._lo,
            'rev': self._rev,
            'ts': self._ts,
        }

    def to_bytes(self):
        return json.dumps(self.to_dict())

    def to_file(self, path):
        return fileutil.atomic_write(path, self.to_bytes(), make_dirs=True)

    def get_rev(self):
        return self._rev

    def get_ts(self):
        return self._ts

    def states(self):
        return self._lo


class StoppableState(salt.state.State):
    """
    State which can be stopped with filesystem flag.
    """

    def __init__(self, ctx, *args, **kwargs):
        self.ctx = ctx
        super(StoppableState, self).__init__(*args, **kwargs)

    def call(self, low, chunks=None, running=None):
        started = int(time.time())
        if self.ctx.done():
            log.info("Skip '{}' state: {}".format(low['name'], self.ctx.error()))
            return {
                'started': started,
                'result': False,
                'name': low['name'],
                'changes': {},
                'comment': self.ctx.error(),
                '__sls__': low['__sls__'],
            }

        try:
            ret = super(StoppableState, self).call(low, chunks=chunks, running=running)
        except Exception as e:
            self.ctx.fail(low['name'], str(e))
            raise

        if ret['result']:
            self.ctx.ok(ret['name'])
        else:
            # 'comment' not used for err - may be very long, will pollute logs
            self.ctx.fail(ret['name'], ret['name'] + ' failed')

        ret['started'] = started
        return ret


class Executor(object):
    """
    Executes lo state.

    To execute lo we need file client (to access salt:// files)
    and options. But state does not accept file client. Why?
    Salt generates it and saves and some crazy place (__context__), see:
        salt/modules/cp.py for details.
    """

    def __init__(self, opts):
        self.opts = opts

    def execute(self, ctx, low):
        s = StoppableState(ctx, opts=self.opts)
        try:
            return s.call_chunks(low.states()), None
        except Exception as e:
            log.exception('')
            return None, str(e)


def get_lo_diff(a, b):
    buf = []
    if len(a) != len(b):
        buf.append('len(a={}) != len(b={})'.format(len(a), len(b)))
        buf.append('First diff state:')
    # Ensure we do NOT MODIFY provided objects
    a = copy.deepcopy(a)
    b = copy.deepcopy(b)
    # If we add one state in the beginning, we'll end up
    # printing diff of all states, so for now let's print first diff
    i = 0
    j = 0
    done = False
    while not done and (i < len(a) and j < len(b)):
        # Salt can generate order numbers that are not equal
        a[i].pop('order', None)
        b[j].pop('order', None)
        if a[i] != b[j]:
            buf.append('a[{}]:\n {}'.format(i, pprint.pformat(a[i])))
            buf.append('b[{}]:\n {}'.format(j, pprint.pformat(b[j])))
            done = True
        i += 1
        j += 1
    if not done and i < len(a):
        buf.append('a[{}]:\n {}'.format(i, pprint.pformat(a[i])))
        done = True
    if not done and j < len(b):
        buf.append('b[{}]:\n {}'.format(j, pprint.pformat(b[j])))
        j += 1
    if len(buf):
        return '\n'.join(buf)
    return None


def mutate_lo_state_to_nop(lo_state):
    new_state = {
        'order': lo_state['order'],
        '__sls__': lo_state['__sls__'],
        '__env__': lo_state['__env__'],
        '__id__': lo_state['__id__'],
        'name': lo_state['name'],
        'fun': 'nop',
        'state': 'test',
    }
    lo_state.clear()
    lo_state.update(new_state)
