import logging
from decimal import ROUND_DOWN, ROUND_HALF_UP, Decimal
from typing import Iterable, List, Optional, Set

from cache_memoize import cache_memoize

from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Permission
from django.contrib.contenttypes.models import ContentType
from django.core.cache import cache
from django.db import transaction
from django.db.models import Max, Min, Sum
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _

from lms.core.models.aggregates import HasNull
from lms.staff.services import load_staff_users
from lms.tags.models import Tag
from lms.tags.services import normalize_tag_name

from .models import (
    Cohort, Course, CourseGroup, CourseModule, CourseOccupancy, CourseStudent, CourseVisibility, StudentCourseProgress,
    StudentModuleProgress,
)
from .permissions import CourseGroupObjectPermission

User = get_user_model()
log = logging.getLogger(__name__)


def update_occupancy(course: Course) -> None:
    """
    Обновляет наполняемость курса
    """
    defaults = {}
    aggr = CourseGroup.objects.available().filter(course=course).aggregate(
        sum_max_participants=Sum('max_participants'),
        min_max_participants=Min('max_participants'),
    )
    if aggr['min_max_participants'] == 0:
        defaults['maximum'] = 0
    elif aggr['sum_max_participants'] is not None:
        defaults['maximum'] = aggr['sum_max_participants']

    defaults['current'] = course.enrolled_users.pending_or_enrolled().count()

    CourseOccupancy.objects.update_or_create(course=course, defaults=defaults)


def update_group_participants(group: CourseGroup, commit: bool = True) -> None:
    """
    Считает кол-во активных заявок
    """
    cnt_students = group.enrolled_users.pending_or_enrolled().count()
    if cnt_students != group.num_participants:
        group.num_participants = cnt_students
    if commit:
        group.save()


def calc_course_dates(course: Course) -> dict:
    """
    Вычисляет мин/макс даты для курса, проверяя значения в группах
    """
    # TODO: не учтен случай, когда интервалы времен групп не пересекаются
    # TODO: не учтен случай, когда группы есть, но ни одна не подходит
    results = {}
    aggr = CourseGroup.objects.available().filter(course=course).not_full().aggregate(
        begin_date=Min('begin_date'),
        end_date=Max('end_date'),
        enroll_begin=Min('enroll_begin'),
        enroll_end=Max('enroll_end'),
        begin_date_has_null=HasNull('begin_date'),
        end_date_has_null=HasNull('end_date'),
        enroll_begin_has_null=HasNull('enroll_begin'),
        enroll_end_has_null=HasNull('enroll_end'),
    )

    for field in ['begin_date', 'end_date', 'enroll_begin', 'enroll_end']:
        results[f'calc_{field}'] = (
            getattr(course, field)
            if aggr.get(f'{field}_has_null') or aggr[field] is None
            else aggr[field]
        )

    return results


def update_course_on_change_group(course: Course) -> None:
    """
    Обновляет даты по курсу, при изменении группы

    :param course:
    :return:
    """
    data = {}

    # Устанавливает значения вычисляемых дат в курсе
    data.update(calc_course_dates(course))

    # Обновление поля наличия групп в курсе
    data['enable_groups'] = course.groups.exists()

    need_update = False
    for field, value in data.items():
        if getattr(course, field) != value:
            need_update = True
            break

    if need_update:
        Course.objects.filter(pk=course.pk).update(**data)


def set_course_calc_dates(course: Course) -> None:
    """
    Устанавливает значения вычисляемых дат в курсе
    """
    dates = calc_course_dates(course)

    for field, value in dates.items():
        setattr(course, field, value)


def update_permissions_on_add_teams(course: Course, teams_ids: Set[int]):
    """
    Обновление прав доступа, при добавлении команд к курсу

    `teams_ids` - список команд, которые добавляются

    :param course:
    :param teams_ids:
    :return:
    """
    course_content_type = ContentType.objects.get_for_model(Course)
    permissions_to_apply = set(
        Permission.objects.filter(
            presets__teams__in=teams_ids,
            content_type=course_content_type,
        ).values_list(
            'presets__teams__id', 'id',
        )
    )
    existing_permissions = set(
        CourseGroupObjectPermission.objects.filter(
            content_object=course,
            group__courseteam__id__in=teams_ids,
        ).values_list(
            'group_id', 'permission_id',
        )
    )

    permissions_to_create = permissions_to_apply - existing_permissions

    CourseGroupObjectPermission.objects.bulk_create(
        (
            CourseGroupObjectPermission(
                content_object=course,
                group_id=permission_to_create[0],
                permission_id=permission_to_create[1],
            )
            for permission_to_create in permissions_to_create
        ),
        batch_size=settings.BULK_BATCH_SIZE_DEFAULT,
    )


