import Queue
import datetime as dt
import threading as th

from pymongo import errors
from mongoengine import connection

from sandbox.common import console
import sandbox.common.types.resource as ctr

from . import base


class LastUsageTime(base.UpgradeStep):
    """ Moves `last_usage_time` attribute from per-host collection to the root of the document. """

    THREADS = 20
    MAX_QUEUE_SIZE = 200

    @classmethod
    def worker(cls, queue, collection):
        while True:
            data = queue.get()
            if not data:
                break
            rid, ctime, atime = data
            collection.update(
                {"_id": rid},
                {"$set": {"time.ct": ctime, "time.at": atime}}
            )

    def _process(self, collection, resources):
        queue = Queue.Queue(self.MAX_QUEUE_SIZE)
        pool = [th.Thread(target=self.worker, args=(queue, collection)) for _ in xrange(self.THREADS)]
        map(th.Thread.start, pool)

        pbar = console.ProgressBar("Updating data", len(resources))
        for i, data in enumerate(resources):
            queue.put(data)
            pbar.update(i)
        pbar.finish()

        with console.LongOperation("Waiting for workers."):
            map(queue.put, [None] * len(pool))
            map(th.Thread.join, pool)

    def pre(self):
        collection = connection.get_db()["resource"]
        with console.LongOperation("Creating new index"):
            collection.create_index("time.at")

        with console.LongOperation("Calculating resources count which can be safely updated"):
            cursor = collection.find(
                {
                    "$or": [
                        {"hosts": {"$size": 0}},
                        {"hosts": {"$exists": False}}
                    ],
                    "time.at": {"$exists": False}
                },
                {"_id": True, "ctime": True}
            )
            amount = cursor.count()

        resources = []
        pbar = console.ProgressBar("Fetching data", amount)
        for i, doc in enumerate(cursor):
            resources.append((doc["_id"], doc["ctime"], doc["ctime"]))
            pbar.update(i)
        pbar.finish()
        self._process(collection, resources)

    def main(self):
        collection = connection.get_db()["resource"]

        with console.LongOperation("Calculating resource record to be updated"):
            cursor = collection.find(
                {"time.at": {"$exists": False}},
                {"_id": True, "ctime": True, "hosts.at": True}
            )
            amount = cursor.count()

        resources = []
        pbar = console.ProgressBar("Fetching data", amount)
        for i, doc in enumerate(cursor):
            hosts = doc["hosts"] if "hosts" in doc else None
            atime = doc["ctime"] if not hosts else max(h["at"] for h in hosts)
            resources.append((doc["_id"], doc["ctime"], atime))
            pbar.update(i)
        pbar.finish()

        self._process(collection, resources)

    def post(self):
        collection = connection.get_db()["resource"]
        with console.LongOperation("Dropping 'hosts.at' field. This can take a while") as op:
            while True:
                updated = collection.update(
                    {"hosts.at": {"$exists": True}},
                    {"$unset": {"hosts.$.at": True}},
                    multi=True
                )
                if not updated["n"]:
                    break
                op.intermediate("Dropped {} records".format(updated['n']))

        with console.LongOperation("Dropping 'ctime' field. This can take a while"):
            collection.update({}, {"$unset": {"ctime": True}}, multi=True)

        with console.LongOperation("Dropping 'hosts.at_1' index"):
            try:
                collection.drop_index("hosts.at_1")
            except errors.OperationFailure:
                pass


class Expires(base.UpgradeStep):
    """ Add `time.expires` attribute for all non-immortal ready resources. """

    THREADS = 20
    MAX_QUEUE_SIZE = 200

    @classmethod
    def worker(cls, queue, collection):
        while True:
            data = queue.get()
            if not data:
                break
            rid, atime = data
            collection.update(
                {"_id": rid},
                {"$set": {"time.ex": atime + dt.timedelta(days=14)}}
            )

    def main(self):
        pass

    def post(self):
        collection = connection.get_db()["resource"]

        with console.LongOperation("Calculating resource records to be updated"):
            cursor = collection.find(
                {
                    "state": "READY",
                    "attrs.k": {"$nin": (
                        ctr.ServiceAttributes.RELEASED,
                        ctr.ServiceAttributes.TTL,
                    )},
                },
                {"_id": True, "time.at": True}
            )
            amount = cursor.count()

        resources = []
        pbar = console.ProgressBar("Fetching data", amount)
        for i, doc in enumerate(cursor):
            resources.append((doc["_id"], doc["time"]["at"]))
            pbar.update(i)
        pbar.finish()

        queue = Queue.Queue(self.MAX_QUEUE_SIZE)
        pool = [th.Thread(target=self.worker, args=(queue, collection)) for _ in xrange(self.THREADS)]
        map(th.Thread.start, pool)

        pbar = console.ProgressBar("Updating data", len(resources))
        for i, data in enumerate(resources):
            queue.put(data)
            pbar.update(i)
        pbar.finish()

        with console.LongOperation("Waiting for workers."):
            map(queue.put, [None] * len(pool))
            map(th.Thread.join, pool)


class ReleasedTTL(base.UpgradeStep):
    """ Set `ttl` attribute to "inf" for released resources. """

    THREADS = 20
    MAX_QUEUE_SIZE = 200

    @classmethod
    def worker(cls, queue, collection):
        while True:
            data = queue.get()
            if not data:
                break
            rid, attrs = data
            attr = next((_ for _ in attrs if _["k"] == "ttl"), None)
            if attr:
                attr["v"] = "inf"
            else:
                attrs.append({"k": "ttl", "v": "inf"})
            collection.update({"_id": rid}, {"$set": {"attrs": attrs}})

    def main(self):
        pass

    def post(self):
        with console.LongOperation("Fetching released tasks IDs"):
            tids = [
                _["_id"]
                for _ in connection.get_db()["task"].find({"exc.st": "RELEASED"}, {"_id": True})
            ]

        collection = connection.get_db()["resource"]
        with console.LongOperation("Fetching resource records to be updated"):
            cursor = collection.find(
                {
                    "state": "READY",
                    "attrs.k": "released",
                    "tid": {"$in": tids}
                },
                {"_id": True, "attrs": True}
            )
            amount = cursor.count()

        resources = []
        pbar = console.ProgressBar("Fetching data", amount)
        for i, doc in enumerate(cursor):
            resources.append((doc["_id"], doc["attrs"]))
            pbar.update(i)
        pbar.finish()

        queue = Queue.Queue(self.MAX_QUEUE_SIZE)
        pool = [th.Thread(target=self.worker, args=(queue, collection)) for _ in xrange(self.THREADS)]
        map(th.Thread.start, pool)

        pbar = console.ProgressBar("Updating data", len(resources))
        for i, data in enumerate(resources):
            queue.put(data)
            pbar.update(i)
        pbar.finish()

        with console.LongOperation("Waiting for workers."):
            map(queue.put, [None] * len(pool))
            map(th.Thread.join, pool)
