from dataclasses import dataclass
from datetime import datetime, timezone
from decimal import Decimal

from asyncpg import Record
from dateutil import parser

from crm.agency_cabinet.certificates.common.structs import (
    AgencyCertificate,
    AgencyCertificateDetails,
    AgencyCertificateDetailsRequest,
    AgencyCertificatesHistoryEntry,
    CertifiedEmployee,
    CertifiedEmployees,
    DirectBonusPoint,
    DirectCertificationCondition,
    DirectCertificationScores,
    DirectKPI,
    EmployeeCertificate,
    EmployeeCertificateStatus,
    FetchAgencyCertificatesHistoryRequest,
    ListEmployeesCertificatesRequest,
)
from crm.agency_cabinet.certificates.server.lib.db.engine import DB
from crm.agency_cabinet.certificates.server.lib.exceptions import (
    AgencyCertificateNotFound,
)
from smb.common.pgswim import PoolType

__all__ = [
    "ListAgencyCertificates",
    "FetchAgencyCertificatesHistory",
    "ListEmployeesCertificates",
    "FetchAgencyCertificateDetails",
]


@dataclass
class ListEmployeesCertificates:
    db: DB

    async def __call__(self, request: ListEmployeesCertificatesRequest) -> CertifiedEmployees:
        sql = """
            SELECT
                agency_id,
                employee_email as email,
                employee_name as name,
                array_agg(
                    jsonb_build_object(
                        'project', project,
                        'start_time', start_time,
                        'expiration_time', expiration_time,
                        'external_id', external_id,
                        'status', CASE
                            WHEN expiration_time < $2 THEN 'expired'
                            WHEN expiration_time < $2 + interval '6 months' THEN 'expires_in_semiyear'
                            ELSE 'active'
                            END
                    ) ORDER BY project
                ) AS certificates
            FROM (
                SELECT
                    agency_id,
                    employee_email,
                    external_id,
                    LAST_VALUE(employee_name) OVER (
                        PARTITION BY agency_id, employee_email
                    ) as employee_name,
                    project,
                    start_time,
                    expiration_time
                FROM employee_certificates
                WHERE (
                        $6::text IS NULL
                        OR ($6::text IS NOT NULL AND project = $6)
                    )
                    AND (
                        $7::text IS NULL
                        OR ($7::text IS NOT NULL AND $7::text = 'expired' AND expiration_time < $2)
                        OR ($7::text IS NOT NULL AND $7::text = 'expires_in_semiyear'
                         AND expiration_time >= $2 AND expiration_time < $2 + interval '6 months'
                        )
                        OR ($7::text IS NOT NULL AND $7::text = 'active' AND expiration_time >= $2)
                    )
            ) as data
            WHERE agency_id = $1
                AND (
                    $5::text IS NULL
                    OR ($5::text IS NOT NULL AND LOWER( employee_name ) LIKE '%' || LOWER( $5 ) || '%')
                    OR ($5::text IS NOT NULL AND LOWER( employee_email ) LIKE '%' || LOWER( $5 ) || '%')
                )
            GROUP BY agency_id, employee_email, employee_name
            ORDER BY employee_name NULLS LAST
            LIMIT $3 OFFSET $4
        """

        async with self.db.acquire(PoolType.replica) as con:
            rows = await con.fetch(
                sql,
                request.agency_id,
                datetime.now(timezone.utc),
                request.pagination.limit,
                request.pagination.offset,
                request.search_query,
                request.project,
                request.status,
            )

        return CertifiedEmployees(employees=[self._format_row(r) for r in rows])

    @staticmethod
    def _format_row(row: Record) -> CertifiedEmployee:
        return CertifiedEmployee(
            name=row["name"],
            email=row["email"],
            agency_id=row["agency_id"],
            certificates=[
                EmployeeCertificate(
                    start_time=parser.isoparse(cert["start_time"]),
                    expiration_time=parser.isoparse(cert["expiration_time"]),
                    project=cert["project"],
                    external_id=cert["external_id"],
                    status=EmployeeCertificateStatus(cert["status"]),
                )
                for cert in row["certificates"]
            ],
        )


@dataclass
class ListAgencyCertificates:
    db: DB

    async def __call__(self, agency_id: int) -> list[AgencyCertificate]:
        sql = """
        SELECT
            DISTINCT ON (certs.project) certs.project as project,
            certs.expiration_time as expiration_time,
            certs.id as certificate_id,
            COALESCE(
                scores.current_score >= scores.target_score,
                FALSE
            ) as auto_renewal_is_met
        FROM agency_certificates as certs
        LEFT JOIN agency_certificates_prolongation_score as scores
            ON certs.agency_id=scores.agency_id AND certs.project=scores.project AND scores.score_group = 'general'
        WHERE certs.agency_id = $1
        ORDER BY certs.project, certs.expiration_time DESC
        """

        async with self.db.acquire(PoolType.replica) as con:
            rows = await con.fetch(sql, agency_id)

        return [
            AgencyCertificate(
                id=row["certificate_id"],
                project=row["project"],
                expiration_time=row["expiration_time"],
                auto_renewal_is_met=row["auto_renewal_is_met"],
            )
            for row in rows
        ]


