# -*- encoding: utf-8 -*-
import urllib.request
import urllib.parse
import urllib.error
from typing import Any, List, Optional, Callable
from functools import partial

from tornado.httpclient import AsyncHTTPClient, HTTPResponse, HTTPRequest, HTTPClient
from tornado.web import HTTPError
from tornado.ioloop import IOLoop

from travel.avia.api_gateway import settings
from travel.avia.api_gateway.application.stat import Stat
from travel.avia.api_gateway.application.cache.cache_root import CacheRoot
from travel.avia.api_gateway.lib.coding import decode


class Fetcher:
    """
    Generic Fetcher. Its basic purpose is to gather some data from external sources, mangle it in some way
    and call finish_callback with the gathered data.

    As a user of Fetcher and its descendants you should create an instance, define finish_callback and
    call `fetch` function either with other fetchers (will wait for data from external fetchers) or without fetchers.
    Finish callback will be called with the resulting data.

    As a programmer of new fetchers, you may need to override fetch method without parameters and use helper methods
    like `request`, `request_sync` and `raise_for_status` to make calls to your APIs.
    Do not forget to put your data into result and call finish_callback with it.
    """

    service: str = 'unknown'
    """Service name used mainly for statistics and error messages"""

    def __init__(
        self,
        finish_callback: Callable = None,
        field=None,
        connect_timeout=None,
        request_timeout=None,
        cache_root=None,
        **kwargs,
    ):
        """
        :param finish_callback: function that accepts results from fetcher as a dict.
            Usually it's an `add_result` function of upper-level Fetcher
        :param field: field name into which result should be stored
        :param connect_timeout: connect timeout used in HTTP calls
        :param request_timeout: request timeout used in HTTP calls
        :param cache_root: object with all the application caches
        :param kwargs: other arguments, that can be used in child class without overriding __init__ call
        """
        self.finish_callback = finish_callback
        self.field = field
        self.connect_timeout = connect_timeout or settings.CONNECTION_TIMEOUT
        self.request_timeout = request_timeout or settings.REQUEST_TIMEOUT
        self.cache_root = cache_root  # type: CacheRoot
        self.params = kwargs

        self.result: dict[Any, Any] = {}
        """Result of a fetcher. Passed to finish_callback"""

        self.waiting_fields = set()
        """Field names, that current fetcher is going to wait for before calling finish_callback"""

        self.waiting_requests = set()
        """Requests, that current fetcher is going to wait for before calling finish_callback"""

    def fetch(self, fetchers: Optional[List['Fetcher']] = None) -> None:
        if not fetchers:
            self.finish_callback(self.result)

        for fetcher in fetchers:
            fetcher.finish_callback = self.add_result
            fetcher.fetch()

    def add_result(self, data: dict, field: Optional[str] = None):
        if field:
            self.result[field] = data
        else:
            if isinstance(data, dict):
                self.result.update(data)
            else:
                self.result = data
        self._check_finish()

    def add_waiting_request(self, r):
        # type: (str) -> None
        self.waiting_requests.discard(r)
        self._check_finish()

    def _check_finish(self):
        if not isinstance(self.result, dict) or (
            self.waiting_fields.issubset(self.result) and not self.waiting_requests
        ):
            self.finish_callback(self.result)

    def request(
        self,
        url: str,
        callback: Callable[[HTTPResponse], None],
        params: Optional[dict] = None,
        attempts: Optional[int] = None,
        raise_error: bool = True,
        method: str = 'GET',
        body: str = None,
        **kwargs,
    ):
        """
        Request API with retries. API call is retried for 502, 504 and 599 responses
        :param url: url to call. ex: 'http://api.avia.yandex.net/api/v1/my_handler'
        :param callback: function that is called with successful response body. Return value is ignored
        :param params: params to pass into request. For POST methods it's going to be included into body
        :param attempts: number of tries.
        :param raise_error: exception will be raised after retries (or callback will be called)
        :param method: HTTP method (POST, GET, PUT). Default GET
        :param body: body to pass into request. Use either params or body, not both.
        :param kwargs: other parameters to pass into tornado HTTPRequest
        """
        if attempts is None:
            attempts = settings.REQUEST_ATTEMPTS

        def on_response(response: HTTPResponse, current_attempts: int) -> None:
            if response.code == 404:
                if raise_error:
                    response_body = decode(response.body)
                    raise HTTPError(response.code, response_body)
                return callback(response)
            if response.code in (599, 502, 504) and current_attempts > 0:
                return IOLoop.current().call_later(
                    settings.REQUEST_ATTEMPTS_TIMEOUT,
                    lambda: self.request(
                        url,
                        callback,
                        params=params,
                        attempts=current_attempts - 1,
                        raise_error=raise_error,
                        method=method,
                        body=body,
                        **kwargs
                    ),
                )
            Stat.hit(self.service, (response.request_time or self.request_timeout) * 1000)

            if raise_error:
                self.raise_for_status(response)
            callback(response)

        AsyncHTTPClient().fetch(
            self._create_request(
                url=url,
                method=method,
                params=params,
                body=body,
                connect_timeout=self.connect_timeout,
                request_timeout=self.request_timeout,
                **kwargs,
            ),
            raise_error=False,
            callback=partial(on_response, current_attempts=attempts),
        )

    def request_sync(
        self,
        url: str,
        params: Optional[dict] = None,
        attempts: Optional[int] = None,
        method: str = 'GET',
        body: str = None,
        **kwargs,
    ) -> HTTPResponse:
        """Requests API with retries. Same as `request` method, but result is returned directly without any callbacks"""
        if attempts is None:
            attempts = settings.REQUEST_ATTEMPTS

        http_client = HTTPClient()
        while attempts > 0:
            response = http_client.fetch(
                self._create_request(
                    url=url,
                    method=method,
                    params=params,
                    body=body,
                    connect_timeout=self.connect_timeout,
                    request_timeout=self.request_timeout,
                    **kwargs,
                ),
                raise_error=False,
            )
            if response.code == 404:
                response_body = response.body.decode('utf-8') if isinstance(response.body, bytes) else response.body
                raise HTTPError(response.code, response_body)
            if response.code in (599, 502, 504):
                attempts -= 1
                if attempts > 0:
                    continue

            Stat.hit(self.service, (response.request_time or self.request_timeout) * 1000)
            self.raise_for_status(response)

            return response

    @staticmethod
    def _create_request(
        url: str, method: str = 'GET', params: Optional[dict] = None, body: Optional[str] = None, **kwargs
    ) -> HTTPRequest:
        """
        :param url: URL to call
        :param method: HTTP method (GET, POST, PUT, etc...)
        :param params: params to pass into request. For POST methods it's going to be included into body
        :param body: body to pass into request. Pass either params or body, not both.
        :param kwargs: other parameters to pass to tornado HTTPRequest
        """
        if method == 'POST':
            if params is not None and body is not None:
                raise ValueError('Use "params" or "body" for post request, not both')
            request_url = url
            request_body = body if body else urllib.parse.urlencode(params) if params else None
        else:
            request_url = '{}?{}'.format(url, urllib.parse.urlencode(params)) if params else url
            request_body = None

        return HTTPRequest(
            url=request_url, method=method, body=request_body, ca_certs='/etc/ssl/certs/ca-certificates.crt', **kwargs
        )

    def raise_for_status(self, response: HTTPResponse):
        if response.error:
            raise HTTPError(response.code, reason='Error requesting {}: {}'.format(self.service, response.error))
