# coding: utf-8
import itertools

from awacs.lib.yamlparser.errors import SchemaError
from infra.awacs.proto import modules_pb2 as proto
from . import defs, errors, luaparser
from .util import Value
import six


def num_to_ith(num):
    """1 becomes 1st, 2 becomes 2nd, etc."""
    # http://stackoverflow.com/a/37343184/1577116
    value = str(num)
    before_last_digit = value[-2] if len(value) > 1 else None
    last_digit = value[-1]
    if len(value) > 1 and before_last_digit == '1':
        return value + 'th'
    if last_digit == '1':
        return value + 'st'
    if last_digit == '2':
        return value + 'nd'
    if last_digit == '3':
        return value + 'rd'
    return value + 'th'


def validate_args(func, args):
    """
    :type func: Func
    :param list args: list of raw args
    """
    expected_args_n = len(func.args)
    expected_optional_args_n = len(func.optional_args)
    actual_args_n = len(args)
    if expected_optional_args_n:
        if actual_args_n < expected_args_n:
            raise errors.ValidationError('{} expects at least {} arguments ({} given)'.format(
                func.name, expected_args_n, actual_args_n))
        if actual_args_n > expected_args_n + expected_optional_args_n:
            raise errors.ValidationError('{} expects at most {} arguments ({} given)'.format(
                func.name, expected_args_n + expected_optional_args_n, actual_args_n))
    else:
        if actual_args_n != expected_args_n:
            raise errors.ValidationError(
                '{} takes exactly {} arguments '
                '({} given)'.format(func.name, expected_args_n, actual_args_n))
    for i, (arg_value, (arg_name, arg_validator)) in enumerate(zip(args, six.iteritems(func.args)), start=1):
        try:
            arg_validator(arg_value)
        except ValueError as e:
            raise errors.ValidationError(
                '{}\'s {} argument ({}) {}'.format(func.name, num_to_ith(i), arg_name, (getattr(e, 'message', None) or str(e))))

    optional_args = args[expected_args_n:]
    for i, (arg_value, (arg_name, arg_validator)) in enumerate(zip(optional_args, six.iteritems(func.optional_args)),
                                                               start=expected_args_n + 1):
        try:
            arg_validator(arg_value)
        except ValueError as e:
            raise errors.ValidationError(
                '{}\'s {} argument ({}) {}'.format(func.name, num_to_ith(i), arg_name, (getattr(e, 'message', None) or str(e))))

    try:
        func.args_validator(args, optional_args)
    except ValueError as e:
        raise errors.ValidationError('invalid arguments for {}\'s call: {}'.format(func.name, (getattr(e, 'message', None) or str(e))))


def get_func_by_name(name):
    if name not in defs.USER_FUNCS:
        raise errors.ValidationError('unknown function "{}"'.format(name))
    return defs.USER_FUNCS[name]


def raw_call_to_call_pb(raw_call, call_pb):
    """
    :type raw_call: luaparser.RawCall
    :type call_pb: proto.Call
    """
    func_name = raw_call.func
    func = get_func_by_name(func_name)

    call_pb.type = getattr(proto.Call, func_name.upper())
    params_pb = getattr(call_pb, func_name + '_params')
    params_pb.SetInParent()
    full_args = itertools.chain(six.iteritems(func.args), six.iteritems(func.optional_args))
    name_value_pairs = zip(full_args, raw_call.args)
    for i, ((arg_name, arg_validator), arg_value) in enumerate(name_value_pairs, start=1):
        try:
            if isinstance(arg_value, luaparser.RawCall):
                arg_validator(Value(Value.CALL, arg_value))
                f_arg_name = 'f_{}'.format(arg_name)
                raw_call_to_call_pb(arg_value, getattr(params_pb, f_arg_name))
            else:
                arg_validator(Value(Value.VALUE, arg_value))
                if params_pb.DESCRIPTOR.fields_by_name[arg_name].message_type:
                    getattr(params_pb, arg_name).value = arg_value
                else:
                    setattr(params_pb, arg_name, arg_value)
        except ValueError as e:
            raise SchemaError(
                '{}\'s {} argument ({}) {}'.format(func_name, num_to_ith(i), arg_name, (getattr(e, 'message', None) or str(e))))


def call_pb_to_raw_call(call_pb):
    """
    :type call_pb: proto.Call
    :rtype: luaparser.RawCall
    """
    type_enum_desc = call_pb.DESCRIPTOR.fields_by_name['type'].enum_type
    func_name = type_enum_desc.values_by_number[call_pb.type].name.lower()
    func = get_func_by_name(func_name)

    rv = luaparser.RawCall(func=func_name, args=[])
    params_pb = getattr(call_pb, func_name + '_params')
    params_pb.SetInParent()

    for arg_name in six.iterkeys(func.args):
        f_arg_name = 'f_{}'.format(arg_name)
        if f_arg_name in params_pb.DESCRIPTOR.fields_by_name and params_pb.HasField(f_arg_name):
            arg_value = call_pb_to_raw_call(getattr(params_pb, f_arg_name))
        else:
            if params_pb.DESCRIPTOR.fields_by_name[arg_name].message_type:
                if params_pb.HasField(arg_name):
                    arg_value = getattr(params_pb, arg_name).value
                else:
                    arg_value = None
            else:
                arg_value = getattr(params_pb, arg_name)
        rv.args.append(arg_value)

    for arg_name in six.iterkeys(func.optional_args):
        f_arg_name = 'f_{}'.format(arg_name)
        if f_arg_name in params_pb.DESCRIPTOR.fields_by_name and params_pb.HasField(f_arg_name):
            arg_value = call_pb_to_raw_call(getattr(params_pb, f_arg_name))
            rv.args.append(arg_value)
        else:
            if params_pb.DESCRIPTOR.fields_by_name[arg_name].message_type:
                if params_pb.HasField(arg_name):
                    arg_value = getattr(params_pb, arg_name).value
                    rv.args.append(arg_value)
            else:
                arg_value = getattr(params_pb, arg_name)
                if arg_value:
                    rv.args.append(arg_value)

    return rv
