"""Useful tools for making requests."""

import cgi
import json
import logging
import os
import random
import re
import socket
import time
from contextlib import contextmanager
from functools import partial

import requests
import six
from fastjsonschema import JsonSchemaException
from requests import RequestException, HTTPError
from requests.structures import CaseInsensitiveDict
from six import wraps
from six.moves.urllib.parse import urlencode, urlparse, urlunparse

import object_validator
from sepelib.core import constants
from walle import constants as walle_constants
from walle.statbox.loggers import requests_logger
from walle.stats import stats_manager
from walle.util.gevent_tools import gevent_idle_iter

log = logging.getLogger(__name__)


class HttpClientError(HTTPError):
    def __init__(self, reason, response):
        http_error_msg = "{} Client Error: {} for url: {}".format(response.status_code, reason, response.url)
        super().__init__(http_error_msg, response=response)


class HttpNotFound(HttpClientError):
    pass


class HttpServerError(HTTPError):
    def __init__(self, reason, response):
        http_error_msg = "{} Server Error: {} for url: {}".format(response.status_code, reason, response.url)
        super().__init__(http_error_msg, response=response)


def request(
    service,
    method,
    url,
    params=None,
    check_status=True,
    success_codes=(requests.codes.ok,),
    allow_errors=(),
    error_from_contents=False,
    headers=None,
    timeout=None,
    as_text=False,
    errors_log_level=logging.ERROR,
    **kwargs
):
    full_url = _format_url(url, params)

    headers = CaseInsensitiveDict(headers)
    headers.setdefault("User-Agent", "Wall-E/" + walle_constants.version)

    if timeout is None:
        timeout = constants.NETWORK_TIMEOUT

    with _request_stats(service, method, full_url, timeout) as stats_context:
        response = requests.request(method, url, params=params, timeout=timeout, headers=headers, **kwargs)
        stats_context.update(status_code=response.status_code)

    if check_status and response.status_code not in success_codes and response.status_code not in allow_errors:
        raise _exception_for_status(response, error=strip_api_error(response.text) if error_from_contents else "")

    if as_text:
        result = response.content

        try:
            result = six.ensure_text(result, "utf-8")
        except ValueError:
            log.log(errors_log_level, "Got an invalid data from %s: %r.", response.url, result)
            raise RequestException("Got an invalid data.", response=response)

        return result.strip()
    else:
        return response


def json_request(
    service,
    method,
    url,
    scheme=None,
    check_status=True,
    success_codes=(requests.codes.ok,),
    allow_errors=(),
    error_scheme=None,
    headers=None,
    data=None,
    form_data=None,
    check_content_type=True,
    with_response=False,
    timeout=None,
    validator=None,
    error_validator=None,
    **kwargs
):
    """Make a json API request.
    Encode request body with json (and set content type appropriately).
    Validate response against given scheme (or error scheme if response is an error).

    Return parsed data object.
    """

    if data is not None:
        headers = CaseInsensitiveDict(headers)
        headers.setdefault("Content-Type", "application/json")
        data = json.dumps(data)
    elif form_data is not None:
        data = form_data

    response = request(
        service,
        method,
        url,
        headers=headers,
        data=data,
        timeout=timeout,
        check_status=check_status,
        success_codes=success_codes,
        allow_errors=allow_errors,
        **kwargs
    )

    result = get_json_response(
        response,
        scheme=scheme,
        check_status=check_status,
        success_codes=success_codes,
        allow_errors=allow_errors,
        error_scheme=error_scheme,
        check_content_type=check_content_type,
        validator=validator,
        error_validator=error_validator,
    )

    return (response, result) if with_response else result


