# coding: utf-8
import re
import collections
import semantic_version

import regex
import contextdecorator
import ipaddr
import pire
import six
from awacs.lib import validators
from six.moves import BaseHTTPServer

from infra.awacs.proto import modules_pb2 as proto
from .luaparser import read_string
from .errors import ValidationError


KILOBYTE = 1024
MEGABYTE = 1024 * 1024

ASTERISK = '*'
META_MODULE_ID = 'awacs-logs'

LOCAL_V4_ADDR = 'local_v4_addr'
LOCAL_V6_ADDR = 'local_v6_addr'

LOCAL_IPV4_NETWORK = ipaddr.IPv4Network('127.0.0.1/8')
LOCAL_IPV6_ADDR = ipaddr.IPAddress('::1')

# https://a.yandex-team.ru/arc/trunk/arcadia/util/datetime/parser_ut.cpp?rev=5533942#L547
DURATION_RE = re.compile(r'^[+-]?\d+(ns|us|ms|s|m|h|d|w|y)$')

VALID_FUNCS = (
    # https://wiki.yandex-team.ru/balancer/Cookbook/#kakdobavljatchto-todinamicheskoevzagolovki
    'reqid',
    'market_reqid',
    'realip',
    'realport',
    'localip',
    'localport',
    'starttime',
    'url',
    'location',
    'host',
    'yuid',
    'proto',
    'scheme',
    'ssl_client_cert_cn',
    'ssl_client_cert_subject',
    'ssl_client_cert_serial_number',
    'ssl_client_cert_verify_result',
    'ssl_handshake_info',
    'ssl_ticket_name',
    'ssl_ticket_iv',
    'ssl_early_data',
    'exp_static',
    'tcp_info',
    'ja3',
    'ja4',
    'search_reqid',
    'p0f',
)


OVERRIDE_CLUSTER_NAMES = {
    'test_sas': 'sas-test',
    'sas_test': 'sas-test',
    'man_pre': 'man-pre',
}


def is_close(a, b=None, tol=1e-8, ref=None):
    # https://groups.google.com/forum/#!topic/python-ideas/H3Aoxfgax8E%5B176-200%5D
    assert (a != 0 and b != 0) or ref is not None
    if b is None:
        assert ref is not None
        b = ref
    if ref is None:
        ref = abs(a) + abs(b)
    return abs(a - b) <= tol * ref


def is_addr_external(ip):
    if ip.is_func():
        return False
    if ip.value in (LOCAL_V4_ADDR, LOCAL_V6_ADDR):
        return False
    if ip.value == ASTERISK:
        return True
    ipaddr_ip = ipaddr.IPAddress(ip.value)
    return ipaddr_ip not in LOCAL_IPV4_NETWORK and ipaddr_ip != LOCAL_IPV6_ADDR


def is_addr_local(ip):
    ipaddr_ip = ipaddr.IPAddress(ip)
    return ipaddr_ip in LOCAL_IPV4_NETWORK or ipaddr_ip == LOCAL_IPV6_ADDR


def format_ip(ip):
    try:
        parsed_ip = ipaddr.IPAddress(ip)
    except ValueError:
        ip = ip
    else:
        ip = '[{}]'.format(ip) if parsed_ip.version == 6 else ip
    return ip


def format_ip_port(ip, port):
    """
    :type ip: Value
    :type port: int
    :rtype: six.text_type
    """
    return '{}:{}'.format(format_ip(ip), port)


def format_addr(addr):
    """
    :type addr: (Value, Value)
    :rtype: six.text_type
    """
    return format_ip_port(addr[0].value, addr[1].value)


TIMEDELTA_RE = re.compile(r'^([0-9]+)(\.[0-9]+)?(ms|s|m)$')
LONG_TIMEDELTA_RE = re.compile(r'^([0-9]+)(ms|s|m|h|d)$')
STATUS_CODE_OR_FAMILY_RE = re.compile(r'^[1-5][0-9]{2}|[1-5]xx$')
STATUS_CODE_RE = re.compile(r'^[1-5][0-9]{2}$')


def validate_match(value, pattern):
    match = re.match(pattern, value)
    if not match:
        raise ValidationError('must match {}'.format(pattern.pattern))


def validate_timedelta(value):
    if not TIMEDELTA_RE.match(value):
        raise ValidationError('"{}" is not a valid timedelta string'.format(value))


def validate_timedeltas(value):
    for td in value.split(','):
        try:
            validate_timedelta(td)
        except ValidationError:
            raise ValidationError('"{}" is not a valid timedeltas string'.format(value))


def validate_long_timedelta(value):
    if not LONG_TIMEDELTA_RE.match(value):
        raise ValidationError('"{}" is not a valid timedelta string'.format(value))


