from urllib.parse import parse_qs

from bson.objectid import ObjectId
import pymongo

from yeller.classes.mongo import MongoHeler
from yeller.logger import log
from bebo import dispatcher
from bebo import utils as bebo_utils

mongo_helper = MongoHeler()


class DbResource:
    def on_get(self, req, resp, db, collection):
        """ get images that havne't been classified """
        """ can be filtered by model"""
        """ limit passed in for # of total images"""
        qs = parse_qs(req.query_string)
        qs['limit'] = qs.get('limit', 20)
        qs['offset'] = qs.get('offset', 0)

        log.info("got qs {}".format(qs))
        log.info("got db {} collection {}".format(db, collection))

        query = mongo_helper.prepare_query(qs)
        log.info("got query {}".format(query))

        result = list(mongo_helper.get_collection(db, collection).find(query).sort(
            "created_at", pymongo.DESCENDING).skip(qs.get('offset')).limit(qs.get('limit')))
        resp.body = mongo_helper.serialize_result(result)

    def on_post(self, req, resp, db, collection):
        mongo_helper.prepare_for_write(req.media)
        b = mongo_helper.get_collection(db, collection).insert_one(req.media)
        resp.body = mongo_helper.serialize_result([{"_id": b.inserted_id}])

    def on_put(self, req, resp, db, collection):
        body = req.media
        id_for_tracking = body["id"]
        bson_id = ObjectId(body.get('id'))
        mongo_helper.prepare_for_write(req.media, set_created=False)
        query = {"_id": bson_id}
        update = {"$set": {'label': req.media['label']}}
        if 'extra' in req.media:
            update['$set']['extra'] = req.media['extra']
        log.info("query {}, update {}".format(query, update))
        mongo_helper.get_collection(db, collection).update_one(
            query, update, upsert=True)

        if db == 'relabel' and req.media['label'] != req.media.get('old_label'):
            log.info('relabelling {}'.format(req.media))

            key = req.media['key']
            rdb = key.split('/')[0]
            rcol = key.split('/')[1]

            res = mongo_helper.client[rdb][rcol].update_one({
                'key': req.media['key']
            }, update)

            if res.modified_count:
                log.info('relabel modified {}'.format(req.media))

        try:
            collection_split = collection.split("_")
            game = collection_split[0]
            event_type = collection_split[1]
            original_label = collection_split[2]
            label = body["label"]
            correct = original_label == label
            track_payload = {
                "image_id": id_for_tracking,
                "game": game,
                "type": event_type,
                "original_label": original_label,
                "classified_label": label,
                "correct": correct
            }
            es_payload = bebo_utils.convert_to_es(track_payload)
            dispatcher.write("yeller_es", es_payload)
        except Exception:
            log.exception("error tracking classification")

    def on_delete(self, req, resp, db, collection):
        qs = req.media
        bson_id = ObjectId(qs.get('id'))
        result = mongo_helper.get_collection(
            db, collection).delete_one({"_id": bson_id})
        resp.body = mongo_helper.serialize_result(
            [{"deleted_count": result.deleted_count}])


class DbQuery:
    def on_post(self, req, resp, db, collection):
        qs = req.media['qs']
        limit = int(req.media.get('limit', 20))
        offset = int(req.media.get('offset', 0))

        log.info("got qs {}".format(qs))
        log.info("got db {} collection {}".format(db, collection))

        query = mongo_helper.prepare_query(qs)
        log.info("got query {}".format(query))

        if req.media.get('random', False):
            result = list(mongo_helper.get_collection(db, collection).aggregate([
                {'$match': query},
                {'$sample': {'size': limit}}
            ]))
        else:
            result = list(mongo_helper.get_collection(db, collection).find(query).sort(
                "created_at", pymongo.DESCENDING).skip(offset).limit(limit))
        resp.body = mongo_helper.serialize_result(result)


class DbCount:
    def on_post(self, req, resp, db, collection):
        qs = req.media['qs']

        query = mongo_helper.prepare_query(qs)

        result = mongo_helper.get_collection(
            db, collection).count_documents(query)

        resp.media = {'result': result}
