from dataclasses import dataclass
from datetime import datetime
from dateutil.relativedelta import relativedelta
from decimal import Decimal

from asyncpg import Connection, Record
from smb.common.pgswim import PoolType

from crm.agency_cabinet.client_bonuses.common.structs import (
    BonusAmount,
    BonusDetails,
    BonusDetailsList,
    BonusStatusType,
    ClientBonus,
    ClientGraph,
    FetchBonusesDetailsInput,
    GraphPoint,
    ListClientsBonusesInput,
    ProgramBonusesGraph,
    GetClientsBonusesSettingsInput,
    ClientBonusSettings,
    ReportInfo,
    ListBonusesReportsInfoInput,
    GetReportUrlResponse,
    CreateReportInput,
    GetDetailedReportInfoResponse,
    ClientType,
    DeleteReportInput,
    DeleteReportOutput,
    ListCashbackProgramsInput,
    CashbackProgram,

)
from crm.agency_cabinet.client_bonuses.server.lib.db.engine import DB
from crm.agency_cabinet.common.consts.report import ReportsStatuses
from crm.agency_cabinet.common.mds import create_presigned_url_async
from crm.agency_cabinet.common.mds.definitions import S3AsyncClient

from .exceptions import (
    ClientNotFound, FileNotFoundException, UnsuitableAgencyException,
    ReportNotReadyException, NoSuchReportException
)

__all__ = [
    "FetchBonusesDetails",
    "ListClientsBonuses",
    "GetClientsBonusesSettings",
    "FetchClientBonusesGraph",
    "ListBonusesReportsInfo",
    "GetReportUrl",
    "GetDetailedReportInfo",
    "CreateReport",
    "ListCashbackPrograms",
]


async def check_client_exists(agency_id: int, client_id: int, con: Connection) -> bool:
    sql = """
        SELECT EXISTS (
            SELECT *
            FROM clients
            WHERE id=$1 AND agency_id=$2
        )
    """

    return await con.fetchval(sql, client_id, agency_id)


@dataclass
class FetchBonusesDetails:
    db: DB

    async def __call__(
        self,
        params: FetchBonusesDetailsInput,
    ) -> BonusDetailsList:
        sql = """
            with accrued as (
                SELECT
                    client_id,
                    'accrued' as type,
                    array_agg(
                        jsonb_build_object(
                            'program_id', program_id,
                            'amount', amount::text
                        )
                        ORDER BY program_id
                    ) as amounts,
                    SUM(amount) as total,
                    date_trunc('month', gcb.gained_at) as date
                FROM gained_client_bonuses as gcb
                WHERE gained_at BETWEEN $3 AND $4
                    AND client_id=$1
                GROUP BY client_id, date_trunc('month', gcb.gained_at)
            ),
            spent as (
                 SELECT
                     client_id,
                     'spent' as type,
                     array[]::jsonb[] as amounts,
                     amount as total,
                     date_trunc('month', spent_at) as date
                 FROM spent_client_bonuses
                 WHERE spent_at BETWEEN $3 AND $4
                     AND client_id=$1
            )
            SELECT
                details.type as type,
                details.amounts as amounts,
                details.total as total,
                details.date as date
            FROM clients
            LEFT JOIN (
                SELECT * FROM accrued
                UNION
                SELECT * FROM spent
            ) as details ON clients.id = details.client_id
            WHERE clients.id=$1 AND clients.agency_id=$2
            ORDER BY details.date DESC, details.type
        """

        async with self.db.acquire(PoolType.replica) as con:
            rows = await con.fetch(
                sql,
                params.client_id,
                params.agency_id,
                params.datetime_start,
                params.datetime_end,
            )

            if not rows:
                raise ClientNotFound(
                    client_id=params.client_id, agency_id=params.agency_id
                )

        return BonusDetailsList(items=self._format_bonuses_log(rows))

    @staticmethod
    def _format_bonuses_log(rows: list[Record]) -> list[BonusDetails]:
        if len(rows) == 1 and rows[0]["type"] is None:
            return []

        return [
            BonusDetails(
                type=BonusStatusType(row["type"]),
                date=row["date"],
                amounts=[
                    BonusAmount(program_id=item["program_id"], amount=Decimal(item["amount"]))
                    for item in row["amounts"]
                ],
                total=Decimal(row["total"]),
            )
            for row in rows
        ]


