from datetime import datetime
from decimal import Decimal
from typing import AsyncIterable, List, Optional

from sqlalchemy import and_, desc, or_, select

from sendr_aiopg import BaseMapper
from sendr_aiopg.data_mapper import SelectableDataMapper, TableDataDumper
from sendr_aiopg.query_builder import CRUDQueries, Filters
from sendr_utils import utcnow

from mail.ohio.ohio.core.entities.order import NDS, Item, Order, OrderData, Refund, RefundStatus
from mail.ohio.ohio.core.entities.order_service import OrderService
from mail.ohio.ohio.storage.db.tables import orders as t_orders
from mail.ohio.ohio.storage.exceptions import OrderNotFoundStorageError
from mail.ohio.ohio.storage.mappers.customer import CustomerSerialMixin

from mail.ohio.ohio.storage.mappers.utils import fix_service_id
import json
import os


class OrderDataMapper(SelectableDataMapper):
    selectable = t_orders
    entity_class = Order

    def _map_items(self, items: List[dict]) -> List[Item]:
        return [
            Item(
                amount=Decimal(item['amount']) if item['amount'] else Decimal(0.0) ,
                price=Decimal(item['price']) if item['price'] else Decimal(0.0),
                currency=item['currency'],
                nds=NDS(item['nds']),
                name=item['name'],
                image_path=item.get('image_path'),
                image_url=item.get('image_url'),
            )
            for item in items
        ]

    def _map_refunds(self, refunds: List[dict]) -> List[Refund]:
        return [
            Refund(
                trust_refund_id=refund['trust_refund_id'],
                refund_status=RefundStatus(refund['refund_status']),
                total=Decimal(refund['total']) if refund['total'] else Decimal(0.0),
                currency=refund['currency'],
                items=self._map_items(refund['items']),
            )
            for refund in refunds
        ]

    def map_order_data(self, data: Optional[dict]) -> Optional[OrderData]:
        if data is None:
            return None
        return OrderData(
            total=Decimal(data['total'] if data['total'] else Decimal(0.00)),
            currency=data['currency'],
            items=self._map_items(data['items']),
            description=data['description'],
            refunds=self._map_refunds(data['refunds']),
        )


class OrderDataDumper(TableDataDumper):
    table = t_orders
    entity_class = Order

    def _dump_items(self, items: List[Item]) -> List[dict]:
        return [
            {
                'amount': str(item.amount),
                'price': str(item.price),
                'currency': item.currency,
                'nds': item.nds.value,
                'name': item.name,
                'image_path': item.image_path,
                'image_url': item.image_url,
            }
            for item in items
        ]

    def _dump_refunds(self, refunds: List[Refund]) -> List[dict]:
        return [
            {
                'trust_refund_id': refund.trust_refund_id,
                'refund_status': refund.refund_status.value,
                'total': str(refund.total),
                'currency': refund.currency,
                'items': self._dump_items(refund.items),
            }
            for refund in refunds
        ]

    def dump_order_data(self, data: Optional[OrderData]) -> Optional[dict]:
        if data is None:
            return None
        return {
            'total': str(data.total),
            'currency': data.currency,
            'items': self._dump_items(data.items),
            'description': data.description,
            'refunds': self._dump_refunds(data.refunds),
        }


