import logging
from collections import defaultdict

import ticket_parser2 as tp2
from gevent import spawn
from gevent.event import Event
from gevent.lock import Semaphore
from ticket_parser2.low_level import ServiceContext

from object_validator import Dict, DictScheme, String
from sepelib.core import config, constants
from walle.clients.juggler import send_event, JugglerCheckStatus
from walle.clients.utils import request, get_json_response, HttpClientError
from walle.errors import RecoverableError, UserRecoverableError, FixableError
from walle.models import timestamp
from walle.stats import stats_manager

log = logging.getLogger(__name__)

TICKET_REFRESH_PERIOD = constants.HOUR_SECONDS
TIME_BETWEEN_RETRIES = 10
GET_TICKET_TIMEOUT = 10


class TvmApiError(RecoverableError):
    def __init__(self, message, *args):
        super().__init__("Error in communication with TVM API: " + message, *args)


class TvmIsNotReady(RecoverableError):
    def __init__(self):
        super().__init__("TVM ticket update didn't finish before timeout")


class TvmUnknownAppId(UserRecoverableError):
    def __init__(self, tvm_app_id):
        super().__init__("TVM app id {} is not known to Wall-e", tvm_app_id)


class TvmSourceIsNotAllowed(UserRecoverableError):
    def __init__(self, tvm_app_id):
        super().__init__("TVM app id {} is not allowed to access Wall-e", tvm_app_id)


class TvmServiceTicketUpdateError(FixableError):
    def __init__(self, tvm_app_id, alias, api_error_msg):
        super().__init__(
            "Got TVM API error '{}' while getting service ticket for tvm app id {} (alias {})",
            api_error_msg,
            tvm_app_id,
            alias,
        )