@dataclass
class ListClientsBonuses:
    db: DB

    async def __call__(self, params: ListClientsBonusesInput) -> list[ClientBonus]:
        sql = """
            WITH
                gains AS (
                    SELECT gained.client_id, gained.currency as currency, SUM(gained.amount) as gains
                    FROM gained_client_bonuses as gained
                    JOIN clients ON gained.client_id = clients.id
                    WHERE clients.agency_id = $1 AND
                        gained.gained_at BETWEEN $2 AND $3
                    GROUP BY gained.client_id, gained.currency
                ),
                spends AS (
                    SELECT spent.client_id, spent.currency as currency, SUM(spent.amount) as spends
                    FROM spent_client_bonuses as spent
                    JOIN clients ON spent.client_id = clients.id
                    WHERE clients.agency_id = $1
                        AND spent.spent_at BETWEEN $2 AND $3
                    GROUP BY spent.client_id, spent.currency
                )
            SELECT
                clients.id as client_id,
                clients.login as email,
                clients.is_active as active,
                COALESCE(gains.currency, spends.currency, 'RUR') as currency,
                COALESCE(gains.gains, 0) as accrued,
                COALESCE(spends.spends, 0) as spent,
                COALESCE(bonuses.amount, 0) as awarded
            FROM clients
            LEFT JOIN gains ON clients.id = gains.client_id
            LEFT JOIN spends ON clients.id = spends.client_id
            LEFT JOIN client_bonuses_to_activate as bonuses
                ON clients.id = bonuses.client_id
            WHERE clients.agency_id = $1
                AND (
                    $6 = 'ALL'
                    OR ($6 = 'ACTIVE' AND clients.is_active)
                    OR ($6 = 'EXCLUDED' AND NOT clients.is_active)
                )
                AND (
                    $7 = 'ALL'
                    OR (
                        $7 = 'WITH_ACTIVATION_OVER_PERIOD'
                        AND clients.is_active
                        AND gains.gains IS NOT NULL
                    )
                    OR (
                        $7 = 'WITH_SPENDS_OVER_PERIOD'
                        AND clients.is_active
                        AND spends.spends IS NOT NULL
                    )
                )
                AND (
                    $8::text IS NULL
                    OR ($8::text IS NOT NULL AND LOWER( clients.login ) LIKE '%' || LOWER( $8 ) || '%')
                    OR clients.id = $9
                )
            ORDER BY clients.create_date DESC
            LIMIT $4 OFFSET $5
        """

        search_query_client_id = None
        if params.search_query is not None and params.search_query.isnumeric():
            search_query_client_id = int(params.search_query)

        async with self.db.acquire(PoolType.replica) as con:
            rows = await con.fetch(
                sql,
                params.agency_id,
                params.datetime_start,
                params.datetime_end,
                # asyncpg почему-то ожидает int и не удается подставить LIMIT ALL;
                # с другой стороны в запросе не должно получаться много значений:
                # фильтр login ILIKE %client_id% не должен давать много вариантов
                # TODO: сделать правильно
                params.limit if params.limit > 0 else 100,
                params.offset,
                params.client_type.value,
                params.bonus_type.value,
                params.search_query,
                search_query_client_id,
            )

        return [ClientBonus(**r) for r in rows]


@dataclass
class GetClientsBonusesSettings:
    db: DB

    async def __call__(self, params: GetClientsBonusesSettingsInput) -> ClientBonusSettings:
        sql = """
            SELECT MIN(spent_bonuses.spent_at), MAX(spent_bonuses.spent_at)
            FROM spent_client_bonuses as spent_bonuses
            JOIN clients as clients
            ON spent_bonuses.client_id = clients.id
            WHERE clients.agency_id = $1;
        """

        async with self.db.acquire(PoolType.replica) as con:
            row = await con.fetchrow(
                sql,
                params.agency_id
            )
            first_date = row['min']
            last_date = row['max']

        return ClientBonusSettings(first_date=first_date, last_date=last_date)


