import torch
import boto3
from torch.utils import data
import numpy as np
from io import StringIO
from urllib.request import urlopen
from bran.libraries import DetectorLibrary, TrainConfigLibrary
import cv2
from bran.shared.logger import log


class S3Dataset(data.Dataset):
    def __init__(self, **kwargs):
        self.bucket = kwargs["bucket"]
        self.game = kwargs["game"]
        self.size = kwargs["size"]
        self.type = kwargs["type"]
        self.prefix = kwargs.get("prefix", "images/{}".format(self.game))
        self.label = kwargs.get("label", None)
        self.s3 = boto3.resource('s3')
        self.s3_bucket = self.s3.Bucket(self.bucket)
        self.s3_client = boto3.client("s3")
        self.keys_map = {}
        self.all_keys = []
        self.detector_klass = self.get_detector_klass()
        self.paginator = self.s3_client.get_paginator("list_objects")
        self.populate_keys()

    def get_detector_klass(self):
        dl = DetectorLibrary()
        detectors = dl.get_items_for_game(self.game)
        print("items = {}".format(detectors))
        for d in detectors:
            split = d.__module__.split(".")
            game = split[-2]
            try:
                type = d.get_type()
            except Exception as e:
                continue
            if game == self.game and type == self.type:
                print("match!!!")
                return d

    def populate_keys(self):
        filter_string = self.prefix
        pages = self.paginator.paginate(
            Bucket=self.bucket, Prefix=filter_string)
        log.info("Populating dataset")
        for page in pages:
            log.info("next page")
            contents = page["Contents"]
            keys = [o["Key"] for o in contents]
            self.all_keys.extend(keys)
            if len(self.all_keys) >= self.size:
                break

        log.info("Finished populating dataset")

    def get_raw_image(self, index):
        key = self.all_keys[index]
        params = {'Bucket': self.bucket, 'Key': key}
        url = self.s3_client.generate_presigned_url(
            ClientMethod='get_object', Params=params)
        resp = urlopen(url)
        image = np.asarray(bytearray(resp.read()), dtype="uint8")
        image = cv2.imdecode(image, cv2.IMREAD_COLOR)
        return (image, key)

    def __len__(self):
        return len(self.all_keys)

    def __getitem__(self, index):
        image = self.get_raw_image(index)
        image = self.detector_klass.preprocess(image)
        image = torch.from_numpy(image).float()

        if not self.label:
            raise Exception("No label, did you forget to pass it in?")
        label_index = labels.index(self.label)
        return (image, torch.tensor(label_index).long())
