"""IDM client."""
import http.client
import json
import logging
from collections import defaultdict

import six
from requests import RequestException

from object_validator import DictScheme, List, String, Integer, Dict
from sepelib.core import config
from sepelib.core.exceptions import LogicalError
from walle.clients import cauth, staff, juggler
from walle.clients.juggler import JUGGLER_MSG_MAX_SIZE
from walle.clients.utils import request, get_json_response, retry
from walle.errors import RecoverableError
from walle.idm.common import idm_list_path_to_string, idm_string_path_to_list
from walle.util.misc import drop_none, ellipsis_string
from walle.util.validation import AnyScheme, NoneValue

log = logging.getLogger(__name__)

PAGING_LIMIT = 500


class IDMInternalError(RecoverableError):
    def __init__(self, message, *args, **kwargs):
        super().__init__("Error in communication with IDM: " + message, *args, **kwargs)

    def __unicode__(self):
        """.message can be str encoded in utf-8 (from batch request) or unicode
        We need this method to not break audit log converting exceptions to unicode
        """
        u_repr = "<{}>: {}".format(self.__class__, six.ensure_text(self.message, "utf-8"))
        return u_repr


class IDMBatchRequestError(RecoverableError):
    def __init__(self, code_to_message, **kwargs):
        message = ""
        for code, messages in code_to_message.items():
            message += "{}: {}".format(code, "; ".join(messages))
        self.codes = list(code_to_message.keys())
        super().__init__("IDM batch request error: {}", message, **kwargs)


def iter_role_nodes(**query_args):
    yield from _objects_pager("GET", "/rolenodes/", **query_args)


class RoleStateType:
    """There are tons of IDM role states, which we can narrow to only four types"""

    REQUESTED = "requested"
    ACTIVE = "active"
    REVOKING = "revoking"
    INACTIVE = "inactive"

    _state_to_type = {
        "rerequested": ACTIVE,
        "need_request": ACTIVE,
        "granted": ACTIVE,
        "review_request": ACTIVE,
        "created": INACTIVE,
        "deprived": INACTIVE,
        "failed": INACTIVE,
        "expired": INACTIVE,
        "declined": INACTIVE,
        "requested": REQUESTED,
        "approved": REQUESTED,
        "depriving": REVOKING,
        "depriving_validation": REVOKING,
    }

    @staticmethod
    def state_to_type(role_state):
        try:
            return RoleStateType._state_to_type[role_state]
        except KeyError:
            raise LogicalError()


class IDMRole:
    """IDM API role objects are giant and complicated, this class contains only parts we need from them"""

    def __init__(self, role_dict):
        self.id = role_dict["id"]
        self.path = idm_string_path_to_list(role_dict["node"]["slug_path"])
        self.state = role_dict["state"]
        self.type = RoleStateType.state_to_type(self.state)
        self.member = None
        self.group_id = None
        self.is_group = False

        if role_dict["user"] is not None:
            self.is_group = False
            self.member = role_dict["user"]["username"]
        elif role_dict["group"] is not None:
            self.is_group = True
            self.member = role_dict["group"]["slug"]
            self.group_id = role_dict["group"]["id"]

    def __repr__(self):
        return "<IDM Role: path {}, state: {}, is_group: {}, member: {}>".format(
            idm_list_path_to_string(self.path), self.state, self.is_group, self.member
        )


def iter_role_dicts(path_prefix=None, system=None, **query_kwargs):
    params = dict(query_kwargs, system=_get_system(system))
    if path_prefix is not None:  # notice that this call takes value path, not usual one
        params["path"] = _get_idm_value_path(path_prefix)
    yield from _objects_pager("GET", "/roles/", **params)


def iter_roles(**query_kwargs):
    for role_dict in iter_role_dicts(**query_kwargs):
        yield IDMRole(role_dict)


