from __future__ import annotations
import textwrap
from typing import List, Set

from infra.awacs.proto import modules_pb2
from awacs.wrappers.main import (ModuleWrapperBase,
                                 Threshold, Shared, Holder, Balancer2, ErrorDocument,
                                 Report, RegexpSection, Headers, ResponseHeaders, Rewrite, StatsEater)
from .model import BaseUpstreamSuggester
from .tlem import SectionVisitor
from .uemutil import (
    regexp_section_matcher_to_l7um_matcher_pb,
    balancer2_to_by_dc_scheme_pb,
    balancer2_to_l7um_flat_scheme_pb,
    errordocument_to_l7um_static_response_pb,
    to_rewrite_actions,
    list_set_fields,
)
from ..core import NamespaceConfig


Segment = List[ModuleWrapperBase]
Path = List[str]


def fmt_path(path):
    return ' -> '.join(path)


def fmt_segment(segment: Segment):
    return ' -> '.join([m.__class__.__name__ for m in segment])


class MismatchError(ValueError):
    def __init__(self, message: str, path: Path = []):
        super().__init__(f'At {fmt_path(path)}:\n{textwrap.indent(message, "  ")}')
        self.path = path


class Matcher:
    def __init__(self, id: str = None):
        self.id = id
        self.ctx = {}
        self.warnings = []
        self.matched = False

    def propagate_from(self, other_matcher: Matcher):
        self.ctx.update(other_matcher.ctx)
        self.warnings.extend(other_matcher.warnings)

    @property
    def name(self):
        rv = self.__class__.__name__
        if self.id is not None:
            rv += f'({self.id})'
        return rv

    def warn(self, path: Path, message: str):
        self.warnings.append((path, message))

    def match(self, segment: Segment, path: Path) -> Segment:
        path.append(self.name)
        self._match(segment[0], path)
        self.matched = True
        return segment[1:]

    def _match(self, segment: ModuleWrapperBase, path: Path):
        raise NotImplementedError

    def __repr__(self):
        return self.__class__.__name__


class RegexpSectionMatcher(Matcher):
    def _match(self, m, path):
        if not isinstance(m, RegexpSection):
            raise MismatchError('regexp_section expected', path)
        try:
            self.ctx['matcher_pb'] = regexp_section_matcher_to_l7um_matcher_pb(m.matcher.pb)
        except Exception as e:
            raise MismatchError(str(e), path)


class SharedMatcher(Matcher):
    def _match(self, m, path):
        if not isinstance(m, Shared):
            raise MismatchError('shared expected', path)
        if m.pb.uuid == 'backends':
            self.ctx['can_handle_announce_checks'] = True
        else:
            self.warn(path, f'shared with unexpected uuid {m.pb.uuid}')


class ThresholdMatcher(Matcher):
    def _match(self, m, path):
        if not isinstance(m, Threshold):
            raise MismatchError('hhreshold expected', path)
        # TODO: check and remember that is meets THRESHOLD_PROFILE_CORE_MAPS profile


class StatsEaterMatcher(Matcher):
    def _match(self, m, path):
        if not isinstance(m, StatsEater):
            raise MismatchError('stats_eater expected', path)


class ReportMatcher(Matcher):
    def _match(self, m, path):
        if not isinstance(m, Report):
            raise MismatchError('report expected', path)
        # TODO: reverse engineer L7UpstreamMacroMonitoringSettings.{fill_report_pb,fill_default_report_pb}
        self.ctx['uuid'] = m.pb.uuid
        if not list_set_fields(m.pb).issubset({'asd', 'asdss', '123'}):
            self.warn(path, 'seen expected fields {}')


class HeadersMatcher(Matcher):
    def _match(self, m, path):
        if not isinstance(m, Headers):
            raise MismatchError('headers expected', path)
        self.ctx['actions'] = SectionVisitor.get_header_actions(m)


class ResponseHeadersMatcher(Matcher):
    def _match(self, m, path):
        if not isinstance(m, ResponseHeaders):
            raise MismatchError('response_headers expected', path)
        self.ctx['actions'] = SectionVisitor.get_header_actions(m)


class RewriteMatcher(Matcher):
    def _match(self, m, path):
        if not isinstance(m, Rewrite):
            raise MismatchError('rewrite expected', path)
        try:
            self.ctx['actions'] = to_rewrite_actions(m.pb.actions)
        except Exception as e:
            raise MismatchError(str(e), path)


