from datetime import datetime
from typing import AsyncIterable, Iterable, Mapping, Optional, Tuple

from sqlalchemy import and_, delete, func, or_, select, tuple_, update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.sql import Selectable, desc
from sqlalchemy.sql.base import Executable

from sendr_aiopg.query_builder import CRUDQueries, Filters, RelationDescription

from mail.payments.payments.core.entities.enums import PayStatus
from mail.payments.payments.core.entities.order import Order
from mail.payments.payments.core.entities.service import ServiceMerchant
from mail.payments.payments.core.entities.transaction import Transaction, TransactionStatus
from mail.payments.payments.storage.db.tables import orders as t_orders
from mail.payments.payments.storage.db.tables import service_merchants as t_service_merchants
from mail.payments.payments.storage.db.tables import services as t_service
from mail.payments.payments.storage.db.tables import shops as t_shops
from mail.payments.payments.storage.db.tables import transactions as t_transactions
from mail.payments.payments.storage.exceptions import TransactionNotFound
from mail.payments.payments.storage.mappers.base import BaseMapper
from mail.payments.payments.storage.mappers.order.serialization import OrderDataMapper
from mail.payments.payments.storage.mappers.service.serialization import ServiceDataMapper
from mail.payments.payments.storage.mappers.service.service_merchant import ServiceMerchantDataMapper
from mail.payments.payments.storage.mappers.shop import ShopDataMapper
from mail.payments.payments.utils.datetime import utcnow
from mail.payments.payments.utils.db import SelectableDataMapper, TableDataDumper


class TransactionDataMapper(SelectableDataMapper):
    selectable = t_transactions
    entity_class = Transaction


class TransactionDataDumper(TableDataDumper):
    table = t_transactions
    entity_class = Transaction


