import collections
import itertools
import logging
import time
from functools import partial
from google.protobuf import json_format
import yaml
import re

from infra.rtc.rebootctl.proto import reboot_pb2

from . import context
from . import job

log = logging.getLogger(__name__)


def get_host_list(spec, term):
    select = spec.select
    exclude = spec.exclude
    if select.WhichOneof("selector") == exclude.WhichOneof("selector") == 'sky':
        return get_host_list_sky('[ {} ] - [ {} ]'.format(select.sky, exclude.sky), term)
    elif select.WhichOneof("selector") == 'hosts' and exclude.WhichOneof("selector") == 'sky':
        raise Exception('Combination select:hosts - exclude:sky not allowed')
    else:
        exclude_hosts = get_host_list_hosts(exclude.hosts)
        if select.WhichOneof("selector") == 'sky':
            return [name for name in get_host_list_sky(select.sky, term) if name not in exclude_hosts]
        elif select.WhichOneof("selector") == 'hosts':
            return [name for name in get_host_list_hosts(select.hosts) if name not in exclude_hosts]


def get_host_list_sky(spec, term):
    from library.sky import hostresolver as hr
    r = hr.Resolver(use_hq=True)
    term.info("Resolving host names {}...".format(spec))
    return list(r.resolveHosts(spec))


def get_host_list_hosts(spec):
    filter_regex = re.compile(r'^[a-z0-9][a-z0-9.-]+[a-z0-9].$').match
    return filter(filter_regex, list(spec.list))


def get_location_filter(spec):
    func = []
    for field, value in spec.ListFields():
        func.append("""host['location']['{}'] == '{}'""".format(str(field.name), str(value)))
    if len(func) == 1:
        return func[0]
    else:
        return r'all(({}))'.format(', '.join(func))


def make_location_filter(spec):
    filter_str = get_location_filter(spec)
    compiled_filter = compile(filter_str, 'generated_location_filter', 'eval')
    return lambda h: eval(compiled_filter, {'host': h})


def location_filter_fm(spec):
    return lambda h: all((h['location'][field.name] == value for field, value in spec.ListFields()))


def fetch_hosts_from_walle(w, names, fields=None, strict=True):
    if fields is None:
        fields = ('location', 'name', 'health', 'status', 'project', 'inv', 'ticket')
    rv = []
    step = 100
    i = 0
    j = min(step, len(names))
    while 1:
        r = w.get_hosts(names=names[i:j],
                        fields=fields)
        hosts = r['result']
        if len(hosts) != j - i:
            walle_hosts = {h['name'] for h in hosts}
            requested_hosts = set(names[i:j])
            if strict:
                raise Exception('We requested {} hosts and WALL-E returned only {}. Diff: {}'.format(
                    j - i, len(hosts), str(requested_hosts.difference(walle_hosts))
                ))
            else:
                log.error('We requested {} hosts and WALL-E returned only {}. Diff: {}'.format(
                    j - i, len(hosts), str(requested_hosts.difference(walle_hosts))
                ))
        rv.extend(hosts)
        if len(rv) >= len(names) or len(names[i:j]) == 0:
            break
        i += step
        j += step
        if j > len(names):
            j = len(names)
    return rv


def iter_scripts(path):
    with open(path) as f:
        for d in yaml.load_all(f, Loader=yaml.CSafeLoader):
            m = reboot_pb2.Script()
            json_format.ParseDict(d, m)
            yield m


def status_color(status):
    colors = {
        'rebooting': 'blue',
        'ready': 'green',
    }
    return colors.get(status, 'red')


def get_task_and_options(spec):
    task_name = spec.WhichOneof("task")
    task_opts = getattr(spec, task_name)
    task_class = job.get_task_class(task_name)
    if task_class is None:
        raise Exception('Unknown task: {}'.format(spec.WhichOneof("task")))
    else:
        return task_class, task_opts


