# coding: utf-8
import logging
import itertools
import cachetools

import gevent
import pire
import six

from awacs.lib.gutils import gevent_idle_iter
from infra.awacs.proto import model_pb2
from awacs.wrappers.base import Holder
from awacs.wrappers.luaparser import read_string
from awacs.wrappers.main import RegexpSection, MatchFsm
from awacs.model.balancer.vector import BalancerVersion, UpstreamVersion
from . import errors


class L7FastBalancerModeValidator(object):
    SPECIAL_UPSTREAM_IDS = ('slbping', 'slb_ping', 'default', 'awacs-balancer-health-check')

    def __init__(self, namespace_id, balancer_id):
        self._log = logging.getLogger(
            'l7-fast-balancer-mode-validator("{}:{}")'.format(namespace_id, balancer_id))

    @staticmethod
    def get_l7_fast_mode_upstream_route(upstream_spec_pb):
        """
        :type upstream_spec_pb: model_pb2.UpstreamSpec
        :rtype: str
        """
        assert upstream_spec_pb.type == model_pb2.YANDEX_BALANCER
        assert upstream_spec_pb.yandex_balancer.mode in (upstream_spec_pb.yandex_balancer.L7_FAST_MODE,
                                                         upstream_spec_pb.yandex_balancer.L7_FAST_SITEMAP_MODE)
        return upstream_spec_pb.yandex_balancer.config.prefix_path_router_section.route.strip('/')

    def _validate_section_prefixes(self, balancer_version, upstream_spec_pbs):
        """
        :type upstream_spec_pbs: dict[UpstreamVersion, awacs.proto.model_pb2.UpstreamSpec]
        """
        upstream_versions_by_route = {}
        for upstream_version, upstream_spec_pb in six.iteritems(upstream_spec_pbs):
            if upstream_version.deleted:
                continue
            _, upstream_id = upstream_version.upstream_id
            if upstream_id in self.SPECIAL_UPSTREAM_IDS:
                continue
            upstream_route = self.get_l7_fast_mode_upstream_route(upstream_spec_pb)
            if upstream_route in upstream_versions_by_route:
                conflicting_upstream_version = upstream_versions_by_route[upstream_route]
                raise errors.ConfigValidationError(
                    'Upstreams "{}" and "{}" have the same route'.format(upstream_version.upstream_id[1],
                                                                         conflicting_upstream_version.upstream_id[1]),
                    cause=max(balancer_version, upstream_version, conflicting_upstream_version))
            upstream_versions_by_route[upstream_route] = upstream_version

    def validate(self, balancer_version, balancer_spec_pb, upstream_spec_pbs):
        """
        :type balancer_version: BalancerVersion
        :type balancer_spec_pb: awacs.proto.model_pb2.BalancerSpec
        :type upstream_spec_pbs: dict[UpstreamVersion, awacs.proto.model_pb2.UpstreamSpec]
        """
        if not balancer_spec_pb.validator_settings.l7_fast_balancer_mode_enabled:
            return

        for upstream_version, upstream_spec_pb in six.iteritems(upstream_spec_pbs):
            _, upstream_id = upstream_version.upstream_id
            assert upstream_spec_pb.type == model_pb2.YANDEX_BALANCER
            if (upstream_id not in self.SPECIAL_UPSTREAM_IDS and
                    upstream_spec_pb.yandex_balancer.mode not in (upstream_spec_pb.yandex_balancer.L7_FAST_MODE,
                                                                  upstream_spec_pb.yandex_balancer.L7_FAST_SITEMAP_MODE)):
                raise errors.ConfigValidationError(
                    'Upstream "{}" is not in L7_FAST_MODE or L7_FAST_SITEMAP_MODE and '
                    'violates L7-fast balancer requirements')

            order_label = upstream_spec_pb.labels.get('order')
            has_order_label = 'order' in upstream_spec_pb.labels
            if not has_order_label or (upstream_id not in self.SPECIAL_UPSTREAM_IDS and order_label != '10000000'):
                raise errors.ConfigValidationError(
                    'Upstream "{}" violates L7-fast balancer requirements: '
                    'label "order" must be present (and equal to "10000000" for all upstreams except '
                    '"{}")'.format(upstream_id, '", "'.join(self.SPECIAL_UPSTREAM_IDS)),
                    cause=max(balancer_version, upstream_version))

        self._validate_section_prefixes(balancer_version, upstream_spec_pbs)


