import csv
import json

from django_bulk_update.helper import bulk_update
from tqdm import tqdm

from django.core.management.base import BaseCommand
from django.db import transaction

from kelvin.accounts.models import User
from kelvin.courses.models import CourseLessonLink, CourseStudent
from kelvin.result_stats.tasks import (
    calculate_course_journal, calculate_lesson_journal, recalculate_courselessonstat, recalculate_studentcoursestat,
)

from ...models import CourseLessonResult, CourseLessonSummary

BULK_BATCH_SIZE_DEFAULT = 1000


class DryRunException(Exception):
    pass


class Command(BaseCommand):
    def create_results(self, options):
        course_lesson_link_id = options['course_lesson_link_id']
        students_logins = options['students_logins']
        answers = json.loads(options['answers'])
        points = options['points']
        max_points = options['max_points']
        logfile = options['logfile']
        progress = options.get('progress', False)

        try:
            course_lesson_link = CourseLessonLink.objects.get(id=course_lesson_link_id)
        except CourseLessonLink.DoesNotExist:
            self.stderr.write(f'Course lesson {course_lesson_link_id} does not exist')
            return

        student_login_id_map = {
            user[0]: user[1] for user in User.objects.filter(username__in=students_logins).values_list('username', 'id')
        }
        if set(student_login_id_map) != set(students_logins):
            not_found_logins = ', '.join(map(str, set(students_logins) - set(student_login_id_map)))
            self.stderr.write(f'Users {not_found_logins} not found')
            return

        summaries_to_create = []
        results_to_create = []
        results_to_update = []
        course_students_to_create = []

        try:
            with transaction.atomic():
                student_summary_map = {
                    summary.student_id: summary for summary
                    in CourseLessonSummary.objects.filter(clesson=course_lesson_link)
                }
                student_result_map = {
                    clr.summary.student_id: clr for clr
                    in CourseLessonResult.objects.select_related('summary').filter(summary__clesson=course_lesson_link)
                }

                existing_course_students = {
                    (course_student[0], course_student[1])
                    for course_student
                    in (
                        CourseStudent.objects
                        .filter(course_id=course_lesson_link.course_id)
                        .values_list('course_id', 'student_id')
                    )
                }

                with open(logfile, 'w') as csv_file:
                    fieldnames = [
                        'result_action',
                        'summary_action',
                        'student_id',
                        'student_login',
                        'result_id',
                        'summary_id',
                    ]

                    csv_writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
                    csv_writer.writeheader()

                    self.stdout.write('Process results:')
                    if progress:
                        students_logins = tqdm(students_logins)
                    for students_login in students_logins:
                        student_id = student_login_id_map[students_login]

                        course_student = (course_lesson_link.course_id, student_id)
                        if course_student not in existing_course_students:
                            course_students_to_create.append(course_student)

                        log_row = {
                            'student_id': student_id,
                            'student_login': students_login,
                        }

                        if student_id in student_summary_map:
                            summary = student_summary_map[student_id]

                            log_row['summary_action'] = 'none'
                            log_row['summary_id'] = summary.id
                        else:
                            summary = CourseLessonSummary(
                                student_id=student_id,
                                clesson_id=course_lesson_link_id,
                                lesson_finished=True,
                            )
                            summaries_to_create.append(summary)

                            log_row['summary_action'] = 'create'

                        if student_id in student_result_map:
                            clr = student_result_map[student_id]
                            clr.points = points
                            clr.max_points = max_points
                            clr.answers = answers

                            log_row['result_action'] = 'update'
                            log_row['result_id'] = clr.id

                            clr._summary = summary
                            clr._log_row = log_row
                            results_to_update.append(clr)
                        else:
                            clr = CourseLessonResult(
                                points=points,
                                max_points=max_points,
                                answers=answers,
                            )

                            log_row['result_action'] = 'create'

                            clr._summary = summary
                            clr._log_row = log_row
                            results_to_create.append(clr)

                    if summaries_to_create:
                        CourseLessonSummary.objects.bulk_create(
                            objs=summaries_to_create,
                            batch_size=BULK_BATCH_SIZE_DEFAULT,
                        )

                    if results_to_create:
                        self.stdout.write('Create results:')

                        for result_to_create in (tqdm(results_to_create) if progress else results_to_create):
                            result_to_create.summary = result_to_create._summary
                            result_to_create._log_row['summary_id'] = result_to_create.summary.id

                        CourseLessonResult.objects.bulk_create(
                            objs=results_to_create,
                            batch_size=BULK_BATCH_SIZE_DEFAULT,
                        )

                        for result_to_create in (tqdm(results_to_create) if progress else results_to_create):
                            result_to_create._log_row['result_id'] = result_to_create.id
                            csv_writer.writerow(result_to_create._log_row)

                    self.stdout.write('Update results:')
                    if results_to_update:
                        bulk_update(
                            results_to_update,
                            update_fields=[
                                'answers', 'points', 'max_points',
                            ],
                            batch_size=BULK_BATCH_SIZE_DEFAULT,
                        )
                        for result_to_update in (tqdm(results_to_update) if progress else results_to_update):
                            csv_writer.writerow(result_to_update._log_row)

                    self.stdout.write('Create course students:')
                    if course_students_to_create:
                        CourseStudent.objects.bulk_create(
                            objs=(
                                CourseStudent(
                                    course_id=course_student[0],
                                    student_id=course_student[1],
                                ) for course_student in course_students_to_create
                            ),
                            batch_size=BULK_BATCH_SIZE_DEFAULT,
                        )


                calculate_lesson_journal.delay(clesson_id=course_lesson_link_id)
                calculate_course_journal.delay(course_id=course_lesson_link.course_id)
                recalculate_courselessonstat.delay(clesson_id=course_lesson_link_id)

                self.stdout.write('Update student journal:')
                students_ids_to_update = list(student_login_id_map.values())
                for student_id in (tqdm(students_ids_to_update) if progress else students_ids_to_update):
                    recalculate_studentcoursestat.delay(course_id=course_lesson_link.course_id, student_id=student_id)

                if options.get('dry_run', False):
                    raise DryRunException

                self.stdout.write('Commiting...')

        except DryRunException:
            self.stdout.write('Rollback')
        else:
            self.stdout.write('Commit')

    def add_arguments(self, parser):
        parser.add_argument(
            '--dry-run', action='store_true', dest='dry_run',
            help='Run without commit to db',
        )
        parser.add_argument(
            '--progress', action='store_true', dest='progress',
            help='Show progress',
        )
        parser.add_argument(
            '--cll', dest='course_lesson_link_id', required=True,
            help='CourseLessonLink id',
        )
        parser.add_argument(
            '--students', dest='students_logins', required=True, nargs='*',
            help='Students space-separated logins',
        )
        parser.add_argument(
            '--answers', dest='answers', required=True,
            help='Answers to set (json)',
        )
        parser.add_argument(
            '--points', dest='points', required=True, type=int,
            help='Points to set',
        )
        parser.add_argument(
            '--max-points', dest='max_points', required=True,
            help='Max points to set',
        )
        parser.add_argument(
            '--file', dest='logfile', required=True,
            help='Logfile (csv) for result movement',
        )

    def handle(self, *args, **options):

        self.create_results(options)

        self.stdout.write("Done\n")