def update_permissions_on_remove_teams(course: Course, teams_ids: Set[int]) -> None:
    """
    Обновление списка permissions при удалении команд из курса

    `teams_ids` - список команд, которые удаляются

    :param course:
    :param teams_ids:
    :return:
    """
    CourseGroupObjectPermission.objects.filter(
        content_object=course,
        group__courseteam__id__in=teams_ids,
    ).delete()


def update_permissions_on_clear_teams(course: Course):
    """
    Обновление прав доступа, при удалении всех команд из курса

    :param course:
    :return:
    """
    teams_ids = getattr(course, '_teams_ids', None)
    if teams_ids:
        update_permissions_on_remove_teams(course, teams_ids)


def is_course_available_for_user(course: Course, user: User) -> bool:
    """
    Проверяет доступен ли курс для пользователя

    Если включен CACHE_ENABLED=True, то результат будет закэширован

    :param course:
    :param user:
    :return:
    """

    def check() -> bool:
        visibility = getattr(course, 'visibility', None)
        return visibility.available_for(user=user) if visibility else True

    if not getattr(settings, 'CACHE_ENABLED', False):
        return check()

    key = settings.COURSE_AVAILABLE_FOR_CACHE_KEY_TEMPLATE.format(
        user_pk=user.pk,
        course_pk=course.id,
    )

    return cache.get_or_set(key, check, timeout=settings.COURSE_VISIBILITY_CACHE_TTL)


def get_unavailable_courses_ids(user: User) -> Set[int]:
    """
    Получение списка курсов, недоступных пользователю

    Если включен CACHE_ENABLED=True, то результат будет закэширован

    :param user:
    :return:
    """

    def check() -> Set[int]:
        return CourseVisibility.objects.unavailable_for(user)

    if not getattr(settings, 'CACHE_ENABLED', False):
        return check()

    key = settings.COURSES_UNAVAILABLE_FOR_CACHE_KEY_TEMPLATE.format(
        user_pk=user.pk,
    )

    return cache.get_or_set(key, check, timeout=settings.COURSE_VISIBILITY_CACHE_TTL)


def get_categories_with_available_courses_ids(user: User) -> Iterable[int]:
    """
    Возвращает список категорий, с курсами доступными для пользователя

    :param user:
    :return:
    """
    return (
        Course.categories.through.objects
        .filter(course__is_active=True, course__is_archive=False, course__show_in_catalog=True)
        .exclude(course_id__in=get_unavailable_courses_ids(user=user))
        .values_list('coursecategory_id', flat=True)
    )


def flush_cache_courses_unavailable_for(users_ids: Optional[Iterable[int]] = None) -> None:
    """
    Сбрасывает кэш недоступных курсов для списка пользователей

    :param users_ids: если список пустой, будет очищен весь кэш
    :return:
    """
    key = settings.COURSES_UNAVAILABLE_FOR_CACHE_KEY_TEMPLATE

    if not users_ids:
        if hasattr(cache, 'delete_pattern'):
            cache.delete_pattern(key.format(user_pk='*'))
        else:
            users_ids = User.objects.values_list('id', flat=True)

    if users_ids:
        cache.delete_many(
            key.format(user_pk=user_id) for user_id in users_ids
        )


def flush_cache_course_available_for(
    users_ids: Optional[Iterable[int]] = None,
    courses_ids: Optional[Iterable[int]] = None,
) -> None:
    """
    Сбрасывает кэш доступности курсов для пользоветелей

    :param users_ids:
    :param courses_ids:
    :return:
    """
    key = settings.COURSE_AVAILABLE_FOR_CACHE_KEY_TEMPLATE

    if hasattr(cache, 'delete_pattern'):
        if users_ids and courses_ids:
            cache.delete_many(
                key.format(user_pk=user_id, course_pk=course_id)
                for user_id in users_ids for course_id in courses_ids
            )
        elif courses_ids:
            for course_id in courses_ids:
                cache.delete_pattern(
                    key.format(user_pk='*', course_pk=course_id)
                )
        elif users_ids:
            for user_id in users_ids:
                cache.delete_pattern(
                    key.format(user_pk=user_id, course_pk='*')
                )
        else:
            cache.delete_pattern(key.format(user_pk='*', course_pk='*'))

    else:
        users_ids = users_ids if users_ids else User.objects.values_list('id', flat=True)
        courses_ids = (
            courses_ids
            if courses_ids
            else list(CourseVisibility.objects.values_list('course_id', flat=True))
        )
        cache.delete_many(
            key.format(user_pk=user_id, course_pk=course_id)
            for user_id in users_ids for course_id in courses_ids
        )


