import logging
import re
from collections import Counter
from functools import partial
from multiprocessing.pool import ThreadPool

from celery.exceptions import SoftTimeLimitExceeded
from django.conf import settings
from django.contrib.postgres.aggregates import ArrayAgg
from django.db import transaction
from django.db.models import Count
from django.utils import timezone
from more_itertools import chunked
from requests.exceptions import ConnectionError, HTTPError

from config import celery_app as app
from photoalbum.contrib.deface import detect_face
from photoalbum.contrib.geocoder import geocoder
from photoalbum.contrib.yadisk import yadisk
from photoalbum.files.constants import EXCLUSIONS, MB
from photoalbum.files.models import (
    MODEL_FOR_MEDIA_TYPE,
    Folder,
    Location,
    MediaFile,
    Photo,
    PhotoDuplicate,
    Tag,
    UserOnPhoto,
    Video,
    create_tags,
)
from photoalbum.files.services import (
    delete_empty_folders,
    delete_old_media,
    download_files,
    upload_file_or_folder,
)

logger = logging.getLogger(__name__)


@app.task()
def get_metadata(*args, **kwargs):
    tags = dict(Tag.objects.values_list("name", "pk"))
    params = {
        "limit": kwargs.get("limit", 1000),
        "offset": kwargs.get("offset", 0),
    }
    force = kwargs.get("force")
    not_visited = [(None, yadisk.init_folder + kwargs.get("init_path", ""))]
    counter = Counter()

    logger.info("Started loading metadata")
    while not_visited:
        current_folder, params["path"] = not_visited.pop()
        for response in yadisk.get_all_resources(**params):
            logger.info("{path}: ".format(**params))
            data = response.get("_embedded", dict()).get("items", [])
            if not data:
                continue
            for item in data:
                size = item.get("size")
                filename = item.get("name")
                path = item.get("path")
                creation_date = item.get("created", timezone.now())
                if size == 0 or any(
                    [re.match(exclusion, filename) for exclusion in EXCLUSIONS]
                ):
                    continue

                resource_id = item.get("resource_id")
                if item.get("type") == "dir":
                    folder, created = Folder.objects.update_or_create(
                        resource_id=resource_id,
                        defaults=dict(
                            name=filename,
                            path=path,
                            parent=current_folder,
                        ),
                    )
                    not_visited.append((folder, path))
                    continue

                media_type = item.get("media_type")
                fields = dict(
                    filename=filename,
                    path=path,
                    size=size,
                    chksum=item.get("sha256"),
                    media_type=media_type,
                    mime_type=item.get("mime_type"),
                    created=creation_date,
                    modified=timezone.now(),
                    folder=current_folder,
                )

                exif_info = item.get("exif", dict())
                shot_date = yadisk.get_date_of_shot(item)
                if media_type == "image":
                    geo_coords = (
                        exif_info.get("gps_longitude"),
                        exif_info.get("gps_latitude"),
                    )
                    fields.update(
                        geo_coords=str(geo_coords) if all(geo_coords) else "",
                        photo_date=shot_date,
                        exif=exif_info,
                    )

                elif media_type == "video":
                    fields.update(
                        video_date=shot_date,
                        exif=exif_info,
                    )

                elif media_type in settings.YADISK_EXCLUDED_MEDIATYPES:
                    continue

                model = MODEL_FOR_MEDIA_TYPE.get(media_type, MediaFile)
                media_object, created = model.objects.get_or_create(
                    resource_id=resource_id,
                    defaults=fields,
                )

                if created:
                    new_tags = create_tags(media_object.path, tags)
                    if new_tags:
                        media_object.tags.set(new_tags)
                    counter[f"{media_type}_created"] += 1

                else:
                    fields = (
                        fields
                        if force
                        else dict(filename=filename, path=path, folder=current_folder)
                    )

                    for field, value in fields.items():
                        setattr(media_object, field, value)
                    media_object.save()

                    counter[f"{media_type}_updated"] += 1

    logger.info(dict(counter))
    logger.info("Finished loading metadata")


def split_into_batches(
    media_type=None, batch_size=settings.YADISK_UPLOAD_BATCH_SIZE, **kwargs
):
    models = [MODEL_FOR_MEDIA_TYPE.get(media_type)]
    filter_by = dict(
        file="",
        size__lt=kwargs.get("max_size", 100000) * MB,
        size__gt=kwargs.get("min_size", 0) * MB,
    )

    if not media_type:
        models = [Photo, Video, MediaFile]
    for model in models:
        media_ids = model.objects.filter(**filter_by).values_list("pk", flat=True)
        for batch in chunked(media_ids, batch_size):
            kwargs.update(media_ids=batch)
            download_batch.delay(**kwargs)


