import contextlib
import logging
import signal
import sys
import time

import gevent
import mongoengine
import pytz
import yappi as yap
from apscheduler.schedulers.gevent import GeventScheduler as Scheduler
from gevent import Greenlet
from gevent.event import Event

import sepelib.mongo.util
from sepelib.core import config, constants
from sepelib.core.exceptions import Error
from sepelib.mongo.util import register_database
from walle import constants as walle_constants
from walle.constants import DEFAULT_TIER_ID, CronType
from walle.models import Settings
from walle.stats import stats_manager
from walle.util import mongo
from walle.util.cloud_tools import get_tier
from walle.util.mongo import lock as mongo_lock
from walle.util.setup import setup_logging, make_origin_regex

log = logging.getLogger(__name__)


class Service:
    API = "api"
    DMC = "dmc"
    DNS = "dns"
    FSM = "fsm"
    CRON = "cron"
    SCENARIO = "scenario"
    ALL = [API, DMC, DNS, FSM, CRON, SCENARIO]


class Application:
    def __init__(self):
        self.tvm_ticket_manager = None

        self.flask = None
        self.api_blueprint = None
        self.metrics_blueprint = None
        self.cms_blueprint = None

        self.yappi = yap
        self.profile_blueprint = None

        self.cron = {}

        self.role = None
        self.__services = None

        self.__stop_handlers = []
        self.__logging_initialized = False

        self.__crashed = False
        self.__stop_event = Event()

    def init_role(self, role=None):
        if self.__services is not None:
            raise Error("Application role is initialized already.")

        self.role = role
        self.__services = _get_role_services(role)
        log.info("The application has following daemons enabled: %s.", ", ".join(self.__services))

    @property
    def services(self):
        if self.__services is None:
            raise Error("Application role is not initialized.")

        return self.__services

    @contextlib.contextmanager
    def init_blueprint(self, name, path):
        from flask import Blueprint
        from walle.util.api import configure_api_blueprint

        blueprint = Blueprint(name, "walle", url_prefix=path)
        yield blueprint

        configure_api_blueprint(blueprint)

        self.flask.register_blueprint(blueprint)

    def setup_api_blueprint(self):
        from walle.util.api import check_client_version, read_only

        with self.init_blueprint("api", "/v1") as api_blueprint:
            self.api_blueprint = api_blueprint
            self.api_blueprint.before_request(check_client_version)
            if config.get_value("run.read_only", False):
                self.api_blueprint.before_request(read_only)

            # import views for this blueprint. This must be done before registering the blueprint in flask.
            # noinspection PyUnresolvedReferences
            import walle.views.api  # noqa

    def setup_cms_api_blueprint(self):
        with self.init_blueprint("cms", "/cms/v1") as cms_blueprint:
            self.cms_blueprint = cms_blueprint

            # import views for this blueprint. This must be done before registering the blueprint in flask.
            # noinspection PyUnresolvedReferences
            import walle.views.cms  # noqa

    def setup_metrics_blueprint(self):
        with self.init_blueprint("metrics", "/metrics/v1") as metrics_blueprint:
            self.metrics_blueprint = metrics_blueprint

            # import views for this blueprint. This must be done before registering the blueprint in flask.
            # noinspection PyUnresolvedReferences
            import walle.views.metrics  # noqa

    def setup_profiler_blueprint_and_backend(self):
        self.yappi.set_context_backend("greenlet")
        self.yappi.set_clock_type("cpu")

        with self.init_blueprint("profiler", "/profiler/v1") as profiler_blueprint:
            self.profile_blueprint = profiler_blueprint

            # import views for this blueprint. This must be done before registering the blueprint in flask.
            # noinspection PyUnresolvedReferences
            import walle.views.profiler  # noqa

    def init_flask(self):
        if self.flask is not None:
            return

        from walle.util.api import ExtendedFlask

        self.flask = ExtendedFlask("walle", static_url_path='')
        self.flask.debug = config.get_value("run.debug")

        from walle.util.rate_limiter import LimitManager

        self.flask.limit_manager = LimitManager()

        from walle.request import WalleRequest

        self.flask.request_class = WalleRequest

        # Allow cross-site HTTP requests (see https://developer.mozilla.org/en-US/docs/Web/HTTP/Access_control_CORS)
        allow_origin = config.get_value("web.http.allow_origin", None)
        if allow_origin is not None:
            from flask_cors import CORS

            if isinstance(allow_origin, (list, tuple)):
                origins = list(map(make_origin_regex, allow_origin))
            else:
                origins = [make_origin_regex(allow_origin)]
            CORS(
                self.flask,
                resources={
                    "/v1/*": {
                        "origins": origins,
                        "supports_credentials": True,
                        "max_age": constants.HOUR_SECONDS,
                    }
                },
            )

    def setup_flask(self):
        self.init_flask()

        if Service.API in self.services:
            self.setup_api_blueprint()
            self.setup_cms_api_blueprint()

        self.setup_metrics_blueprint()
        self.setup_profiler_blueprint_and_backend()

    def setup_logging(self, level=None):
        if self.__logging_initialized:
            raise Error("Logging is already initialized.")
        self.__logging_initialized = True
        setup_logging(config, level)

    def init_database(self, check_version=True, connect=True):
        register_database("mongodb", connect=connect)
        register_database("health-mongodb", alias="health", connect=connect)

        self.__add_stop_handler(mongoengine.connection.disconnect)

        if check_version:
            schema_version = self.settings().schema_version
            if schema_version != walle_constants.DATABASE_SCHEMA_VERSION:
                raise Error(
                    "Database schema has an invalid version: {} vs {}.",
                    schema_version,
                    walle_constants.DATABASE_SCHEMA_VERSION,
                )

    def init_tvm_ticket_manager(self):
        if config.get_value("tvm.app_id", default=None) is not None:
            from walle.clients import tvm
            from walle.projects import map_cms_project_alias_to_tvm_app_id

            def idm_alias_mapper():
                return config.get_value("idm.tvm_aliases")

            def calendar_alias_mapper():
                return {config.get_value("calendar.tvm_aliase"): config.get_value("calendar.tvm_app_id")}

            def bot_hwr_alias_mapper():
                return {config.get_value("bot.hwr.tvm_alias"): config.get_value("bot.hwr.tvm_app_id")}

            alias_mappers = [
                map_cms_project_alias_to_tvm_app_id,
                idm_alias_mapper,
                calendar_alias_mapper,
                bot_hwr_alias_mapper,
            ]

            self.tvm_ticket_manager = tvm.TvmServiceTicketManager(alias_mappers)
            self.__add_stop_handler(self.tvm_ticket_manager.stop)

    @staticmethod
    def settings() -> Settings:
        document_id = "global"

        try:
            return Settings.objects(id=document_id).get()
        except mongoengine.DoesNotExist:
            try:
                settings = Settings(id=document_id, schema_version=walle_constants.DATABASE_SCHEMA_VERSION)
                return settings.save(force_insert=True)
            except mongoengine.NotUniqueError:
                return Settings.objects(id=document_id).get()

    def start(self):
        log.info("Starting the application...")

        self.__setup_signal_handling()

        try:
            self.__start()
        except Exception as exc:
            log.exception("Application start failed")
            print(str(exc), file=sys.stderr)
            self.__crashed = True
            raise
        finally:
            self.__stop()

    @contextlib.contextmanager
    def as_stopping_context(self):
        try:
            yield
        finally:
            self.__stop()

    @contextlib.contextmanager
    def as_context(self, connect_database=True):
        with self.as_stopping_context():
            self.init_database(check_version=False, connect=connect_database)
            self.init_tvm_ticket_manager()
            yield

    def __start(self):
        if config.get_value("force_ipv6", False):
            force_ipv6()

        self.init_database()
        self.init_tvm_ticket_manager()

        scheduler = Scheduler()
        self.__add_stop_handler(scheduler.shutdown)
        scheduler.start()

        _start_gevent_checker(scheduler, self.services)

        self.__start_lock_heartbeat(scheduler)
        # Load stage processing handlers. We can't load them only in FSM daemon because they are actually used in other
        # daemons: for example forcing host status initiates in-place cancellation of currently processing task which
        # requires FSM stages logic to be loaded.
        # noinspection PyUnresolvedReferences
        import walle.fsm_stages.handlers  # noqa

        # every instance runs api, at the very least it serves unstat for yasm.
        self.__start_api_daemon()

        if Service.CRON in self.services:
            self.__start_cron_daemon(scheduler)

        if Service.DMC in self.services:
            self.__start_dmc_daemon(scheduler)

        if Service.FSM in self.services:
            settings = self.settings()
            self.__start_fsm_daemon(settings)

        if Service.SCENARIO in self.services:
            settings = self.settings()
            self.__start_scenario_daemon(settings)

        if Service.DNS in self.services:
            self.__start_dns_daemon(scheduler)

        # we can assume that at this moment all required models have already been imported
        sepelib.mongo.util.ensure_all_indexes()

        self.__stop_event.wait()

    def __add_stop_handler(self, handler):
        self.__stop_handlers.append(handler)

    def __stop(self):
        if self.__crashed:
            log.error("Something went wrong. Stopping the application...")
        else:
            log.info("Got a termination UNIX signal. Stopping the application...")

        for handler in reversed(self.__stop_handlers):
            log.info("Executing %s stop handler...", handler)

            try:
                handler()
            except Exception:
                log.exception("Application stop handler has crashed:")
            else:
                log.info("Stop handler %s is finished", handler)

        log.info("All stop handlers were executed!")

    def __setup_signal_handling(self):
        def terminate(*args, **kwargs):
            self.__stop_event.set()

        gevent.signal.signal(signal.SIGINT, terminate)
        gevent.signal.signal(signal.SIGTERM, terminate)
        gevent.signal.signal(signal.SIGQUIT, terminate)

    def __start_lock_heartbeat(self, scheduler: Scheduler) -> None:
        mongo_lock.start_heartbeat(scheduler)
        self.__add_stop_handler(mongo_lock.stop_heartbeat)

    def __start_cron_daemon(self, scheduler):
        crons = config.get_value("crons")
        if CronType.JUGGLER in crons:
            juggler_partitioner = mongo.MongoPartitionerService("cron-juggler")
            self.__add_stop_handler(juggler_partitioner.stop)

            from walle import juggler

            juggler.start(scheduler, juggler_partitioner, config.get_value("juggler.shards_num"))

        if CronType.NETMON in crons:
            from walle.expert.netmon import Netmon

            netmon = Netmon()
            self.__add_stop_handler(netmon.stop)
            netmon.start(scheduler)

        with self.__start_cron("main") as cron:
            from walle.setup_cron import setup_cron

            setup_cron(cron)

    def __start_dmc_daemon(self, scheduler):
        if not config.get_value("expert_system.enabled"):
            return

        import walle.expert.screening
        import walle.expert.triage

        tier_id = DEFAULT_TIER_ID
        try:
            tier_id = get_tier()
        except Exception as e:
            log.error(
                "Couldn't get tier for dmc instance, default value will be used: %s. Unexpected error happened: %s",
                e,
                DEFAULT_TIER_ID,
            )

        screening_partitioner = mongo.MongoPartitionerService(f"hosts-dmc-tier-screening-{tier_id}")
        self.__add_stop_handler(screening_partitioner.stop)
        screening_partitioner.start()

        triage_partitioner = mongo.MongoPartitionerService(f"hosts-dmc-tier-triage-{tier_id}")
        self.__add_stop_handler(triage_partitioner.stop)
        triage_partitioner.start()

        self.__add_stop_handler(walle.expert.screening.stop)
        self.__add_stop_handler(walle.expert.triage.stop)

        walle.expert.screening.start(scheduler, screening_partitioner)
        walle.expert.triage.start(scheduler, triage_partitioner)

    def __start_dns_daemon(self, scheduler):
        partitioner = mongo.MongoPartitionerService("dns-fixer")
        self.__add_stop_handler(partitioner.stop)
        partitioner.start()

        from walle.dns import dns_fixer

        self.__add_stop_handler(dns_fixer.stop)
        dns_fixer.start(scheduler, partitioner)

    def __start_fsm_daemon(self, settings):
        from walle.host_fsm import control as fsm

        fsm.start(settings)
        self.__add_stop_handler(fsm.stop)

    def __start_scenario_daemon(self, settings):
        from walle.scenario import control as scenario_daemon

        scenario_daemon.start(settings)
        self.__add_stop_handler(scenario_daemon.stop)

    def __start_api_daemon(self):
        from flask import g, request

        self.setup_flask()
        access_log = logging.getLogger("access_log")

        class PerfCounter:
            def augment(self, app):
                app.before_request(self._before_request)
                app.after_request(self._after_request)
                app.teardown_request(self._teardown_request)

            def _before_request(self):
                g._walle_started = time.time()
                for section in ("total", request.endpoint or "invalid"):
                    stats_manager.increment_counter(("api", "request_count", section))

            def _teardown_request(self, _=None):
                started = getattr(g, '_walle_started', None)
                if started is not None:
                    ended = time.time()
                    for section in ("total", request.endpoint or "invalid"):
                        stats_manager.add_sample(("api", "response_time", section), max(ended - started, 0))

            def _after_request(self, response):
                status_code = response.status_code
                status_class = status_code // 100

                if status_code == 404:
                    metric_name = '404'
                elif status_code == 429:
                    metric_name = '429'
                elif 1 <= status_class <= 5:
                    metric_name = '{}xx'.format(status_class)
                else:
                    metric_name = 'xxx'

                for section in ("total", request.endpoint or "invalid"):
                    stats_manager.increment_counter(("api", "status_code", metric_name, section))
                    stats_manager.increment_counter(("api", "response_count", section))

                return response

        class AccessLog:
            def write(self, message):
                access_log.info(message.rstrip())

        def start_server():
            try:
                server.run()
            except BaseException:
                log.exception("The web server has crashed:")
                self.__crashed = True
                self.__stop_event.set()

        def stop_server():
            log.info("Stopping the Web server...")

            if server.wsgi.started:
                server.stop()
                server_greenlet.join()
            else:
                server_greenlet.kill()

            log.info("Web server is stopped.")

        log.info("Starting the Web server...")

        from sepelib.flask.server import WebServer

        server_greenlet = Greenlet(start_server)
        server = WebServer(config.get(), self.flask, walle_constants.version, proxyfix=True, logstream=AccessLog())
        PerfCounter().augment(server.app)

        self.__add_stop_handler(stop_server)
        server_greenlet.start()

    @contextlib.contextmanager
    def __start_cron(self, cron_id):
        from walle.util.cron import Cron

        cron = Cron(cron_id, retry_delay=5)
        self.cron[cron_id] = cron
        self.__add_stop_handler(cron.stop)

        yield cron

        cron.start()