def validate_ip(value, field_name=None):
    if not validators.ipv6(value) and not validators.ipv4(value) and value not in ('::', ASTERISK):
        raise ValidationError('is not a valid IP address', field_name=field_name)


def validate_comma_separated_subnets(value, field_name=None):
    items = value.split(',')
    for item in items:
        try:
            ip_part, prefix_part = item.rsplit('/', 1)
            if not re.match(r'^\d+$', prefix_part):
                raise ValueError()
            prefix = int(prefix_part)
            ip = ipaddr.IPAddress(ip_part)
        except ValueError:
            raise ValidationError('"{}" is not a valid subnet'.format(item), field_name)
        else:
            if prefix < 0:
                raise ValidationError('"{}" is not a valid prefix'.format(prefix_part), field_name)
            if prefix > 128:
                raise ValidationError('"{}" is not a valid prefix'.format(prefix_part), field_name)
            if prefix > 32:
                if ip.version != 6:
                    raise ValidationError('"{}" is not a valid prefix for "{}"'.format(prefix_part, ip), field_name)


MAX_PORT = 2 ** 16 - 1


def validate_port(value, field_name=None):
    if not (0 < value <= MAX_PORT):
        raise ValidationError('is not a valid port', field_name=field_name)


def __compare_with_asterisk(ip):
    try:
        is_local = is_addr_local(ip.value)
    except ValueError:
        is_local = False
    if not is_local:
        return ip
    else:
        return None


def __compare_ips(ip1, ip2):
    if ip1.value == ASTERISK:
        return __compare_with_asterisk(ip2)
    elif ip2.value == ASTERISK:
        return __compare_with_asterisk(ip1)
    elif ip1 == ip2:
        return ip1
    else:
        return None


def __compare_addrs(addr1, addr2):
    ip1, port1 = addr1
    ip2, port2 = addr2
    ip = __compare_ips(ip1, ip2)
    if ip and port1 == port2:
        return ip, port1
    else:
        return None


def __contains(container, value, compare):
    for elem in container:
        result = compare(elem, value)
        if result:
            return result


def contains_ip(ips, ip):
    return __contains(ips, ip, __compare_ips)


def contains_addr(addrs, addr):
    return __contains(addrs, addr, __compare_addrs)


def intersect_addrs(addrs_1, addrs_2):
    result = [contains_addr(addrs_2, addr) for addr in addrs_1]
    return [addr for addr in result if addr]


def validate_pire_regexp(string, lua_unescape_first=True):
    if lua_unescape_first:
        try:
            string = read_string(string)  # unescape it as Lua would do
        except ValueError as e:
            raise ValidationError('is not a valid Lua string ({})'.format(e))
    if isinstance(string, six.text_type):
        try:
            string = string.encode('ascii')
        except UnicodeEncodeError as e:
            raise ValidationError('is not a valid regexp ({})'.format(e))
    try:
        try:
            re.compile(string)
        except AssertionError:
            # romanovich@: "sorry, but this version only supports 100 named groups"
            # is the only assertion error I've ever seen (see AWACS-822 for details).
            # Let's re-check using https://pypi.org/project/regex/:
            regex.compile(string)
        fsm = pire.parse_regexp(string)
        fsm.canonize()
    except Exception as e:
        raise ValidationError('is not a valid regexp ({})'.format(e))
    if fsm.contains_begin_mark():
        raise ValidationError('is not a valid regexp: using ^ anchor is not allowed')
    if fsm.contains_end_mark():
        raise ValidationError('is not a valid regexp: using $ anchor is not allowed')


def validate_re2_regexp(string, lua_unescape_first=True):
    # https://st.yandex-team.ru/SWAT-5830
    if lua_unescape_first:
        try:
            string = read_string(string)  # unescape it as Lua would do
        except ValueError as e:
            raise ValidationError('is not a valid Lua string ({})'.format(e))
    try:
        re.compile(string)
    except Exception as e:
        raise ValidationError('is not a valid regexp ({})'.format(e))


def validate_item_uniqueness(items):
    seen_items = set()
    for item in items:
        if item in seen_items:
            raise ValidationError('duplicate item "{}"'.format(item))
        seen_items.add(item)


def validate_key_uniqueness(items):
    keys = set()
    for key, _ in items:
        if key in keys:
            raise ValidationError('duplicate key "{}"'.format(key))
        keys.add(key)


def validate_status_code_or_family(value):
    if not STATUS_CODE_OR_FAMILY_RE.match(value):
        raise ValidationError('unknown status code or family: {}'.format(value))


def validate_status_code(value):
    if not STATUS_CODE_RE.match(value):
        raise ValidationError('unknown status code: {}'.format(value))


