"""Common utilities."""

import datetime
import errno
import functools
import itertools
import argparse
import os
import queue as Queue
import re
import sys
import threading
import time
import traceback
import json

from walle_api.client import WalleClientError, WalleAuthorizationError, WalleApiError

import pcli.text_table
import yaml

_SHORT_MODE = False


class LogicalError(Exception):
    """Raised from code that never should be reached."""

    def __init__(self):
        super(LogicalError, self).__init__("Logical error.")


class Error(Exception):
    """Base exception class with error formatting."""

    def __init__(self, *args, **kwargs):
        message, args = args[0], args[1:]
        super(Error, self).__init__(
            message.format(*args, **kwargs) if args or kwargs else message)


class NegativeUserAnswerError(Error):
    def __init__(self):
        super(NegativeUserAnswerError, self).__init__("User answered 'No'.")


class TableView(object):
    """A simple abstraction for displaying data in table view."""

    def __init__(self, default_columns, other_columns=tuple()):
        self.__default_columns = default_columns
        self.__columns = {column.id: column for column in default_columns + other_columns}

    def get_column_ids(self):
        return sorted(self.__columns.keys())

    def get_column_fields(self, column_ids=None):
        if column_ids is None:
            column_ids = self.get_column_ids()

        return list(itertools.chain.from_iterable(self.__columns[column_id].fields for column_id in column_ids))

    def parse_column_list(self, column_ids_string):
        if column_ids_string is None:
            column_ids = [column.id for column in self.__default_columns]
        else:
            column_ids = parse_arg_list(column_ids_string)

            invalid_column_ids = set(column_ids) - set(self.__columns)
            if invalid_column_ids:
                raise Error("Invalid column name: {}.", invalid_column_ids.pop())

            if not column_ids:
                raise Error("You must specify at least one column to output.")

        return column_ids, self.get_column_fields(column_ids)

    def render(self, object_name, objects, column_ids, batch, total=None):
        if batch:
            for obj in objects:
                print(" ".join(str(obj.get(column_id, "-")) for column_id in column_ids))
        else:
            columns = [self.__columns[column_name] for column_name in column_ids]

            table = pcli.text_table.Table(columns)
            table.add_rows(objects)
            table.draw(placeholder="-")

            if total is not None:
                print("\n> Listed {count} {name}. Total: {total}.".format(
                    count=len(objects), name=object_name, total=total))


class Column(pcli.text_table.Column):
    """Represents a table column."""

    def __init__(self, id, *args, **kwargs):
        # Object fields that are required to display the column
        self.fields = kwargs.pop("fields", (id,))
        pcli.text_table.Column.__init__(self, id, *args, **kwargs)


def get_config_path(ensure_exists=False):
    config_path = os.path.expanduser("~/.wall-e")

    if ensure_exists:
        try:
            os.mkdir(config_path)
        except EnvironmentError as e:
            if e.errno != errno.EEXIST:
                raise Error("Unable to create '{}' directory: {}.", config_path, e.strerror)

    return config_path


def handle_error(error, arg=None):
    if isinstance(error, WalleAuthorizationError):
        message = "{}\nPlease obtain a valid API access token by issuing `wall-e access-token obtain` command.".format(error)
    else:
        message = "API error: " if isinstance(error, WalleApiError) else "Error: "
        if arg is not None:
            message += "{}: ".format(arg)
        message += str(error)

    user_errors = (Error, WalleClientError)

    # Try to not load heavy modules when it's not needed to make CLI more responsive
    if "skyresolver" in sys.modules:
        user_errors += (sys.modules["skyresolver"].ResolveError,)

    print(message, file=sys.stderr)
    if not isinstance(error, user_errors):
        traceback.print_exc()


def register_subparsers(parser, name, help, short=None):
    subparsers = _add_parser(parser, short, name, help).add_subparsers(required=True, dest="subcommand")
    return subparsers


