# coding: utf-8
import ctypes
import re
import itertools
import functools
import collections
import six

from awacs.lib import OrderedDict
from .tree import Node
from .defs import FUNCS, _call_func_providers


LUA_ID_REGEXP = re.compile(r'^[a-zA-Z_]\w*$')
LUA_KEYWORDS = frozenset(['and', 'break', 'do', 'else', 'elseif', 'end',
                          'false', 'for', 'function', 'if', 'in', 'local',
                          'nil', 'not', 'or', 'repeat', 'return', 'then',
                          'true', 'until', 'while'])


def escape_lua_key(key):
    if not LUA_ID_REGEXP.match(key) or key in LUA_KEYWORDS:
        key = '["{}"]'.format(key)
    return key


def call_to_lua_key(call):
    return '[{}]'.format(call.to_lua())


def indent(lines):
    return '  ' + lines.replace('\n', '\n  ')


def parts_cmp(x, y):
    x_len = len(x)
    y_len = len(y)
    if x_len > y_len:
        return 1
    elif x_len < y_len:
        return -1
    else:
        if x < y:
            return -1
        elif x == y:
            return 0
        else:
            return 1


class Context(object):
    __slots__ = ('prerendered_luas', 'shared_configs', 'report_referring_configs', 'tree')

    def __init__(self, prerendered_luas=None, shared_configs=None, report_referring_configs=None, tree=None):
        """
        :param prerendered_luas:
            Mapping from config ids to their Lua representations, used to memorize already rendered configs
        :param shared_configs:
            Mapping from config ids to pairs of (uuid, just_pointer).
            If config's id is presented in this mapping and just_pointer is false, it should be
            wrapped in "shared" module with the specified uuid. If just_pointer is true, config should
            be replaced with "shared" module that points to the specified uuid.
        :param report_referring_configs:
            Mapping from report config ids to boolean, which indicates whether "uuid" must be moved to "refers" fields.
            See SWAT-3255 for details.
        """
        self.prerendered_luas = prerendered_luas or {}
        self.shared_configs = shared_configs or {}
        self.report_referring_configs = report_referring_configs or {}
        self.tree = tree


class Lua(str):
    pass


def to_lua(value, context=None, iter_=iter):
    if isinstance(value, (Config, Call)):
        return value.to_lua(context=context, iter_=iter_)
    elif isinstance(value, Lua):
        return value
    elif isinstance(value, six.string_types):
        return '"{}"'.format(value)
    elif isinstance(value, bool):
        if value:
            return 'true'
        else:
            return 'false'
    elif isinstance(value, int):
        return str(value)
    elif isinstance(value, float):
        return '{0:.3f}'.format(value)
    elif value is None:
        return 'nil'
    else:
        raise ValueError('can not convert {!r} to lua'.format(value))


def unique(seq):
    seen = set()
    seen_add = seen.add
    return [x for x in seq if not (x in seen or seen_add(x))]


def remove_node(config_ids_by_uuid, node):
    config_ids_by_uuid[node.config_uuid].remove(node.config_id)
    if not config_ids_by_uuid[node.config_uuid]:
        del config_ids_by_uuid[node.config_uuid]
    for child in node.children:
        remove_node(config_ids_by_uuid, child)


class Call(object):
    __slots__ = ('func', 'args')

    def __init__(self, func, args=None):
        self.func = func
        self.args = args or []

    def to_lua(self, context=None, iter_=iter):
        lua_args = []
        for arg in self.args:
            lua_arg = to_lua(arg, context=context, iter_=iter_)
            lua_args.append(lua_arg)
        return '{}({})'.format(self.func, ', '.join(lua_args))

    def __hash__(self):
        return hash(repr(self))

    def __repr__(self):
        return '{}({})'.format(self.func, repr(self.args).strip('[]'))

    def __eq__(self, other):
        if not isinstance(other, Call):
            return False
        return other.func == self.func and other.args == self.args

    def iter_nested_calls(self):
        for arg in self.args:
            if isinstance(arg, Call):
                yield arg
                for nested_call in arg.iter_nested_calls():
                    yield nested_call
            elif isinstance(arg, Config):
                for nested_call in arg.iter_calls():
                    yield nested_call