@dataclass
class FetchAgencyCertificatesHistory:
    db: DB

    async def __call__(
        self, request: FetchAgencyCertificatesHistoryRequest
    ) -> list[AgencyCertificatesHistoryEntry]:
        sql = """
        SELECT
            id,
            project,
            start_time,
            expiration_time
        FROM agency_certificates
        WHERE agency_id = $1 AND project = COALESCE ($2, project)
        ORDER BY expiration_time DESC, project
        LIMIT $3 OFFSET $4
        """

        async with self.db.acquire(PoolType.replica) as con:
            rows = await con.fetch(
                sql,
                request.agency_id,
                request.project,
                request.pagination.limit,
                request.pagination.offset,
            )
        return [
            AgencyCertificatesHistoryEntry(
                id=row["id"],
                project=row["project"],
                start_time=row["start_time"],
                expiration_time=row["expiration_time"],
            )
            for row in rows
        ]


@dataclass
class FetchAgencyCertificateDetails:
    db: DB

    _sql = """
            WITH temp_conditions as (
                SELECT
                    agency_id as agency_id,
                    array_agg(jsonb_build_object(
                        'name', name,
                        'value', value,
                        'threshold', threshold,
                        'is_met', is_met
                    )) AS conditions
                FROM agency_certificates_direct_conditions
                WHERE agency_id = $1
                GROUP BY agency_id
            ), temp_kpis as (
                SELECT
                    agency_id as agency_id,
                    array_agg(jsonb_build_object(
                        'name', name,
                        'group_name', group_name,
                        'value', value,
                        'max_value', max_value
                    )) AS kpis
                FROM agency_certificates_direct_kpi
                WHERE agency_id = $1
                GROUP BY agency_id
            ), temp_bonuses as (
                SELECT
                    agency_id as agency_id,
                    array_agg(jsonb_build_object(
                        'name', name,
                        'value', value,
                        'threshold', threshold,
                        'is_met', is_met,
                        'score', score
                    ))  AS scores
                FROM agency_certificates_direct_bonus_scores
                WHERE agency_id = $1
                GROUP BY agency_id
            ), temp_scores as (
                SELECT
                    agency_id as agency_id,
                    array_agg(jsonb_build_object(
                        'score_group', score_group,
                        'value', current_score,
                        'threshold', target_score,
                        'is_met', current_score >= target_score
                    )) AS prolongation_scores
                FROM agency_certificates_prolongation_score
                WHERE agency_id = $1 and project='direct'
                GROUP BY agency_id
            )
            SELECT
                coalesce(
                    temp_kpis.agency_id,
                    temp_bonuses.agency_id,
                    temp_conditions.agency_id,
                    temp_scores.agency_id
                ) as agency_id,
                coalesce(temp_conditions.conditions,  array[]::jsonb[]) as conditions,
                coalesce(temp_bonuses.scores,  array[]::jsonb[]) as scores,
                coalesce(temp_kpis.kpis,  array[]::jsonb[]) as kpis,
                coalesce(temp_scores.prolongation_scores,  array[]::jsonb[]) as prolongation_scores
            FROM temp_kpis
                FULL JOIN temp_bonuses USING(agency_id)
                FULL JOIN temp_conditions USING(agency_id)
                FULL JOIN temp_scores USING (agency_id)
        """

    @staticmethod
    def _format_agency_certificate_details(row: Record) -> AgencyCertificateDetails:
        return AgencyCertificateDetails(
            agency_id=row["agency_id"],
            conditions=[
                DirectCertificationCondition(
                    name=condition["name"],
                    value=condition["value"],
                    threshold=condition["threshold"],
                    is_met=condition["is_met"],
                )
                for condition in row["conditions"]
            ],
            kpis=[
                DirectKPI(
                    name=kpi["name"],
                    value=Decimal(kpi["value"]),
                    max_value=Decimal(kpi["max_value"]),
                    group=kpi["group_name"],
                )
                for kpi in row["kpis"]
            ],
            bonus_points=[
                DirectBonusPoint(
                    name=bonus_point["name"],
                    value=bonus_point["value"],
                    threshold=bonus_point["threshold"],
                    is_met=bonus_point["is_met"],
                    score=Decimal(bonus_point["score"]),
                )
                for bonus_point in row["scores"]
            ],
            scores=[
                DirectCertificationScores(
                    score_group=score["score_group"],
                    value=Decimal(score["value"]),
                    threshold=Decimal(score["threshold"]),
                    is_met=score["is_met"],
                )
                for score in row["prolongation_scores"]
            ],
        )

    async def __call__(
        self, request: AgencyCertificateDetailsRequest
    ) -> AgencyCertificateDetails:
        async with self.db.acquire(PoolType.replica) as con:
            row = await con.fetchrow(self._sql, request.agency_id)

        if row is None:
            raise AgencyCertificateNotFound(agency_id=request.agency_id)

        return self._format_agency_certificate_details(row)
