import enum
import logging
import logging.handlers
import sys
import traceback
import typing

import flask
from google.protobuf import json_format

from infra.rtc_sla_tentacles.backend.lib.clickhouse import database
from infra.rtc_sla_tentacles.backend.lib.api.handlers import pipeline, manage_processes, pods_list, features
from infra.rtc_sla_tentacles.backend.lib.api.utils import parse_request_params
from infra.rtc_sla_tentacles.backend.lib.clickhouse.client import ClickhouseClient
from infra.rtc_sla_tentacles.backend.lib.api import monitoring, utils
from infra.rtc_sla_tentacles.backend.lib.config import utils as config_utils
from infra.rtc_sla_tentacles.backend.lib.missing_nodes import missing_nodes
from infra.rtc_sla_tentacles.backend.lib.mongo.client import MongoClient
from infra.rtc_sla_tentacles.backend.lib.api.proto import (
    overview_pb2, allocation_zone_details_pb2, pods_list_pb2, incidents_pb2, filters_pb2,
)
from infra.rtc_sla_tentacles.backend.lib.incidents.storage import Incident, IncidentFilter, IncidentStorage
from infra.rtc_sla_tentacles.backend.lib.harvesters_snapshots import manager as harvester_snapshots_manager
from infra.rtc_sla_tentacles.backend.lib.metrics import metrics_provider


class DeployEngine(str, enum.Enum):
    yp_lite = "yp_lite"
    gencfg = "gencfg"


class BaseApiException(Exception):
    status_code = 500

    def __init__(self, message):
        super().__init__()
        self.message = message


class NoMetricsForAllocationZone(BaseApiException):
    status_code = 400


class WrongInputProto(BaseApiException):
    status_code = 400


