import asyncio
import logging
import re

from dataclasses import dataclass
from typing import Tuple, Optional, Iterable

from saas.library.python.async_saas_slot import BackendSlot
from saas.library.python.deploy_manager_api import SaasService
from saas.tools.devops.fban.config import ServiceSettings

_metrics_regexps = {
    'Fusion_DistAge': {
        'total': re.compile(r'Fusion_DistAge[ ,].*total=[ ]*(-?[\d]+\.?[\d]*)'),
    },
    'Rtyserver_DistAge': {
        'total': re.compile(r'Rtyserver_DistAge[ ,].*total=[ ]*(-?[\d]+\.?[\d]*)'),
    },
    'Fusion_AckAge': {
        'total': re.compile(r'Fusion_AckAge[ ,].*total=[ ]*(-?[\d]+\.?[\d]*)'),
        'time': re.compile(r'Fusion_AckAge[ ,].*time=[ ]*(-?[\d]+\.?[\d]*)'),
    },
    'Rtyserver_AckAge': {
        'total': re.compile(r'Rtyserver_AckAge[ ,].*total=[ ]*(-?[\d]+\.?[\d]*)'),
        'time': re.compile(r'Rtyserver_AckAge[ ,].*time=[ ]*(-?[\d]+\.?[\d]*)'),
    },
    'Fusion_IndexedDocTimestamp': {
        'total': re.compile(r'Fusion_IndexedDocTimestamp[ ,].*total=[ ]*(-?[\d]+\.?[\d]*)'),
    },
    'Rtyserver_IndexedDocTimestamp': {
        'total': re.compile(r'Rtyserver_IndexedDocTimestamp[ ,].*total=[ ]*(-?[\d]+\.?[\d]*)'),
    }
}

_status_regexps = {
    'Active': re.compile(r'Active.*([\d]+)'),
    'Search_Server_Running': re.compile(r'Search_Server_Running.*([\d]+)'),
    'Shards': re.compile(r'Shards.*:[ ]*([\d]+)/')
}

_stream_regexps = {
    'receive_lag': re.compile(r'^\w+-\w+-SERV-CTYPE(.*)-receive_lag')
}


@dataclass
class FusionInstanceInfo:
    state: str
    receive_lag: Optional[float]
    search_enabled: bool

    action_result: Optional[str] = None