class OuterBalancerMatcher(Matcher):
    def __init__(self, upstream_id: str, id: str = None):
        super().__init__()
        self.upstream_id = upstream_id

    def _match(self, m, path):
        if not isinstance(m, Balancer2):
            raise MismatchError('balancer2 expected', path)
        if not m.backends:
            raise MismatchError('balancer2.backends expected', path)

        first_prefix = None
        first_backend_ctx = None
        seen_devnull = False
        try:
            by_dc_scheme_pb = balancer2_to_by_dc_scheme_pb(m)
        except Exception as e:
            raise MismatchError(str(e), path)

        for backend in m.backends:
            if '_' not in backend.pb.name:
                raise MismatchError(f'unexpected backend name {backend.pb.name}', path)

            prefix, dc_name = backend.pb.name.rsplit('_', 1)
            if dc_name not in ('sas', 'man', 'vla', 'iva', 'myt', 'devnull'):
                raise MismatchError(f'unexpected backend name with unknown dc: {backend.pb.name}', path)

            if dc_name == 'devnull':
                seen_devnull = True
                continue

            if first_prefix is None:
                first_prefix = prefix
            elif first_prefix != prefix:
                raise MismatchError(f'prefixes differ: {first_prefix} != {prefix}', path)

            inner_balancer_matcher = BackendsBalancerMatcher()
            mb_report_matcher = MaybeMatcher(ReportMatcher())
            backend_matcher = SequenceMatcher([
                mb_report_matcher,
                inner_balancer_matcher,
            ])
            backend_segment = list(backend.nested.walk_chain())
            backend_matcher.match(backend_segment, path)

            ctx = inner_balancer_matcher.ctx
            dc_pb = by_dc_scheme_pb.dcs.add(name=dc_name, backend_ids=ctx.pop('backend_ids'))
            if mb_report_matcher.matcher.matched:
                if mb_report_matcher.ctx['uuid'] != self.upstream_id + '_' + dc_name:
                    dc_pb.monitoring.uuid = mb_report_matcher.ctx['uuid']
            else:
                dc_pb.compat.disable_monitoring = True

            if first_backend_ctx is None:
                first_backend_ctx = ctx
            elif first_backend_ctx != ctx:
                raise MismatchError(f'dc branches differ: {first_backend_ctx} != {ctx}', path)

            # propagate warnings from branches
            for path, message in inner_balancer_matcher.warnings:
                self.warn(path + [backend.pb.name], message)

        by_dc_scheme_pb.dc_balancer.weights_section_id = first_prefix

        if seen_devnull:
            by_dc_scheme_pb.devnull.static.status = 200
            by_dc_scheme_pb.devnull.static.content = 'OK'
        else:
            by_dc_scheme_pb.compat.disable_devnull = True

        by_dc_scheme_pb.balancer.CopyFrom(first_backend_ctx.pop('flat_scheme_pb').balancer)
        self.ctx['by_dc_scheme_pb'] = by_dc_scheme_pb


class BackendsBalancerMatcher(Matcher):
    def _match(self, m, path):
        if not isinstance(m, Balancer2):
            raise MismatchError('balancer2 expected', path)
        try:
            self.ctx['flat_scheme_pb'] = balancer2_to_l7um_flat_scheme_pb(m)
        except Exception as e:
            raise MismatchError(str(e), path)
        if not m.generated_proxy_backends.include_backends:
            raise MismatchError('hardcoded backends used', path)
        self.ctx['backend_ids'] = list(m.generated_proxy_backends.include_backends.pb.ids)


class ErrorDocumentMatcher(Matcher):
    def _match(self, m, path):
        if not isinstance(m, ErrorDocument):
            raise MismatchError('errordocument expected', path)
        try:
            self.ctx['static_response_pb'] = errordocument_to_l7um_static_response_pb(m)
        except Exception as e:
            raise MismatchError(str(e), path)


class AnyOfMatcher(Matcher):
    """
    Runs and stops at first matcher that works.
    Propagates ctx and warnings.
    """

    def __init__(self, matchers: List[Matcher]):
        super().__init__()
        self.matchers = matchers

    def match(self, segment, path):
        errors = []

        tail = segment
        for m in self.matchers:
            try:
                tail = m.match(tail, list(path))
            except MismatchError as e:
                errors.append(e)
            else:
                self.ctx.update(m.ctx)
                self.warnings.extend(m.warnings)
                self.matched = True
                return tail

        errors_msg = '\n' + '\n'.join(['* ' + str(error) for error in errors])
        raise MismatchError(f'could not parse segment {fmt_segment(segment)} into matchers {self.matchers}: {errors_msg}', path)


class SequenceMatcher(Matcher):
    def __init__(self, matchers: List[Matcher]):
        super().__init__()
        self.matchers = matchers

    def match(self, segment, path):
        path = list(path) + [self.name]
        for m in self.matchers:
            segment = m.match(segment, path)
            self.warnings.extend(m.warnings)
        if segment:
            raise MismatchError(f'could not parse segment fully: {fmt_segment(segment)}', path)
        self.matched = True
        return segment


