"""Staff client."""

import logging
from collections import defaultdict

from cachetools.func import ttl_cache

from sepelib.core import config
from sepelib.core.constants import HOUR_SECONDS, MINUTE_SECONDS
from sepelib.core.exceptions import Error, LogicalError
from sepelib.yandex.staff import StaffClient
from walle.errors import UserRecoverableError
from walle.util import misc, db_cache

# assume we have nearly ~500 projects and we want to be able to resolve owners for all hosts as fast as possible
# these projects have ~250 unique combinations of owners
# we expect number of projects to grow, but not that fast
from walle.util.misc import iter_chunks

USER_GROUPS_HOT_CACHE_SIZE = 500
USER_GROUPS_HOT_CACHE_TTL = HOUR_SECONDS

USER_GROUPS_COLD_CACHE_TTL = 11 * HOUR_SECONDS

USER_GROUPS_ERROR_CACHE_SIZE = 100
USER_GROUPS_ERROR_CACHE_TTL = MINUTE_SECONDS

STAFF_GROUP_RESOLVING_BATCH_SIZE = 20

log = logging.getLogger(__name__)


class InvalidOwnerError(UserRecoverableError):
    pass


class InvalidLoginError(InvalidOwnerError):
    def __init__(self, login):
        super().__init__("Invalid login: {}.", login)
        self.login = login


class InvalidGroupError(InvalidOwnerError):
    def __init__(self, group):
        super().__init__("Invalid group: {}.", group)
        self.group = group


@ttl_cache(USER_GROUPS_HOT_CACHE_SIZE, USER_GROUPS_HOT_CACHE_TTL)
@db_cache.memoized(
    "project_owners",
    value_ttl=USER_GROUPS_COLD_CACHE_TTL,
    error_ttl=USER_GROUPS_ERROR_CACHE_TTL,
    error_cache_size=USER_GROUPS_ERROR_CACHE_SIZE,
)
def resolve_owners(owners, allow_robots=True):
    """Get project owners tuple and return list of logins (both specified directly and as group members).
    Any dismissed, external or (optionally) robot employee will be silently ignored."""

    groups = check_groups(list(filter(is_group, owners)), skip_missing=True)
    logins = set(
        check_logins((owner for owner in owners if not is_group(owner)), allow_robots=allow_robots, skip_missing=True)
    )

    for group in groups:
        logins.update(get_group_members(group, allow_robots=allow_robots))

    return sorted(logins)


def check_login(login, allow_dismissed=False):
    return check_logins([login], allow_dismissed)[0]


def check_owners(owners):
    if not owners:
        return list()

    logins = check_logins(list(filter(lambda owner: not is_group(owner), owners)))
    groups = check_groups(list(filter(is_group, owners)))

    return logins + groups


def check_logins(logins, allow_dismissed=False, allow_robots=True, allow_external=False, skip_missing=False):
    if not logins:
        return list()

    logins = list(set(logins))
    valid_logins = []
    client = _get_staff_client()

    staff_filter = {"is_deleted": "false"}
    query = None
    if not allow_dismissed:
        staff_filter["official.is_dismissed"] = "false"
    if allow_external:
        if not allow_robots:
            staff_filter["official.is_robot"] = "false"
    else:
        query = 'official.affiliation!="external"'
        if allow_robots:
            query += ' or official.is_robot==true'

    for login_bucket in misc.iter_chunks(logins, 10):
        try:
            found_logins = {
                result["login"]
                for result in client.list_persons(
                    dict(staff_filter, _query=query, login=",".join(login_bucket), _limit=len(login_bucket)),
                    fields=["login"],
                )["result"]
            }
        except Exception as e:
            raise Error("Failed to validate logins against Staff database: {}", e)

        valid_logins.extend(found_logins)
        missing_logins = set(login_bucket) - found_logins

        if missing_logins and not skip_missing:
            raise InvalidLoginError(missing_logins.pop())

    return sorted(valid_logins)


def check_groups(groups, skip_missing=False):
    if not groups:
        return list()

    groups = list(set(groups))
    valid_groups = []
    client = _get_staff_client()

    for group_bucket in misc.iter_chunks(groups, 10):
        group_urls = list(map(group_name, group_bucket))  # remove @ from group name
        try:
            found_groups = {
                group(result["url"])
                for result in client.list_groups(
                    {
                        "url": ",".join(group_urls),
                        "is_deleted": "false",
                        "_limit": len(group_urls),
                    },
                    fields=("url",),
                )["result"]
            }
        except Exception as e:
            raise Error("Failed to validate groups against Staff database: {}", e)

        valid_groups.extend(found_groups)
        missing_groups = set(group_bucket) - found_groups

        if missing_groups and not skip_missing:
            raise InvalidGroupError(missing_groups.pop())

    return sorted(valid_groups)


