# -*- coding: utf-8 -*-
import datetime
import itertools

from mpfs.dao.base import BaseDAO, BaseDAOItem, PostgresBaseDAOImplementation, MongoBaseDAOImplementation
from mpfs.dao.session import Session
from mpfs.dao.fields import ByteStringField, DateTimeField, UTCDateTimeField, IntegerField, JsonField, Md5Field, \
                            NullableStringField, UidField, UuidField
from mpfs.metastorage.postgres.schema import operations
from mpfs.metastorage.postgres.queries import SQL_OPERATIONS_GET_BY_DATE_PERIOD_AND_STATE, SQL_OPERATION_UPDATE_DTIME, \
                            SQL_OPERATIONS_GET_BY_STATE_TYPES_SUBTYPES_AND_DTIME_AGE, \
                            SQL_OPERATIONS_COUNT_BY_STATE_TYPES_SUBTYPES_AND_DTIME_AGE, \
    SQL_OPERATIONS_COUNT_BY_STATE_TYPES_WITH_LIMIT
from mpfs.metastorage.postgres.services import Sharpei


class OperationDAOItem(BaseDAOItem):
    mongo_collection_name = 'operations'
    postgres_table_obj = operations
    is_sharded = True
    mongo_compressed_field = 'data'

    id = ByteStringField(mongo_path='_id', pg_path=operations.c.id)
    uid = UidField(mongo_path='uid', pg_path=operations.c.uid)

    ctime = DateTimeField(mongo_path='data.ctime', pg_path=operations.c.ctime)
    dtime = UTCDateTimeField(mongo_path='dtime', pg_path=operations.c.dtime)
    mtime = DateTimeField(mongo_path='mtime', pg_path=operations.c.mtime)

    state = IntegerField(mongo_path='state', pg_path=operations.c.state)
    version = IntegerField(mongo_path='version', pg_path=operations.c.version)

    type = NullableStringField(mongo_path='data.type', pg_path=operations.c.type, default_value=None)
    subtype = NullableStringField(mongo_path='data.subtype', pg_path=operations.c.subtype, default_value=None)
    md5 = Md5Field(mongo_path='md5', pg_path=operations.c.md5, default_value=None)
    uniq_id = UuidField(mongo_path='uniq_id', pg_path=operations.c.uniq_id, default_value=None)
    ycrid = NullableStringField(mongo_path='ycrid', pg_path=operations.c.ycrid, default_value=None)

    data = JsonField(mongo_path='data.data', pg_path=operations.c.data, default_value=None)

    validation_ignored_mongo_dict_fields = (
        'data.uniq_id',  # почему-то хранится в двух экземплярах - в поле data и в самом верхнем словаре, причем совпадает. Если data.uniq_id = None, то в главном словаре поля нет
        'data.id',  # зачем-то дубилируется из _id
        'type',  # дублируется в data.type
        'subtype',  # дублируется в data.type
    )

    exclude_keys_after_conversion_to_mongo = {
        'data': {
            'type': None,
            'subtype': None,
            'data': 'None'
        },
        'md5': None,
        'uniq_id': None,
    }


class OperationDAO(BaseDAO):
    dao_item_cls = OperationDAOItem

    def __init__(self):
        super(OperationDAO, self).__init__()
        self._mongo_impl = MongoOperationDAOImplementation(self.dao_item_cls)
        self._pg_impl = PostgresOperationDAOImplementation(self.dao_item_cls)

    def fetch_by_date_period_and_states(self, from_dt, to_dt, states):
        items = itertools.chain(
            self._mongo_impl.fetch_by_date_period_and_states(from_dt, to_dt, states),
            self._pg_impl.fetch_by_date_period_and_states(from_dt, to_dt, states)
        )
        for item in items:
            yield item

    def set_dtime(self, uid, oid, dtime):
        if not isinstance(dtime, datetime.datetime):
            raise TypeError('dtime must be datetime')
        impl = self._get_impl(uid)
        return impl.set_dtime(uid, oid, dtime)

    def fetch_by_age_types_subtypes_and_states(self, age, types_subtypes, states):
        # Мигрируем пользователей с монги, так что делаем реализацию только для postgres.
        return self._pg_impl.fetch_by_age_types_subtypes_and_states(age, types_subtypes, states)

    def get_count_by_age_types_subtypes_and_states(self, age, types_subtypes, states):
        # Мигрируем пользователей с монги, так что делаем реализацию только для postgres.
        return self._pg_impl.get_count_by_age_types_subtypes_and_states(age, types_subtypes, states)

    def get_count_by_types_and_states(self, uid, types, states, limit):
        return self._get_impl(uid).get_count_by_types_and_states(uid, types, states, limit)


