# -*- coding: utf-8 -*-
import Queue

from itertools import imap
from multiprocessing.pool import ThreadPool
from operator import itemgetter

import mpfs.engine.process

from mpfs.common.util import grouped_chunks, ThreadSafeCounter
from mpfs.config import settings
from mpfs.dao.base import BaseDAO, BaseDAOItem, PostgresBaseDAOImplementation, ValuesTemplateGenerator
from mpfs.dao.fields import UidField, DateField, EnumField
from mpfs.dao.session import Session
from mpfs.metastorage.postgres.schema import user_activity_info, PlatformType
from mpfs.metastorage.postgres.services import Sharpei, SharpeiUserNotFoundError
from mpfs.metastorage.postgres.queries import (
    SQL_GET_USER_ACTIVITY_INFO, SQL_BULK_UPSERT_USER_ACTIVITY_INFO,
    SQL_CLEAR_USER_ACTIVITY_INFO,
    SQL_GET_USER_PLATFORM_LAST_ACTIVITY_DATE,
    SQL_BULK_UPSERT_USER_ACTIVITY_INFO_AND_GET_EARLIEST_ACTIVITY_AFTER,
    SQL_BULK_GET_LATEST_ACTIVITY,
)

THREAD_POOL_SIZE = settings.user_activity_info['dao']['sharpei_thread_pool_size']
QUERY_CHUNK_SIZE = settings.user_activity_info['dao']['query_chunk_size']
PROMO_COME_BACK_USER_DISCOUNT_MINIMUM_DIFFERENCE_IN_MONTHS = settings.promo['come_back_user_discount']['minimum_difference_in_months']


log = mpfs.engine.process.get_default_log()
error_log = mpfs.engine.process.get_error_log()


class UserActivityDAOItem(BaseDAOItem):
    postgres_table_obj = user_activity_info
    is_sharded = True

    uid = UidField(mongo_path=None, pg_path=user_activity_info.c.uid)
    platform_type = EnumField(mongo_path=None, pg_path=user_activity_info.c.platform_type, enum_class=PlatformType)
    first_activity = DateField(mongo_path=None, pg_path=user_activity_info.c.first_activity)
    last_activity = DateField(mongo_path=None, pg_path=user_activity_info.c.last_activity)


class UserActivityInfoDAO(BaseDAO):
    dao_item_cls = UserActivityDAOItem

    def __init__(self):
        super(UserActivityInfoDAO, self).__init__()
        self._mongo_impl = None
        self._pg_impl = PostgresUserActivityInfoDAOImplementation(self.dao_item_cls)

    def bulk_update_activity_dates_and_fetch_closest_activity_dates(self, activity_data, error_info_container):
        """
        Добавить данные об активности и вытащить самые близкие даты активности пользователей до и после обновления.

        Метод добавления данных об активности пользователей по всем шардам с многопоточными походами в шарпей. В
        качестве результата возвращает генератор uid'ов с датами активности. В качестве дат достается самая поздняя дата
        активности до обновления и самая ранняя дата после обновления. К несчастью, бизнес логику поиска таких уидов
        пришлось делать в DAO, потому что кол-во данных, которое сюда попадает каждый день, довольно велико, и есть
        желание все это сделать за один проход.

        :param iterable[dict] activity_data: дынные об активности пользователей
        :param ErrorInfoContainer error_info_container: количество ошибок шарпея, количество пользователей, которых не
        нашли на шардах
        :return Generator[dict(uid, activity_date_before_update, activity_date_after_update)]: уиды и даты вернувшихся
        пользователей
        """
        return self._pg_impl.bulk_update_activity_dates_and_fetch_closest_activity_dates(activity_data, error_info_container)

    def find_by_uid(self, uid):
        implementation = self._get_impl(uid)
        if implementation is self._mongo_impl:
            return []
        return implementation.find_by_uid(uid)

    def remove_by_uid(self, uid):
        implementation = self._get_impl(uid)
        if implementation is self._mongo_impl:
            raise NotImplementedError
        return implementation.remove_by_uid(uid)

    def bulk_insert(self, uid, items):
        implementation = self._get_impl(uid)
        if implementation is self._mongo_impl:
            raise NotImplementedError
        return implementation.bulk_insert(uid, items)

    def update_activity_dates(self, item):
        implementation = self._get_impl(item.uid)
        if implementation is self._mongo_impl:
            return
        return implementation.bulk_insert(item.uid, (item,))

    def get_last_platform_activity(self, uid, platform_type):
        implementation = self._get_impl(uid)
        if implementation is self._mongo_impl:
            return None
        return implementation.get_last_platform_activity(uid, platform_type)