def _get_role_services(role):
    if role is None:
        return Service.ALL[:]

    services = config.get_value("roles.services")

    try:
        role_services = set(services[role])
    except KeyError:
        raise Error("Role '{}' is unknown.", role)

    invalid_services = role_services - set(Service.ALL)
    if invalid_services:
        raise Error("Got an invalid service names for role {}: {}.", role, ", ".join(invalid_services))

    return role_services


def _start_gevent_checker(scheduler, services):
    # Simple gevent lock checking mechanism: if we have a busy loop that locks gevent for more than one second,
    # we'll see warnings from APScheduler in the logs saying that it failed to execute the job in time.
    scheduler.add_job(
        lambda: None,
        "interval",
        seconds=1,
        name="gevent checker for {} daemon".format(",".join(services)),
        timezone=pytz.timezone("Europe/Moscow"),
    )


def force_ipv6():
    # WALLE-4560, https://stackoverflow.com/a/46972341/12564527
    import requests.packages.urllib3.util.connection as urllib3_cn
    import socket

    def allowed_gai_family():
        family = socket.AF_INET
        if urllib3_cn.HAS_IPV6:
            family = socket.AF_INET6  # Force IPv6 if it is available.
        return family

    urllib3_cn.allowed_gai_family = allowed_gai_family
    log.info("Patched urllib3 to use IPv6 in 'requests'.")


app = Application()