class TvmServiceTicketManager:
    """Generates outcoming and validates incoming TVM service tickets"""

    def __init__(self, alias_mappers):
        """:arg alias_mappers: List of callables that return dict alias->tvm_app_id"""
        self._alias_mappers = alias_mappers
        self._alias_to_tvm_app_id = {}
        self._tvm_app_id_to_aliases = {}
        self._alias_to_ticket = {}
        self._alias_to_api_error = {}
        self._walle_tvm_app_id = config.get_value("tvm.app_id")

        self._ready = Event()
        self._need_update = Event()
        self._stop = False

        self._restart_semaphore = Semaphore()

        self._service_context = None
        self._refresher = spawn(self._ticket_refresher)

    def get_ticket_for_alias(self, alias):
        if alias not in self._alias_to_tvm_app_id:
            with self._restart_semaphore:
                if alias not in self._alias_to_tvm_app_id:
                    log.info("We are not aware of alias %s, will update list of aliases and fetch fresh tickets", alias)
                    self._set_need_update()
                    if not self._ready.wait(timeout=GET_TICKET_TIMEOUT):
                        raise TvmIsNotReady()

        if alias in self._alias_to_api_error:
            raise TvmServiceTicketUpdateError(self._alias_to_tvm_app_id[alias], alias, self._alias_to_api_error[alias])
        return self._alias_to_ticket[alias]

    def check_service_ticket(self, ticket, allowed_source_aliases):
        if not self._ready.wait(timeout=GET_TICKET_TIMEOUT):
            raise TvmIsNotReady()

        source_app_id = self.get_service_ticket_source_app_id(ticket)
        source_aliases = self._tvm_app_id_to_aliases.get(source_app_id)
        if source_aliases is None:
            raise TvmUnknownAppId(source_app_id)

        allowed = bool(set(source_aliases).intersection(allowed_source_aliases))
        if not allowed:
            raise TvmSourceIsNotAllowed(source_app_id)

    def get_service_ticket_source_app_id(self, ticket):
        return self._service_context.check(ticket).src

    def is_ready(self):
        return self._ready.is_set()

    def stop(self):
        self._stop = True
        self._need_update.set()

    def _set_need_update(self):
        self._need_update.set()
        self._ready.clear()

    def _ticket_refresher(self):
        while True:
            if self._stop:
                self._ready.clear()
                break

            try:
                self._refresh_tickets()
            except Exception as e:
                self._handle_api_error(e)
                timeout = TIME_BETWEEN_RETRIES
            else:
                self._handle_successful_refresh()
                timeout = TICKET_REFRESH_PERIOD

            self._need_update.wait(timeout=timeout)

    def _refresh_tickets(self):
        log.info("Starting TVM service tickets update")
        self._update_alias_to_tvm_app_id()
        if self._alias_to_tvm_app_id:
            self._alias_to_ticket = self._get_tickets_for_dsts(self._alias_to_tvm_app_id.values())
        log.info("Finished TVM service tickets update")

    def _update_alias_to_tvm_app_id(self):
        alias_mappings = {}
        for mapper in self._alias_mappers:
            alias_mappings.update(mapper())
        self._alias_to_tvm_app_id = alias_mappings

        tvm_app_id_to_aliases = defaultdict(list)
        for alias, tvm_app_id in self._alias_to_tvm_app_id.items():
            tvm_app_id_to_aliases[tvm_app_id].append(alias)
        self._tvm_app_id_to_aliases = tvm_app_id_to_aliases

    def _get_tickets_for_dsts(self, dsts):
        self._service_context = self._get_service_context()
        tickets_resp = self._get_tickets_resp(dsts)
        alias_to_ticket = self._extract_tickets_from_response(tickets_resp)
        return alias_to_ticket

    def _get_tickets_resp(self, dsts):
        request_data = self._get_ticket_request_data(dsts)
        raw_resp = self._api_request("POST", "/ticket/", data=request_data, error_from_contents=True, timeout=10)
        schema = Dict(
            key_scheme=String(regex=r"\d+"),
            value_scheme=DictScheme({"ticket": String(optional=True), "error": String(optional=True)}),
        )
        tickets_resp = get_json_response(raw_resp, scheme=schema)
        return tickets_resp

    def _get_ticket_request_data(self, dsts):
        now = timestamp()
        dst_str = ",".join(str(dst) for dst in set(dsts))
        signature = self._service_context.sign(now, dst_str)
        request_data = {
            "grant_type": "client_credentials",
            "src": self._walle_tvm_app_id,
            "dst": dst_str,
            "ts": now,
            "sign": signature,
        }
        return request_data

    def _extract_tickets_from_response(self, resp):
        self._alias_to_api_error = {}
        alias_to_ticket = {}
        for tvm_app_id_str, contents in resp.items():
            tvm_app_id = int(tvm_app_id_str)

            error = contents.get("error", None)
            if error:
                self._handle_service_ticket_update_error(tvm_app_id, error)
            else:
                ticket = contents["ticket"]
                for alias in self._tvm_app_id_to_aliases[tvm_app_id]:
                    alias_to_ticket[alias] = ticket

        return alias_to_ticket

    def _get_service_context(self):
        tvm_keys = self._get_pub_keys()
        return ServiceContext(self._walle_tvm_app_id, config.get_value("tvm.secret"), tvm_keys)

    def _get_pub_keys(self):
        pub_keys_path = "/keys?lib_version={version}".format(version=tp2.__version__)
        return self._api_request("GET", pub_keys_path, as_text=True, timeout=10)

    def _handle_api_error(self, exception):
        log.exception("Got exception during TVM tickets update:")
        self._notify_juggler(exception)
        self._need_update.clear()

    def _handle_successful_refresh(self):
        stats_manager.set_age_timestamp("tvm-tickets")
        self._ready.set()
        self._need_update.clear()
        self._notify_juggler()

    def _handle_service_ticket_update_error(self, tvm_app_id, error):
        for alias in self._tvm_app_id_to_aliases[tvm_app_id]:
            self._alias_to_api_error[alias] = error
        self._notify_juggler(TvmServiceTicketUpdateError(tvm_app_id, self._tvm_app_id_to_aliases[tvm_app_id], error))

    def _notify_juggler(self, exception=None):
        if exception is None:
            message = "OK"
            status = JugglerCheckStatus.OK
        else:
            message = "Got error during tickets update: {}".format(exception)
            status = JugglerCheckStatus.CRIT
        send_event("wall-e-tvm-tickets-update", status=status, message=message)

    def _api_request(self, method, path, data=None, **kwargs):
        url = config.get_value("tvm.api_url") + path
        try:
            return request("tvm", method, url, data=data, **kwargs)
        except HttpClientError as e:
            raise TvmApiError(str(e))


def get_ticket_for_service(alias):
    from walle.application import app

    return app.tvm_ticket_manager.get_ticket_for_alias(alias)


def check_ticket(ticket, allowed_tvm_source_aliases):
    from walle.application import app

    return app.tvm_ticket_manager.check_service_ticket(ticket, allowed_tvm_source_aliases)
