import json
from collections import Counter, defaultdict
from datetime import datetime, timedelta
from functools import partial
from typing import Tuple, Callable, Union, Iterable, TypeVar, Dict, List

import constance
from django.conf import settings
from django.db.models import Q, F, OuterRef, Count, IntegerField, Subquery, Func
from django.db.models.functions import Coalesce
from django.utils import timezone
from metrics_framework.models import MetricPoint

from idm.core.constants.action import ACTION
from idm.core.constants.role import ROLE_STATE
from idm.core.models import Action, Role, SystemRolePush, System, SystemMetainfo, RoleNode
from idm.core.workflow.exceptions import RoleNodeDoesNotExist
from idm.monitorings import metric
from idm.monitorings.juggler import JugglerEvent, JugglerStatus
from idm.users.models import User, GroupMembership

TCheckCallback = Callable[..., Tuple[JugglerStatus, str]]
NO_DATA_MESSAGE = 'No data'


ACTIVE_CHECKS: List['JugglerCheck'] = []


class JugglerCheck:
    def __init__(self, name: str, *, callback: TCheckCallback = None, tags: List[str] = None):
        self.name = name
        self.callback = callback
        self.tags = tags or []

    @classmethod
    def register(
            cls,
            name_or_func: Union[str, TCheckCallback] = None,
            callback: TCheckCallback = None,
            tags: List[str] = None,
    ) -> Union[Callable[[TCheckCallback], 'JugglerCheck'], 'JugglerCheck']:
        if callable(name_or_func):
            check = cls(name=name_or_func.__name__, callback=name_or_func, tags=tags)
            ACTIVE_CHECKS.append(check)
            return check

        def wrapper(func: TCheckCallback) -> 'JugglerCheck':
            _check = cls(name=name_or_func or func.__name__, callback=func, tags=tags)
            ACTIVE_CHECKS.append(_check)
            return _check

        if callback is None:
            return wrapper

        return wrapper(callback)

    def __call__(self):
        return self.callback()

    def get_status(self) -> Tuple[JugglerStatus, str]:
        return self.callback()

    def get_event(self) -> JugglerEvent:
        status, description = self.get_status()
        return JugglerEvent(
            service=self.name,
            status=status,
            description=description,
            tags=self.tags,
        )


@JugglerCheck.register(tags=['roles'])
def active_roles_of_inactive_groups() -> Tuple[JugglerStatus, str]:
    data = metric.ActiveRolesOfInactiveGroupsMetric.get()

    if data is None:
        return JugglerStatus.WARNING, NO_DATA_MESSAGE

    for key, value in data.items():
        if not value:
            continue
        if 'onhold' in key:
            continue
        if key == 'personal_by_ref_amount' and value < settings.IDM_PERSONAL_BY_REF_AMOUNT_THRESHOLD:
            continue

        return JugglerStatus.CRITICAL, f'Active roles of inactive groups: {json.dumps(data)}'
    return JugglerStatus.OK, ''


@JugglerCheck.register(tags=['users'])
def fired_users_limit_exceeded() -> Tuple[JugglerStatus, str]:
    data = metric.FiredUsersLimitExceededMetric.get()

    if data is None:
        return JugglerStatus.WARNING, NO_DATA_MESSAGE
    elif data:
        return  JugglerStatus.CRITICAL, f'Not blocked fired users count: {data}'
    return JugglerStatus.OK, ''


@JugglerCheck.register
def gap_synchronization() -> Tuple[JugglerStatus, str]:
    last_action = Action.objects.filter(
            Q(action='gap_synchronization_completed') &
            ~Q(error__exact='') &
            Q(added__gte=datetime.utcnow() - timedelta(hours=constance.config.GAP_MONITORING_TIMEDELTA))
    ).values('error', 'added').first()

    if last_action:
        return JugglerStatus.CRITICAL, '[{}] => Exception message: {}'.format(
            last_action['added'].replace(microsecond=0, tzinfo=None).isoformat(),
            last_action['error'],
        )
    return JugglerStatus.OK, ''