if six.PY3:
    def shash(value):
        """
        Returns a Python 2.7 hash for a string.
        Logic ported from the 2.7 Python branch: cpython/Objects/stringobject.c
        Method: static long string_hash(PyStringObject *a)
        Args:
            value: input string
        Returns:
            Python 2.7 hash
        """

        length = len(value)

        if length == 0:
            return 0

        mask = 0xffffffffffffffff
        x = (ord(value[0]) << 7) & mask
        for c in value:
            x = (1000003 * x) & mask ^ ord(c)

        x ^= length & mask

        # Convert to C long type
        x = ctypes.c_long(x).value

        if x == -1:
            x = -2

        return x
else:
    shash = hash


class Config(object):
    __slots__ = ('table', 'array', 'global_vars', 'shareable', 'report', 'outlets', 'compact', 'shared_uuid')

    def __init__(self, table=None, array=None, global_vars=None, shareable=False, report=False, outlets=(),
                 compact=False, shared_uuid=None):
        self.shareable = shareable
        self.report = report
        self.table = table if table is not None else {}
        self.array = array or []
        self.global_vars = global_vars or {}
        self.outlets = outlets
        self.compact = compact
        self.shared_uuid = shared_uuid

    def is_shared(self):
        return bool(self.shared_uuid)

    def is_shared_anchor(self):
        assert self.shared_uuid and self.table.get('uuid') == self.shared_uuid
        return not (
            not self.array and
            len(self.table) == 1 and
            next(iter(self.table)) == 'uuid'
        )

    def get_outlets(self):
        return self.outlets or (self,)

    def extend(self, config):
        # extend table
        if isinstance(config.table, OrderedDict):
            if not isinstance(self.table, OrderedDict):
                self.table = OrderedDict(six.iteritems(self.table))
        self.table.update(six.iteritems(config.table))

        # extend array
        if self.array:
            if isinstance(self.array, Call):
                if config.array:
                    raise ValueError('can\'t extend an unpacked function with another unpacked function or array')
            elif isinstance(self.array, list):
                if isinstance(config.array, Call):
                    raise ValueError('can\'t extend a config with array with unpacked function')
                else:
                    self.array.extend(config.array)
        else:
            if isinstance(config.array, Call):
                self.array = config.array
            else:
                self.array = list(config.array)

        self.outlets = config.outlets

        return self

    def _iter_keys(self):
        # iterate over table entries in deterministic order
        if isinstance(self.table, OrderedDict):
            return six.iterkeys(self.table)
        else:
            return sorted(self.table.keys())

    def _iter_items(self):
        to_chain = []
        # iterate over table entries in deterministic order
        if isinstance(self.table, OrderedDict):
            to_chain.append(six.itervalues(self.table))
        else:
            to_chain.append((value for _, value in sorted(six.iteritems(self.table))))
        if isinstance(self.array, list):
            to_chain.append(self.array)
        return itertools.chain.from_iterable(to_chain)

    def iter_calls(self):
        for key in self._iter_keys():
            if isinstance(key, Call):
                yield key
                for nested_call in key.iter_nested_calls():
                    yield nested_call
        for item in self._iter_items():
            if isinstance(item, Config):
                for call in item.iter_calls():
                    yield call
            elif isinstance(item, Call):
                yield item
                for nested_call in item.iter_nested_calls():
                    yield nested_call
        if isinstance(self.array, Call):
            yield self.array
            for nested_call in self.array.iter_nested_calls():
                yield nested_call

    def iter_configs(self):
        """
        Yields configs with in a deterministic DFS order.
        """
        yield self
        for item in self._iter_items():
            if not isinstance(item, Config):
                continue
            for config in item.iter_configs():
                yield config

    def iter_report_configs(self):
        """
        Yields configs with report=True in a deterministic DFS order.
        """
        if self.report:
            yield self
        for item in self._iter_items():
            if not isinstance(item, Config):
                continue
            for report_config in item.iter_report_configs():
                yield report_config

    def iter_shareable_configs(self, parent_shareable_config_id=None):
        """
        Performs DFS on configs with shareable=True in deterministic order.
        Yields pairs (parent shareable config id, config), starting from the deepest ones.
        """
        new_parent_shareable_config_id = parent_shareable_config_id
        if self.shareable:
            new_parent_shareable_config_id = id(self)
        for item in self._iter_items():
            if not isinstance(item, Config):
                continue
            for shareable_config in item.iter_shareable_configs(new_parent_shareable_config_id):
                yield shareable_config
        if self.shareable:
            yield (parent_shareable_config_id, self)

    def iter_shareable_and_shared_configs(self, parent_shareable_config_id=None):
        """
        Performs DFS on configs with shareable=True in deterministic order.
        Yields pairs (parent shareable config id, config), starting from the deepest ones.
        """
        new_parent_shareable_config_id = parent_shareable_config_id
        if self.shareable or self.shared_uuid:
            new_parent_shareable_config_id = id(self)
        for item in self._iter_items():
            if not isinstance(item, Config):
                continue
            for shareable_config in item.iter_shareable_and_shared_configs(new_parent_shareable_config_id):
                yield shareable_config
        if self.shareable or self.shared_uuid:
            yield (parent_shareable_config_id, self)

    @staticmethod
    def _create_uuid(s):
        return str(shash(s)).lstrip('-')

    def compute_context(self, iter_=iter):
        # a map of uuid to config ids
        config_ids_by_uuid = collections.defaultdict(list)
        dangling_nodes = collections.defaultdict(list)

        mem_context = Context()
        for parent_config_id, config in iter_(self.iter_shareable_and_shared_configs()):
            # iterate all shareable configs, note that deepest ones yielded first, from left to right
            config_id = id(config)
            lua = config.to_lua(context=mem_context, iter_=iter_)
            # optimization: put rendered configs to the temporary context to not render them again and again
            mem_context.prerendered_luas[config_id] = lua
            uuid = self._create_uuid(lua)
            config_ids_by_uuid[uuid].append(config_id)

            if config_id not in dangling_nodes:
                node = Node(config_id=config_id,
                            config=config,
                            config_uuid=uuid)
            else:
                node = Node(config_id=config_id,
                            config=config,
                            config_uuid=uuid,
                            children=dangling_nodes.pop(config_id))
            dangling_nodes[parent_config_id].append(node)

        branches = dangling_nodes.pop(None, [])
        assert not dangling_nodes

        tree = Node('root', children=branches)

        shared_configs = {}

        while 1:
            it = iter_(tree.bfs())
            next(it)  # skip the root node, it's not shareable

            for node in it:
                if node.config_uuid not in config_ids_by_uuid:
                    continue
                shared_config_ids = config_ids_by_uuid[node.config_uuid]
                uniq_shared_config_ids = unique(shared_config_ids)
                if len(uniq_shared_config_ids) >= 2:
                    node_ids_to_remove = collections.Counter(shared_config_ids)
                    del node_ids_to_remove[node.config_id]
                    for i, shared_config_id in enumerate(uniq_shared_config_ids):
                        just_pointer = shared_config_id != node.config_id
                        shared_configs[shared_config_id] = (node, just_pointer)
                    break
            else:
                break

            nodes_to_remove = []

            try:
                node = next(it)
                while iter_(itertools.count()):
                    if node.config_id in node_ids_to_remove:
                        nodes_to_remove.append(node)
                        node_ids_to_remove[node.config_id] -= 1
                        if not sum(node_ids_to_remove.values()):
                            break
                        node = it.send(True)  # send do_not_visit = True
                    else:
                        node = it.send(False)  # send do_not_visit = False
            except StopIteration:
                pass

            for node in nodes_to_remove:
                remove_node(config_ids_by_uuid, node)
                tree.remove_subtrees({n.id for n in nodes_to_remove})

        just_storage_report_config_ids = set()
        report_config_ids_by_report_uuid = collections.defaultdict(list)
        report_config_ids_order = []
        # iterate over all report configs and group them by report uuid
        for report_config in iter_(self.iter_report_configs()):
            report_config_id = id(report_config)
            report_uuid = report_config.table.get('uuid')
            if not report_uuid:
                continue
            if report_config_id not in report_config_ids_order:
                report_config_ids_order.append(report_config_id)
            if report_config.table.get('just_storage'):
                just_storage_report_config_ids.add(report_config_id)
            if report_config_id not in report_config_ids_by_report_uuid[report_uuid]:
                report_config_ids_by_report_uuid[report_uuid].append(report_config_id)

        shared_config_ids = set()
        for _, config_ids in six.iteritems(config_ids_by_uuid):
            shared_config_ids.update(config_ids)

        report_referring_configs = set()
        for report_uuid, report_config_ids in six.iteritems(report_config_ids_by_report_uuid):
            report_config_ids = [config_id for config_id in report_config_ids
                                 if config_id in shared_config_ids]
            if not report_config_ids:
                continue
            base_config_i = 0
            for i, config_id in enumerate(report_config_ids):
                if config_id in just_storage_report_config_ids:
                    base_config_i = i

            report_config_ids.pop(base_config_i)
            report_referring_configs.update(report_config_ids)

        return Context(tree=tree,
                       shared_configs=shared_configs,
                       report_referring_configs=report_referring_configs)

    def to_top_level_lua(self, iter_=iter):
        context = self.compute_context(iter_=iter)
        instance_lua_config = self.to_lua(context=context, iter_=iter_)

        global_vars = {}
        for config in iter_(self.iter_configs()):
            if config.global_vars:
                global_vars.update(config.global_vars)

        used_func_names = set()
        for call in iter_(self.iter_calls()):
            used_func_names.add(call.func)

        parts = []
        for var_name, value in sorted(global_vars.items()):
            parts.append('{} = {}'.format(var_name, to_lua(value, iter_)))

        required_func_names = set()
        for func_name in used_func_names:
            required_func_names.add(func_name)
            func = FUNCS[func_name]
            required_func_names.update(func.get_required_func_names())

        overridable_func_names = set()
        for func_name in sorted(required_func_names):
            func = FUNCS[func_name]
            parts.append(func.lua)
            if func.overridable:
                overridable_func_names.add(func_name)

        if overridable_func_names:
            parts.append(Call(_call_func_providers.name, args=(Config(array=sorted(overridable_func_names)),)).to_lua())

        parts.append('instance = {}'.format(instance_lua_config))
        return '\n\n\n'.join(parts)

    def to_lua(self, context=None, iter_=iter):
        table = self.table
        if context and self.report and id(self) in context.report_referring_configs:
            table = self.table.copy()
            uuid = table.pop('uuid')
            if 'refers' in table:
                table['refers'] = table['refers'] + ',' + uuid
            else:
                table['refers'] = uuid

        parts = []
        pinned_parts = []
        for key, item in iter_(six.iteritems(table)):
            if isinstance(key, Call):
                key = call_to_lua_key(key)
            else:
                key = escape_lua_key(key)
            pinned = False

            if context:
                id_ = id(item)
                if id_ in context.prerendered_luas:
                    # just an optimization
                    lua = context.prerendered_luas[id_]
                elif id_ in context.shared_configs:
                    node, just_pointer = context.shared_configs.pop(id_)
                    uuid = node.config_uuid
                    item_table = {
                        'uuid': uuid,
                    }
                    if not just_pointer:
                        item_table[key] = item
                    key = 'shared'
                    lua = Config(item_table).to_lua(context=context, iter_=iter_)
                    pinned = True
                else:
                    lua = to_lua(item, context=context, iter_=iter_)
            else:
                lua = to_lua(item, context=context, iter_=iter_)

            suffix = ' -- ' + key if '\n' in lua else ''
            part = '{} = {};{}'.format(key, lua, suffix)
            if pinned:
                pinned_parts.append(part)
            else:
                parts.append(part)

        if table and isinstance(table, OrderedDict):
            parts = ['\n'.join(parts)]  # join them early to prevent reordering
        else:
            parts.sort(key=functools.cmp_to_key(parts_cmp))
        parts.extend(pinned_parts)

        if isinstance(self.array, Call):
            parts.append('unpack({})'.format(to_lua(self.array, context=context, iter_=iter_)))
        else:
            for item in iter_(self.array):
                lua = to_lua(item, context=context, iter_=iter_)
                parts.append('{};'.format(lua))

        if parts:
            if self.compact:
                return '{ %s }' % ' '.join(parts)
            else:
                return '{\n%s\n}' % indent('\n'.join(parts))
        else:
            return '{}'