class PostgresUserActivityInfoDAOImplementation(PostgresBaseDAOImplementation):

    def __init__(self, *args, **kwargs):
        super(PostgresUserActivityInfoDAOImplementation, self).__init__(UserActivityDAOItem)
        self._sharpei = Sharpei()

    def bulk_update_activity_dates_and_fetch_closest_activity_dates(self, activity_data, error_info_container):
        template_generator = ValuesTemplateGenerator(
            ('uid', 'platform_type', 'first_activity', 'last_activity'), expected_values_count=QUERY_CHUNK_SIZE)

        sharpei_error_counter = ThreadSafeCounter()
        sharpei_pool = Queue.Queue()
        for _ in xrange(THREAD_POOL_SIZE):
            sharpei_pool.put(Sharpei())

        def shard_id_getter(record):
            sharpei = sharpei_pool.get()
            shard_id = None
            uid = record['uid']
            try:
                shard_id = sharpei.get_shard(uid).get_id()
            except SharpeiUserNotFoundError:
                pass
            except Exception:
                sharpei_error_counter.increment()
                error_log.exception('Cant get shard_id for user %s' % uid)
            finally:
                sharpei_pool.put(sharpei)
            return shard_id, record

        pool = ThreadPool(processes=THREAD_POOL_SIZE)
        data_with_shards = pool.imap_unordered(shard_id_getter, activity_data)

        try:
            for shard_id, data_chunk in grouped_chunks(data_with_shards, itemgetter(0), chunk_size=QUERY_CHUNK_SIZE):
                if shard_id is None:
                    continue

                values_template = template_generator.get_values_tmpl(len(data_chunk))
                values_for_template = template_generator.get_values_for_tmpl(imap(itemgetter(1), data_chunk))

                uids_to_update = {x[1]['uid'] for x in data_chunk}
                # TODO: место для потенциального ускорения за счет распараллеливания
                session = Session.create_from_shard_id(shard_id)

                result = session.execute(
                    SQL_BULK_GET_LATEST_ACTIVITY,
                    {'uids': [UserActivityDAOItem.get_field_pg_representation('uid', x) for x in uids_to_update]}
                )
                activity_dates_before_update = {
                    UserActivityDAOItem.convert_from_postgres('uid', x['uid']): x['activity_before_update']
                    for x in result
                }

                query = (SQL_BULK_UPSERT_USER_ACTIVITY_INFO_AND_GET_EARLIEST_ACTIVITY_AFTER
                         % {'values': values_template})
                result = session.execute(query, values_for_template)
                updated_uids = set()
                for row in result:
                    uid = UserActivityDAOItem.convert_from_postgres('uid', row['uid'])
                    update_result_doc = {
                        'uid': uid,
                        'activity_before_update': activity_dates_before_update.get(uid),
                        'activity_after_update': row['activity_after_update'],
                    }
                    updated_uids.add(uid)
                    yield update_result_doc

                uids_failed_update = uids_to_update - updated_uids
                for missing_uid in uids_failed_update:
                    error_log.error(
                        'Sharpei says that shard: %s contain uid: %s, but this is wrong' % (shard_id, missing_uid))
                error_info_container.missing_uids_count += len(uids_failed_update)
        finally:
            pool.terminate()

        error_info_container.sharpei_errors_count = sharpei_error_counter.get_value()

    def find_by_uid(self, uid):
        session = Session.create_from_uid(uid)
        result = session.execute(SQL_GET_USER_ACTIVITY_INFO, {'uid': uid})
        return [UserActivityDAOItem.create_from_pg_data(record) for record in result]

    def remove_by_uid(self, uid):
        session = Session.create_from_uid(uid)
        session.execute(SQL_CLEAR_USER_ACTIVITY_INFO, {'uid': uid})

    def bulk_insert(self, uid, items):
        if not items:
            return
        uid = self.get_field_repr('uid', uid)
        for item in items:
            if uid != self.get_field_repr('uid', item.uid):
                raise ValueError("Can save only for one uid. %s %s" % (uid, item.uid))

        template_generator = ValuesTemplateGenerator(('uid', 'platform_type', 'first_activity', 'last_activity'))
        values_template = template_generator.get_values_tmpl(len(items))
        values_for_template = template_generator.get_values_for_tmpl(item.as_raw_pg_dict() for item in items)
        query = SQL_BULK_UPSERT_USER_ACTIVITY_INFO % {'values': values_template}
        session = Session.create_from_uid(uid)
        session.execute(query, values_for_template)

    def get_last_platform_activity(self, uid, platform_type):
        session = Session.create_from_uid(uid)
        result = session.execute(
            SQL_GET_USER_PLATFORM_LAST_ACTIVITY_DATE, {'uid': uid, 'platform_type': platform_type.value}).fetchone()
        last_activity = result['last_activity'] if result else None
        return last_activity