@JugglerCheck.register(tags=['roles'])
def hanging_approving_roles() -> Tuple[JugglerStatus, str]:
    hanging_roles = Role.objects.hanging_approved(threshold=None)
    if not hanging_roles:
        return JugglerStatus.OK, ''

    oldest = min([role.last_action_added for role in hanging_roles])
    hanging_time = (timezone.now() - oldest).total_seconds() / 60

    status = JugglerStatus.CRITICAL
    if hanging_time < constance.config.HANGING_ROLES_WARN_MINUTES:
        status = JugglerStatus.WARNING

    return status, f'IDM has {hanging_roles.count()} approved hanging roles since {oldest.strftime("%F %T")}'


@JugglerCheck.register(tags=['roles'])
def hanging_depriving_roles() -> Tuple[JugglerStatus, str]:
    hanging_roles = Role.objects.hanging_depriving()

    by_system = Counter(hanging_roles.values_list('system_id', flat=True))
    if not by_system or (
            len(by_system) < settings.IDM_HANGING_DEPRIVING_SYSTEMS_THRESHOLD \
            and max(by_system.values()) < settings.IDM_HANGING_DEPRIVING_BY_SYSTEM_THRESHOLD
    ):
        return JugglerStatus.OK, ''

    oldest = min([role.last_action_added for role in hanging_roles])
    hanging_time = (timezone.now() - oldest).total_seconds() / 60

    status = JugglerStatus.CRITICAL
    if hanging_time < constance.config.HANGING_ROLES_WARN_MINUTES:
        status = JugglerStatus.WARNING

    return status, f'IDM has {hanging_roles.count()} depriving hanging roles since {oldest.strftime("%F %T")}'


@JugglerCheck.register(tags=['roles'])
def hanging_ref_roles() -> Tuple[JugglerStatus, str]:
    returnable_refs = Role.objects.filter(
        Q(user_id=F('parent__user_id'), group__isnull=True) | Q(group_id=F('parent__group_id'), user__isnull=True),
        parent_id=OuterRef('pk'),
    ).order_by().values('parent').annotate(total=Count('*')).values('total')

    roles_with_failed_refs = (
        Role.objects
        .prefetch_related('refs')
        .annotate(
            declared_refs_count=Func(F('ref_roles'), function='jsonb_array_length', output_field=IntegerField()),
            returnable_refs_count=Coalesce(Subquery(returnable_refs, output_field=IntegerField()), 0),
        ).filter(
            state=ROLE_STATE.GRANTED,
            ref_roles__isnull=False,
            declared_refs_count__gt=0,
            returnable_refs_count__lt=F('declared_refs_count'),
        )
    )
    active_systems: Dict[str, System] = {system.slug: system for system in System.objects.operational()}
    problem_counters = defaultdict(int)
    expected_ref_definitions = defaultdict(list)

    for role in roles_with_failed_refs:  # type: Role
        if role.ref_roles is None:
            continue

        if not isinstance(role.ref_roles, list):
            problem_counters['corrupted_ref_roles_value'] += 1
            continue

        for ref_role_data in role.ref_roles:
            if not isinstance(ref_role_data, dict) or not {'system', 'role_data'}.issubset(ref_role_data):
                problem_counters['corrupted_ref_roles_value'] += 1
                continue

            elif ref_role_data['system'] not in active_systems:
                problem_counters['refs_on_inactive_systems'] += 1
                continue

            elif not isinstance(ref_role_data['role_data'], dict):
                problem_counters['corrupted_ref_roles_value'] += 1
                continue
            expected_ref_definitions[(role, active_systems[ref_role_data['system']])].append(ref_role_data['role_data'])

    for ref_role_key, nodes_data in expected_ref_definitions.items():
        parent_role, system = ref_role_key
        unknown_nodes = len(nodes_data) - RoleNode.objects.active().filter(system=system, data__in=nodes_data).count()
        problem_counters['refs_on_unknown_nodes'] += unknown_nodes
        problem_counters['non_requested_refs'] += len(nodes_data) - unknown_nodes -\
            parent_role.refs.filter(system=system, node__data__in=nodes_data).count()

    if problem_counters:
        status = JugglerStatus.WARNING
        if problem_counters['non_requested_refs']:
            status = JugglerStatus.CRITICAL
        return status, f'Ref roles problems: {json.dumps(problem_counters)}'
    return JugglerStatus.OK, ''