def get_json_response(
    response,
    scheme=None,
    check_status=True,
    success_codes=(requests.codes.ok,),
    allow_errors=(),
    error_scheme=None,
    check_content_type=True,
    errors_log_level=logging.ERROR,
    validator=None,
    error_validator=None,
):
    if (scheme is not None and validator is not None) or (error_scheme is not None and error_validator is not None):
        raise ValueError("scheme and validator is mutually exclusive parameters")

    result_validator = validator
    result_scheme = scheme

    if check_status:
        if response.status_code in success_codes:
            pass
        elif response.status_code in allow_errors:
            result_scheme = error_scheme
            result_validator = error_validator
        else:
            raise _exception_for_status(response)

    try:
        encoding = "utf-8"
        content_type, type_options = cgi.parse_header(response.headers.get("Content-Type", ""))

        if content_type in ("application/json", "text/json"):
            encoding = type_options.get("charset", encoding)
        elif check_content_type:
            raise ValueError("Invalid content type: {}".format(content_type))

        result = json.loads(six.ensure_text(response.content, encoding))
    except Exception as e:
        log.log(errors_log_level, "%s: Got an invalid JSON response", response.url)
        raise RequestException("The server returned an invalid JSON response: {}".format(e), response=response)

    if result_validator is not None:
        try:
            result_validator(result)
        except JsonSchemaException as e:
            log.log(errors_log_level, "%s: Got an invalid JSON response: %r.", response.url, result)
            raise RequestException("The server returned an invalid JSON response: {}".format(e), response=response)

    if result_scheme is not None:
        try:
            result = object_validator.validate("result", result, result_scheme)
        except object_validator.ValidationError as e:
            log.log(errors_log_level, "%s: Got an invalid JSON response: %r.", response.url, result)
            raise RequestException("The server returned an invalid JSON response: {}".format(e), response=response)

    return result


def retry(func=None, max_tries=3, interval=5, backoff=2, jitter=0.2, max_time=None, exceptions=None, skip=None):
    """Retry wrapper for the api-accessing methods.
    Take retry parameters and return a wrapper that retries given function on given exceptions.

    :param func: function to wrap. If this parameter is missed, then return value is a decorator.
    :param max_tries: max retry count. `-1` to retry forever. `0` means 1 try (and no retry).
    :param interval: delay between subsequent tries
    :param max_time: max working time. No retries after this time is out.
    :param backoff: base for exponential interval. `1` makes interval constant.
    :param jitter: additional random delay between retries.
    :param exceptions: list of exceptions to retry on.
    :param skip: list exceptions to not retry on

    :type interval: int
    :type max_tries: int
    """
    exceptions = tuple(exceptions or ()) + (socket.error,)
    skip = tuple(skip or ())

    if func is None:
        return partial(
            retry,
            max_tries=max_tries,
            interval=interval,
            backoff=backoff,
            jitter=jitter,
            max_time=max_time,
            exceptions=exceptions,
            skip=skip,
        )

    @wraps(func)
    def _retry_wrapper(*args, **kwargs):
        attempts = 1
        start_time = time.time()
        current_interval = interval
        while True:
            try:
                return func(*args, **kwargs)
            except skip:  # don't retry on these exceptions
                raise
            except exceptions:
                if attempts == max_tries:
                    raise
                if max_time and time.time() - start_time >= max_time:
                    raise
                else:
                    attempts += 1
                    time.sleep(current_interval + random.randint(0, jitter * 1000.0) / 1000.0)
                    current_interval *= backoff

    return _retry_wrapper


def iter_csv_response(response, fields_number, ignore_extra_fields):
    for line in gevent_idle_iter(response.iter_lines()):
        try:
            row = six.ensure_text(line, "utf-8").split("\t")

            if len(row) != fields_number:
                if len(row) < fields_number or not ignore_extra_fields:
                    raise ValueError

                row = row[:fields_number]
        except ValueError:
            log.error("Got an invalid data from %s: %r.", response.url, line)
            raise RequestException("Got an invalid data.", response=response)

        yield row