def initiate_walle_actions(w, names, reason, spec, keep_going, term):
    # Suppose we started rebooting a thousand of hosts
    # and decided to cancel it.
    # First: in order to send 1000 requests - we do them in threads in parallel.
    # Second: we cannot kill threads, we cannot wait on mutex, because Ctrl+C (cancellation)
    # won't be processed.
    # Third: thread pool's shutdown only puts None in queue.
    # So we use this clever ugly trick - make our tasks cancellable on their own,
    # and cancel them upon KeyboardInterrupt.
    # There must be a better way to achieve this, but atm we have more pressing issues.
    tasks = []
    task_type, options = get_task_and_options(spec)
    for hostname in names:
        t = task_type(w, hostname, reason, options)
        tasks.append(t)
    term.info("Sending {} requests...".format(task_type.NAME))
    time.sleep(0.5)
    j = job.Job()
    ctx, cancel_func = context.Ctx.with_cancel()
    j.run(ctx, tasks)
    statuses = collections.defaultdict(int)
    # Fire away reboot requests
    while 1:
        try:
            statuses.clear()
            done = True
            for t in tasks:
                term.info(t.status())
                statuses[t.status().condition] += 1
                done &= t.status().is_terminal()
            term.clr()
            if ctx.done():
                if done:
                    term.info("=" * 5, "Shutdown", "=" * 5)
                else:
                    term.info("=" * 5, "Shutting down", "=" * 5)
            else:
                term.info("=" * 5, "Requesting actions", "=" * 5)
            for name, count in sorted(statuses.items(), key=lambda x: x[0]):
                term.info(name, '\t', count)
            if done:
                break
            time.sleep(5)
        except KeyboardInterrupt:
            # There is no break here, as we wait for our tasks to
            # to dry run draining the queue and shutting down thread pool.
            term.info("\nCancelling tasks...")
            cancel_func("cancelled by user")
    if ctx.done():
        return True
    has_err = False
    for t in tasks:
        err = t.status().error()
        if err:
            term.info("Failed to reboot {}: {}".format(t.h, err))
        has_err |= bool(err)
    if has_err and not keep_going:
        return not term.read_yes_no("We had errors, continue?")
    return False


def get_done_hosts(w, names, done_hosts, term, task_id):
    state_condition = 'manual' if task_id == 'power_off' else 'ready'
    bufs = []
    done = True
    hosts = get_hosts_from_walle(w, names, term)
    term.clr()
    for h in hosts:
        host_done = h['status'] == state_condition
        if host_done and h['name'] not in done_hosts:
            log.info("host %s become ready", h)
            done_hosts.add(h['name'])
        done &= host_done
        fmt = '{} [{}]'.format(h['name'], term.colorize(h['status'], status_color(h['status'])))
        bufs.append(fmt)
    log_hosts_statuses(bufs, term)
    return done, done_hosts


def get_hosts_from_walle(w, names, term):
    while 1:
        try:
            return list(w.iter_hosts(names=names, fields=('name', 'health.status', 'status')))
        except Exception as e:
            term.info("Failed to poll host statuses: {}".format(e))
            term.info("Will retry...")
            time.sleep(10)
            continue


def log_hosts_statuses(bufs, term):
    term_buf = term.get_buf()
    cols = 3  # Assume screen has 3 columns
    for i in xrange(0, len(bufs), cols):
        for j in xrange(cols):
            if i + j >= len(bufs):
                break
            term_buf.write(bufs[i + j])
            term_buf.write('\t')
        term_buf.flush()


def uniq_err(names):
    u = set()
    dup = []
    for i, n in enumerate(names):
        if n in u:
            dup.append((n, i))
        else:
            u.add(n)
    if dup:
        errors = '\n'.join('* {} at pos {}'.format(*i) for i in dup)
        return 'duplicate hosts names:\n' + errors
    return None