@JugglerCheck.register
def idm_error_count() -> Tuple[JugglerStatus, str]:
    count = Action.objects.filter(
        action=ACTION.IDM_ERROR,
        added__gt=timezone.now() - timedelta(hours=constance.config.IDM_ERROR_WINDOW_HOURS)
    ).count()
    if count > constance.config.IDM_ERROR_COUNT_THRESHOLD:
        return JugglerStatus.CRITICAL, \
               f"IDM has {count} actions `{ACTION.IDM_ERROR}` in last {constance.config.IDM_ERROR_WINDOW_HOURS} hours"
    return JugglerStatus.OK, ''


@JugglerCheck.register
def logins_to_subscribe() -> Tuple[JugglerStatus, str]:
    data = metric.PassportLoginsToSubscribeMetric.get()
    if data is None:
        return JugglerStatus.WARNING, NO_DATA_MESSAGE
    elif data > settings.IDM_SID67_THRESHOLD:
        return JugglerStatus.CRITICAL, f'IDM has {data} passport logins to subscribe'
    return JugglerStatus.OK, ''


@JugglerCheck.register(tags=['users'])
def not_blocked_ad_users() -> Tuple[JugglerStatus, str]:
    flap_time = timezone.now() - timezone.timedelta(hours=int(constance.config.FIRED_BUT_ACTIVE_IN_AD_TIMEDELTA))
    users = User.objects.dismissed().filter(
            idm_found_out_dismissal__lt=flap_time
        ).exclude(ldap_active=False).values_list('username', flat=True)
    if users:
        return JugglerStatus.CRITICAL, f'IDM has {len(users)} fired, but active in AD users: {", ".join(users)}'
    return JugglerStatus.OK, ''


@JugglerCheck.register(tags=['roles'])
def not_pushed_system_roles() -> Tuple[JugglerStatus, str]:
    pushes = set(
        SystemRolePush.objects.filter(
            system__added__lt=timezone.now() - timezone.timedelta(minutes=30)
        ).values_list('system__slug', flat=True)
    )

    if pushes:
        return JugglerStatus.CRITICAL,\
               f'IDM has {len(pushes)} systems with not pushed responsibles or team members: {", ".join(pushes)}'

    return JugglerStatus.OK, ''


@JugglerCheck.register(tags=['roles'])
def overlengthed_ref_roles_chain() -> Tuple[JugglerStatus, str]:
    data = metric.OverlengthedRefRoleChainMetric.get()

    if data is None:
        return JugglerStatus.WARNING, NO_DATA_MESSAGE
    elif data:
        return JugglerStatus.CRITICAL,\
               f"Roles with overlengthed refs chain: {', '.join(map(str, data))}"

    return JugglerStatus.OK, ''


@JugglerCheck.register(tags=['roles'])
def review_roles_threshold_exceeded() -> Tuple[JugglerStatus, str]:
    data = metric.ReviewRolesThresholdExceededMetric.get()

    if data is None:
        return JugglerStatus.WARNING, NO_DATA_MESSAGE
    elif data:
        return JugglerStatus.CRITICAL, \
               f'Number of roles to review on last run: {json.dumps(data)}'

    return JugglerStatus.OK, ''


@JugglerCheck.register
def unsubscribed_logins():
    data = metric.UnsubscribedLoginsMetric.get()

    if data is None:
        return JugglerStatus.WARNING, NO_DATA_MESSAGE
    elif data:
        return JugglerStatus.CRITICAL, data

    return JugglerStatus.OK, ''

_K = TypeVar('_K')
_V = TypeVar('_V')


def common_value(keys: Iterable[_K], value: _V) -> Dict[_K, _V]:
    """Метод, который берёт словарь вида {(a, b, c) -> v} и разворачивает его в {a->v, b->v, c->v]"""
    return {key: value for key in keys}


