import os
import uuid
import collections
from random import shuffle
from urllib.request import urlopen
import multiprocessing
import tempfile
import itertools
import math

from torch.utils import data
import boto3
import botocore
import cv2
import numpy as np
import pymongo
import torch
import tqdm

from bran.shared.config import config
from bran.shared.logger import log


class MongoDatasetCreator(object):
    def __init__(self, **kwargs):
        self.dataset = kwargs["dataset"]
        self.output_dataset = kwargs.get('output_dataset', self.dataset)
        self.labels = kwargs['labels']
        self.val_split = kwargs.get('val_split', 0.1)
        self.test_split = kwargs.get('test_split', 0.1)
        self.client = pymongo.MongoClient(config.MONGO_DB)

    def get_labelled_data(self, query):
        q = {"label": {"$in": self.labels}}
        q.update(query)
        labeled_data = []
        for db in ['pre_classify']:  # , 'upload']:
            for collection in self.client[db].collection_names():
                if not collection.startswith(self.dataset):
                    continue

                mongo = self.client[db][collection]
                log.info("Loading {} documents from {}/{}".format(
                    mongo.count_documents(q),
                    db, collection))

                for doc in mongo.find(q):
                    labeled_data.append({
                        "bucket": doc['bucket_name'],
                        "key": doc['key'],
                        "label": doc['label'],
                        "extra": doc.get('extra')
                    })

        return labeled_data

    def populate(self, query={}, size=None):
        log.info("MongoDB: {}".format(config.MONGO_DB))
        log.info("New collection: {}".format(self.output_dataset))
        mongo = self.client["dataset"][self.output_dataset]

        all_labeled_data = self.get_labelled_data(query)
        shuffle(all_labeled_data)

        if size and size > 0:
            all_labeled_data = all_labeled_data[:size]

        val_len = int(len(all_labeled_data) * self.val_split)
        test_len = int(len(all_labeled_data) * self.test_split)
        train_len = len(all_labeled_data) - val_len - test_len

        train_data = all_labeled_data[:train_len]
        val_data = all_labeled_data[train_len: train_len + val_len]
        test_data = all_labeled_data[train_len + val_len:]

        log.info("Train set: {}, Val set: {}, Test set: {}".format(
            len(train_data), len(val_data), len(test_data)))

        mongo.drop()
        for r in train_data:
            r['split_type'] = 'train'
        mongo.insert_many(train_data)

        for r in val_data:
            r['split_type'] = 'val'
        mongo.insert_many(val_data)

        for r in test_data:
            r['split_type'] = 'test'
        mongo.insert_many(test_data)

        log.info("Collection {} populated".format(self.output_dataset))
