import cv2
import math
import numpy as np
import multiprocessing
from bran.shared.logger import log
import tqdm
import os
import boto3
import itertools

# god rez https://www.reddit.com/r/FortniteCompetitive/comments/a9nbfq/if_youre_using_any_res_near_to_this_without_a_dot/
ASPECT_RATIOS = [
    "16x9",  # base
    "16x10",  # standard
    "4x3",  # standard
    "1079x1080",  # god rez (astroh)
    "1153x1080",  # god rez
    "1444x1080",  # god rez
    "950x1080",  # (Dryleaf)
    "1154x1080"  # (72hrs, chap)
]

CACHE_DIR = '/var/bran/gamesense/cache'

s3 = boto3.resource('s3')


def fake_stretch(img, ratio):
    """
        img -> numpy [h, w, c]
        ratio -> "AxB"
    """
    a, b = ratio.split("x")
    height, width = img.shape[0], img.shape[1]
    crop_x = float(a)/float(b) / (width/height)
    crop_x_px = math.floor(width * (1 - crop_x) / 2)
    cropped_image = img[:, crop_x_px: width-crop_x_px]
    return cv2.resize(cropped_image, (width, height))


def stretch_one(record):
    bucket = record['bucket']
    key = record['key']
    label = record['label']

    image_bytes = s3.Object(bucket, key).get()['Body'].read()
    image = cv2.imdecode(np.frombuffer(
        image_bytes, np.uint8), cv2.IMREAD_COLOR)

    if image is None:
        log.warning("Broken image {}".format(record))
        return []

    stretched = []
    for ar in ASPECT_RATIOS:
        new_key = '{}.{}.jpg'.format(os.path.splitext(key)[0], ar)
        upload_path = os.path.join(CACHE_DIR, bucket, new_key)
        if not os.path.exists(upload_path):
            img = fake_stretch(image, ar)
            os.makedirs(os.path.dirname(upload_path), exist_ok=True)
            cv2.imwrite(upload_path, img, [cv2.IMWRITE_JPEG_QUALITY, 80])
        stretched.append({"bucket": bucket, "key": new_key,
                          "label": label, "s3_key": key})
    return stretched


def stretch_one_safe(record):
    try:
        return stretch_one(record)
    except KeyboardInterrupt:
        pass
    except cv2.error:
        log.warning("Broken image {}".format(record))
    except:
        log.warning("Error processing record: {}".format(
            record), exc_info=True)
    return []


def stretch_images(records):
    with multiprocessing.Pool(3*multiprocessing.cpu_count()) as pool:
        results = tqdm.tqdm(
            pool.imap(stretch_one_safe, records), total=len(records))
        return list(itertools.chain.from_iterable(results))