def strip_api_error(error, strip_html=False):
    if strip_html:
        error = re.sub(r"<br\s*/?>", "", error, flags=re.IGNORECASE)

    error = re.sub(r"\s+", " ", error).strip()

    if len(error) > 1000:
        error = error[:1000] + "..."
    elif error and not error.endswith("."):
        error += "."

    return error


def _exception_for_status(response, error=None):
    """Raises subclass of :class:`HTTPError`"""
    if isinstance(response.reason, bytes):
        # copy-paste from requests.Response.raise_for_status:
        # We attempt to decode utf-8 first because some servers
        # choose to localize their reason strings. If the string
        # isn't utf-8, we fall back to iso-8859-1 for all other
        # encodings. (See PR #3538)
        try:
            reason = six.ensure_text(response.reason, "utf-8")
        except UnicodeDecodeError:
            reason = six.ensure_text(response.reason, "iso-8859-1")
    else:
        reason = response.reason

    if error:
        reason = "{}. {}".format(reason, error)

    if response.status_code == 404:
        return HttpNotFound(reason, response)

    elif 400 <= response.status_code < 500:
        return HttpClientError(reason, response)

    elif 500 <= response.status_code < 600:
        return HttpServerError(reason, response)

    return HTTPError("Server returned an error: {} {}.".format(response.status_code, reason), response=response)


@contextmanager
def _request_stats(service, method, full_url, timeout):
    context = dict(service=service, method=method, url=full_url, timeout=timeout)
    start_time = time.time()

    try:
        yield context  # contract is: context should be updated with response's status_code
    except Exception as e:
        context.update(error=str(e))
        raise
    else:
        _count_response(service, context["status_code"])

    finally:
        request_time = time.time() - start_time
        result = "error" if "error" in context else "success"

        _save_request_time(service, result, request_time)
        _count_request(service, result)

        context.update(request_time=request_time)
        requests_logger().log(**context)


def _count_response(service, status_code=None):
    metric_name = 'xxx'

    if status_code is not None:
        status_class = status_code // 100
        if status_code in (400, 404):
            metric_name = str(status_code)
        elif 1 <= status_class <= 5:
            metric_name = '{}xx'.format(status_class)

    stats_manager.increment_counter((service, "request_count", metric_name))


def _save_request_time(service, result, request_time):
    stats_manager.add_sample((service, "request_time", result), request_time)
    stats_manager.add_sample(("total", "request_time"), request_time)


def _count_request(service, result):
    # result is a string, one of ("success", "error")
    stats_manager.increment_counter((service, "request_count"))
    stats_manager.increment_counter(("total", "request_count", result))
    stats_manager.increment_counter(("total", "request_count"))


def _format_url(url, params=None):
    scheme, netloc, path, url_params, url_query, fragment = urlparse(url)

    query = _request_params_to_query(params)
    if query:
        if url_query:
            url_query += "&"
        url_query += query

    return urlunparse((scheme, netloc, path, url_params, url_query, fragment))


def _request_params_to_query(params=None):
    if not params:
        return ""

    def encode_string(value):
        return value

        if isinstance(value, unicode):  # noqa
            value = value.encode("utf-8")
        elif isinstance(value, (tuple, list)):
            value = [encode_string(inner_value) for inner_value in value]

        return value

    return urlencode({encode_string(k): encode_string(v) for k, v in params.items() if v is not None}, doseq=True)


def check_certs_exist(cert, key):
    error_message = "{name} by path {path} doesn't exist (abspath={abspath}, cwd={cwd})"
    if not os.path.exists(cert):
        raise ValueError(
            error_message.format(
                name="Certificate",
                path=repr(cert),
                abspath=repr(os.path.abspath(cert)),
                cwd=os.getcwdb(),
            )
        )
    if not os.path.exists(key):
        raise ValueError(
            error_message.format(
                name="Key",
                path=repr(key),
                abspath=repr(os.path.abspath(key)),
                cwd=os.getcwdb(),
            )
        )