def _get_group_ids(client, groups, is_deleted=False):
    group_to_id = {}

    try:
        for group_bucket in misc.iter_chunks(groups, 10):
            group_urls = list(map(group_name, group_bucket))  # remove @ from group name
            query = {
                "url": ",".join(group_urls),
                "_limit": len(group_urls),
                "is_deleted": "true" if is_deleted else "false",
            }
            resp = client.list_groups(query, fields=("url", "id"))["result"]
            for result in resp:
                group_to_id[group(result["url"])] = result["id"]
    except Exception as e:
        raise Error("Failed to validate groups against Staff database: {}", e)

    return group_to_id


def groups_to_ids(groups):
    groups = set(groups)
    client = _get_staff_client()

    group_to_id = _get_group_ids(client, groups, is_deleted=False)
    if len(group_to_id) == len(groups):
        return group_to_id

    potentially_deleted_groups = groups - set(group_to_id.keys())
    group_to_id.update(_get_group_ids(client, potentially_deleted_groups, is_deleted=True))

    return group_to_id


def get_group_id(group):
    group_to_id = groups_to_ids([group])
    try:
        return group_to_id[group]
    except KeyError:
        raise Error("Could not find id of group {}".format(group))


def groups_by_planner_ids(planner_ids, fields=None):
    fields = fields or set()
    fields = fields.union({"id", "name", "service"})

    planner_ids = set(planner_ids)
    client = _get_staff_client()

    groups = []
    for planner_id_bucket in misc.iter_chunks(planner_ids, 50):
        # PQL query to execute. See https://github.com/alonho/pql
        query = "service.id in [{}]".format(','.join(str(pid) for pid in planner_id_bucket))

        resp = client.list_groups({"_query": query, "_limit": len(planner_id_bucket)}, fields=fields)["result"]
        groups.extend(resp)

    return groups


def id_to_group(group_id, is_deleted=False):
    client = _get_staff_client()

    try:
        query = {"id": group_id, "is_deleted": "true" if is_deleted else "false"}
        resp = client.list_groups(query, fields=("url", "id"))["result"]
    except Exception as e:
        log.exception("Failed to convert id to Staff group:")
        raise Error("Failed to convert id to Staff group: {}", e)
    if not resp:
        raise Error("Could not find Staff group with id {}", group_id)
    group_name = resp[0]["url"]
    return group(group_name)


def is_group(owner):
    """Return True if the specified owner is a group name"""

    return owner.startswith("@")


def is_member(login, groups):
    """Check if given user is a member of any of given groups."""
    if groups:
        user_groups = set(get_user_groups(login))
        if user_groups.intersection(groups):
            return True
    return False


def group(name):
    return "@" + name


def group_name(group):
    if not is_group(group):
        raise ValueError("owner {} is not a group".format(group))

    return group[1:]


@ttl_cache(USER_GROUPS_HOT_CACHE_SIZE, USER_GROUPS_HOT_CACHE_TTL)
@db_cache.memoized(
    "user_groups",
    value_ttl=USER_GROUPS_COLD_CACHE_TTL,
    error_ttl=USER_GROUPS_ERROR_CACHE_TTL,
    error_cache_size=USER_GROUPS_ERROR_CACHE_SIZE,
)
def get_user_groups(login):
    groups = []
    client = _get_staff_client()

    try:
        person = client.list_persons(
            {"login": login}, ("groups.group.url", "department_group.ancestors.url", "department_group.url"), one=True
        )
    except Exception as e:
        raise Error("Failed to get user's groups from the Staff portal: {}", e)

    groups.extend(group(group_["group"]["url"]) for group_ in person["groups"])
    groups.append(group(person["department_group"]["url"]))
    groups.extend(group(dept["url"]) for dept in person["department_group"]["ancestors"])

    return groups


@ttl_cache(USER_GROUPS_HOT_CACHE_SIZE, USER_GROUPS_HOT_CACHE_TTL)
@db_cache.memoized(
    "group_members",
    value_ttl=USER_GROUPS_COLD_CACHE_TTL,
    error_ttl=USER_GROUPS_ERROR_CACHE_TTL,
    error_cache_size=USER_GROUPS_ERROR_CACHE_SIZE,
)
def get_group_members(group, allow_robots=True, allow_external=False):
    client = _get_staff_client()
    query = _get_belongs_to_groups_query(group)

    filters, query = _external_and_robots_criteria(query, allow_external, allow_robots)

    members = []
    for page in _iter_pages(client.list_persons, filters, query, fields=["login"]):
        for person in page:
            members.append(person["login"])
    return members