def hosts_portion_g(hosts, portion=None, boundary=None):
    """
    :param hosts:     array of wall-e hosts objects with attributes
    :param boundary:  list of host['location'] attributes, by which hosts will be grouped
    :param portion:   size of hosts's portion returnet by access
    All hosts will be grouped by given attributes list. Hostnames will be returned by groups.
    If portion is set, groups will be returned by portions. If last portion is smaller then `portion` it will be returned as is.
    """

    def filterfunction(item):
        if boundary is None or boundary == ():
            return None
        ret = []
        for filed in boundary:
            ret.append(item['location'][filed])
        return ret

    for group in itertools.groupby(sorted(hosts, key=filterfunction), filterfunction):
        if portion is None or portion == 0:
            yield group[1], group[0]
        else:
            start = 0
            hosts = [h for h in group[1]]
            ret = None
            while True:
                end = start + portion
                ret = hosts[start:end]
                if ret != []:
                    yield ret, group[0]
                else:
                    break
                start = end


def run(w, path, batch, keep_going, strict, term):
    start_time = time.time()
    done_hosts = set()
    term.info("Reading routine scripts from '{}'...".format(path))
    for s in iter_scripts(path):
        task, task_spec = get_task_and_options(s.spec)
        task_name = task.__name__.upper()
        term.info("-" * 10)
        term.info("Processing script '{}'".format(s.meta.name))
        names = get_host_list(s.spec, term)
        term.info("Found {} hosts...".format(len(names)))
        term.info("Fetching hosts from WALL-E which comply for {}...".format(task_name))
        hosts = fetch_hosts_from_walle(w, names, strict=strict)
        if s.spec.limit.IsInitialized():
            term.info("Limit hosts by GEO location")
            location_filter = make_location_filter(s.spec.limit)
            hosts = filter(location_filter, hosts)
        if s.spec.ticket_key and task.id == "power_on":
            hosts = filter(partial(task.filter, ticket_key=s.spec.ticket_key) , hosts)
        else:
            hosts = filter(task.filter, hosts)
        term.info("Filter hosts by task")
        if len(hosts) == 0:
            term.info("No hosts found for {}...".format(task_name))
            continue
        term.info("Going to {} {} hosts:".format(task_name, len(hosts)))
        for group, attrs in hosts_portion_g(hosts,
                                            portion=s.spec.portion,
                                            boundary=list(s.spec.boundary)):
            msg = []
            if list(s.spec.boundary) != ():
                msg.append(("filter: {}".format(
                    ", ".join(("{}={}".format(k, v) for k, v in zip(list(s.spec.boundary), attrs))))))
            if s.spec.portion != 0:
                msg.append(("portion size: {}".format(s.spec.portion)))
            if msg:
                term.info("\n  {}".format("; ".join(msg)))
            for h in group:
                term.info('- {} (prj={} switch={} queue={})'.format(
                    h['name'],
                    h['project'],
                    h['location']['switch'],
                    h['location']['queue'],
                ))
        reason = s.spec.reason
        term.info('')
        if not reason:
            reason = term.read_line('Enter reason:')
        else:
            term.info("Reason: {}".format(reason))
        if not batch:
            ans = term.read_yes_no("Start {}?".format(task_name))
        else:
            ans = True
        if not ans:
            continue
        active_batch = []
        max_active_batch = s.spec.processing_window
        for group, attrs in hosts_portion_g(hosts,
                                            portion=s.spec.portion,
                                            boundary=list(s.spec.boundary)):
            if max_active_batch == 0:
                active_batch = [h['name'] for h in group]
                cancelled = initiate_walle_actions(w, active_batch, reason, s.spec, keep_going, term)
                if cancelled:
                    break
                batch = []
            else:
                active_batch = []
                batch = [h['name'] for h in group]
            all_ready = False
            while not all_ready or batch:
                batch_free = max_active_batch - len(active_batch)
                if batch_free > 0:
                    active_batch.extend(batch[:batch_free])
                    cancelled = initiate_walle_actions(w, batch[:batch_free], reason, s.spec, keep_going, term)
                    if cancelled:
                        break
                    del batch[:batch_free]
                all_ready, done_hosts = get_done_hosts(w, active_batch, done_hosts, term, task_id=task.id)
                active_batch = [host for host in active_batch if host not in done_hosts]
                time.sleep(10)
            if not keep_going:
                ans = term.read_yes_no("Continue?")
                if not ans:
                    break
    term.info("Done in %s seconds" % (time.time() - start_time))