@JugglerCheck.register
def unsynchronized_systems() -> Tuple[JugglerStatus, str]:
    TASK_TO_QUERYSET = {
        'sync_nodes': System.objects.get_operational().filter(nodes__is_auto_updated=True,nodes__state='active'),
        'default': System.objects.get_operational(),
        **common_value(
            (
                'activate_memberships',
                'deprive_memberships',
                'update_memberships',
                'check_memberships',
                'resolve_memberships',
            ),
            System.objects.get_operational().get_systems_with_group_sync_policy()
        )
    }
    TASK_TO_THRESHOLD = {
        'default': int(constance.config.UNSYNCHRONIZED_SYSTEMS_DELTA_DAYS) * 24 * 60,  # переводим в минуты
        **common_value(
            (
                'activate_memberships',
                'deprive_memberships',
                'update_memberships',
                'recalc_pipeline',
            ),
            int(constance.config.UNSYNCHRONIZED_REGULAR_SYSTEMS_DELTA_MINUTES)
        ),
    }

    start_time = timezone.now()
    errors = []
    for task_type in sorted(SystemMetainfo.TRACKABLE_TASKS):
        queryset = TASK_TO_QUERYSET.get(task_type, TASK_TO_QUERYSET['default'])
        threshold = TASK_TO_THRESHOLD.get(task_type, TASK_TO_THRESHOLD['default'])
        since_threshold = start_time - timezone.timedelta(minutes=threshold)
        min_system_age = start_time - timezone.timedelta(days=int(constance.config.NEW_SYSTEMS_MONITORINGS_DOWNTIME))

        timestamp_field = 'metainfo__last_{}_finish__lt'.format(task_type)
        timestamp_isnull = 'metainfo__last_{}_finish__isnull'.format(task_type)
        should_monitor_field = 'metainfo__monitor_{}'.format(task_type)

        failed_systems = (
            queryset
            .filter(**{should_monitor_field: True})
            .filter(
                Q(**{timestamp_field: since_threshold}) |
                (Q(**{timestamp_isnull: True}) & Q(added__lte=min_system_age))
            )
            .values_list('slug', flat=True)
            .distinct()
        )

        if len(failed_systems) > 0:
            errors += ['"{}" task has not been recently run for systems {}'.format(
                task_type,
                ','.join(sorted(failed_systems))
            )]
    if errors:
        return JugglerStatus.CRITICAL, 'System synchronisation errors: ' + '; '.join(errors)
    else:
        return JugglerStatus.OK, ''


@JugglerCheck.register(tags=['users'])
def users_without_depratment_groups() -> Tuple[JugglerStatus, str]:
    flap_time = timezone.now() - timezone.timedelta(hours=int(constance.config.USERS_WITHOUT_DEPARTMENT_TIMEDELTA))
    group_type_filter = Q(group__type='department')
    active_membership_state_filter = Q(state='active') | Q(date_leaved__gt=flap_time)
    exclude_query = group_type_filter & active_membership_state_filter

    users_with_department_groups = GroupMembership.objects.filter(exclude_query).values_list('user', flat=True)

    users = (
        User.objects
        .users().active()
        .exclude(pk__in=users_with_department_groups)
        .exclude(username='AnonymousUser')
        .values_list('username', flat=True)
    )

    count = users.count()
    if count:
        return JugglerStatus.CRITICAL, f'IDM has {count} users without department group: {", ".join(users)}'

    return JugglerStatus.OK, ''


def metric_point_status(metric_slug: str, validator: Callable[[float], bool]) -> Tuple[JugglerStatus, str]:
    last_point = MetricPoint.objects.filter(metric__slug=metric_slug).order_by('-created_at').first()
    if not last_point:
        return JugglerStatus.WARNING, NO_DATA_MESSAGE
    elif last_point.is_outdated():
        return JugglerStatus.WARNING, 'Last updated at %s' % last_point.created_at

    if bad_values := [str(value) for value in last_point.values.all() if validator(value.value)]:
        return JugglerStatus.CRITICAL, f'Metric {metric_slug} has errors: ' + ';'.join(bad_values)
    return JugglerStatus.OK, ''


closure_inconsistencies_count: JugglerCheck = JugglerCheck.register(
    'closure_inconsistencies_count',
    partial(metric_point_status, 'closure_inconsistencies_count', lambda value: value > 0),
)
closure_inconsistent_paths: JugglerCheck = JugglerCheck.register(
    'closure_inconsistent_paths',
    partial(metric_point_status, 'closure_inconsistent_paths', lambda value: value > 0),
)
groups_closure_inconsistent_count: JugglerCheck = JugglerCheck.register(
    'groups_closure_inconsistent_count',
    partial(metric_point_status, 'groups_closure_inconsistent_count', lambda value: value > 0),
)