# coding: utf-8
import six
from awacs.lib import OrderedDict


class Func(object):
    def __init__(self, name, args, lua, optional_args=None, args_validator=None, requires=None, overridable=False):
        self.name = name
        self.args = args
        self.optional_args = optional_args or OrderedDict()
        self.args_validator = args_validator or (lambda args, optional_args: 1)
        self.requires = requires or frozenset()
        self.overridable = overridable
        self.lua = lua

    def get_required_func_names(self):
        rv = set(self.requires)
        if self.overridable:
            rv.add(_call_func_providers.name)
        return rv


def any_arg(value):
    """
    :type value: Value
    """
    pass


def str_arg(value):
    """
    :type value: Value
    """
    if value.is_func():
        return
    v = value.value
    if not isinstance(v, six.string_types):
        raise ValueError('must be a string')


def non_empty_str_arg(value):
    """
    :type value: Value
    """
    str_arg(value)
    if not value.is_func() and not value.value:
        raise ValueError('must not be empty')


def bool_arg(value):
    """
    :type value: Value
    """
    if value.is_func():
        return
    v = value.value
    if not isinstance(v, bool):
        raise ValueError('must be a boolean')


def int_arg(value):
    if value.is_func():
        return
    v = value.value
    if not isinstance(v, int):
        raise ValueError('must be an integer')


def positive_int_arg(value):
    if value.is_func():
        return
    v = value.value
    if not isinstance(v, int):
        raise ValueError('must be an integer')
    if v < 0:
        raise ValueError('must be non-negative')


def int_range_arg(value, min_, max_, exclusive_min=False, exclusive_max=False):
    """
    :type min_: int
    :type max_: int
    :type exclusive_min: bool
    :type exclusive_max: bool
    :raises: ValidationError
    """
    if value.is_func():
        return
    v = value.value
    if not isinstance(v, int):
        raise ValueError('must be an integer')

    start = min_
    if (exclusive_min and v <= start) or (not exclusive_min and v < start):
        cond = 'greater than' if exclusive_min else 'greater or equal to'
        raise ValueError('must be {} {}'.format(cond, min_))

    end = max_
    if (exclusive_max and v >= end) or (not exclusive_max and v > end):
        cond = 'less than' if exclusive_max else 'less or equal to'
        raise ValueError('must be {} {}'.format(cond, max_))


def number_arg(value):
    if value.is_func():
        return
    v = value.value
    if not isinstance(v, (int, float)):
        raise ValueError('must be a number')


def create_enum_arg(choices):
    def enum_arg(value):
        if value.is_func():
            raise ValueError('call is not allowed')
        if value.value not in choices:
            raise ValueError('must be one of the following: {}'.format(
                ', '.join(repr(choice) for choice in choices)))

    return enum_arg


get_ip_by_iproute = Func(
    name='get_ip_by_iproute',
    args=OrderedDict([
        ('family', create_enum_arg(('v4', 'v6'))),
    ]),
    lua="""
function get_ip_by_iproute(addr_family)
  if disable_external then
    if addr_family == "v4" then
      return "127.1.1.1"
    elseif addr_family == "v6" then
      return "127.2.2.2"
    else
      error("invalid parameter")
    end
  end

  local ipcmd
  if addr_family == "v4" then
    ipcmd = "ip route get 77.88.8.8 2>/dev/null| awk '/src/ {print $NF}'"
  elseif addr_family == "v6" then
    ipcmd = "ip route get 2a00:1450:4010:c05::65 2>/dev/null | grep -oE '2a02[:0-9a-f]+' | tail -1"
  else
    error("invalid parameter")
  end
  local handler = io.popen(ipcmd)
  local ip = handler:read("*l")
  handler:close()
  if ip == nil or ip == "" or ip == "proto" then
    return "127.0.0.2"
  end
  return ip
end
""".strip()
)

get_int_var = Func(
    name='get_int_var',
    args=OrderedDict([
        ('var', str_arg),
        ('default', int_arg),
    ]),
    lua="""
function check_int(value, var_name)
    return tonumber(value) or error("Could not cast variable \\"" .. var_name .. "\\" to a number.'")
end

function get_int_var(name, default)
  value = _G[name]
  return value and check_int(value) or default
end
""".strip()
)

get_port_var = Func(
    name='get_port_var',
    args=OrderedDict([
        ('var', str_arg),
    ]),
    optional_args=OrderedDict([
        ('offset', int_arg),
        ('default', int_arg),
    ]),
    requires={'get_int_var'},
    lua="""
function get_port_var(name, offset, default)
  value = get_int_var(name, default)
  if value == nil then
    error("Neither port variable \\"" .. name .. "\\" nor default port is specified.")
  end
  if value < 0 or value > 65535 then
    error("Variable \\"" .. name .. "\\" is not a valid port: " .. value)
  end
  if offset ~= nil then
    value = value + offset
  end
  return value
end
""".strip()
)