def register_parser(parent, name, handler, help, short=None, with_sudo=False, with_reason=False):
    parser = _add_parser(parent, short, name, help)
    parser.set_defaults(handle=handler)

    if with_sudo:
        parser.add_argument("--sudo", default=None, action="store_true",
                            help="use admin privileges to perform the request")

    if with_reason:
        parser.add_argument("-m", "--reason", help="reason string to write to audit log")

    return parser


def parse_arg_list(arg, type=(str)):
    """Parses a list specified by user via command line."""

    if arg is None:
        return None

    if not arg:
        return []

    items = []

    for item in arg.split(","):
        item = item.strip()
        if not item:
            raise Error("Invalid list specification: {}.", arg)

        if type is str:
            pass
        elif type is int:
            try:
                item = type(item)
            except ValueError:
                raise Error("Invalid argument: {}. It must be an integer.", item)
        else:
            raise LogicalError()

        items.append(item)

    return items


def parse_limit(value):
    try:
        period, limit = value.split(":")
        limit = int(limit)
    except ValueError:
        raise ValueError

    if not re.match(r'^\d+[sdmh]$', period):
        raise ValueError

    return {"period": period, "limit": limit}


def get_supported_time_formats():
    """
    Returns a string with a list of time formats supported by parse_time() intended to be inserted to argparse help
    messages.
    """

    formats = ("'$number[mhd]|$time|$date|$time $date$|$date $time' "
               "where $number[mhd] is for example 3d for 3 day period, "
               "$time is %H:%M|%H:%M:%S and $date is %d.%m.%Y|%Y.%m.%d")

    return formats.replace("%", "%%")  # argparse tries to expand '%'


def parse_time(time_string, future_period):
    """Parses time specified by user in various formats like time, date or a time period.

    :param future_period determines whether the period specifies time in the future or in the past.
    """

    if time_string is None:
        return

    if "." in time_string or ":" in time_string:
        return _parse_time(time_string)
    else:
        period = parse_period(time_string)
        if not future_period:
            period = -period
        return int(time.time()) + period


def parse_period(period_string):
    if period_string is None:
        return

    try:
        period = period_string.lower()
        if len(period) < 2:
            raise ValueError

        period, unit = int(period[:-1]), period[-1:]

        if unit == "m":
            period *= 60
        elif unit == "h":
            period *= 60 * 60
        elif unit == "d":
            period *= 24 * 60 * 60
        else:
            raise ValueError()

        return period
    except ValueError:
        raise Error("Invalid period specification: {}.", period_string)


def period_format():
    return "${seconds}s|${minutes}m|${hours}h|${days}d"


def format_time(time):
    return datetime.datetime.fromtimestamp(time).strftime("%Y.%m.%d %H:%M:%S")


def format_choices(choices):
    # Formats a list of choices to bytes to avoid `u'...'` in parser messages
    return [str(choice) for choice in choices]


def print_object(name, obj, as_json=False):
    """Prints a JSON object in human readable form."""
    renderer = render_json if as_json else render_yaml
    print(renderer({name: obj}))


def render_yaml(data):
    return yaml.safe_dump(data, default_flow_style=False).rstrip("\n")


def render_json(data):
    return json.dumps(data, indent=2, sort_keys=True)


def render_yaml_unsorted(data):
    return yaml.dump(data, default_flow_style=False, sort_keys=False, allow_unicode=True).rstrip("\n")


def render_result(func=None, renderer=render_yaml):
    if func is not None:
        @functools.wraps(func)
        def wrapper(client, args):
            res = func(client, args)
            if args.json:
                print(render_json(res))
            else:
                print(renderer(res))

        return wrapper
    else:
        return functools.partial(render_result, renderer=renderer)


def set_short_mode(short):
    global _SHORT_MODE
    _SHORT_MODE = short