def validate_status_code_3xx(value):
    if not (value.startswith('3') and STATUS_CODE_RE.match(value)):
        raise ValidationError('unknown status code: {}'.format(value))


def validate_status_codes(value, allow_families=True):
    for code in value:
        if allow_families:
            validate_status_code_or_family(code)
        else:
            validate_status_code(code)


HEADER_NAME_RE = re.compile('^[a-zA-Z][-a-zA-Z0-9_]*$')


def validate_header_name(value):
    if not HEADER_NAME_RE.match(value):
        raise ValidationError('invalid header name: "{}"'.format(value))


COOKIE_NAME_RE = re.compile('^[-_a-zA-Z0-9]+$')


def validate_cookie_name(value):
    validate_match(value, COOKIE_NAME_RE)


def append_field_name_to_validation_error(e, field_name):
    """
    :type e: ValidationError
    :type field_name: six.text_type
    """
    if isinstance(field_name, six.string_types):
        e.path.appendleft(field_name)
    else:
        e.path.extendleft(reversed(field_name))


class validate(contextdecorator.ContextDecorator):
    __slots__ = ('field_name',)

    def __init__(self, field_name):
        self.field_name = field_name

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_val, exc_tb):
        if isinstance(exc_val, ValidationError):
            append_field_name_to_validation_error(exc_val, self.field_name)


class HTTPRequest(BaseHTTPServer.BaseHTTPRequestHandler):
    default_request_version = None

    def __init__(self, request_text):
        self.rfile = six.BytesIO(request_text)
        self.raw_requestline = self.rfile.readline()
        self.error_code = self.error_message = self.headers = None
        self.parse_request()

    if six.PY3:
        def send_error(self, code, message=None, explain=None):
            self.error_code = code
            self.error_message = message
    else:
        def send_error(self, code, message):
            self.error_code = code
            self.error_message = message


VALID_REQUEST_LINE_HTTP_VERSIONS = frozenset(('HTTP/1.0', 'HTTP/1.1'))
VALID_REQUEST_LINE_HTTP_COMMANDS = frozenset(('GET', 'HEAD'))


def validate_request_line(value):
    if not value.endswith((r'\n\n', r'\r\n\r\n')):
        raise ValidationError(r'must end with "\n\n" or "\r\n\r\n"')

    if six.PY3:
        line = value.encode('ascii').decode('unicode_escape').encode('ascii')
    else:
        line = value.decode('string_escape')
    request = HTTPRequest(line)

    if request.error_code:
        if request.error_message:
            raise ValidationError('is not valid: {}'.format(request.error_message))
        else:
            raise ValidationError('is not valid')

    if request.command not in VALID_REQUEST_LINE_HTTP_COMMANDS:
        raise ValidationError('command "{}" is not allowed'.format(request.command))

    if not request.request_version:
        raise ValidationError('http version is missing')

    if request.request_version not in VALID_REQUEST_LINE_HTTP_VERSIONS:
        raise ValidationError('http version "{}" is not valid'.format(request.request_version))

    if request.request_version == 'HTTP/1.1' and 'Host' not in request.headers:
        raise ValidationError('header "Host" must be set for {}'.format(request.request_version))


def timedelta_to_ms(value):
    """
    :type value: six.text_type
    :rtype: int
    """
    # value is ^([0-9]+)(ms|s)$ (see TIMEDELTA_RE)
    if value.endswith('ms'):
        unit_size = 1
    elif value.endswith('s'):
        unit_size = 1000
    elif value.endswith('m'):
        unit_size = 1000 * 60
    else:
        raise ValueError('value is not a valid timedelta')

    unsuffixed_value = value.strip('ms')  # strip all s and m letters from the end
    try:
        quantity = float(unsuffixed_value)
    except ValueError:
        raise ValueError('value is not a valid timedelta')

    if quantity < 0:
        raise ValueError('value is not a valid timedelta')

    return int(quantity * unit_size)


def validate_timedelta_range(value, min_, max_, exclusive_min=False, exclusive_max=False):
    """
    :type value: six.text_type
    :type min_: six.text_type
    :type max_: six.text_type
    :type exclusive_min: bool
    :type exclusive_max: bool
    :raises: ValidationError
    """
    validate_timedelta(value)
    validate_timedelta(min_)
    validate_timedelta(max_)
    value_ms = timedelta_to_ms(value)

    start_ms = timedelta_to_ms(min_)
    if (exclusive_min and value_ms <= start_ms) or (not exclusive_min and value_ms < start_ms):
        cond = 'greater than' if exclusive_min else 'greater or equal to'
        raise ValidationError('must be {} {}'.format(cond, min_))

    end_ms = timedelta_to_ms(max_)
    if (exclusive_max and value_ms >= end_ms) or (not exclusive_max and value_ms > end_ms):
        cond = 'less than' if exclusive_max else 'less or equal to'
        raise ValidationError('must be {} {}'.format(cond, max_))