@app.task()
def download_batch(*args, **kwargs):
    try:
        download_files(*args, **kwargs)
    except SoftTimeLimitExceeded as exc:
        logger.info(exc)


@app.task()
def sync_files(*args, **kwargs):
    get_metadata(**kwargs)
    delete_old_files()
    download_batch.delay(preview=True, **kwargs)
    split_into_batches(media_type="image", batch_size=1000, threads=4, **kwargs)
    # загрузка каждого батча должна укладываться в 2 часа
    # 100 видео размера от 1 до 2гб загружаются чуть меньше, чем за 2 часа
    # для видео большего размера уменьшаем размер батча пропорционально
    split_into_batches(
        media_type="video", batch_size=150, threads=4, max_size=1000, **kwargs
    )
    split_into_batches(
        media_type="video",
        batch_size=100,
        threads=3,
        min_size=1000,
        max_size=2000,
        **kwargs,
    )
    split_into_batches(
        media_type="video",
        batch_size=50,
        threads=2,
        min_size=2000,
        max_size=4000,
        **kwargs,
    )
    split_into_batches(
        media_type="video",
        batch_size=20,
        threads=1,
        min_size=4000,
        max_size=8000,
        **kwargs,
    )
    split_into_batches(
        media_type="video",
        batch_size=10,
        threads=1,
        min_size=8000,
        max_size=16000,
        **kwargs,
    )


@app.task()
def get_duplicates(*args, **kwargs):
    for grouped_photos in (
        Photo.objects.values("chksum")
        .annotate(item_ids=ArrayAgg("id"), num_duplicates=Count("id"))
        .filter(num_duplicates__gte=2)
    ):
        item_ids = grouped_photos["item_ids"]
        photo_duplicate, _ = PhotoDuplicate.objects.update_or_create(
            chksum=grouped_photos["chksum"],
            defaults=dict(original_id=item_ids.pop()),
        )
        photo_duplicate.duplicates.set(item_ids)


@app.task()
def delete_old_files(*args, **kwargs):
    delete_old_media(Folder)
    delete_old_media(MediaFile, media_type__in=["image", "video"])
    delete_empty_folders()


@app.task()
def get_location(photo_id, *args, **kwargs):
    photo = Photo.objects.filter(id=photo_id).first()
    if not photo:
        return
    country, city, address = geocoder.address(
        *photo.geo_coords[1:-1].split(", ")
    ).split(", ", 2)
    location, created = Location.objects.get_or_create(
        geo_coords=photo.geo_coords,
        defaults=dict(
            country=country,
            city=city,
            address=address,
        ),
    )
    photo.location = location
    photo.save()


@app.task(bind=True, max_retries=3)
def get_user_on_photo(self, photo_id, *args, **kwargs):
    photo = Photo.objects.filter(id=photo_id, detected=False).exclude(file="").first()
    if not photo:
        return
    logger.info(f"Detecting users in photo {photo.path}")
    try:
        data = detect_face.detect(photo.file.url)
        photo.detected = True
        photo.save()
    except (ConnectionError, HTTPError) as exc:
        logger.error(exc)
        raise self.retry(exc=exc)
    number_of_users = 0
    logins = set()
    with transaction.atomic():
        for candidate in data.get("candidates", []):
            login = candidate.get("staff")
            if login and login not in logins:
                number_of_users += 1
                logins.add(login)
                UserOnPhoto.objects.get_or_create(
                    photo_id=photo_id,
                    login=login,
                    match_score=candidate.get("score"),
                    staff_image_url=candidate.get("image_url"),
                )
    logger.info(f"Detected {number_of_users} candidates in photo {photo.path}")


@app.task()
def upload_files(*args, **kwargs):
    folders = Folder.objects.filter(path=kwargs["path"]).get_descendants(
        include_self=True
    )

    files = (
        MediaFile.objects.filter(
            folder__in=Folder.objects.filter(path=kwargs["path"])
            .get_descendants(include_self=True)
            .values_list("pk", flat=True)
        )
        .filter(
            size__lt=kwargs.get("max_size", 100000) * MB,
            size__gt=kwargs.get("min_size", 0) * MB,
        )
        .exclude(file="")
    )

    for folder in folders:
        upload_file_or_folder(folder, loader=yadisk.create_folder)

    pool = ThreadPool(kwargs["threads"])
    pool.map(partial(upload_file_or_folder, loader=yadisk.upload_file), files)
    pool.close()
    pool.join()
