# -*- coding: utf-8 -*-
import datetime
import time

from mpfs.config import settings
from mpfs.dao.base import (
    BaseDAO,
    BaseDAOItem,
    MongoBaseDAOImplementation,
    PostgresBaseDAOImplementation,
)
from mpfs.dao.fields import (
    StringField,
    IntegerField,
    DateTimeField,
    UidField,
    UuidField,
    PromoCodeArchiveStatusTypeField,
)
from mpfs.metastorage.postgres.queries import SQL_FIND_AVAILABLE_PROMO_CODE, SQL_FIND_PROMO_CODE, \
    SQL_DECREMENT_PROMO_CODE_COUNT, SQL_FIND_PROMO_CODE_ARCHIVE, SQL_UPSERT_PROMO_CODE_ARCHIVE, SQL_COUNT_PROMO_CODE, \
    SQL_COUNT_PROMO_CODE_ARCHIVE
from mpfs.metastorage.postgres.schema import (
    promo_codes,
    promo_codes_archive,
)


class PromoCodeDAOItem(BaseDAOItem):
    mongo_collection_name = 'promo_codes'
    postgres_table_obj = promo_codes
    is_sharded = False

    @classmethod
    def get_postgres_primary_key(cls):
        return 'id'

    id = StringField(mongo_path='_id', pg_path=promo_codes.c.id)
    pid = StringField(mongo_path='pid', pg_path=promo_codes.c.pid, default_value=None)
    discount_template_id = UuidField(mongo_path='discount_template_id', pg_path=promo_codes.c.discount_template_id, default_value=None)
    begin_datetime = DateTimeField(mongo_path='begin_datetime', pg_path=promo_codes.c.begin_datetime)
    end_datetime = DateTimeField(mongo_path='end_datetime', pg_path=promo_codes.c.end_datetime)
    count = IntegerField(mongo_path='count', pg_path=promo_codes.c.count)


class PromoCodeDAO(BaseDAO):
    dao_item_cls = PromoCodeDAOItem

    def __init__(self):
        super(PromoCodeDAO, self).__init__()
        self._mongo_impl = MongoPromoCodeDAOImplementation(self.dao_item_cls)
        self._pg_impl = PostgresPromoCodeDAOImplementation(self.dao_item_cls)

    def find_available_promo_code(self, promo_code):
        return self._get_impl(None).find_available_promo_code(promo_code)

    def find_promo_code(self, promo_code):
        return self._get_impl(None).find_promo_code(promo_code)

    def decrement_count(self, promo_code):
        return self._get_impl(None).decrement_count(promo_code)

    def count(self):
        return self._get_impl(None).count()


class MongoPromoCodeDAOImplementation(MongoBaseDAOImplementation):
    def find_available_promo_code(self, promo_code):
        cur_time = int(time.time())
        doc = {
            '_id': promo_code,
            'count': {'$gt': 0},
            'begin_datetime': {'$lte': cur_time},
            'end_datetime': {'$gte': cur_time},
        }
        result = super(MongoPromoCodeDAOImplementation, self).find_one(doc)
        if result is None:
            return None
        return PromoCodeDAOItem.create_from_mongo_dict(result)

    def find_promo_code(self, promo_code):
        doc = {'_id': promo_code}
        result = super(MongoPromoCodeDAOImplementation, self).find_one(doc)
        if result is None:
            return None
        return PromoCodeDAOItem.create_from_mongo_dict(result)

    def decrement_count(self, promo_code):
        doc = {
            '_id': promo_code,
        }
        update = {
            '$inc': {'count': -1}
        }
        self.update(doc, update)


class PostgresPromoCodeDAOImplementation(PostgresBaseDAOImplementation):
    def find_available_promo_code(self, promo_code):
        query_params = {
            'id': promo_code,
            'now': datetime.datetime.now(),
        }
        session = self._get_session()
        result = session.execute(SQL_FIND_AVAILABLE_PROMO_CODE, query_params).fetchone()
        if result:
            return PromoCodeDAOItem.create_from_pg_data(result)
        return None

    def find_promo_code(self, promo_code):
        session = self._get_session()
        result = session.execute(SQL_FIND_PROMO_CODE, {'id': promo_code}).fetchone()
        if result:
            return PromoCodeDAOItem.create_from_pg_data(result)
        return None

    def decrement_count(self, promo_code):
        session = self._get_session()
        session.execute(SQL_DECREMENT_PROMO_CODE_COUNT, {'id': promo_code})

    def count(self):
        session = self._get_session()
        return int(session.execute(SQL_COUNT_PROMO_CODE).fetchone()[0])


class PromoCodeArchiveDAOItem(BaseDAOItem):
    mongo_collection_name = 'promo_codes_archive'
    postgres_table_obj = promo_codes_archive
    is_sharded = False

    id = UuidField(mongo_path='_id', pg_path=promo_codes_archive.c.id)
    promo_code = StringField(mongo_path='promo_code', pg_path=promo_codes_archive.c.promo_code)
    pid = StringField(mongo_path='pid', pg_path=promo_codes_archive.c.pid, default_value=None)
    sid = StringField(mongo_path='sid', pg_path=promo_codes_archive.c.sid, default_value=None)
    discount_template_id = UuidField(mongo_path='discount_template_id', pg_path=promo_codes_archive.c.discount_template_id, default_value=None)
    uid = UidField(mongo_path='uid', pg_path=promo_codes_archive.c.uid)
    activation_datetime = DateTimeField(mongo_path='activation_timestamp',
                                        pg_path=promo_codes_archive.c.activation_datetime)
    status = PromoCodeArchiveStatusTypeField(mongo_path='status', pg_path=promo_codes_archive.c.status)


class PromoCodeArchiveDAO(BaseDAO):
    dao_item_cls = PromoCodeArchiveDAOItem

    def __init__(self):
        super(PromoCodeArchiveDAO, self).__init__()
        self._mongo_impl = MongoPromoCodeArchiveDAOImplementation(self.dao_item_cls)
        self._pg_impl = PostgresPromoCodeArchiveDAOImplementation(self.dao_item_cls)

    def find_promo_code(self, promo_code):
        return self._get_impl(None).find_promo_code(promo_code)

    def save(self, item):
        return self._get_impl(None).save(item)

    def count(self):
        return self._get_impl(None).count()


class MongoPromoCodeArchiveDAOImplementation(MongoBaseDAOImplementation):
    def find_promo_code(self, promo_code):
        doc = {'promo_code': promo_code}
        result = super(MongoPromoCodeArchiveDAOImplementation, self).find_one(doc)
        if result is None:
            return None
        return PromoCodeArchiveDAOItem.create_from_mongo_dict(result)

    def save(self, item):
        self.update({'_id': item.id}, item.get_mongo_representation(), upsert=True)

    def count(self):
        return self.find().count()


class PostgresPromoCodeArchiveDAOImplementation(PostgresBaseDAOImplementation):
    def find_promo_code(self, promo_code):
        session = self._get_session()
        result = session.execute(SQL_FIND_PROMO_CODE_ARCHIVE, {'promo_code': promo_code}).fetchone()
        if result:
            return PromoCodeArchiveDAOItem.create_from_pg_data(result)
        return None

    def save(self, item):
        params = {c.name: v for c, v in item.get_postgres_representation().iteritems()}
        session = self._get_session()
        session.execute(SQL_UPSERT_PROMO_CODE_ARCHIVE, params)

    def count(self):
        session = self._get_session()
        return int(session.execute(SQL_COUNT_PROMO_CODE_ARCHIVE).fetchone()[0])