class OrderMapper(CustomerSerialMixin, BaseMapper):
    _builder = CRUDQueries(
        base=t_orders,
        id_fields=('customer_uid', 'order_id'),
        mapper_cls=OrderDataMapper,
        dumper_cls=OrderDataDumper,
    )

    def get_service_id_by_subservice(self, subservice_id):
        temp_services_hack = json.loads(os.environ.get('SERVICE_TO_SUBSERVICE_MAP', '{}'))
        for service_info in temp_services_hack:
            if service_info['subservice_id'] == subservice_id:
                return service_info['service_id']
        return subservice_id

    async def create(self, obj: Order) -> Order:
        async with self.conn.begin():
            obj.updated = utcnow()
            obj.order_id = await self.acquire_next_order_id(obj.customer_uid)
            query, mapper = self._builder.insert(obj)
            return mapper(await self._query_one(query))

    async def find_for_customer(self,
                                customer_uid: int,
                                service_ids: Optional[list] = None,
                                subservice_ids: Optional[list] = None,
                                created_keyset: Optional[datetime] = None,
                                order_id_keyset: Optional[int] = None,
                                limit: Optional[int] = None,
                                ) -> AsyncIterable[Order]:
        """
        Returns customer's orders. Keyset parameters must be passed together for pagination or not be passed at all.
        Since pagination requires order_id and created, only returns orders with non-empty created.
        """
        assert (created_keyset is None) == (order_id_keyset is None), 'keyset must be passed or not passed together'
        filters = Filters()
        overriden_service_ids = []
        if subservice_ids and len(subservice_ids) > 0:
            for subservice_id in subservice_ids:
                overriden_service_ids.append(self.get_service_id_by_subservice(subservice_id))
        if len(overriden_service_ids) > 0:
            filters.add_not_none('service_id', overriden_service_ids,
                                 lambda service_id: service_id.in_(overriden_service_ids))
        else:
            filters.add_not_none('service_id', service_ids, lambda service_id: service_id.in_(service_ids))
        # TODO remove this later
        # filters.add_not_none('subservice_id', subservice_ids, lambda subservice_id: subservice_id.in_(subservice_ids))
        filters['created'] = lambda f: f.isnot(None)
        query, mapper = self._builder.select(
            id_values=(customer_uid,),
            filters=filters,
            order=('-created', '-order_id'),
            limit=limit,
        )
        if created_keyset is not None and order_id_keyset is not None:
            query = query.where(or_(
                t_orders.c.created < created_keyset,
                and_(t_orders.c.created == created_keyset, t_orders.c.order_id < order_id_keyset),
            ))
        async for row in self._query(query):
            cur_order = mapper(row)
            if cur_order.subservice_id and cur_order.subservice_id == '1':
                cur_order.subservice_id = fix_service_id(cur_order.service_id, cur_order.subservice_id)

            # filter out invalid attempts to execute transaction
            if not (len(cur_order.order_data.items) == 0 and
                    len(cur_order.order_data.refunds) == 0 and
                    cur_order.order_data.total == 0.0):
                yield cur_order

    async def find_yandex_account_for_customer(self,
                                               customer_uid: int,
                                               created_keyset: Optional[datetime] = None,
                                               order_id_keyset: Optional[int] = None,
                                               limit: Optional[int] = None,
                                               ) -> AsyncIterable[Order]:
        """
        Returns customer's orders with yandex account operations(yandex_account_topup or yandex_account_withdraw). Keyset parameters must be passed together for pagination or not be passed at all.
        Since pagination requires order_id and created, only returns orders with non-empty created.
        """
        assert (created_keyset is None) == (order_id_keyset is None), 'keyset must be passed or not passed together'
        filters = Filters()
        filters['created'] = lambda f: f.isnot(None)
        # filters['service_data'] = lambda f: f['payment_method'].astext == 'yandex_account_topup' or f['payment_method'].astext == 'yandex_account_withdraw'
        query, mapper = self._builder.select(
            id_values=(customer_uid,),
            filters=filters,
            order=('-created', '-order_id'),
            limit=limit,
        )
        query = query.where(
            or_(
                t_orders.c.service_data['payment_method'].astext == 'yandex_account_topup',
                t_orders.c.service_data['payment_method'].astext == 'yandex_account_withdraw',
            )
        )
        if created_keyset is not None and order_id_keyset is not None:
            query = query.where(
                or_(
                    t_orders.c.created < created_keyset,
                    and_(t_orders.c.created == created_keyset, t_orders.c.order_id < order_id_keyset),
                )
            )
        async for row in self._query(query):
            order = mapper(row)
            if order.subservice_id and order.subservice_id == '1':
                order.subservice_id = fix_service_id(order.service_id, order.subservice_id)
            yield order

    async def get(self, customer_uid: int, order_id: int, for_update: bool = False) -> Order:
        query, mapper = self._builder.select(id_values=(customer_uid, order_id), for_update=for_update)
        return mapper(await self._query_one(query, raise_=OrderNotFoundStorageError))

    async def get_by_trust_purchase_token(self,
                                          customer_uid: int,
                                          trust_purchase_token: str,
                                          for_update: bool = False,
                                          ) -> Order:
        query, mapper = self._builder.select(
            id_values=(customer_uid,),
            filters={'trust_purchase_token': trust_purchase_token},
            for_update=for_update,
        )
        return mapper(await self._query_one(query, raise_=OrderNotFoundStorageError))

    async def get_services(self, customer_uid: int, limit: Optional[int] = None) -> AsyncIterable[OrderService]:
        subquery = (
            select([t_orders.c.service_id, t_orders.c.subservice_id]).
                where(t_orders.c.customer_uid == customer_uid).
                order_by(desc(t_orders.c.order_id))
        )
        if limit is not None:
            subquery = subquery.limit(limit)
        subquery = subquery.alias('subquery')
        query = select(subquery.columns).select_from(subquery).distinct()
        async for row in self._query(query):
            subservice_id = row['subservice_id']
            if subservice_id and subservice_id == '1':
                subservice_id = fix_service_id(row['service_id'], subservice_id)
            yield OrderService(service_id=row['service_id'], subservice_id=subservice_id)

    async def save(self, obj: Order) -> Order:
        obj.updated = utcnow()
        query, mapper = self._builder.update(
            obj,
            ignore_fields=('service_id',),
        )
        return mapper(await self._query_one(query, raise_=OrderNotFoundStorageError))
