import functools
import time

import bson
import pymongo.collection
import pymongo.cursor


__all__ = ("install_tracker", "register_event_handler")


class OriginalMethods(object):
    insert = pymongo.collection.Collection.insert
    update = pymongo.collection.Collection.update
    remove = pymongo.collection.Collection.remove
    _count = pymongo.collection.Collection._count
    _refresh = pymongo.cursor.Cursor._refresh


_event_handlers = []


def install_tracker():
    if pymongo.collection.Collection.insert != _insert:
        pymongo.collection.Collection.insert = _insert
    if pymongo.collection.Collection.update != _update:
        pymongo.collection.Collection.update = _update
    if pymongo.collection.Collection.remove != _remove:
        pymongo.collection.Collection.remove = _remove
    if pymongo.collection.Collection._count != _count:
        pymongo.collection.Collection._count = _count
    if pymongo.cursor.Cursor._refresh != _refresh:
        pymongo.cursor.Cursor._refresh = _refresh


def register_event_handler(func):
    _event_handlers.append(func)


# Generic event handler, user can register their own handler via `register_event_handler`
def send_event(event):
    for func in _event_handlers:
        func(event)


# Wrap Collection.insert for getting queries
@functools.wraps(OriginalMethods.insert)
def _insert(collection_self, *args, **kwargs):
    start_time = time.time()
    result = OriginalMethods.insert(
        collection_self,
        *args, **kwargs
    )
    total_time = time.time() - start_time

    send_event({
        "type": "insert",
        "time": total_time,
    })
    return result


# Wrap Collection.update for getting queries
@functools.wraps(OriginalMethods.update)
def _update(collection_self, spec, document, *args, **kwargs):
    start_time = time.time()
    result = OriginalMethods.update(
        collection_self, spec, document,
        *args, **kwargs
    )
    total_time = time.time() - start_time

    send_event({
        "type": "update",
        "spec": spec,
        "document": document,
        "time": total_time,
    })
    return result


# Wrap Collection.remove for getting queries
@functools.wraps(OriginalMethods.remove)
def _remove(collection_self, spec_or_id, *args, **kwargs):
    start_time = time.time()
    result = OriginalMethods.remove(
        collection_self, spec_or_id,
        *args, **kwargs
    )
    total_time = time.time() - start_time

    send_event({
        "type": "remove",
        "spec_or_id": spec_or_id,
        "time": total_time,
    })
    return result


# Wrap Collection._count for getting queries
@functools.wraps(OriginalMethods._count)
def _count(collection_self, *args, **kwargs):
    start_time = time.time()
    result = OriginalMethods._count(
        collection_self,
        *args, **kwargs
    )
    total_time = time.time() - start_time

    send_event({
        "type": "count",
        "time": total_time,
    })
    return result


# Wrap Cursor._refresh for getting queries
@functools.wraps(OriginalMethods._refresh)
def _refresh(cursor_self):
    # Look up __ private instance variables
    def privar(name):
        return getattr(cursor_self, "_Cursor__{}".format(name), None)

    if privar("id") is not None:
        # getMore not query - move on
        return OriginalMethods._refresh(cursor_self)

    # NOTE: See pymongo/cursor.py+557 [_refresh()] and
    # pymongo/message.py for where information is stored

    # Time the actual query
    start_time = time.time()
    result = OriginalMethods._refresh(cursor_self)
    total_time = time.time() - start_time

    query_son = privar("query_spec")()
    if not isinstance(query_son, bson.SON):

        if "$query" not in query_son:
            query_son = {"$query": query_son}

        data = privar("data")
        if data:
            query_son["data"] = data

        orderby = privar("ordering")
        if orderby:
            query_son["$orderby"] = orderby

        hint = privar("hint")
        if hint:
            query_son["$hint"] = hint

        snapshot = privar("snapshot")
        if snapshot:
            query_son["$snapshot"] = snapshot

        maxScan = privar("max_scan")
        if maxScan:
            query_son["$maxScan"] = maxScan

    query_data = {
        "type": "query",
        "time": total_time,
        "operation": "query",
    }

    # Collection in format <db_name>.<collection_name>
    collection_name = privar("collection")
    query_data["collection"] = collection_name.full_name.split(".")[1]

    if query_data["collection"] == "$cmd":
        query_data["operation"] = "command"
        # Handle count as a special case
        if "count" in query_son:
            # Information is in a different format to a standar query
            query_data["collection"] = query_son["count"]
            query_data["operation"] = "count"
            query_data["skip"] = query_son.get("skip")
            query_data["limit"] = query_son.get("limit")
            query_data["query"] = query_son["query"]
    else:
        def _get_ordering(son):
            """Helper function to extract formatted ordering from dict.
            """
            def fmt(field, direction):
                return "{}{}".format("+" if direction == 1 else "-", field)

            if "$orderby" in son:
                return ", ".join(fmt(f, d) for f, d in son["$orderby"].items())

        # Normal Query
        query_data["skip"] = privar("skip")
        query_data["limit"] = privar("limit")
        query_data["query"] = query_son["$query"]
        query_data["ordering"] = _get_ordering(query_son)

    send_event(query_data)

    return result