@ttl_cache(USER_GROUPS_HOT_CACHE_SIZE, USER_GROUPS_HOT_CACHE_TTL)
@db_cache.memoized(
    "group_members_batch",
    value_ttl=USER_GROUPS_COLD_CACHE_TTL,
    error_ttl=USER_GROUPS_ERROR_CACHE_TTL,
    error_cache_size=USER_GROUPS_ERROR_CACHE_SIZE,
)
def batch_get_groups_members(groups, allow_robots=True, allow_external=False):
    """
    Batch operation to resolve multiple staff groups at once
    Requests members of multiple groups, then reassembles them by groups in response
    Splits groups in batches (requesting too much groups kills Staff backend)
    Staff doesn't know about our notation of groups (@group), so we have remove @ during request, then add it back

    :param groups: sorted tuple of group names (prepended by @). Without sorting we'll mess the caching keys
    :returns dict[group] -> sorted list of members
    """
    client = _get_staff_client()

    group_to_members = defaultdict(set)
    groups_to_match = {group_name(group_) for group_ in groups}  # remove @
    for groups_part in iter_chunks(groups, STAFF_GROUP_RESOLVING_BATCH_SIZE):
        query = _get_belongs_to_groups_query(groups_part)
        filters, query = _external_and_robots_criteria(query, allow_external, allow_robots)
        fields = ("login",) + _get_group_types()

        for page in _iter_pages(client.list_persons, filters, query, fields=fields):
            for person in page:
                group_to_members = _extract_groups(person, group_to_members, groups_to_match)

    # create placeholders for groups without members
    for group_wo_members in set(groups) - set(group_to_members):
        group_to_members[group_wo_members] = set()

    group_to_members = {group_: sorted(members) for group_, members in group_to_members.items()}

    return group_to_members


def _extract_groups(person, group_to_members, groups_to_match):
    """
    Processes `person` record (it contains up to three group types),
    extracting groups matching `groups_to_match` (and prepending them with @),
    and updating `group_to_members`, adding login to corresponding groups
    """

    def _extract_field(record, path):
        """Recursively extract all values matching path
        If first component of path is present in record, the rest will be present too
        Some levels may be lists, extract from each element in them
        """
        cur_part, rest = path[0], path[1:]
        cur_level = record.get(cur_part)
        if cur_level is None:
            raise LogicalError()

        if not rest:
            yield cur_level
        elif isinstance(cur_level, dict):
            yield from _extract_field(cur_level, rest)
        elif isinstance(cur_level, list):
            for elem in cur_level:
                yield from _extract_field(elem, rest)
        else:
            raise LogicalError()

    def group_type_to_path(group_type):
        return group_type.split(".")

    person_login = person["login"]
    for group_type in _get_group_types():
        path = group_type_to_path(group_type)
        if path[0] in person:
            groups_in_record = set(_extract_field(person, path))
            for group_url in groups_in_record.intersection(groups_to_match):
                group_to_members[group(group_url)].add(person_login)

    return group_to_members


def _get_belongs_to_groups_query(groups):
    group_types = _get_group_types()
    if isinstance(groups, str):
        query = " or ".join(['{}=="{}"'.format(group_type, group_name(groups)) for group_type in group_types])
    else:
        groups_quoted = ("\"{}\"".format(group_name(group)) for group in groups)
        groups_joined = "[{}]".format(",".join(groups_quoted))

        query = " or ".join(['{} in {}'.format(group_type, groups_joined) for group_type in group_types])
    return "({})".format(query)


def _get_group_types():
    return "groups.group.url", "department_group.ancestors.url", "department_group.url"


def _external_and_robots_criteria(query, allow_external, allow_robots):
    filters = {"official.is_dismissed": False}
    if allow_external:
        if not allow_robots:
            filters["official.is_robot"] = False
    else:
        # possible affiliations are "yandex", "yamoney" and "external"
        if allow_robots:
            # NB: robots have official.affiliation == external.
            query += ' and (official.is_robot==true or official.affiliation!="external")'
        else:
            query += ' and official.affiliation!="external"'
    return filters, query


def _iter_pages(client_fn, filters, query, fields):
    page, pages = 1, 1
    while page <= pages:
        response = client_fn(spec=dict(filters, _query=query, _page=page), fields=fields)
        yield response["result"]
        page, pages = page + 1, response["pages"]


def _get_staff_client():
    return StaffClient(
        oauth_token=config.get_value("staff.access_token"),
        staff_url=config.get_value("staff.url", None),
    )