@dataclass
class FetchClientBonusesGraph:
    db: DB

    async def __call__(self, agency_id: int, client_id: int):
        sql = """
        WITH overall_spent AS (
            SELECT
                client_id,
                array_agg(
                    jsonb_build_object(
                        'point', spent_at,
                        'value', amount::text
                    ) ORDER BY spent_at
                ) AS spent
            FROM spent_client_bonuses
            WHERE client_id=$1
            GROUP BY client_id
        ),
        programs AS (
            SELECT
                client_id,
                program_id,
                jsonb_build_object(
                    'program_id', program_id,
                    'historical_monthly_data', array_agg(
                            jsonb_build_object(
                                'point', gained_at,
                                'value', amount::text
                            ) ORDER BY gained_at
                        )
                ) as programs
            FROM gained_client_bonuses
            WHERE client_id=$1
            GROUP BY client_id, program_id
        ),
        overall_gained AS (
            SELECT
                client_id,
                array_agg(
                    jsonb_build_object(
                        'point', gained_at,
                        'value', amount::text
                    ) ORDER BY gained_at
                ) as overall_gained
            FROM (
                 SELECT client_id,
                        sum(amount) as amount,
                        gained_at
                 FROM gained_client_bonuses
                 GROUP BY client_id, gained_at
            ) as gained
            WHERE client_id=$1
            GROUP BY client_id
        )
        SELECT
            clients.id,
            COALESCE(cba.amount, 0) as bonuses_available,
            overall_gained.overall_gained as overall_accrued,
            overall_spent.spent as overall_spent,
            array_remove(
                array_agg(
                    programs.programs
                    ORDER BY programs.program_id
                ), NULL
            ) as programs
        FROM clients
        LEFT JOIN client_bonuses_to_activate as cba ON cba.client_id = clients.id
        LEFT JOIN overall_spent ON overall_spent.client_id = clients.id
        LEFT JOIN overall_gained ON overall_gained.client_id = clients.id
        LEFT JOIN programs ON programs.client_id = clients.id
        WHERE clients.id = $1 AND agency_id = $2
        GROUP BY
            clients.id,
            cba.amount,
            overall_gained.overall_gained,
            overall_spent.spent
        """

        async with self.db.acquire(PoolType.replica) as con:
            graph_dateset = await con.fetchrow(sql, client_id, agency_id)
            if not graph_dateset:
                raise ClientNotFound(client_id=client_id, agency_id=agency_id)

        return self._format_graph(graph_dateset)

    @staticmethod
    def _format_graph(graph_dateset: Record) -> ClientGraph:
        time_series = FetchClientBonusesGraph._get_time_series(
            graph_dateset["overall_spent"],
            graph_dateset["overall_accrued"]
        )

        overall_spent = FetchClientBonusesGraph._fill_data(time_series, graph_dateset["overall_spent"])
        overall_accrued = FetchClientBonusesGraph._fill_data(time_series, graph_dateset["overall_accrued"])

        programs = [
            ProgramBonusesGraph(
                program_id=program["program_id"],
                historical_monthly_data=[
                    GraphPoint(
                        point=datetime.fromisoformat(point["point"]),
                        value=Decimal(point["value"]),
                    )
                    for point in program["historical_monthly_data"]
                ],
            )
            for program in graph_dateset["programs"]
        ]

        return ClientGraph(
            bonuses_available=graph_dateset["bonuses_available"],
            overall_spent=overall_spent,
            overall_accrued=overall_accrued,
            programs=programs,
        )

    @staticmethod
    def _get_time_series(overall_spent, overall_accrued):
        def get_time_bounds(data):
            if not data:
                return None, None
            return datetime.fromisoformat(data[0]["point"]), datetime.fromisoformat(data[-1]["point"])

        start_spent, end_spent = get_time_bounds(overall_spent)
        start_accrued, end_accrued = get_time_bounds(overall_accrued)

        def get_dt(dt_spent, dt_accrued, func):
            if dt_spent and dt_accrued:
                return func(dt_spent, dt_accrued)
            if dt_spent:
                return dt_spent
            if dt_accrued:
                return dt_accrued
            return None

        time_series = []
        t = get_dt(start_spent, start_accrued, min)
        while t and t <= get_dt(end_spent, end_accrued, max):
            time_series.append(t)
            t += relativedelta(months=1)

        return time_series

    @staticmethod
    def _fill_data(time_series, data):
        return [
            GraphPoint(
                t,
                Decimal(
                    next(
                        (x["value"] for x in data or [] if datetime.fromisoformat(x["point"]) == t),
                        0
                    )
                )
            )
            for t in time_series]


@dataclass
class ListBonusesReportsInfo:
    db: DB

    async def __call__(self, params: ListBonusesReportsInfoInput) -> list[ReportInfo]:
        sql = """
            SELECT id, name, created_at, period_from, period_to, status, client_type
            FROM report_meta_info
            WHERE agency_id = $1;
        """

        async with self.db.acquire(PoolType.replica) as con:
            rows = await con.fetch(
                sql,
                params.agency_id
            )

        return [_row_to_report_info(r) for r in rows]


