import os
import math
import itertools
import functools
import uuid
import collections
from random import shuffle
from urllib.request import urlopen
import multiprocessing
from torch.utils import data
import boto3
import cv2
import numpy as np
import pymongo
import torch
import tqdm
from bran.shared.config import config
from bran.shared.logger import log
import bran.shared.image_utils as ImageUtils
from bran.train.stretch import stretch_images
from bran.train.mongo_dataset import MalformedError

s3 = boto3.resource('s3')

CACHE_DIR = '/var/bran/gamesense/cache'


class MongoCollection(data.Dataset):
    def __init__(self, **kwargs):
        self.query = kwargs['query']
        self.db = kwargs.get('database', 'dataset')
        self.collection = kwargs["collection"]
        self.stretch = kwargs.get("stretch", False)
        self.transform = kwargs.get("transform", lambda x: x)
        self.target_transform = kwargs.get(
            "target_transform", lambda record: record['label'])

        self.db = pymongo.MongoClient(config.MONGO_DB)[
            self.db][self.collection]
        self.current_key = None
        self.current_image = None
        self.classes_count = collections.defaultdict(int)
        self.images = None
        self.last_good_index = None
        self.populate_images()

    def populate_images(self):
        query = self.query
        total = self.db.count_documents(query)
        log.info("Loading {} Images".format(total))

        self.images = list(self.db.find(query))

        if self.stretch:
            self.images = stretch_images(self.images)

        for i, r in enumerate(self.images):
            self.classes_count[r['label']] += 1
            self.images[i]['bucket'] = r.get('bucket', r.get('bucket_name'))

        log.info("Total: {}, Class Distribution: {}".format(
            len(self.images), self.classes_count))

    def get_image(self, bucket, key):
        cache_path = os.path.join(CACHE_DIR, bucket, key)
        if not os.path.exists(cache_path):
            os.makedirs(os.path.dirname(cache_path), exist_ok=True)
            s3.Bucket(bucket).download_file(key, cache_path)
        return cv2.imread(cache_path, cv2.IMREAD_COLOR)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        record = self.images[index]
        bucket = record["bucket"]
        key = record["key"]

        try:
            image = self.get_image(bucket, key)
            if image is None:
                raise MalformedError('malformed rip')
        except Exception as e:
            log.error("Failed to load image: {}, reason: {}".format(record, e))
            if self.last_good_index:
                return self[self.last_good_index]
            else:
                return self[index + 1]

        image_tensor = self.transform(image)
        label_tensor = self.target_transform(record)

        self.last_good_index = index
        return image_tensor, label_tensor
