from __future__ import absolute_import, unicode_literals, print_function

import re
import string
import numbers
import traceback
import datetime as dt
import itertools as it
from docutils import core as docutils_core

import six

from .. import encoding
from .. import config as common_config
from .. import itertools as common_itertools


# noinspection PyShadowingBuiltins
zip, range, map = six.moves.zip, six.moves.range, six.moves.map


def print_table(data, printer=print):
    """
    Print the data as pretty-printed table.
    Example:

    .. code-block:: python

        >>> data = [
        ...     ["X", "0", "-"],
        ...     None,
        ...     ["0", "X", "-"],
        ...     None,
        ...     ["-", "-", "X"],
        ... ]
        >>> print_table(data)
        +---+---+---+
        | X | 0 | - |
        +---+---+---+
        | 0 | X | - |
        +---+---+---+
        | - | - | X |
        +---+---+---+

    :param data:    Array of table rows of strings. `None` row means horizontal line.
    :param printer: Print function to be used to output the actual content.
    """

    cols = len(data[0])
    _ESCAPE_SEQ_CUT_RE = re.compile(r"\x1b\[.*?m")
    for row in data:
        for i in range(cols) if row else []:
            if row[i] is not None:
                row[i] = encoding.force_unicode_safe(row[i])
    lens = [max(map(
        lambda x: len(_ESCAPE_SEQ_CUT_RE.sub("", x)) if x else 0,
        (row[i] for row in data if row)
    )) for i in range(cols)]
    hr = "+%s+" % "+".join("-" * (l + 2) for l in lens)

    for i, row in enumerate([None] + list(data) + [None]):
        if not row:
            printer(hr)
            continue
        printer("".join(common_itertools.chain(
            (
                " ".join([
                    " " if i and cell is None else "|",
                    cell or "",
                    " " * (width - len(_ESCAPE_SEQ_CUT_RE.sub("", cell)))
                ]) for i, (cell, width) in enumerate(zip(map(six.text_type, row), lens))
            ),
            "|"
        )))


SIZE_SUFFIXES = ("B", "K", "M", "G", "T")


def str2size(value):
    """
    Convert a human-readable data size string into the number of bytes.
    Supports suffixes B, K, M, G, T as powers of 1024
    Examples:

    .. code-block:: python

        >>> str2size("15K")
        15360
        >>> str2size("120G")
        128849018880
    """

    if isinstance(value, numbers.Number):
        return int(value)

    suffix = encoding.force_unicode(value)[-1]
    try:
        multiplier = 1 << (10 * SIZE_SUFFIXES.index(suffix))
        return int(float(value[:-1]) * multiplier)
    except ValueError:
        # Unknown suffix
        pass

    return int(value)


###############################################################################
# A number of functions copied from :module:`kernel.util.misc`
###############################################################################

def size2str(size, till=None):
    """ Size to string formatter. """
    mods = " KMGTPEZY"
    size = float(size)
    if till:
        till = encoding.force_unicode(till).upper()
        assert len(till) == 1 and till in mods
    for mod in mods:
        if ((abs(size) < 0x400 or mod == mods[-1]) and not till) or mod == till:
            return str("%.2f%siB" % (size, mod)) if mod != " " else str("%.2fbyte(s)" % (size,))
        size /= 0x400


TIME_FORMAT = "%Y-%m-%d %H:%M:%S"


def dt2str(x):
    """ Datetime to string formatter """
    return str("None" if x is None else x.strftime(TIME_FORMAT))


def str2dt(s):
    return None if s == "None" else dt.datetime.strptime(s, TIME_FORMAT)


def utcdt2iso(v=None):
    return str((v or dt.datetime.utcnow()).isoformat() + "Z")


def td2str(delta, full=False):
    """ Timedelta to string formatter. """
    if isinstance(delta, (float, six.integer_types)):
        delta = dt.timedelta(seconds=delta)
    hrs, secs = divmod(delta.seconds, 3600)
    mins, secs = divmod(secs, 60)
    ret = ''

    suffixes = list("dhms" if not full else map(lambda x: " %s(s)" % x, "day hour minute second".split()))
    for fmt, val in zip(suffixes, (delta.days, hrs, mins, secs)):
        if val or ((ret or fmt == 's') and not full):
            ret += '{0:s}{1:0>2d}{2:s}'.format(' ' if ret else '', val, fmt)
    return str(ret)


def rst2html(txt):
    if common_config.Registry().server.web.rst2html:
        return docutils_core.publish_parts(txt, writer_name='html')['html_body']
    return six.ensure_text(txt)


def format_exception():
    """ Format exception traceback as a unicode string """
    trace = traceback.format_exc()
    return encoding.force_unicode_safe(trace)