class CommonServicesBalancerModeValidator(object):
    def __init__(self, namespace_id, balancer_id):
        self._log = logging.getLogger(
            'common-services-balancer-mode-validator("{}:{}")'.format(namespace_id, balancer_id))

        self._fsm_cache = cachetools.RRCache(maxsize=2000)
        self._fsm_id_by_valid_upstream_versions = {}

    def validate(self,
                 # valid:
                 valid_vector,
                 # current:
                 balancer_version, balancer_spec_pb, upstream_spec_pbs, upstreams):

        """
        :param valid_vector: Vector
        :type balancer_version: BalancerVersion
        :type balancer_spec_pb: awacs.proto.model_pb2.BalancerSpec
        :type upstream_spec_pbs: dict[UpstreamVersion, awacs.proto.model_pb2.UpstreamSpec]
        :param upstreams: already wrapped upstream_spec_pbs
        :type upstreams: dict[UpstreamVersion, Holder]
        """
        if not balancer_spec_pb.validator_settings.common_services_balancer_mode_enabled:
            return

        self._log.debug('Start validating...')

        for upstream_version, upstream_spec_pb in six.iteritems(upstream_spec_pbs):
            full_upstream_id = upstream_version.upstream_id
            order_label = upstream_spec_pb.labels.get('order')
            has_order_label = 'order' in upstream_spec_pb.labels
            if (not has_order_label or
                    (full_upstream_id[1] not in ('slbping', 'slb_ping', 'default', 'awacs-balancer-health-check') and order_label != '10000000')):
                raise errors.ConfigValidationError(
                    'Upstream "{}" violates common services balancer mode enabled in balancer settings: '
                    'label "order" must be present (and equal to "10000000" for all upstreams except '
                    '"slb_ping", "slbping", "default" and "awacs-balancer-health-check")'.format(upstream_version.upstream_id),
                    cause=max(balancer_version, upstream_version))

        self._log.debug('FSMs cache size is {}...'.format(len(self._fsm_cache)))
        fsms = {}
        non_intersecting_fsm_ids_by_versions = {}
        fsm_ids_to_check_by_versions = {}
        for upstream_version, upstream_holder in gevent_idle_iter(six.iteritems(upstreams), idle_period=100):
            full_upstream_id = upstream_version.upstream_id
            if full_upstream_id[1] == 'default':
                continue
            chain = upstream_holder.walk_chain()
            first_module = next(chain)

            if not isinstance(first_module, RegexpSection):
                raise errors.ConfigValidationError(
                    'Upstream "{}" violates common services balancer mode enabled in balancer settings: '
                    'it is not a regexp section'.format(upstream_version.upstream_id),
                    cause=max(balancer_version, upstream_version))

            if full_upstream_id[1] == 'awacs-balancer-health-check':
                continue

            match_fsm = first_module.matcher.match_fsm
            if not match_fsm or not match_fsm.pb.host:
                raise errors.ConfigValidationError(
                    'Upstream "{}" violates common services balancer mode enabled in balancer settings: '
                    'first match_fsm must match by host'.format(upstream_version.upstream_id),
                    cause=max(balancer_version, upstream_version))

            pattern = read_string(match_fsm.pb.host)
            surround = match_fsm.pb.surround
            case_insensitive = MatchFsm.DEFAULT_CASE_INSENSITIVE
            if match_fsm.pb.HasField('case_insensitive'):
                case_insensitive = match_fsm.pb.case_insensitive.value
            fsm_id = (pattern, surround, case_insensitive)

            valid_upstream_version = valid_vector.upstream_versions.get(full_upstream_id)
            if upstream_version == valid_upstream_version:
                # this version is already marked as valid
                self._fsm_id_by_valid_upstream_versions[upstream_version] = fsm_id
                non_intersecting_fsm_ids_by_versions[upstream_version] = fsm_id
            elif fsm_id == self._fsm_id_by_valid_upstream_versions.get(valid_upstream_version):
                # fsm has not changed since this upstream has been marked as valid
                non_intersecting_fsm_ids_by_versions[upstream_version] = fsm_id
            else:
                fsm_ids_to_check_by_versions[upstream_version] = fsm_id

            if fsm_id in self._fsm_cache:
                fsm = self._fsm_cache[fsm_id]
            else:
                self._fsm_cache[fsm_id] = fsm = pire.parse_regexp(
                    pattern,
                    surround=match_fsm.pb.surround,
                    case_insensitive=case_insensitive)
                fsm.canonize()
            fsms[fsm_id] = fsm

        self._log.debug('FSMs size is {}, FSMs cache size is {}...'.format(len(fsms), len(self._fsm_cache)))
        self._log.debug('Compatible FSMs number: {}, FSM ids to check against them: {}'.format(
            len(non_intersecting_fsm_ids_by_versions), sorted(six.iteritems(fsm_ids_to_check_by_versions))))

        for v_1, fsm_id_1 in six.iteritems(fsm_ids_to_check_by_versions):
            pairs_to_check = []
            versions_to_check = []

            self._log.debug('Validating {}...'.format(fsm_id_1))

            for v_2, fsm_id_2 in itertools.chain(six.iteritems(non_intersecting_fsm_ids_by_versions),
                                                 six.iteritems(fsm_ids_to_check_by_versions)):
                if v_1 == v_2:
                    continue
                if fsm_id_1 == fsm_id_2:
                    raise errors.ConfigValidationError(
                        "Common services balancer mode violation: {}'s host regexp ({}) "
                        "intersects with {}'s host regexp ({})".format(
                            v_1.upstream_id, fsm_id_1[0],
                            v_2.upstream_id, fsm_id_2[0],
                        ),
                        cause=max(balancer_version, v_1, v_2))
                pairs_to_check.append((fsm_id_1, fsm_id_2))
                versions_to_check.append((v_1, v_2))

            chunk_size = 200
            i = 0
            n = len(pairs_to_check)
            while i * chunk_size < n:
                s = i * chunk_size
                e = s + chunk_size
                chunk = pairs_to_check[s:e]
                idx = pire.get_first_intersection_idx(chunk, fsms)
                if idx != -1:
                    failed_fsm_ids_pair = pairs_to_check[s + idx]
                    failed_versions_pair = versions_to_check[s + idx]
                    failed_fsm_id_1, failed_fsm_id_2 = failed_fsm_ids_pair
                    failed_v_1, failed_v_2 = failed_versions_pair
                    raise errors.ConfigValidationError(
                        "Common services balancer mode violation: {}'s host regexp ({}) "
                        "intersects with {}'s host regexp ({})".format(
                            failed_v_1.upstream_id, failed_fsm_id_1[0],
                            failed_v_2.upstream_id, failed_fsm_id_2[0],
                        ),
                        cause=max(balancer_version, failed_v_1, failed_v_2))
                i += 1
                gevent.sleep(0.05)
            self._log.debug('Validated {}...'.format(fsm_id_1))
            gevent.sleep(0.1)
        self._log.debug('Finished validating')
