"""
Plugin to collect metrics for ProxyApiCall signal
"""

import datetime as dt
import urllib.parse

import flask
import requests

from sandbox.common import rest as common_rest
from sandbox.common import statistics as common_statistics
from sandbox.common.types import statistics as ctst


class RequestStatistics:

    def __init__(self):
        self._request = flask.request
        self._response = None

        self._start_time = dt.datetime.utcnow()
        self._finish_time = None
        self._api_duration_ms = 0

        self._resource_id = 0
        self._resource_type = ""
        self._resource_owner = ""
        self._data_source = ""
        self._data_size = 0

    def after_request(self, response: flask.Response):
        self._response = response

    def teardown_request(self, exc):
        self._finish_time = dt.datetime.utcnow()
        try:
            self._send_proxy_request_signal()
        except Exception:  # noqa
            flask.g.logger.exception("Failed to send %s signal", ctst.SignalType.PROXY_API_CALL)

    def provide_resource_meta(self, resource_meta):
        self._resource_id = resource_meta["id"]
        self._resource_type = resource_meta["type"]
        self._resource_owner = resource_meta["owner"]

    def provide_data_source(self, data_source):
        self._data_source = data_source

    def provide_data_size(self, data_size):
        self._data_size = data_size

    def on_request_to_api(self, request: common_rest.Client.Request):
        self._api_duration_ms += request.duration

    def on_request_to_s3(self, response: requests.Response, size: int = None, streaming_duration: float = 0):
        if size is None:
            size = len(response.content)
        req = response.request
        path = urllib.parse.urlparse(req.url).path
        path_parts = path.split("/", 2)
        bucket_id = path_parts[1] if len(path_parts) > 1 else ""  # path starts with "/", so bucket_id has index 1
        now = dt.datetime.utcnow()
        common_statistics.Signaler().push(dict(
            type=ctst.SignalType.MDS_API_CALL,
            date=now,
            timestamp=now,
            user=flask.g.user,
            server=flask.current_app.ctx.config.this.fqdn,

            proxy_request_id=flask.g.req_id,
            mds_request_id=response.headers.get("X-Amz-Request-Id", ""),
            method=req.method,
            path=path,
            response_code=response.status_code,

            bucket_id=bucket_id,
            resource_id=self._resource_id,
            data_size=size,

            total_duration=int((response.elapsed.total_seconds() + streaming_duration) * 1000),  # milliseconds
        ))

    def _send_proxy_request_signal(self):
        req = flask.request
        rule_name = req.url_rule.rule if req.url_rule else req.path
        query_string = urllib.parse.urlencode(sorted(req.args.items(multi=True)))

        common_statistics.Signaler().push(dict(
            type=ctst.SignalType.PROXY_API_CALL,
            date=self._finish_time,
            timestamp=self._finish_time,
            server=flask.current_app.ctx.config.this.fqdn,

            remote_ip=flask.g.remote_ip,
            user=flask.g.user,

            request_id=flask.g.req_id,
            endpoint="{} {}".format(req.method, rule_name),
            path=req.path,
            query_string=query_string,
            response_code=self._response.status_code,

            resource_id=self._resource_id,
            resource_type=self._resource_type,
            resource_owner=self._resource_owner,
            data_source=self._data_source or "",
            data_size=self._data_size,

            total_duration=(self._finish_time - self._start_time).total_seconds() * 1000,  # milliseconds
            api_duration=self._api_duration_ms,
        ))


def _before_request():
    flask.g.request_statistics = RequestStatistics()


def _after_request(response):
    rs = flask.g.get("request_statistics")
    if rs:
        rs.after_request(response)
    return response


def _teardown_request(exc):
    rs = flask.g.get("request_statistics")
    if rs:
        rs.teardown_request(exc)


def init_plugin(app: flask.Flask):
    app.before_request(_before_request)
    app.after_request(_after_request)
    app.teardown_request(_teardown_request)