def get_role(role_path, user=None, group_id=None, **query_args):
    params = dict(path_prefix=role_path, **query_args)
    if user is not None:
        params["user"] = user
    elif group_id is not None:
        params["group"] = group_id
    else:
        raise LogicalError()

    for role in iter_roles(**params):
        if user is not None and not role.is_group:
            return role
        if group_id is not None and role.is_group and role.group_id == group_id:
            return role

    return None


class BatchRequest:
    def __init__(self):
        self._subrequests = []
        self._cur_subrequest_id = -1
        self._subrequest_id_to_expected_codes = {}

    def add_role_node(self, parent_path, node_dict, system=None):
        request_params = _mk_add_role_node_request_params(parent_path, node_dict, system)
        self._add_subrequest(request_params, {http.client.CREATED})

    def remove_role_node(self, path, system=None):
        request_params = _mk_remove_role_node_request_params(path, system)
        self._add_subrequest(request_params, {http.client.NO_CONTENT})

    def request_role(self, path, system=None, user=None, group=None, **query_kwargs):
        request_params = _mk_request_role_request_params(path, system, user, group, **query_kwargs)
        # 409 means that role already exists
        self._add_subrequest(request_params, {http.client.CREATED, http.client.CONFLICT})

    def revoke_role(self, role_id):
        request_params = _mk_revoke_role_request_params(role_id)
        self._add_subrequest(request_params, {http.client.NO_CONTENT})

    def execute(self):
        if not self._subrequests:
            return

        self._fix_requests_params()
        data = json.dumps(self._subrequests)
        headers = {"Content-Type": "application/json"}

        resp = _api_request("POST", "/batch/", data=data, headers=headers, check_status=False)

        schema = DictScheme(
            {
                "responses": List(
                    DictScheme(
                        {
                            "body": AnyScheme([Dict(), String(), NoneValue()]),
                            "headers": Dict(),
                            "id": String(optional=True),
                            "status_code": Integer(),
                        }
                    ),
                    optional=True,
                ),
                "error_code": String(optional=True),
                "message": String(optional=True),
            }
        )
        resp_json = get_json_response(resp, scheme=schema, check_status=False)

        if resp.status_code != http.client.OK:
            self._raise_errors(resp_json)

        return resp

    def is_empty(self):
        return not len(self._subrequests)

    def _add_subrequest(self, request_params, expected_status_codes):
        request_params["id"] = self._gen_subreq_id()
        self._subrequests.append(request_params)
        self._subrequest_id_to_expected_codes[request_params["id"]] = expected_status_codes

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_val is not None:
            # all exceptions are wrapped in IDMInternalError
            if isinstance(exc_val, IDMInternalError):
                pass
            # batch request error, we use specific type below to be able to filter particular types of problems
            if isinstance(exc_val, IDMBatchRequestError):
                exc_val = IDMInternalError("Could not execute batch request: {}", exc_val.message)
            else:
                exc_val = IDMInternalError("Could not execute batch request: {}", exc_val)

            log.exception("Got exception during batch request: %s", exc_val)

            juggler.send_event(
                "{}-idm-batch-request".format(_get_system()),
                juggler.JugglerCheckStatus.CRIT,
                message=ellipsis_string(str(exc_val), JUGGLER_MSG_MAX_SIZE),
            )
            raise exc_val

    def _gen_subreq_id(self):
        self._cur_subrequest_id += 1
        subreq_id = str(self._cur_subrequest_id)
        return subreq_id

    def _fix_requests_params(self):
        """Request data param is called body in batch operations, rename them"""
        for subrequest in self._subrequests:
            if "data" in subrequest:
                subrequest["body"] = subrequest.pop("data")

    def _raise_errors(self, resp):
        # special case: if error was not on subrequest level (for example, no subrequests were passed),
        # error is passed at top-level with different fields
        if "error_code" in resp:
            raise IDMBatchRequestError({resp["error_code"]: [resp["message"]]})

        code_to_errors = defaultdict(list)
        for subresp in resp.get("responses", []):
            expected_statuses = self._subrequest_id_to_expected_codes[subresp["id"]]
            if subresp["status_code"] not in expected_statuses:
                code_to_errors[subresp["status_code"]].append(self._get_subresp_message(subresp))

            # anyway log exceptions on expected non-200 codes, but don't throw up
            elif subresp["status_code"] >= 400:
                log.exception("Got exception during batch request: %s", self._get_subresp_message(subresp))

        if not code_to_errors:
            return

        log.error("Got error response for batch request: %s", code_to_errors)
        raise IDMBatchRequestError(code_to_errors)

    @staticmethod
    def _get_subresp_message(subresp):
        message = subresp.get("body", {}).get("message", "")
        if message:
            return six.ensure_str(message, "utf-8")


