# -*- coding: utf-8 -*-r

from __future__ import absolute_import


import time

from celery.app.amqp import TaskProducer
from kombu.messaging import Consumer


def setup_producer_logging():
    import mpfs.engine.process
    requests_log = mpfs.engine.process.get_requests_log()

    _wrapped_publish = TaskProducer._publish

    def _publish(self, body, priority, content_type, content_encoding, headers, properties, routing_key, mandatory,
                 immediate, exchange, declare):
        ptime = time.time()
        result = _wrapped_publish(self, body, priority, content_type, content_encoding, headers, properties, routing_key, mandatory,
                                  immediate, exchange, declare)
        ptime = time.time() - ptime

        queue = '?'
        if isinstance(declare, list) and len(declare):
            queue = getattr(declare[0], 'name', queue)
        requests_log.info('amqp://%s/%s.%s.%s(%s) %d %.3f' % (self.connection.host, self.connection.virtual_host,
                                                              queue, 'publish', body, len(body), ptime))

        return result

    TaskProducer._publish = _publish


def setup_consumer_logging():
    import mpfs.engine.process
    requests_log = mpfs.engine.process.get_requests_log()

    _wrapped_receive = Consumer.receive

    def receive(self, body, message):
        queue = '?'
        consumer_tag = getattr(message, 'delivery_info', {}).get('consumer_tag', None)
        if consumer_tag:
            for q, t in getattr(self, '_active_tags', {}).iteritems():
                if consumer_tag == t:
                    queue = q
                    break

        host, vhost = None, None
        if self.connection:
            host, vhost = self.connection.host, self.connection.virtual_host

        ycrid = body.get('kwargs', {}).get('context', {}).get('ycrid', None)
        task_short_name = body.get('task', 'None').rsplit('.', 1)[-1]
        task_id = body.get('id', None)

        mpfs.engine.process.set_cloud_req_id(ycrid)
        mpfs.engine.process.set_req_id('%s-%s' % (task_id, task_short_name))
        mpfs.engine.process.reset_cached()

        ptime = time.time()
        result = _wrapped_receive(self, body, message)
        ptime = time.time() - ptime

        requests_log.info('amqp://%s/%s.%s.%s(%s) %d %.3f' % (host, vhost,
                                                              queue, 'receive', str(body), len(str(body)), ptime))

        return result

    Consumer.receive = receive
