from datetime import date
from decimal import Decimal
import logging
from uuid import UUID

from sqlalchemy import and_, func, select
from intranet.trip.src.db.tables import billing_transactions_table, billing_deposit_table
from intranet.trip.src.models import Transaction, Deposit
from intranet.trip.src.api.schemas import TransactionFilter
from intranet.trip.src.db.gateways.base import DBGateway
from intranet.trip.src.db.subqueries.common import get_person_with_company_column

logger = logging.getLogger(__name__)


class TransactionGateway(DBGateway):

    table = billing_transactions_table
    model_class = Transaction

    async def get_transaction(self, transaction_id: UUID) -> Transaction:
        return await self._get_one(
            select_list=[
                self.table,
                get_person_with_company_column(correlate_table=self.table, label='person'),
            ],
            where_clause=and_(
                self.table.c.transaction_id == transaction_id,
                self.table.c.is_obsolete.is_(False),
            ),
        )

    async def create(self, **fields) -> UUID:
        fields['is_obsolete'] = False
        create_new_transaction_query = (
            self.table
            .insert()
            .values(**fields)
            .returning(self.table.c.transaction_id)
        )
        return await self.conn.scalar(create_new_transaction_query)

    async def update(self, transaction_id: UUID, **fields) -> UUID:
        """
        Чтобы таблица была аудируемой, при редактировании запись помечается как устаревшая и
        создается новая запись с измененными полями

        Эту операцию нужно выполнять в транзакции
        """
        column_names = {str(column.name) for column in self.table.columns}
        updatable_fields = column_names - {'created_at', 'updated_at', 'is_obsolete'}

        old_transaction_data = dict(await self.get_transaction(transaction_id))
        old_transaction_data['person_id'] = old_transaction_data['person'].person_id

        new_transaction_data = {
            field: old_transaction_data[field]
            for field in updatable_fields
        }
        new_transaction_data.update(fields)
        new_transaction_data['is_obsolete'] = False

        await self.delete(transaction_id)
        return await self.create(**new_transaction_data)

    async def delete(self, transaction_id: UUID):
        query = (
            self.table
            .update()
            .where(
                and_(
                    self.table.c.transaction_id == transaction_id,
                    self.table.c.is_obsolete.is_(False),
                ),
            )
            .values(is_obsolete=True)
        )
        await self.conn.execute(query)

    async def get_expenses(
        self,
        company_id: int,
        date_from: date = None,
        date_to: date = None,
    ) -> Decimal:
        query = (
            select([
                func.sum(self.table.c.price + self.table.c.yandex_fee + self.table.c.provider_fee),
            ])
            .where(self.table.c.company_id == company_id)
        )

        if date_from is not None:
            query = query.where(self.table.c.execution_date >= date_from)

        if date_to is not None:
            query = query.where(self.table.c.execution_date <= date_to)

        return await self.conn.scalar(query) or Decimal('0')

    async def get_list_and_count(self, fltr: TransactionFilter) -> tuple[list[Transaction], int]:
        # TODO: BTRIP-3006 обработать случай ППЛ
        query = (
            select([
                self.table,
                get_person_with_company_column(correlate_table=self.table, label='person'),
            ])
            .where(self.table.c.is_obsolete.is_(False))
            .where(self.table.c.company_id == fltr.company_id)
            .order_by(self.table.c.created_at.desc())
        )

        if fltr.trip_id is not None:
            query = query.where(self.table.c.trip_id == fltr.trip_id)

        if fltr.person_id is not None:
            query = query.where(self.table.c.person_id == fltr.person_id)

        if fltr.service_id is not None:
            query = query.where(self.table.c.service_id == fltr.service_id)

        if fltr.service_type is not None:
            query = query.where(self.table.c.service_type == fltr.service_type)

        if fltr.status is not None:
            query = query.where(self.table.c.status == fltr.status)

        if fltr.execution_date__gte is not None:
            query = query.where(func.date(self.table.c.execution_date) >= fltr.execution_date__gte)

        if fltr.execution_date__lte is not None:
            query = query.where(func.date(self.table.c.execution_date) <= fltr.execution_date__lte)

        if fltr.invoice_date__gte is not None:
            query = query.where(func.date(self.table.c.invoice_date) >= fltr.invoice_date__gte)

        if fltr.invoice_date__lte is not None:
            query = query.where(func.date(self.table.c.invoice_date) <= fltr.invoice_date__lte)

        count_query = select([func.count()]).select_from(query.subquery())
        count = await self.conn.scalar(count_query)

        if fltr.limit is not None:
            query = query.limit(fltr.limit)

        if fltr.offset is not None:
            query = query.offset(fltr.offset)

        rows = await self._fetchall(query)

        return [Transaction(**item) for item in rows], count


class DepositGateway(DBGateway):

    table = billing_deposit_table
    pk = 'deposit_id'
    model_class = Deposit

    async def get_deposits_sum(self, company_id: int) -> Decimal:
        query = (
            select([func.sum(self.table.c.amount)])
            .where(self.table.c.company_id == company_id)
        )

        return await self.conn.scalar(query) or Decimal('0')

    async def get_last_deposit_date(self, company_id: int) -> date | None:
        query = (
            select([self.table.c.charge_date])
            .where(self.table.c.company_id == company_id)
            .order_by(self.table.c.charge_date.desc())
        )

        return await self.conn.scalar(query)