class Api:

    def __init__(self, config_interface):
        self.config_interface = config_interface
        self._mongo_client = MongoClient(config_interface)
        self.api_config = self.config_interface.get_api_config()
        self.env_name = self.config_interface.get_env_name()
        self.processes = self.api_config["processes"]
        self.app = flask.Flask("Api")
        self.metric_storage = monitoring.MetricStorage(config_interface, self._mongo_client, self.app.logger)

        api_blueprint = flask.Blueprint("api_blueprint", "Api")

        def _parse_proto(proto_class):
            try:
                return json_format.Parse(flask.request.data, proto_class)
            except Exception as exc:
                raise WrongInputProto(str(exc))

        def _json_response(proto_response, including_default_value_fields=True):
            return json_format.MessageToDict(
                proto_response,
                including_default_value_fields,
                preserving_proto_field_name=True
            )

        def _get_unused_nodes_metrics(allocation_zone_id) -> typing.Optional[metrics_provider.MetricsTentaclesYpUnusedNodes]:
            if not config_utils.is_daemonset_location(allocation_zone_id):
                return
            try:
                unused_nodes_harvester_result = self.metric_storage.harvester_snapshot_manager.read_last_snapshot_data(
                    "yp_unused_nodes_monitoring", allocation_zone_id,
                )
            except harvester_snapshots_manager.SnapshotNotFoundError:
                self.app.logger.exception("Can't find unused nodes snapshot")
                return
            else:
                try:
                    return metrics_provider.MetricsTentaclesYpUnusedNodes(**unused_nodes_harvester_result['metric'])
                except Exception:
                    self.app.logger.exception("Can't deserialize yp unused metrics")
                    return

        @api_blueprint.route("/api/data/overview/", methods=["GET", "POST"])
        @self.metric_storage.stats_collector("nodes_overview")
        def get_overview():
            response = overview_pb2.TDataOverviewResponse()
            allocation_zones = response.result.allocation_zones
            allocation_zones_metrics = self.metric_storage.fetch_all_zones_metrics()
            for allocation_zone_id, allocation_zone_config in self.config_interface.get_allocation_zones().items():
                allocation_zone = allocation_zones.add()
                features.fill_allocation_zone_features(
                    allocation_zones_metrics, allocation_zone, allocation_zone_id, allocation_zone_config,
                )
                tags = allocation_zone.tags
                tags.location = allocation_zone_config.get("location", "")

                yp_cluster = allocation_zone_config.get("yp_cluster")
                if yp_cluster:
                    tags.deploy_engine = DeployEngine.yp_lite
                elif allocation_zone_config.get("gencfg_master_group"):
                    tags.deploy_engine = DeployEngine.gencfg

            return _json_response(response)

        @api_blueprint.route("/api/data/allocation_zone_details/", methods=["POST"])
        @self.metric_storage.stats_collector("allocation_zone_details")
        def get_allocation_zone_details():
            request = _parse_proto(allocation_zone_details_pb2.TAllocationZoneDetailsRequest)
            allocation_zone_id = request.allocation_zone_id

            response = allocation_zone_details_pb2.TAllocationZoneDetailsResponse()
            allocation_zone_config = self.config_interface.get_allocation_zone_config(allocation_zone_id)
            allocation_zones_metrics = self.metric_storage.fetch_all_zones_metrics(allocation_zone_id)
            features.fill_allocation_zone_features(
                allocation_zones_metrics,
                response.result.features,
                allocation_zone_id,
                allocation_zone_config,
            )

            pipeline.fill_allocation_zone_pipeline(ClickhouseClient(self.config_interface), allocation_zone_id,
                                                   response.result, _get_unused_nodes_metrics(allocation_zone_id))
            allocation_zone_metrics = allocation_zones_metrics.metrics[allocation_zone_id]
            manage_processes.fill_allocation_zone_processes(
                allocation_zone_metrics.redeployed_on_time,
                allocation_zone_metrics.reallocation,
                config_interface,
                allocation_zone_id,
                response.result.manage_processes,
            )
            return _json_response(response)

        @api_blueprint.route("/api/data/pods_list/", methods=["POST"])
        @self.metric_storage.stats_collector("pods_list")
        def get_pods_list():
            proto_request = _parse_proto(pods_list_pb2.TPodsListRequest())
            unused_nodes_metrics = None
            if proto_request.filter == filters_pb2.EPodFilter.UNUSED_YP_NODES:
                unused_nodes_metrics = _get_unused_nodes_metrics(proto_request.allocation_zone_id)
            return _json_response(
                pods_list.get_pods_list(proto_request, config_interface, unused_nodes_metrics),
                including_default_value_fields=False,
            )

        @api_blueprint.route("/api/data/missing_nodes_overview/", methods=["GET", "POST"])
        @self.metric_storage.stats_collector("missing_nodes_overview")
        def get_missing_nodes_overview():
            client = ClickhouseClient(self.config_interface)
            params = parse_request_params(flask.request.data)

            columns_by_comma = params.get("columns")
            if columns_by_comma:
                columns = [c.strip() for c in columns_by_comma.split(",")]
            else:
                columns = missing_nodes.EXTERNAL_VISIBLE_WALLE_COLUMNS
            result = client.select(missing_nodes.get_overview_query(params.get("ts"), columns))
            result_columns = columns + ["count"]
            last_ts_response = next(iter(client.select(database.get_last_timestamp())))
            return flask.jsonify({
                "result": {
                    "columns": result_columns,
                    "values": [[utils.from_db_to_proto(getattr(r, c)) for c in result_columns] for r in result],
                },
                "meta": {
                    "ts": last_ts_response.max_ts.timestamp(),
                }
            })

        @api_blueprint.route("/api/data/missing_nodes/", methods=["GET", "POST"])
        @self.metric_storage.stats_collector("missing_nodes")
        def get_missing_nodes():
            client = ClickhouseClient(self.config_interface)
            params = parse_request_params(flask.request.data)

            columns_by_comma = params.get("columns")
            if columns_by_comma:
                columns = [c.strip() for c in columns_by_comma.split(",")]
            else:
                columns = ["fqdn"]
            result = client.select(missing_nodes.get_missing_nodes_query(
                params.get("ts"), columns, params,
            ))
            last_ts_response = next(iter(client.select(database.get_last_timestamp())))
            return flask.jsonify({
                "result": {
                    "columns": columns,
                    "values": [[utils.from_db_to_proto(getattr(r, c)) for c in columns] for r in result],
                },
                "meta": {
                    "ts": last_ts_response.max_ts.timestamp(),
                }
            })

        # noinspection PyBroadException
        @api_blueprint.route("/api/incidents/create", methods=["POST"])
        @self.metric_storage.stats_collector("incidents_create")
        def incidents_create():
            # TODO: wrapper for copy/paste
            response = incidents_pb2.TCreateIncidentResponse()
            try:
                proto_request = _parse_proto(incidents_pb2.TCreateIncidentRequest())
                incident_storage = IncidentStorage(config_interface, self._mongo_client)
                incident_obj = Incident.from_protobuf(proto_request.data)
                _id = incident_storage.create_incident(incident_obj)
                response.result.id = str(_id)
            except Exception:
                response.error = traceback.format_exc()
            return _json_response(response)

        # noinspection PyBroadException
        @api_blueprint.route("/api/incidents/delete", methods=["POST"])
        @self.metric_storage.stats_collector("incidents_delete")
        def incidents_delete():
            response = incidents_pb2.TDeleteIncidentResponse()
            try:
                proto_request = _parse_proto(incidents_pb2.TDeleteIncidentRequest())
                incident_storage = IncidentStorage(config_interface, self._mongo_client)
                _id = IncidentStorage.id_from_string(proto_request.id)
                incident_storage.delete_incident(_id)
            except Exception:
                response.error = traceback.format_exc()
            return _json_response(response)

        # noinspection PyBroadException
        @api_blueprint.route("/api/incidents/list_filters", methods=["POST"])
        @self.metric_storage.stats_collector("incidents_list_filters")
        def incidents_list_filters():
            response = incidents_pb2.TListIncidentFiltersResponse()
            try:
                incident_storage = IncidentStorage(config_interface, self._mongo_client)
                filters_list_obj = incident_storage.list_filters()
                response.result.filters_list.CopyFrom(filters_list_obj.to_protobuf())
            except Exception:
                response.error = traceback.format_exc()
            return _json_response(response)

        # noinspection PyBroadException
        @api_blueprint.route("/api/incidents/read", methods=["POST"])
        @self.metric_storage.stats_collector("incidents_read")
        def incidents_read():
            response = incidents_pb2.TReadIncidentsResponse()
            try:
                proto_request = _parse_proto(incidents_pb2.TReadIncidentsRequest())
                incident_storage = IncidentStorage(config_interface, self._mongo_client)
                filter_obj = IncidentFilter.from_protobuf(proto_request.filter)
                total, total_minutes_out_of_slo, incident_objs = incident_storage.read_incidents(
                    filter_obj, proto_request.limit, proto_request.offset
                )
                response.result.total = total
                response.result.total_minutes_out_of_slo = total_minutes_out_of_slo
                for incident_obj in incident_objs:
                    incident = response.result.incidents.add()
                    incident.CopyFrom(incident_obj.to_protobuf())
            except Exception:
                response.error = traceback.format_exc()
            return _json_response(response)

        # noinspection PyBroadException
        @api_blueprint.route("/api/incidents/read_data", methods=["POST"])
        @self.metric_storage.stats_collector("incidents_read_data")
        def incidents_read_data():
            response = incidents_pb2.TReadIncidentDataResponse()
            try:
                proto_request = _parse_proto(incidents_pb2.TReadIncidentDataRequest())
                incident_storage = IncidentStorage(config_interface, self._mongo_client)
                _id = IncidentStorage.id_from_string(proto_request.id)
                incident_obj = incident_storage.read_incident_data(_id)
                response.result.data.CopyFrom(incident_obj.to_protobuf())
            except Exception:
                response.error = traceback.format_exc()
            return _json_response(response)

        # noinspection PyBroadException
        @api_blueprint.route("/api/incidents/update", methods=["POST"])
        @self.metric_storage.stats_collector("incidents_update")
        def incidents_update():
            response = incidents_pb2.TUpdateIncidentResponse()
            try:
                proto_request = _parse_proto(incidents_pb2.TUpdateIncidentRequest())
                incident_storage = IncidentStorage(config_interface, self._mongo_client)
                incident_obj = Incident.from_protobuf(proto_request.data)
                incident_storage.update_incident(incident_obj)
            except Exception:
                response.error = traceback.format_exc()
            return _json_response(response)

        @api_blueprint.route("/stats/", methods=["GET", "POST"])
        def stats():
            if self.processes != 1:
                return flask.jsonify(self.metric_storage.get_shared_stats())
            else:
                return flask.jsonify(self.metric_storage.get_local_stats())

        @api_blueprint.route("/api/monitorings/yasm/", methods=["GET"])
        @self.metric_storage.stats_collector("monitorings_yasm")
        def get_monitorings_yasm():
            with self.metric_storage.shared_metrics() as shared_metrics:
                result = (
                    shared_metrics["yasm_pods_manager_metrics"]
                    + shared_metrics["yasm_missing_nodes_metrics"]
                    + shared_metrics["yasm_allocation_zone_metrics"]
                )
            return flask.jsonify(result)

        @api_blueprint.route("/ping", methods=["GET"])
        @self.metric_storage.stats_collector("ping")
        def ping():
            """
                Used to monitor API containers availability.
            """
            return "OK"

        self.app.register_blueprint(api_blueprint)
        self.app.register_blueprint(api_blueprint, url_prefix="/tentacles/")

        @self.app.errorhandler(BaseApiException)
        def handle_no_metrics_for_allocation_zone(error):
            response = flask.jsonify(error.message)
            response.status_code = error.status_code
            return response

    def make_api_app(self, log_file, log_stdout, it_is_uwsgi):
        _format = "%(asctime)s %(message)s"
        fmt = logging.Formatter(_format)
        if log_file:
            handler = logging.handlers.RotatingFileHandler("common.binary.api")
            handler.setFormatter(fmt)
            self.app.logger.addHandler(handler)
        if log_stdout:
            handler = logging.StreamHandler(sys.stdout)
            handler.setFormatter(fmt)
            handler.setLevel(logging.DEBUG)
            self.app.logger.addHandler(handler)
        if it_is_uwsgi:
            self.metric_storage.start_collect()
            import uwsgidecorators
            # NOTE(rocco66): uwsgi is a preforking (or “fork-abusing”) server, so we need it
            uwsgidecorators.postfork(self._mongo_client.init_client)
            uwsgidecorators.postfork(self.config_interface.sentry_sender.init)
        return self.app