_DIGITS = encoding.force_unicode(string.digits)
_UPPERS = encoding.force_unicode(string.ascii_uppercase)
_LOWERS = encoding.force_unicode(string.ascii_lowercase)
_IDENT_TABLE = {
    0: dict(
        (c, (c.upper(), 1)) for c in it.chain("_", _UPPERS, _LOWERS)
    ),
    1: dict(it.chain(
        (("_", ("_", 1)),),
        ((c, (c, 2)) for c in _DIGITS),
        ((c, (c, 3)) for c in _UPPERS),
        ((c, (c.upper(), 4)) for c in _LOWERS)
    )),
    2: dict(it.chain(
        (("_", ("_", 1)),),
        ((c, (c, 2)) for c in _DIGITS),
        ((c, ("_" + c, 3)) for c in _UPPERS),
        ((c, (c.upper(), 4)) for c in _LOWERS)
    )),
    3: dict(it.chain(
        (("_", ("_", 1)),),
        ((c, ("_" + c, 2)) for c in _DIGITS),
        ((c, (c, 3)) for c in _UPPERS),
        ((c, (c.upper(), 4)) for c in _LOWERS)
    )),
    4: dict(it.chain(
        (("_", ("_", 1)),),
        ((c, (c.upper(), 4)) for c in _LOWERS),
        ((c, ("_" + c.upper(), 2)) for c in _DIGITS),
        ((c, ("_" + c.upper(), 3)) for c in _UPPERS)
    ))
}


def ident(name):
    """
    Helper for transforming CamelCase into UPPER_SNAKE_CASE.
    Commonly used for generating type names:

    .. code-block:: python

        >>> ident("CamelCase")
        'CAMEL_CASE'
        >>> ident("ADangerousExample")
        'ADANGEROUS_EXAMPLE'
        >>> ident("CamelCase2SnakeCase4ALL")
        'CAMEL_CASE_2_SNAKE_CASE_4_ALL'
    """

    name = encoding.force_unicode(name)
    if name.isupper():
        return str(name)
    state = 0
    output = ""
    try:
        for c in name:
            c, state = _IDENT_TABLE[state][c]
            output += c
        return str(output)
    except KeyError:
        raise ValueError("Invalid identifier: {!r}".format(name))


def obfuscate_token(token):
    """ Return the first few characters of the given string. Used across Sandbox to hide secrets from logs. """
    minimum_token_len = 8
    if not isinstance(token, (six.text_type, six.binary_type)) or len(token) < minimum_token_len:
        return token
    return token[:minimum_token_len]


def brace_expansion(argv, join=None):
    """
    This is fixed copy of `library.sky.hosts.braceExpansion` function, which a simple bash brace expansion function.
    It does not support nested braces but is pretty useful for expand hosts list packed by previously executed command.
    :param argv:    A list of expression to be processed.
    :param join:    Perform result concatenation into a single string joined by specified symbol(s).
    :return:        Expanded list or string depending on `join` argument.
    """
    r = re.compile(r'''
        ([^{]*)
        {
            (?:(?:((?:0*)(\d+))\.\.((?:0*)(\d+)))
            |
            ([^{}\s]+))
        }
        ([^{}]*)
    ''', re.VERBOSE)
    hosts = []
    for arg in it.chain(*(encoding.force_unicode(s).split() for s in argv)):
        pos = 0
        addon = []
        while pos < len(arg):
            m = r.match(arg, pos)
            if not m:
                break
            prev = pos
            pos = m.end()
            if m.group(3) and m.group(5):  # Range
                range_ = [m.group(i) for i in [2, 4]]
                mlen = (max if any(len(s) > 1 and s[0] == '0' for s in range_) else min)(map(len, range_))
                chunk = map(
                    lambda x: '{0}{1:0{len}d}{2}'.format(m.group(1), x, m.group(7), len=mlen),
                    range(int(m.group(3)), int(m.group(5)) + 1)
                )
            elif m.group(6):
                chunk = map(
                    lambda x: ''.join([m.group(1), x, m.group(7)]),
                    m.group(6).split(',')
                )
            else:
                chunk = [arg[prev:pos]]
            addon = (
                chunk
                if not addon else
                list(it.chain.from_iterable(map(lambda x: x + s, addon) for s in chunk))  # noqa
            )
        hosts += addon if addon else [arg]

    return [str(h) for h in hosts] if join is None else str(join.join(hosts))


def suggest_color(octets):
    crc = 0xB704CE
    poly = 0x1864CFB
    for octet in octets:
        crc ^= (ord(octet) << 16)
        for _ in six.moves.xrange(8):
            crc <<= 1
            if crc & 0x1000000:
                crc ^= poly
    return "#{:0>6X}".format(crc & 0xFFFFFF)