class TransactionMapper(BaseMapper):
    name = 'transaction'
    _order_relation = RelationDescription(
        name='order',
        base=t_transactions,
        related=t_orders,
        mapper_cls=OrderDataMapper,
        base_cols=('uid', 'order_id'),
        related_cols=('uid', 'order_id'),
    )
    _service_merchant_relation = RelationDescription(
        name='service_merchant',
        base=t_orders,
        related=t_service_merchants,
        mapper_cls=ServiceMerchantDataMapper,
        base_cols=('service_merchant_id',),
        related_cols=('service_merchant_id',),
        outer_join=True,
    )
    _service_relation = RelationDescription(
        name='service',
        base=t_service_merchants,
        related=t_service,
        mapper_cls=ServiceDataMapper,
        base_cols=('service_id',),
        related_cols=('service_id',),
        outer_join=True,
    )

    _shop_relation = RelationDescription(
        name='shop',
        base=t_orders,
        base_cols=('uid', 'shop_id'),
        related=t_shops,
        mapper_cls=ShopDataMapper,
        related_cols=('uid', 'shop_id'),
        outer_join=True,
    )

    _builder_kwargs = dict(
        base=t_transactions,
        id_fields=('uid', 'tx_id'),
        mapper_cls=TransactionDataMapper,
        dumper_cls=TransactionDataDumper,
    )

    _builder = CRUDQueries(
        **_builder_kwargs,
        related=(_order_relation, _shop_relation),
    )
    _builder_service_merchant = CRUDQueries(
        **_builder_kwargs,
        related=(_order_relation, _service_merchant_relation, _service_relation),
    )

    @staticmethod
    def map(row: Mapping) -> Transaction:
        return Transaction(
            uid=row['uid'],
            tx_id=row['tx_id'],
            order_id=row['order_id'],
            revision=row['revision'],
            created=row['created'],
            updated=row['updated'],
            status=row['status'],
            trust_purchase_token=row['trust_purchase_token'],
            trust_payment_url=row['trust_payment_url'],
            trust_failed_result=row['trust_failed_result'],
            trust_resp_code=row['trust_resp_code'],
            trust_payment_id=row['trust_payment_id'],
            trust_terminal_id=row['trust_terminal_id'],
            poll=row['poll'],
            check_at=row['check_at'],
            check_tries=row['check_tries'],
        )

    @staticmethod
    def map_related(row: Mapping,
                    mapper: SelectableDataMapper,
                    rel_mappers: Mapping[str, SelectableDataMapper],
                    ) -> Transaction:
        transaction: Transaction = mapper(row)
        order: Order = rel_mappers['order'](row)
        transaction.customer_uid = order.customer_uid
        transaction.user_email = order.user_email
        transaction.order = order
        if order.service_merchant_id and 'service_merchant' in rel_mappers and 'service' in rel_mappers:
            service_merchant: ServiceMerchant = rel_mappers['service_merchant'](row)
            service_merchant.service = rel_mappers['service'](row)
            order.service_merchant = service_merchant
        if 'shop' in rel_mappers:
            order.shop = rel_mappers['shop'](row)
        return transaction

    @staticmethod
    def unmap(obj: Transaction) -> dict:
        return {
            'order_id': obj.order_id,
            'revision': obj.revision,
            'created': obj.created,
            'updated': obj.updated,
            'status': obj.status,
            'trust_purchase_token': obj.trust_purchase_token,
            'trust_payment_url': obj.trust_payment_url,
            'trust_failed_result': obj.trust_failed_result,
            'trust_resp_code': obj.trust_resp_code,
            'trust_payment_id': obj.trust_payment_id,
            'trust_terminal_id': obj.trust_terminal_id,
            'poll': obj.poll,
            'check_at': obj.check_at,
            'check_tries': obj.check_tries,
        }

    async def count_unfinished(self) -> int:
        query = (
            select([func.count()]).
            where(
                or_(
                    t_transactions.c.status == TransactionStatus.ACTIVE,
                    t_transactions.c.status == TransactionStatus.HELD,
                )
            )
        )
        return (await self._query_one(query))[0]

    async def create(self, obj: Transaction) -> Transaction:
        async with self.conn.begin():
            uid = obj.uid
            tx_id = await self._acquire_tx_id(uid)

            unmapped = self.unmap(obj)
            unmapped['revision'] = await self._acquire_revision(uid)

            query = (
                insert(t_transactions).
                values(uid=uid, tx_id=tx_id, **unmapped).
                returning(*t_transactions.c)
            )
            return self.map(await self._query_one(query))

    async def get(self, uid: int, tx_id: int) -> Transaction:
        query = (
            select([t_transactions]).
            where(t_transactions.c.uid == uid).
            where(t_transactions.c.tx_id == tx_id)
        )
        return self.map(await self._query_one(query, raise_=TransactionNotFound))

    async def get_last_by_order(self,
                                uid: int,
                                order_id: int,
                                raise_: bool = True,
                                for_update: bool = False,
                                ) -> Optional[Transaction]:
        query, mapper, rel_mappers = self._builder.select_related(
            id_values=(uid,),
            filters={'order_id': order_id},
            order=('-tx_id',),
            limit=1,
            for_update=for_update,
            lock_of=[t_orders, t_transactions],
        )
        row = await self._query_one(query, raise_=raise_ and TransactionNotFound)
        return self.map_related(row, mapper, rel_mappers) if row is not None else None

    async def find_last_by_orders(self,
                                  uid_and_order_id_list: Iterable[Tuple[int, int]]
                                  ) -> AsyncIterable[Transaction]:
        query, mapper = self._builder.select()
        query = query \
            .where(tuple_(t_transactions.c.uid, t_transactions.c.order_id).in_(uid_and_order_id_list)) \
            .order_by(t_transactions.c.uid, t_transactions.c.order_id, desc(t_transactions.c.tx_id)) \
            .distinct(*(t_transactions.c.uid, t_transactions.c.order_id))

        async for row in self._query(query):
            yield mapper(row)

    async def get_for_check(self) -> Transaction:
        filters = Filters()
        filters['status'] = lambda field: field.in_([TransactionStatus.ACTIVE, TransactionStatus.HELD])
        filters['poll'] = True
        filters['check_at'] = lambda field: field <= func.now()

        query, mapper, rel_mappers = self._builder.select_related(
            filters=filters,
            order=['check_at'],
            for_update=True,
            skip_locked=True,
            limit=1,
            lock_of=[t_orders, t_transactions],
        )
        row = await self._query_one(query, raise_=TransactionNotFound)
        return self.map_related(row, mapper, rel_mappers)

    async def save(self, obj: Transaction) -> Transaction:
        uid = obj.uid
        tx_id = obj.tx_id
        async with self.conn.begin():
            unmapped = self.unmap(obj)
            unmapped['revision'] = await self._acquire_revision(uid)
            unmapped['updated'] = utcnow()

            query = (
                update(t_transactions).
                values(**unmapped).
                where(t_transactions.c.uid == uid).
                where(t_transactions.c.tx_id == tx_id).
                returning(*t_transactions.c)
            )
            return self.map(await self._query_one(query))

    async def delete(self, obj):
        raise NotImplementedError

    def _filter_by_services(self,
                            transactions_query: Selectable,
                            services: Iterable[int]
                            ) -> Executable:
        transactions_query = transactions_query.alias('transactions_query')

        orders_from_clause = t_orders \
            .join(t_service_merchants, t_orders.c.service_merchant_id == t_service_merchants.c.service_merchant_id)

        orders_query = select(t_orders.columns) \
            .select_from(orders_from_clause) \
            .where(t_service_merchants.c.service_id.in_(services)) \
            .alias('orders_query')

        join_clause = and_(
            transactions_query.c.Transaction__uid == orders_query.c.uid,
            transactions_query.c.Transaction__order_id == orders_query.c.order_id
        )

        from_clause = transactions_query \
            .join(orders_query, join_clause)

        transactions = select(transactions_query.columns) \
            .select_from(from_clause)

        return transactions

    def _prepare_find_query(
        self,
        uid: Optional[int] = None,
        order_id: Optional[int] = None,
        statuses: Optional[Iterable[TransactionStatus]] = None,
        order_pay_statuses: Optional[Iterable[PayStatus]] = None,
        tx_id: Optional[int] = None,
        email_query: Optional[str] = None,
        sort_by: Optional[str] = None,
        descending: Optional[bool] = None,
        created_from: Optional[datetime] = None,
        created_to: Optional[datetime] = None,
        updated_from: Optional[datetime] = None,
        updated_to: Optional[datetime] = None,
        services: Optional[Iterable[int]] = None,
        customer_uid: Optional[int] = None,
    ) -> Tuple[Executable, SelectableDataMapper, Mapping[str, SelectableDataMapper]]:
        filters = Filters()
        filters.add_not_none('order_id', order_id)
        filters.add_not_none('status', statuses, lambda field: field.in_(statuses))
        filters.add_not_none('order.pay_status', order_pay_statuses, lambda field: field.in_(order_pay_statuses))
        filters.add_not_none('tx_id', tx_id)
        filters.add_not_none('order.customer_uid', customer_uid)
        filters.add_range('created', created_from, created_to)
        filters.add_range('updated', updated_from, updated_to)

        order = ['-tx_id']
        if sort_by is not None:
            order = [('-' if descending else '') + (sort_by if sort_by != 'order_pay_status' else 'order.pay_status')]

        query, mapper, rel_mappers = self._builder_service_merchant.select_related(
            id_values=None if uid is None else (uid,),
            filters=filters,
            order=order,
        )

        if email_query is not None:
            email_query = email_query.lower()
            uid_and_order_id_list = select([t_orders.c.uid, t_orders.c.order_id]) \
                .where(func.lower(t_orders.c.user_email).contains(email_query, autoescape=True))

            query = query.where(tuple_(t_transactions.c.uid, t_transactions.c.order_id).in_(uid_and_order_id_list))

        if services is not None:
            query = self._filter_by_services(query, services)

        return query, mapper, rel_mappers

    async def find(
        self,
        created_from: Optional[datetime] = None,
        created_to: Optional[datetime] = None,
        customer_uid: Optional[int] = None,
        descending: Optional[bool] = None,
        email_query: Optional[str] = None,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
        order_id: Optional[int] = None,
        order_pay_statuses: Optional[Iterable[PayStatus]] = None,
        services: Optional[Iterable[int]] = None,
        sort_by: Optional[str] = None,
        statuses: Optional[Iterable[TransactionStatus]] = None,
        tx_id: Optional[int] = None,
        uid: Optional[int] = None,
        updated_from: Optional[datetime] = None,
        updated_to: Optional[datetime] = None,
    ) -> AsyncIterable[Transaction]:

        query, mapper, rel_mappers = self._prepare_find_query(
            uid=uid,
            order_id=order_id,
            statuses=statuses,
            order_pay_statuses=order_pay_statuses,
            tx_id=tx_id,
            email_query=email_query,
            sort_by=sort_by,
            descending=descending,
            created_from=created_from,
            created_to=created_to,
            updated_from=updated_from,
            updated_to=updated_to,
            services=services,
            customer_uid=customer_uid,
        )

        async for row in self._query(query, limit=limit, offset=offset):
            yield self.map_related(row, mapper, rel_mappers)

    async def get_found_count(
        self,
        created_from: Optional[datetime] = None,
        created_to: Optional[datetime] = None,
        customer_uid: Optional[int] = None,
        email_query: Optional[str] = None,
        order_id: Optional[int] = None,
        order_pay_statuses: Optional[Iterable[PayStatus]] = None,
        services: Optional[Iterable[int]] = None,
        statuses: Optional[Iterable[TransactionStatus]] = None,
        tx_id: Optional[int] = None,
        uid: Optional[int] = None,
        updated_from: Optional[datetime] = None,
        updated_to: Optional[datetime] = None,
    ) -> int:
        query, mappers, rel_mappers = self._prepare_find_query(
            uid=uid,
            order_id=order_id,
            statuses=statuses,
            order_pay_statuses=order_pay_statuses,
            tx_id=tx_id,
            email_query=email_query,
            created_from=created_from,
            created_to=created_to,
            updated_from=updated_from,
            updated_to=updated_to,
            services=services,
            customer_uid=customer_uid
        )
        query = query.alias('transactions_query')

        query = (
            select([func.count()]).
            select_from(query)
        )
        return await self._query_scalar(query)

    async def get_transactions_count(self) -> int:
        query = (
            select([func.count()]).
            select_from(t_transactions)
        )
        return await self._query_scalar(query)

    async def delete_by_uid(self, uid: int) -> None:
        query = (
            delete(t_transactions).
            where(t_transactions.c.uid == uid)
        )
        await self.conn.execute(query)
