import pymongo
from bson.binary import Binary
from pybloom_live import ScalableBloomFilter
from pymongo.errors import OperationFailure

from jafar import jafar_mongo
from jafar.feed.base import FeedBlock, FeedPage, FeedMeta
from jafar.filters import MultipleBloomItemFilter

try:
    from cStringIO import StringIO
except:
    from StringIO import StringIO
import datetime
import logging

logger = logging.getLogger(__name__)


class FeedDictBackend(object):
    def __init__(self, feed, lifetime_seconds):
        self.feed = feed
        self.lifetime_seconds = lifetime_seconds

    def deserialize_filter(self, data):
        def deserialize_bloom_filter(binary):
            io = StringIO()
            io.write(binary)
            io.seek(0)
            return ScalableBloomFilter.fromfile(io)

        return MultipleBloomItemFilter(
            num_filters=data['num_filters'],
            bloom_filters=[
                deserialize_bloom_filter(binary)
                for binary in data['filters']
            ]
        )

    def deserialize_meta(self, meta):
        return FeedMeta(
            page_count=meta['page_count'],
            item_filter=self.deserialize_filter(meta['item_filter']),
            explanation_filter=self.deserialize_filter(meta['explanation_filter']),
            generated_at=meta['generated_at'],
            expire_at=meta.get('expire_at') or (meta['generated_at'] +
                                                datetime.timedelta(seconds=self.lifetime_seconds))
        )

    def deserialize_page(self, page_data):
        return FeedPage(
            blocks=[self.deserialize_block(block) for block in page_data['blocks']],
            generated_at=page_data['generated_at'],
            expire_at=page_data.get('expire_at') or (page_data['generated_at'] +
                                                     datetime.timedelta(seconds=self.lifetime_seconds))
        )

    def deserialize_block(self, block_data):
        return FeedBlock(
            items=block_data['items'],
            title=block_data['title'],
            subtitle=block_data['subtitle'],
            algorithm=block_data['algorithm'],
            card_type=block_data['card_type'],
            content_type=block_data.get('content_type', 'apps'),  # TODO: remove when new configs appear in mongo
            explanation=block_data['explanation'],
            reserved=block_data['reserved'],
            rotation_interval=block_data.get('rotation_interval'),
            placement_id=block_data.get('placement_id'),
            external_promo_provider=block_data.get('external_promo_provider'),
            external_promo_ttl=block_data.get('external_promo_ttl')
        )

    def serialize_filter(self, item_filter):
        def serialize_bloom_filter(bloom_filter):
            io = StringIO()
            bloom_filter.tofile(io)
            return Binary(io.getvalue())

        return {
            'num_filters': item_filter.num_filters,
            'filters': [
                serialize_bloom_filter(bloom)
                for bloom in item_filter.content
            ]
        }

    def serialize_meta(self, meta):
        return {
            "page_count": meta.page_count,
            "item_filter": self.serialize_filter(meta.item_filter),
            "explanation_filter": self.serialize_filter(meta.explanation_filter),
            "generated_at": meta.generated_at,
            "expire_at": meta.expire_at
        }

    def serialize_page(self, page):
        return {
            "blocks": [self.serialize_block(block) for block in page.blocks],
            "generated_at": page.generated_at,
            "expire_at": page.expire_at
        }

    def serialize_block(self, block):
        return {
            "items": block.items,
            "title": block.title,
            "subtitle": block.subtitle,
            "card_type": block.card_type,
            "content_type": block.content_type,
            "algorithm": block.algorithm,
            "explanation": block.explanation,
            "reserved": block.reserved,
            "rotation_interval": block.rotation_interval,
            "placement_id": block.placement_id,
            "external_promo_provider": block.external_promo_provider,
            "external_promo_ttl": block.external_promo_ttl,
        }


class FeedMongoBackend(FeedDictBackend):
    ttl_field = 'expire_at'
    max_feed_count = 100

    def get_feed_args(self):
        return {
            'device_id': self.feed.user.device_id,
            'cache_key': self.feed.cache_key,
            'experiment_name': self.feed.experiment_name
        }

    def load_page(self, page_number):
        query = self.get_feed_args()
        query['page_number'] = page_number
        page = jafar_mongo.db.feed.find_one(query)
        if not page:
            return None
        return self.deserialize_page(page)

    def save_page(self, page):
        if self.lifetime_seconds == 0:
            logger.debug("Not saving page because lifetime_seconds=0")
            return
        # fetch previous page index
        query = self.get_feed_args()
        if self.feed.meta.page_count == 0:
            self.check_overall_feed_count()
        query['page_number'] = self.feed.meta.page_count + 1
        document = self.serialize_page(page)
        jafar_mongo.db.feed.update_one(query, {'$set': document}, upsert=True)

    def load_meta(self):
        if self.lifetime_seconds == 0:
            logger.debug("Not trying to load meta because lifetime_seconds=0")
            return None
        query = self.get_feed_args()
        meta = jafar_mongo.db.feed_meta.find_one(query)
        if not meta:
            return None
        return self.deserialize_meta(meta)

    def save_meta(self, meta):
        if self.lifetime_seconds == 0:
            logger.debug("Not trying to save meta because lifetime_seconds=0")
            return
        document = self.serialize_meta(meta)
        jafar_mongo.db.feed_meta.update_one(self.get_feed_args(), {
            '$set': document
        }, upsert=True)

    def clear(self, query=None):
        if not query:
            query = self.get_feed_args()
        jafar_mongo.db.feed.delete_many(query)
        jafar_mongo.db.feed_meta.delete_many(query)

    def check_overall_feed_count(self):
        # user is not allowed to keep more then self.max_feed_count feeds cached
        device_id = self.feed.user.device_id
        user_feeds = jafar_mongo.db.feed_meta.find({'device_id': device_id})
        count = user_feeds.count()
        if count < self.max_feed_count:
            return
        extra_count = count - self.max_feed_count + 1
        logger.info("User %s reached maximum number of %s feeds: removing %s of them", device_id, self.max_feed_count, extra_count)
        oldest_feeds = user_feeds.sort(self.ttl_field, 1).limit(extra_count)
        self.clear({
            'cache_key': {'$in': [meta['cache_key'] for meta in oldest_feeds]},
            'device_id': device_id
        })

    @classmethod
    def create_indexes(cls):
        for collection in (jafar_mongo.db.feed, jafar_mongo.db.feed_meta):
            try:
                collection.create_index(cls.ttl_field, expireAfterSeconds=0, background=True)
                for index in [
                    'device_id',
                    'cache_key',
                    [('device_id', pymongo.ASCENDING),
                     ('cache_key', pymongo.ASCENDING)],
                    [('device_id', pymongo.ASCENDING),
                     ('cache_key', pymongo.ASCENDING),
                     ('experiment_name', pymongo.ASCENDING)]
                ]:
                    collection.create_index(index)
            except OperationFailure, e:
                logger.error("Failed to create index for %s collection %s", collection.name, e)