class RandomSequenceMatcher(Matcher):
    def __init__(self, matchers: Set[Matcher]):
        super().__init__()
        self.matchers = matchers

    def match(self, segment, path):
        path = list(path) + [self.name]
        matchers = set(self.matchers)
        matched = False
        while matchers:
            any_matched = False
            for m in set(matchers):
                try:
                    segment = m.match(segment, path)
                except MismatchError:
                    pass
                else:
                    self.warnings.extend(m.warnings)
                    matchers.discard(m)
                    any_matched = True
                    break
            if not any_matched:
                break
            else:
                matched = True
        self.matched = matched
        return segment


class MaybeMatcher(Matcher):
    def __init__(self, matcher: Matcher):
        super().__init__()
        self.matcher = matcher

    def __repr__(self):
        return f'MaybeMatcher({self.matcher.__class__.__name__})'

    def match(self, segment, path):
        try:
            rest = self.matcher.match(segment, path)
        except MismatchError:
            return segment
        else:
            self.ctx.update(self.matcher.ctx)
            self.warnings.extend(self.matcher.warnings)
            return rest
        finally:
            self.matched = True  # it's maybe after all :)


class RootMatcher(Matcher):
    def __init__(self, upstream_id: str, id: str = None):
        super().__init__()
        self.upstream_id = upstream_id

    def match(self, segment, path):
        regexp_section_matcher = RegexpSectionMatcher()

        shared_matcher = SharedMatcher()
        threshold_matcher = ThresholdMatcher()
        report_matcher = ReportMatcher()
        headers_matcher = HeadersMatcher()
        response_headers_matcher = ResponseHeadersMatcher()
        rewrite_matcher = RewriteMatcher()

        by_dc_scheme_matcher = OuterBalancerMatcher(self.upstream_id)
        flat_scheme_matcher = BackendsBalancerMatcher()
        static_response_matcher = ErrorDocumentMatcher()

        # TODO: this does not fully cover all the features of l7_upstream_macro
        root_matcher = SequenceMatcher([
            regexp_section_matcher,
            RandomSequenceMatcher({
                shared_matcher,
                threshold_matcher,
                report_matcher,
                headers_matcher,
                response_headers_matcher,
                rewrite_matcher,
                StatsEaterMatcher(),
            }),
            AnyOfMatcher([
                by_dc_scheme_matcher,
                flat_scheme_matcher,
                static_response_matcher,
            ]),
        ])

        try:
            root_matcher.match(segment, path)
            self.matched = True
        finally:
            self.warnings.extend(root_matcher.warnings)

        l7um_pb = modules_pb2.L7UpstreamMacro(
            version='0.0.1',
            id=self.upstream_id,
            matcher=regexp_section_matcher.ctx['matcher_pb'],
        )
        if report_matcher.matched:
            if report_matcher.ctx['uuid'] != self.upstream_id:
                l7um_pb.monitoring.uuid = report_matcher.ctx['uuid']
        else:
            l7um_pb.compat.disable_monitoring = True

        if shared_matcher.matched:
            l7um_pb.can_handle_announce_checks = shared_matcher.ctx['can_handle_announce_checks']

        if headers_matcher.matched:
            l7um_pb.headers.extend(headers_matcher.ctx['actions'])

        if response_headers_matcher.matched:
            l7um_pb.response_headers.extend(response_headers_matcher.ctx['actions'])

        if rewrite_matcher.matched:
            l7um_pb.rewrite.extend(rewrite_matcher.ctx['actions'])

        if flat_scheme_matcher.matched:
            l7um_pb.flat_scheme.CopyFrom(flat_scheme_matcher.ctx['flat_scheme_pb'])
            l7um_pb.flat_scheme.backend_ids.extend(flat_scheme_matcher.ctx['backend_ids'])

        if by_dc_scheme_matcher.matched:
            l7um_pb.by_dc_scheme.CopyFrom(by_dc_scheme_matcher.ctx['by_dc_scheme_pb'])

            # Just for prettiness:
            if l7um_pb.compat.disable_monitoring:
                for dc_pb in l7um_pb.by_dc_scheme.dcs:
                    dc_pb.compat.disable_monitoring = False
                    if not list_set_fields(dc_pb.compat):
                        dc_pb.ClearField('compat')

        self.ctx['l7_upstream_macro_pb'] = l7um_pb


class Checker(BaseUpstreamSuggester):
    RULE = 'UEM_SIMPLIFIED'

    def suggest(self, upstream_id, holder_pb: modules_pb2.Holder, namespace_config: NamespaceConfig = None):
        root_matcher = RootMatcher(upstream_id)

        holder = Holder(holder_pb)
        initial_segment = list(holder.walk_chain())
        initial_path = []
        root_matcher.match(initial_segment, initial_path)
        assert root_matcher.matched

        print('warnings')
        print(root_matcher.warnings)

        return True, '', modules_pb2.Holder(l7_upstream_macro=root_matcher.ctx['l7_upstream_macro_pb'])