get_str_var = Func(
    name='get_str_var',
    args=OrderedDict([
        ('var', str_arg),
        ('default', str_arg),
    ]),
    lua="""
function get_str_var(name, default)
  return _G[name] or default
end
""".strip()
)

get_str_env_var = Func(
    name='get_str_env_var',
    args=OrderedDict([
        ('var', str_arg),
    ]),
    optional_args=OrderedDict([
        ('default', str_arg),
    ]),
    lua=r"""
function get_str_env_var(name, default)
  rv = os.getenv(name)
  if rv == nil then
    if default == nil then
      error(string.format('Environment variable "%s" is not set.', name))
    else
      return default
    end
  else
    return rv
  end
end
""".strip()
)

get_log_path = Func(
    name='get_log_path',
    args=OrderedDict([
        ('name', str_arg),
        ('port', int_arg),
    ]),
    optional_args=OrderedDict([
        ('default_log_dir', str_arg),
    ]),
    lua="""
function get_log_path(name, port, default_log_dir)
  default_log_dir = default_log_dir or "/place/db/www/logs"
  rv = (log_dir or default_log_dir) .. "/current-" .. name .. "-balancer";
  if port ~= nil then
    rv = rv .. "-" .. port;
  end
  return rv
end
""".strip()
)

get_public_cert_path = Func(
    name='get_public_cert_path',
    args=OrderedDict([
        ('name', str_arg),
    ]),
    optional_args=OrderedDict([
        ('default_public_cert_dir', str_arg),
    ]),
    lua="""
function get_public_cert_path(name, default_public_cert_dir)
  default_public_cert_dir = default_public_cert_dir or "/dev/shm/balancer"
  return (public_cert_dir or default_public_cert_dir) .. "/" .. name;
end
""".strip()
)

get_private_cert_path = Func(
    name='get_private_cert_path',
    args=OrderedDict([
        ('name', str_arg),
    ]),
    optional_args=OrderedDict([
        ('default_private_cert_dir', str_arg),
    ]),
    lua="""
function get_private_cert_path(name, default_private_cert_dir)
  default_private_cert_dir = default_private_cert_dir or "/dev/shm/balancer/priv"
  return (private_cert_dir or default_private_cert_dir) .. "/" .. name;
end
""".strip()
)

count_backends = Func(
    name='count_backends',
    args=OrderedDict(),
    optional_args=OrderedDict([
        ('compat_enable_sd_support', bool_arg),
    ]),
    lua=""
)

count_backends_sd = Func(
    name='count_backends_sd',
    args=OrderedDict(),
    lua=""
)

gen_proxy_backends = Func(
    name='gen_proxy_backends',
    args=OrderedDict([
        ('backends', any_arg),
        ('proxy_options', any_arg),
    ]),
    lua="""
function gen_proxy_backends(backends, proxy_options)
  local result = {}

  for index, backend in pairs(backends) do
    local proxy = {
      host = backend[1] or backend['host'];
      port = backend[2] or backend['port'];
      cached_ip = backend[4] or backend['cached_ip'];
    };

    if proxy_options ~= nil then
      for optname, optvalue in pairs(proxy_options) do
        proxy[optname] = optvalue
      end
    end

    result[index] = {
      weight = backend[3] or backend['weight'];
      proxy = proxy;
    };
  end

  if next(result) == nil then
    error("backends list is empty")
  end

  return result
end
""".strip()
)

get_geo = Func(
    name='get_geo',
    args=OrderedDict([
        ('name', str_arg),
    ]),
    optional_args=OrderedDict([
        ('default_geo', any_arg),
    ]),
    lua="""
function get_geo(name, default_geo)
  default_geo = default_geo or "random"
  return name .. (DC or default_geo);
end
""".strip()
)

suffix_with_dc = Func(
    name='suffix_with_dc',
    args=OrderedDict([
        ('name', str_arg),
    ]),
    optional_args=OrderedDict([
        ('default_dc', str_arg),
        ('separator', str_arg),
    ]),
    lua="""
function suffix_with_dc(name, default_dc, separator)
  dc = DC or default_dc or "unknown";
  separator = separator or "_";
  return name .. separator .. dc;
end
""".strip()
)