def question_user(*args, **kwargs):
    cli_args, message, args = args[0], args[1], args[2:]

    if cli_args.yes:
        return

    # Assume "yes" if not attached to a terminal
    if not sys.stdin.isatty():
        return

    answer = ""
    while answer not in ("y", "n"):
        answer = input((message + " (y/n) ").format(*args, **kwargs))

    if answer != "y":
        raise NegativeUserAnswerError()


def process_args(args, func, concurrency=10):
    """Concurrently processes the specified arguments.

    Attention: Immediately interrupts by SystemExit leaving alive processing threads.
    """

    if concurrency < 1:
        raise Error("Invalid concurrency: {0}.", concurrency)

    if len(args) == 0:
        return True

    lock = threading.Lock()

    def processor(arg):
        try:
            func(arg)
        except Exception as e:
            with lock:
                handle_error(e, arg=arg)
            return False

        return True

    if len(args) == 1:
        return processor(args[0])

    interrupted = False
    result = {"ok": True}

    queue = Queue.Queue()
    for arg in args:
        queue.put(arg)

    def worker(queue):
        while not interrupted:
            try:
                arg = queue.get(False)
            except Queue.Empty:
                break

            if not processor(arg):
                result["ok"] = False

    thread_pool = []

    try:
        try:
            for i in range(min(queue.qsize(), concurrency)):
                thread = threading.Thread(target=worker, args=(queue,))
                thread.setDaemon(True)
                thread_pool.append(thread)
                thread.start()
        except SystemExit:
            interrupted = True
            raise
        finally:
            if not interrupted:
                for thread in thread_pool:
                    while thread.is_alive():
                        # Any timeout must be specified to be able to receive UNIX signals in the main thread
                        thread.join(24 * 60 * 60)
    except SystemExit:
        interrupted = True
        raise

    return result["ok"]


def unique_list(values):
    result, added = [], set()

    for value in values:
        if value not in added:
            result.append(value)
            added.add(value)

    return result


def filter_none(d):
    return {k: v for k, v in d.items() if v is not None}


def _add_parser(parent, short, name, help):
    if _SHORT_MODE:
        name = name[0] if short is None else short

    if name in parent._name_parser_map:
        raise Error("Unable to add a new subparser with name '{}': a subparser with such name already exists.", name)

    parser = parent.add_parser(name, description=help, help=help)
    parser.set_defaults(handle=parser_help(parser))

    return parser


def parser_help(parser):
    def _help(*args, **kwargs):
        parser.print_help()

    return _help


def _parse_time(time_string):
    orig_time_string = time_string

    # Reformat the string to "$time $date" format if it's "$date $time"
    time_string = " ".join(sorted(time_string.split(" "), key=lambda s: "." in s))

    dot_count = time_string.count(".")
    colon_count = time_string.count(":")

    time_format = ""
    has_date = False

    try:
        if colon_count:
            if colon_count == 1:
                time_format = "%H:%M"
            elif colon_count == 2:
                time_format = "%H:%M:%S"
            else:
                raise ValueError

        if dot_count:
            has_date = True
            if time_format:
                time_format += " "

            if dot_count != 2:
                raise ValueError

            if len(time_string.split(" ")[-1].split(".")[0]) == 4:
                time_format += "%Y.%m.%d"
            else:
                time_format += "%d.%m.%Y"

        if not time_format:
            raise ValueError

        time_datetime = datetime.datetime.strptime(time_string, time_format)
        if not has_date:
            time_datetime = datetime.datetime.combine(datetime.date.today(), time_datetime.time())

        return int(time.mktime(time_datetime.timetuple()))
    except ValueError:
        raise Error("Invalid time specification: {}.", orig_time_string)


def drop_none(d):
    return {k: v for k, v in d.items() if v is not None}


def split_comma_separated_values_in_dict(extra):
    if extra is None:
        return

    for key in extra:
        if isinstance(extra[key], str) and ',' in extra[key]:
            extra[key] = extra[key].split(',')

    return extra


def str2bool(v):
    if v is None:
        return None
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