class PostgresOperationDAOImplementation(PostgresBaseDAOImplementation):
    def fetch_by_date_period_and_states(self, from_dt, to_dt, states):
        from_dt = self.dao_item_cls.get_field_pg_representation('dtime', from_dt)
        to_dt = self.dao_item_cls.get_field_pg_representation('dtime', to_dt)
        for shard_id in Sharpei().get_all_shard_ids():
            session = Session.create_from_shard_id(shard_id)
            cursor = session.execute(SQL_OPERATIONS_GET_BY_DATE_PERIOD_AND_STATE,
                                     {'from_dt': from_dt, 'to_dt': to_dt, 'states': tuple(states)})
            for doc in cursor:
                yield self.dao_item_cls.create_from_pg_data(doc)

    def set_dtime(self, uid, oid, dtime):
        pg_dtime = self.dao_item_cls.get_field_pg_representation('dtime', dtime)
        session = Session.create_from_uid(uid)
        session.execute(SQL_OPERATION_UPDATE_DTIME, {'dtime': pg_dtime, 'id': oid, 'uid': uid})

    def fetch_by_age_types_subtypes_and_states(self, age, op_types_subtypes, states):
        maxdtime = datetime.datetime.utcnow() - datetime.timedelta(seconds=age)
        pg_maxdtime = self.dao_item_cls.get_field_pg_representation('dtime', maxdtime)
        request_params = {'maxdtime': pg_maxdtime, 'types_subtypes': tuple(op_types_subtypes), 'states': tuple(states)}
        for shard_id in Sharpei().get_all_shard_ids():
            session = Session.create_from_shard_id(shard_id)
            cursor = session.execute(SQL_OPERATIONS_GET_BY_STATE_TYPES_SUBTYPES_AND_DTIME_AGE, request_params)
            for doc in cursor:
                yield self.dao_item_cls.create_from_pg_data(doc)

    def get_count_by_age_types_subtypes_and_states(self, age, op_types_subtypes, states):
        operations_count = 0
        maxdtime = datetime.datetime.utcnow() - datetime.timedelta(seconds=age)
        pg_maxdtime = self.dao_item_cls.get_field_pg_representation('dtime', maxdtime)
        request_params = {'maxdtime': pg_maxdtime, 'types_subtypes': tuple(op_types_subtypes), 'states': tuple(states)}
        for shard_id in Sharpei().get_all_shard_ids():
            session = Session.create_from_shard_id(shard_id)
            operations_count += session.execute(SQL_OPERATIONS_COUNT_BY_STATE_TYPES_SUBTYPES_AND_DTIME_AGE,
                                                request_params).fetchone()[0]
        return operations_count

    def get_count_by_types_and_states(self, uid, types, states, limit):
        params = {'uid': uid,
                  'states': tuple(states),
                  'types': tuple(types),
                  'limit': limit}

        session = Session.create_from_uid(uid)
        return session.execute(SQL_OPERATIONS_COUNT_BY_STATE_TYPES_WITH_LIMIT, params).fetchone()[0]


class MongoOperationDAOImplementation(MongoBaseDAOImplementation):
    def fetch_by_date_period_and_states(self, from_dt, to_dt, states):
        spec = {
            'dtime': {
                '$lte': to_dt,
                '$gte': from_dt,
            },
            'state': {"$in": states},
        }
        for coll in self._mongo_helper.iter_over_all_shards(self.dao_item_cls.mongo_collection_name):
            for doc in coll.find(spec):
                yield self.dao_item_cls.create_from_mongo_dict(doc)

    def set_dtime(self, uid, oid, dtime):
        self.update({'uid': uid, '_id': oid}, {'$set': {'dtime': dtime}})

    def get_count_by_types_and_states(self, uid, types, states, limit=None):
        day_ago = datetime.datetime.utcnow() - datetime.timedelta(days=1)
        day_ago_ts = self.dao_item_cls.get_field_mongo_representation('dtime', day_ago)
        spec = {'uid': uid,
                'dtime': {'$gte': day_ago_ts},
                'state': {'$in': states},
                'type': {'$in': types}}
        cursor = self.find(spec)
        if limit is not None:
            cursor.limit(limit)
        return cursor.count()