def validate_range(value, min_, max_, exclusive_min=False, exclusive_max=False):
    """
    :type value: int | float
    :type min_: int | float
    :type max_: int | float
    :type exclusive_min: bool
    :type exclusive_max: bool
    :raises: ValidationError
    """
    if (exclusive_min and value <= min_) or (not exclusive_min and value < min_):
        cond = 'greater than' if exclusive_min else 'greater or equal to'
        raise ValidationError('must be {} {}'.format(cond, min_))
    if (exclusive_max and value >= max_) or (not exclusive_max and value > max_):
        cond = 'less than' if exclusive_max else 'less or equal to'
        raise ValidationError('must be {} {}'.format(cond, max_))


def validate_comma_separated_ints(value):
    for item in value.split(','):
        if not item.isdigit():
            raise ValidationError('must be a string with comma-separated integers')


class Value(collections.namedtuple('Value', ['type', 'value'])):
    VALUE = 0
    CALL = 1
    KNOB = 2
    CERT = 3

    def to_config(self, ctx):
        if self.is_func():
            return self.value.to_config(ctx=ctx)
        elif self.is_knob():
            raise AssertionError('knob "{}" is not injected'.format(self.value.id))
        elif self.is_cert():
            raise AssertionError('certificate "{}" is not injected'.format(self.value.id))
        else:
            return self.value

    def is_knob(self):
        return self.type == self.KNOB

    def is_cert(self):
        return self.type == self.CERT

    def is_func(self):
        return self.type == self.CALL


def validate_func_name_one_of(call, choices):
    """
    :type call: Call
    :type choices: Iterable[six.text_type]
    """
    if call.func_name not in choices:
        raise ValidationError('only the following functions allowed here: "{}"'.format('", "'.join(choices)))


def fill_get_public_cert_path(call_pb, name, default_public_cert_dir=None):
    """
    :type call_pb: awacs.proto.modules_pb2.Call
    :type name: six.text_type
    :type default_public_cert_dir: six.text_type
    """
    call_pb.type = proto.Call.GET_PUBLIC_CERT_PATH
    params_pb = call_pb.get_public_cert_path_params
    params_pb.name = name
    if default_public_cert_dir:
        params_pb.default_public_cert_dir = default_public_cert_dir


def fill_get_private_cert_path(call_pb, name, default_private_cert_dir=None):
    """
    :type call_pb: awacs.proto.modules_pb2.Call
    :type name: six.text_type
    :type default_private_cert_dir: six.text_type
    """
    call_pb.type = proto.Call.GET_PRIVATE_CERT_PATH
    params_pb = call_pb.get_private_cert_path_params
    params_pb.name = name
    if default_private_cert_dir:
        params_pb.default_private_cert_dir = default_private_cert_dir


def host_to_regexp(host):
    """
    :type host: six.text_type
    :rtype: six.text_type
    """
    if six.PY3:
        rv = str.encode(host, encoding='idna', errors='strict')
    else:
        rv = host.encode(u'idna')
    return rv.decode('ascii').replace(u'.', u'\\.').replace(u'*', u'[^.]+')


def hosts_to_regexp(hosts):
    """
    :type hosts: Iterable[six.text_type]
    :rtype: six.text_type
    """
    return u'|'.join([u'(' + host_to_regexp(host) + u')' for host in hosts])


def port_to_regexp(port):
    """
    :type port: int
    :rtype: six.text_type
    """
    return six.text_type(port)


def ports_to_regexp(ports):
    """
    :type ports: list[int]
    :rtype: six.text_type
    """
    rv = u'|'.join([port_to_regexp(port) for port in ports])
    if len(ports) > 1:
        rv = u'(' + rv + u')'
    return rv


def lua_escape(s):
    return str(s).encode('string-escape')


def validate_version(version, valid_versions, field_name=None):
    if not semantic_version.validate(version):
        raise ValidationError('is not a valid version', field_name=field_name)
    if semantic_version.Version(version) not in valid_versions:
        raise ValidationError('is not supported', field_name=field_name)


def is_valid_func(func):
    if func.startswith('time:'):
        param = func[len('time:'):]
        return DURATION_RE.match(param) is not None
    else:
        return func in VALID_FUNCS


def validate_header_func(func, header, hint=None):
    if is_valid_func(func):
        return
    message = 'invalid func "{}" for header "{}"'.format(func, header)
    if hint:
        message += '. {}'.format(hint)
    raise ValidationError(message)


def cluster_name_from_alias(cluster):
    return OVERRIDE_CLUSTER_NAMES.get(cluster, cluster)