class FusionInstance:
    class State:
        ALIVE = 'alive'
        DEAD = 'dead'
        INACTIVE = 'inactive'

    class ActionResult:
        UNBAN_OK = 'unban-ok'
        UNBAN_FAIL = 'unban-fail'
        BAN_OK = 'ban-ok'
        BAN_FAIL = 'ban-fail'

    def __init__(self, service: SaasService, settings: ServiceSettings, host: str, port: int) -> None:
        self._slot: BackendSlot = BackendSlot(host, port)
        self._host: str = host
        self._port: int = port

        self._service: SaasService = service
        self._settings: ServiceSettings = settings
        self._info: Optional[FusionInstanceInfo] = None

    @property
    def host(self) -> str:
        return self._host

    @property
    def port(self) -> int:
        return self._port

    @property
    def info(self) -> Optional[FusionInstanceInfo]:
        return self._info

    async def _get_formatted_tass(self) -> dict:
        result = {}

        try:
            tass_data = await self._slot.get_tass()
            for item in tass_data:
                key, value = item
                key = key.split(';')[-1]

                result[key] = value
        except Exception as e:
            raise IncorrectTASSException(e)

        return result

    def _get_status_value(self, status: str, key: str) -> bool:
        regexp = _status_regexps.get(key)
        if not regexp:
            regexp = re.compile(rf'{key}.*:[ ]*(.*)')

        res = regexp.search(status)
        if not res:
            raise KeyError(f'Unable to read status "{key}" for {self}')

        return bool(int(res.group(1)))

    @staticmethod
    def _get_metric_value(metrics: str, metric: str, key: str) -> float:
        regexp = _metrics_regexps.get(metric, {}).get(key)
        if not regexp:
            regexp = re.compile(rf'{metric}[ ,].*{key}=[ ]*(-?[\d]+\.?[\d]*)')

        res = regexp.search(metrics)
        if not res:
            raise MetricNotFoundException(f'Unable to read metric "{metric}.{key}"')

        return float(res.group(1))

    @staticmethod
    def _get_stream_names(metrics: str, tass: dict, prefix: str) -> Iterable[str]:
        stream_names = set()

        # for distributor streams
        regexp = re.compile(rf'{prefix}([^ ]*)')
        matches = regexp.finditer(metrics)
        if matches:
            for res in matches:
                stream_names.add(res.group(1)[1:])

        # for other streams
        regexp = _stream_regexps['receive_lag']
        for signal in tass.keys():
            match = regexp.search(signal)
            if match:
                stream_names.add(match.group(1)[1:])

        if not stream_names:
            raise StreamNotFoundException(f'No stream names by prefix={prefix}')

        return stream_names

    def _calc_receive_lag_from_metrics(self, itype: str, metrics: str, metric_suffix: str) -> float:
        dt_dist = self._get_metric_value(metrics, f'{itype}_DistAge{metric_suffix}', 'total')
        dt_ack = self._get_metric_value(metrics, f'{itype}_AckAge{metric_suffix}', 'total')
        t_now = self._get_metric_value(metrics, f'{itype}_AckAge{metric_suffix}', 'time')
        t_index = self._get_metric_value(metrics, f'{itype}_IndexedDocTimestamp{metric_suffix}', 'total')

        # t_index must look like time (protection against fake 'Versions' in video-quick-saas)
        dt = t_now - t_index - dt_dist if t_index > 1300000000 else 0

        dt = dt if dt > 0 else 0
        dt_post = dt_ack - dt_dist
        dt_post = dt_post if dt_post > 0 else 0
        return dt + dt_dist + dt_post

    def _calc_receive_lag(self, metrics: str, tass: dict) -> Optional[float]:
        itype = self._settings.itype.capitalize()
        result = 0

        try:
            stream_names = self._get_stream_names(metrics, tass, f'{itype}_DistAge')

            for stream in stream_names:
                suffix = ('-' if stream else '') + stream
                signals = [f'backend-base{suffix}-receive_lag_avvv', f'backend-df{suffix}-receive_lag_avvv']

                filtered_signals = list(filter(
                    lambda x: x is not None,
                    map(lambda x: tass.get(x), signals)
                ))
                receive_lag = max(filtered_signals) if filtered_signals else None

                if receive_lag is None:
                    receive_lag = self._calc_receive_lag_from_metrics(itype, metrics, suffix)

                result = max(result, receive_lag)
        except (StreamNotFoundException, MetricNotFoundException) as e:
            logging.info(
                '[%s %s] The receive lag for %s:%d is considered zero, details: %s',
                self._service.name, self._service.ctype, self.host, self.port, e
            )
            return None
        except Exception:
            logging.exception(
                '[%s %s] Unable to calculate the receive lag for %s:%d',
                self._service.name, self._service.ctype, self.host, self.port
            )
            return None

        return result

    def _is_search_enabled(self, status: str) -> Tuple[bool, Optional[str]]:
        server_state = None

        try:
            search_enabled = self._get_status_value(status, 'Search_Server_Running')
            if not self._get_status_value(status, 'Active'):
                server_state = self.State.INACTIVE
        except Exception:
            search_enabled = False

        return search_enabled, server_state

    async def load_fusion_info(self) -> None:
        try:
            status, metrics, tass = await asyncio.gather(*[
                self._slot.get_status(),
                self._slot.get_metrics(),
                self._get_formatted_tass()
            ])
        except Exception:
            logging.warning('The instance is dead: %s', self)

            state = self.State.DEAD
            receive_lag = None
            search_enabled = False
        else:
            state = self.State.ALIVE
            receive_lag = self._calc_receive_lag(metrics, tass)

            search_enabled, server_state = self._is_search_enabled(status)
            state = server_state or state

        self._info = FusionInstanceInfo(state, receive_lag, search_enabled)

    async def ban(self):
        try:
            await self._slot.disable_search()
            return True
        except Exception:
            return False

    async def unban(self):
        try:
            await self._slot.enable_search()
            return True
        except Exception:
            return False

    def __str__(self) -> str:
        return f'{self.__class__.__name__}({self.host}:{self.port})'


class IncorrectTASSException(Exception):
    pass


class StreamNotFoundException(Exception):
    pass


class MetricNotFoundException(Exception):
    pass
