import getpass
import logging
from contextlib import closing

import psycopg2
from psycopg2.extras import DictCursor
from psycopg2.sql import SQL, Identifier
from tqdm import tqdm

from django.contrib.auth import get_user_model
from django.core.management.base import BaseCommand
from django.db import models, transaction

from mentor.contrib.staff.client import staff_api
from mentor.mentorships.models import Mentor, Mentorship
from mentor.staff.models import StaffProfile
from mentor.staff.services import load_staff_profile

User = get_user_model()

STATUS_MAP = {
    0: Mentorship.Status.CREATED,
    1: Mentorship.Status.ACCEPTED,
    2: Mentorship.Status.DECLINED,
    3: Mentorship.Status.COMPLETED,
}


class DryRunException(Exception):
    pass


class Command(BaseCommand):
    def create_or_update_model(self, model, **kwargs):
        method = "get_or_create" if self.new_only else "update_or_create"
        model_method = getattr(model.objects, method)

        return model_method(**kwargs)

    def load_uids_chunk(self, chunk):
        logins = [
            row["staff_login"]
            for row in chunk
            if row["staff_login"] not in self.login_to_uid_map
        ]

        if len(logins) == 0:
            return

        params = {
            "_limit": len(logins),
            "_fields": "login,uid",
            "login": ",".join(logins),
        }

        response = staff_api.persons.get(**params)

        for person in response["result"]:
            uid = person["uid"]
            login = person["login"]

            self.login_to_uid_map[login] = uid

    def load_uids(self, conn, table_name):
        self.stdout.write(f"Loading uids for users from {table_name} table")

        query = SQL("SELECT staff_login FROM {}").format(Identifier(table_name))

        with conn.cursor(cursor_factory=DictCursor) as cursor:
            cursor.execute(query)

            with tqdm(total=cursor.rowcount) as pbar:
                chunk = cursor.fetchmany(self.users_chunk_size)
                while len(chunk) > 0:
                    self.load_uids_chunk(chunk)
                    pbar.update(self.users_chunk_size)

                    chunk = cursor.fetchmany(self.users_chunk_size)

    def import_users(self, conn, table_name):
        self.load_uids(conn, table_name)

        self.stdout.write(f"Importing users from {table_name} table")

        users_query = SQL("SELECT * FROM {}").format(Identifier(table_name))

        with conn.cursor(cursor_factory=DictCursor) as cursor:
            cursor.execute(users_query)

            for row in tqdm(cursor, total=cursor.rowcount):
                original_user_id = row["id"]

                if original_user_id in self.user_map:
                    continue

                login = row["staff_login"]

                user, created = self.create_or_update_model(
                    User,
                    username=login,
                    defaults={
                        "yauid": self.login_to_uid_map[login],
                        "email": row["email"],
                        "first_name": row["name"],
                        "last_name": row["surname"],
                    },
                )

                if created:
                    self.created_users += 1
                elif not self.new_only:
                    self.updated_users += 1

                self.user_map[original_user_id] = user
                self.original_user_map[original_user_id] = row

    def import_mentors(self, conn):
        self.stdout.write("Importing mentors from mentors_targeting table")

        mentors_query = "SELECT * FROM mentors_targeting"

        with conn.cursor(cursor_factory=DictCursor) as cursor:
            cursor.execute(mentors_query)

            for row in tqdm(cursor, total=cursor.rowcount):
                original_user_id = row["id"]
                user = self.user_map[original_user_id]

                StaffProfile.objects.create(user=user)
                if not load_staff_profile(user.pk):
                    raise ValueError(
                        f"Staff profile wasn't loaded. User id from old db {original_user_id}"
                    )

                mentor, created = self.create_or_update_model(
                    Mentor,
                    user=user,
                    defaults={
                        "description": row["description"],
                        "carrier_begin": user.staff_profile.joined_at,
                    },
                )

                if created:
                    self.created_mentors += 1
                elif not self.new_only:
                    self.updated_mentors += 1

                self.mentor_map[original_user_id] = mentor

    def import_mentorships(self, conn):
        self.stdout.write("Importing mentorships from mentoring_requests table")

        mentorships_query = "SELECT * FROM mentoring_requests"

        with conn.cursor(cursor_factory=DictCursor) as cursor:
            cursor.execute(mentorships_query)

            for row in tqdm(cursor, total=cursor.rowcount):
                original_mentorship_id = row["id"]
                original_mentor_id = row["mentor_id"]
                original_mentee_id = row["menty_id"]
                original_status = row["request_state"]

                if not self.original_user_map[original_mentor_id]["is_mentor"]:
                    continue

                status_by = self.mentor_map[original_mentor_id].user
                if STATUS_MAP[original_status] == Mentorship.Status.CREATED:
                    status_by = self.user_map[original_mentee_id]

                mentorship = Mentorship(
                    mentee=self.user_map[original_mentee_id],
                    mentor=self.mentor_map[original_mentor_id],
                    intro=row["menty_msg"],
                    status=STATUS_MAP[original_status],
                    status_by=status_by,
                    status_message=row["mentor_answer"] or "",
                    removed_by_mentor=row["deleted_by_mentor"],
                    removed_by_mentee=row["deleted_by_menty"],
                )

                mentorship.save()

                self.created_mentorships += 1

                self.mentorship_map[original_mentorship_id] = mentorship

    def import_old_mentor_db(self, options):
        # log only if level is more severe than or equal to WARNING
        logging.disable(logging.INFO)

        # Disable User's post_save signal to prevent updating StaffProfile
        models.signals.post_save.disconnect(
            sender=User, dispatch_uid="user_post_save_handler"
        )

        self.new_only = options.get("new_only")
        self.users_chunk_size = options.get("users_chunk_size")
        self.no_mentorships = options.get("no_mentorships")

        self.login_to_uid_map = {}
        self.user_map = {}
        self.original_user_map = {}
        self.mentor_map = {}
        self.mentorship_map = {}

        self.created_users = 0
        self.updated_users = 0
        self.created_mentors = 0
        self.updated_mentors = 0
        self.created_mentorships = 0

        with closing(
            psycopg2.connect(
                dbname=options["name"],
                user=options["login"],
                password=options["password"],
                host=options["host"],
                port=options["port"],
                sslmode=options["sslmode"],
            ),
        ) as conn:
            try:
                with transaction.atomic():
                    self.import_users(conn, table_name="users")
                    self.import_users(conn, table_name="users_backup")
                    self.import_mentors(conn)

                    if not self.no_mentorships:
                        self.import_mentorships(conn)

                    if options.get("dry_run", False):
                        raise DryRunException

            except DryRunException:
                self.stdout.write("Rollback")

            else:
                self.stdout.write("Commit")

            self.stdout.write(
                f"USERS created: {self.created_users}, updated: {self.updated_users}"
            )
            self.stdout.write(
                f"MENTORS created: {self.created_mentors}, updated: {self.updated_mentors}"
            )
            self.stdout.write(f"MENTORSHIPS created: {self.created_mentorships}")

    def add_arguments(self, parser):
        parser.add_argument(
            "--host",
            dest="host",
            type=str,
            required=True,
            help="Old mentor db host",
        )
        parser.add_argument(
            "--name",
            "-n",
            dest="name",
            type=str,
            required=True,
            help="Old mentor db name",
        )
        parser.add_argument(
            "--login",
            "-l",
            dest="login",
            type=str,
            required=True,
            help="Old mentor db user login",
        )
        parser.add_argument(
            "--password",
            "-w",
            dest="password",
            type=str,
            help="Old mentor db user password (stay empty to enter in protection mode)",
        )
        parser.add_argument(
            "--port",
            "-p",
            dest="port",
            type=int,
            default=5432,
            help="Old mentor db port",
        )
        parser.add_argument(
            "--new-only",
            action="store_true",
            dest="new_only",
            default=False,
            help="Import only new data",
        )
        parser.add_argument(
            "--no-mentorships",
            action="store_true",
            dest="no_mentorships",
            default=False,
            help="Don't import mentorships",
        )
        parser.add_argument(
            "--dry-run",
            action="store_true",
            dest="dry_run",
            help="Run without commit to db",
        )
        parser.add_argument(
            "--sslmode",
            dest="sslmode",
            type=str,
            default="disable",
            choices=[
                "disable",
                "allow",
                "prefer",
                "require",
                "verify-ca",
                "verify-full",
            ],
            help="SSL mode",
        )
        parser.add_argument(
            "--users-chunk-size",
            dest="users_chunk_size",
            type=int,
            default=100,
            help="Max users per staff-api request",
        )

    def handle(self, *args, **options):
        if options.get("password") is None:
            options["password"] = getpass.getpass()

        self.import_old_mentor_db(options)