def update_cohort_users(cohort: Cohort, force: bool = False):
    if cohort.status != Cohort.StatusChoices.PENDING and not force:
        return

    users = load_staff_users(logins=cohort.logins)

    with transaction.atomic():
        if len(users) == len(cohort.logins):
            cohort.status = Cohort.StatusChoices.READY
        else:
            not_found_logins = list(set(cohort.logins) - {user.username for user in users})
            error_message = (
                _("При обработке когорты %s не найдены логины %s") % (
                    cohort.id,
                    ', '.join(not_found_logins),
                )
            )
            log.error(error_message)
            cohort.error_messages = str(error_message)
            cohort.status = Cohort.StatusChoices.ERROR

        cohort.users.clear()
        cohort.users.add(*users)
        cohort.save()


def process_pending_cohort(cohort: Cohort, force: bool = False):
    flush_cache_courses_unavailable_for()
    flush_cache_course_available_for(courses_ids=[cohort.course_id])
    update_cohort_users(cohort=cohort, force=force)


def calculate_sum_of_module_weights(aggregated_weight, module_weight):
    if not aggregated_weight:
        return Decimal(1) if not module_weight else Decimal(module_weight)
    return Decimal(aggregated_weight) + Decimal(module_weight)


def format_scaled_value(value: Decimal) -> Decimal:
    return value.quantize(Decimal('1.00'), rounding=ROUND_DOWN)


# TODO: нужно подумать, хотим ли мы хранить weight_scaled в бд, кажется проще его вычислять на лету и кэшировать.
# изначально планировалось вызывать эту функцию в post_save сигнале для CourseModule, если поле score обновлялось
@transaction.atomic
def update_course_module_weight_scaled(course_module: CourseModule):
    course_modules = CourseModule.objects.filter(course=course_module.course)
    if not course_module._state.adding:
        course_modules = course_modules.exclude(id=course_module.id)
    aggregated_weight = course_modules.select_for_update().aggregate(sum_weights=Sum('weight'))["sum_weights"]
    total_weight = calculate_sum_of_module_weights(aggregated_weight, course_module.weight)
    modules_to_update = []
    for module in course_modules:
        module.weight_scaled = format_scaled_value(Decimal(module.weight) / total_weight)
        modules_to_update.append(module)
    CourseModule.objects.bulk_update(
        objs=modules_to_update,
        fields=['weight_scaled', 'modified'],
        batch_size=settings.BULK_BATCH_SIZE_DEFAULT,
    )
    course_module.weight_scaled = format_scaled_value(Decimal(course_module.weight) / total_weight)
    course_module.save(update_fields=['weight_scaled', 'modified'])


def calculate_course_score(student: CourseStudent):
    course_progress = Decimal(0)
    module_progresses = StudentModuleProgress.objects.filter(
        course_id=student.course_id, student=student
    ).values('score', 'module_id')
    module_weights = get_course_module_weights(course_id=student.course_id)
    for module_progress in module_progresses:
        weight = module_weights['weights'].get(module_progress['module_id'], Decimal(0))
        course_progress += weight * module_progress['score'] / module_weights['total']

    return min(100, int(course_progress.quantize(
        Decimal('1.'), rounding=ROUND_HALF_UP
    )))


def update_course_progress(student: CourseStudent) -> None:
    CourseStudent.objects.filter(id=student.id).update(modified=now())
    course_progress = StudentCourseProgress.objects.filter(
        course_id=student.course_id, student=student
    ).first()
    course_score = calculate_course_score(student)
    # perf: Надёжней использовать .update_or_create
    if course_progress is None:
        StudentCourseProgress.objects.create(course_id=student.course_id, student=student, score=course_score)
        return
    if course_progress.score != course_score:
        course_progress.score = course_score
        course_progress.save(update_fields=['score', 'modified'])


