import os
import math
import itertools
import functools
import uuid
import collections
from random import shuffle
from urllib.request import urlopen
import multiprocessing
from torch.utils import data
import boto3
import cv2
import numpy as np
import pymongo
import torch
import tqdm
from bran.shared.config import config
from bran.shared.logger import log
import bran.shared.image_utils as ImageUtils
from bran.train.stretch import stretch_images

s3 = boto3.resource('s3')

CACHE_DIR = '/var/bran/gamesense/cache'


class MalformedError(Exception):
    pass


class MongoDataset(data.Dataset):
    def __init__(self, **kwargs):
        self.collection = kwargs["collection"]
        self.split_type = kwargs["split_type"]
        self.labels = kwargs["labels"]
        self.stretch = kwargs.get("stretch", False)
        self.transform = kwargs.get("transform", lambda x: x)

        self.db = pymongo.MongoClient(config.MONGO_DB)[
            'dataset'][self.collection]
        self.current_key = None
        self.current_image = None
        self.classes_count = collections.defaultdict(int)
        self.images = None
        self.last_good_index = None
        self.populate_images()

    def populate_images(self):
        query = {"label": {"$in": self.labels}, "split_type": self.split_type}
        total = self.db.count_documents(query)
        log.info("Loading {} Images".format(total))

        self.images = list(self.db.find(query))

        if self.stretch:
            self.images = stretch_images(self.images)

        for i in self.images:
            self.classes_count[i['label']] += 1
        log.info("Total: {}, Class Distribution: {}".format(
            len(self.images), self.classes_count))

    def get_class_weight(self):
        return [self.classes_count[label] for label in self.labels]

    def get_image(self, bucket, key):
        cache_path = os.path.join(CACHE_DIR, bucket, key)
        if not os.path.exists(cache_path):
            os.makedirs(os.path.dirname(cache_path), exist_ok=True)
            s3.Bucket(bucket).download_file(key, cache_path)
        return cv2.imread(cache_path, cv2.IMREAD_COLOR)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        record = self.images[index]
        bucket = record["bucket"]
        key = record["key"]
        label = record["label"]

        try:
            image = self.get_image(bucket, key)
            if image is None:
                raise MalformedError('malformed rip')
        except Exception as e:
            log.error("Failed to load image: {}, reason: {}".format(record, e))
            if self.last_good_index:
                return self[self.last_good_index]
            else:
                return self[index + 1]

        image_tensor = self.transform(image)
        label_tensor = torch.tensor(self.labels.index(label)).long()

        self.last_good_index = index
        return image_tensor, label_tensor, index