prefix_with_dc = Func(
    name='prefix_with_dc',
    args=OrderedDict([
        ('name', str_arg),
    ]),
    optional_args=OrderedDict([
        ('default_dc', str_arg),
        ('separator', str_arg),
    ]),
    lua="""
function prefix_with_dc(name, default_dc, separator)
  dc = DC or default_dc or "unknown";
  separator = separator or "_";
  return dc .. separator .. name;
end
""".strip()
)

get_ca_cert_path = Func(
    name='get_ca_cert_path',
    args=OrderedDict([
        ('name', str_arg),
    ]),
    optional_args=OrderedDict([
        ('default_ca_cert_dir', any_arg),
    ]),
    lua="""
function get_ca_cert_path(name, default_ca_cert_dir)
  default_ca_cert_dir = default_ca_cert_dir or "/dev/shm/balancer/priv"
  return (ca_cert_dir or default_ca_cert_dir) .. "/" .. name;
end
""".strip()
)


def validate_get_random_timedelta_args(args, optional_args):
    start = args[0]
    end = args[1]
    if start > end:
        raise ValueError('"start" must be less or equal to "end"')


get_random_timedelta = Func(
    name='get_random_timedelta',
    args=OrderedDict([
        ('start', positive_int_arg),
        ('end', positive_int_arg),
        ('unit', create_enum_arg(('ms', 's'))),
    ]),
    args_validator=validate_get_random_timedelta_args,
    lua="""
function get_random_timedelta(start, end_, unit)
  return math.random(start, end_) .. unit;
end
""".strip()
)

get_total_weight_percent = Func(
    name='get_total_weight_percent',
    args=OrderedDict([
        ('value', positive_int_arg),
    ]),
    optional_args=OrderedDict([
        ('allow_different_backend_weights', bool_arg),
    ]),
    lua=""
)

get_workers = Func(
    name='get_workers',
    args=OrderedDict(),
    requires={'get_int_var'},
    overridable=True,
    lua=r"""
function do_get_workers()
  -- actual get_workers() implementation, can be overridden
  value = _G["workers"]
  if value == nil then
    error('Variable "workers" is not specified.')
  end
  int_value = tonumber(value)
  if int_value == nil then
    error('Could not cast variable "workers" to a number.')
  end
  return int_value
end


function get_workers()
  value = do_get_workers()
  if type(value) ~= 'number' then
    error(string.format('Provided get_workers() implementation must return a number, not %s.', type(value)))
  end
  if value < 0 or value % 1 ~= 0 then
    error(string.format('Provided get_workers() implementation must return a non-negative integer, not %s', value))
  end
  return value
end
""".strip()
)

_call_func_providers = Func(
    name='_call_func_providers',
    args=OrderedDict(),
    lua=r"""
function _call_func_providers(overridable_func_names)
  for _, func_name in pairs(overridable_func_names) do
    local func_provider_path = _G[func_name .. "_provider"]
    if func_provider_path ~= nil then
      local env = {}
      setmetatable(env, {__index = _G})
      local provider, err = loadfile(func_provider_path, nil, env)
      if provider == nil then
        error(string.format('Failed to import provider "%s": %s', func_provider_path, err))
      end
      ok, rv = pcall(provider)
      if ok then
        if type(rv) ~= 'function' then
          error(string.format('Provider "%s" must return a function, not %s.', func_provider_path, type(rv)))
        end
        _G["do_" .. func_name] = rv
      else
        error(string.format('Provider "%s" failed: %s', func_provider_path, rv))
      end
    end
  end
end
""".strip()
)

get_its_control_path = Func(
    name='get_its_control_path',
    args=OrderedDict([
        ('filename', str_arg),
    ]),
    lua="""
function do_get_its_control_path(filename)
  -- actual get_its_control_path() implementation, can be overridden
  return "./controls/" .. filename
end

function get_its_control_path(filename)
  return do_get_its_control_path(filename)
end
""".strip()
)

_USER_FUNCS = (
    get_ip_by_iproute,
    get_str_var,
    get_str_env_var,
    get_int_var,
    get_port_var,
    get_log_path,
    get_public_cert_path,
    get_private_cert_path,
    count_backends,
    count_backends_sd,
    get_geo,
    suffix_with_dc,
    prefix_with_dc,
    get_ca_cert_path,
    get_random_timedelta,
    get_total_weight_percent,
    get_workers,
)
_INTERNAL_FUNCS = (
    gen_proxy_backends,
    get_its_control_path,
    _call_func_providers,

)
USER_FUNCS = {f.name: f for f in _USER_FUNCS}
FUNCS = {f.name: f for f in _USER_FUNCS + _INTERNAL_FUNCS}
del _USER_FUNCS
del _INTERNAL_FUNCS