@dataclass
class GetReportUrl:
    db: DB

    async def __call__(self, agency_id: int, report_id: int, s3_client: S3AsyncClient) -> GetReportUrlResponse:
        sql = f"""
            SELECT agency_id, status, file_id, s3_mds_file.bucket, s3_mds_file.name AS filename
            FROM report_meta_info
            FULL JOIN s3_mds_file
            ON s3_mds_file.id=file_id
            WHERE report_meta_info.id={report_id}
        """

        async with self.db.acquire(PoolType.replica) as con:
            report = await con.fetchrow(sql)

        if report is None:
            raise NoSuchReportException
        if report['agency_id'] != agency_id:
            raise UnsuitableAgencyException
        if report['status'] != ReportsStatuses.ready.value:
            raise ReportNotReadyException
        if report['file_id'] is None:
            raise FileNotFoundException

        url = await create_presigned_url_async(s3_client, report['bucket'], report['filename'])
        return GetReportUrlResponse(report_url=url)


@dataclass
class GetDetailedReportInfo:
    db: DB

    async def __call__(self, agency_id: int, report_id: int) -> GetDetailedReportInfoResponse:
        sql = f"""
            SELECT *
            FROM report_meta_info
            WHERE report_meta_info.id={report_id}
        """

        async with self.db.acquire(PoolType.replica) as con:
            report = await con.fetchrow(sql)

        if report is None:
            raise NoSuchReportException
        if report['agency_id'] != agency_id:
            raise UnsuitableAgencyException

        return GetDetailedReportInfoResponse(report=_row_to_report_info(report))


@dataclass
class CreateReport:
    db: DB

    async def __call__(self, params: CreateReportInput) -> ReportInfo:
        sql = """
            INSERT INTO report_meta_info (
                name,
                agency_id,
                period_from,
                period_to,
                client_type,
                status
            )

            VALUES (
                $1,
                $2,
                $3,
                $4,
                $5,
                $6
            )
            RETURNING id, name, period_from, period_to, created_at, status, client_type ;
        """

        async with self.db.acquire(PoolType.master) as con:
            row = await con.fetchrow(
                sql,
                params.name,
                params.agency_id,
                params.period_from,
                params.period_to,
                params.client_type.value,
                ReportsStatuses.requested.value
            )

        return _row_to_report_info(row)


@dataclass
class DeleteReport:
    db: DB

    async def __call__(self, params: DeleteReportInput) -> DeleteReportOutput:
        sql_find_report = """
            SELECT *
            FROM report_meta_info
            WHERE id = $1;
        """

        async with self.db.acquire(PoolType.replica) as con:
            report = await con.fetchrow(
                sql_find_report,
                params.report_id)

        if report is None:
            raise NoSuchReportException
        if report['agency_id'] != params.agency_id:
            raise UnsuitableAgencyException

        sql_delete_report = """
            DELETE FROM report_meta_info WHERE id = $1 and agency_id = $2
            RETURNING * ;
        """

        async with self.db.acquire(PoolType.master) as con:
            deleted_row = await con.fetchrow(
                sql_delete_report,
                params.report_id,
                params.agency_id
            )

        if deleted_row is None:
            raise NoSuchReportException

        return DeleteReportOutput(is_deleted=True)


def _row_to_report_info(row):
    return ReportInfo(
        id=row['id'],
        name=row['name'],
        period_to=row['period_to'],
        period_from=row['period_from'],
        created_at=row['created_at'],
        client_type=ClientType(row['client_type']),
        status=row['status']
    )


def _row_to_cashback_program(row):
    return CashbackProgram(
        id=row['id'],
        category_id=row['category_id'],
        is_general=row['is_general'],
        is_enabled=row['is_enabled'],
        name_ru=row['name_ru'],
        name_en=row['name_en'],
        description_ru=row['description_ru'],
        description_en=row['description_en'],
    )


@dataclass
class ListCashbackPrograms:
    db: DB

    async def __call__(self, params: ListCashbackProgramsInput) -> list[CashbackProgram]:
        sql = """
            SELECT id, category_id, is_general, is_enabled, name_ru, name_en, description_ru, description_en
            FROM cashback_programs;
        """

        async with self.db.acquire(PoolType.replica) as con:
            rows = await con.fetch(sql)

        return [_row_to_cashback_program(r) for r in rows]