def calculate_score_scaled(progress: StudentModuleProgress) -> Decimal:
    total_weight = get_course_module_weights(course_id=progress.course_id)['total']
    module_weight = Decimal(progress.module.weight) / total_weight
    score_scaled = Decimal(progress.score) * module_weight
    return format_scaled_value(score_scaled)


def set_score_scaled(progress: StudentModuleProgress) -> None:
    """
    Устанавливает значение абсолютного веса в прогрессе
    """

    score_scaled = calculate_score_scaled(progress)
    setattr(progress, 'score_scaled', score_scaled)


def refresh_course_module_weight_sum_cache(course_id) -> None:
    get_course_module_weights(course_id=course_id, _refresh=True)


@transaction.atomic
def update_module_progress(module: CourseModule, student: CourseStudent, value: int, force: bool = False):
    module_progress = StudentModuleProgress.objects.filter(
        module=module, student=student
    ).select_for_update().first()
    module_score = min(100, max(0, value))
    if module_progress is None:
        StudentModuleProgress.objects.create(
            course=student.course, student=student, module=module, score=module_score
        )
        return
    if module_score > module_progress.score or (force and module_progress.score != module_score):
        module_progress.score = module_score
        module_progress.save(update_fields=['score', 'score_scaled', 'modified'])


@cache_memoize(timeout=settings.COURSE_MODULE_WEIGHT_SUM_CACHE_TIMEOUT)
def get_course_module_weights(course_id):
    module_weights = {}
    total_weight = 0
    modules = CourseModule.objects.filter(course_id=course_id).values('id', 'weight', 'is_active')
    for module in modules:
        weight = 0 if not module['is_active'] else module['weight']
        module_weights[module['id']] = Decimal(weight)
        total_weight += weight
    return {
        'weights': module_weights,
        'total': Decimal(1) if not total_weight else Decimal(total_weight),
    }


@transaction.atomic
def recalculate_student_module_progress(course_id, student: CourseStudent, module_weight_map) -> int:
    """
    Пересчитывает значение score_scaled во всех прогрессах по модулю
    :param course_id: id курса
    :param student: студент
    :param module_weight_map: словарь весов модулей курса
    :return: количество обновлённых записей прогресса по модулю
    """
    module_progresses = StudentModuleProgress.objects.filter(
        course_id=course_id, student=student
    ).select_for_update()
    module_progress_to_update = []

    for progress in module_progresses:
        module_weight = module_weight_map.get('weights', {}).get(progress.module_id, Decimal(0))
        module_weight_scaled = module_weight / module_weight_map.get('total', Decimal(1))
        score_scaled = Decimal(progress.score) * module_weight_scaled
        progress.score_scaled = format_scaled_value(score_scaled)
        module_progress_to_update.append(progress)

    if not module_progress_to_update:
        return 0

    StudentModuleProgress.objects.bulk_update(
        objs=module_progress_to_update,
        fields=['score_scaled'],
        batch_size=settings.BULK_BATCH_SIZE_DEFAULT,
    )
    return len(module_progress_to_update)


def recalculate_all_course_progresses(course_id):
    course_progresses = StudentCourseProgress.objects.filter(
        student__course_id=course_id
    ).select_related('student')
    for course_progress in course_progresses:
        update_course_progress(student=course_progress.student)


@transaction.atomic
def update_course_tags(course: Course, tags: List[str]) -> None:
    if not tags:
        course.tags.clear()
        return

    normalized_tags = {normalize_tag_name(tag): tag for tag in tags}
    existing_tags = Tag.objects.filter(
        normalized_name__in=normalized_tags
    ).values_list('id', 'normalized_name')

    if existing_tags:
        existing_tag_ids, existing_tag_names = list(zip(*existing_tags))
        course.tags.set(existing_tag_ids)
        tag_names_to_create = set(normalized_tags).difference(set(existing_tag_names))
    else:
        tag_names_to_create = set(normalized_tags)

    if tag_names_to_create:
        created_tag_ids = []
        for normalized_tag_name in tag_names_to_create:
            tag = Tag.objects.create(name=normalized_tags[normalized_tag_name])
            created_tag_ids.append(tag.id)
        course.tags.add(*created_tag_ids)