def is_system_broken():
    @retry(interval=2, backoff=3, exceptions=(IDMInternalError,))
    def _aux():
        system = _get_system()
        sys_info = _get_system_info(system)
        return sys_info["is_broken"]

    return _aux()


def _mk_add_role_node_request_params(parent_path, node_dict, system):
    parent_path = idm_list_path_to_string(parent_path)
    system = _get_system(system)

    data = dict(node_dict, parent=parent_path, system=system)
    data = json.dumps(data)

    request_params = {"method": "POST", "path": "/rolenodes/", "data": data}
    return request_params


def _mk_remove_role_node_request_params(path_list, system):
    path_list = idm_list_path_to_string(path_list)
    system = _get_system(system)
    full_path = "/rolenodes/{}{}".format(system, path_list)

    request_params = {
        "method": "DELETE",
        "path": full_path,
    }
    return request_params


def _mk_request_role_request_params(path, system, user, group, **query_kwargs):
    data = dict(query_kwargs, system=_get_system(system))
    data["path"] = _get_idm_value_path(path)

    if user is not None:
        data["user"] = user
    elif group is not None:
        data["group"] = staff.get_group_id(group)
    else:
        raise LogicalError()

    data = json.dumps(data)

    request_params = {"method": "POST", "path": "/rolerequests/", "data": data}
    return request_params


def _mk_revoke_role_request_params(role_id):
    request_params = {"method": "DELETE", "path": "/roles/{}/".format(role_id)}
    return request_params


def _get_idm_value_path(path):
    values = path[1::2]  # path consists of seq [role,value]+, get every second element (values)
    return idm_list_path_to_string(values)


def _get_system(system=None):
    return system or config.get_value("idm.system_name")


def _objects_pager(method, url_path, **query_args):
    offset = 0
    # IDM devs ask to pass last element key to speed up requests:
    # https://wiki.yandex-team.ru/intranet/abc/api/#pagination
    last_key = None

    while True:
        params = drop_none(dict(query_args, offset=offset, limit=PAGING_LIMIT, last_key=last_key))
        resp = _api_request(method, url_path, params=params)
        resp_json = get_json_response(resp)

        yield from resp_json["objects"]

        if len(resp_json["objects"]) < PAGING_LIMIT:
            break

        offset += PAGING_LIMIT
        last_key = resp_json["objects"][-1]["id"]


def _get_system_info(system):
    resp = _api_request("GET", "/systems/{}/".format(system))
    return get_json_response(resp)


def _api_request(method, path, params=None, expected_status_code=http.client.OK, headers=None, **kwargs):
    url = "https://" + config.get_value("idm.api_url") + "/api/v1" + path

    headers = headers or {}
    headers["Authorization"] = "OAuth " + config.get_value("idm.access_token")

    try:
        response = request(
            "IDM",
            method,
            url,
            params=params,
            cert=cauth.get_certificate(),
            headers=headers,
            success_codes=(expected_status_code,),
            **kwargs
        )
    except RequestException as e:
        msg = e.response.content if e.response is not None else e
        raise IDMInternalError("{}", msg, url=url, params=params)

    return response
