from abc import abstractmethod
from operator import itemgetter
from typing import Dict, List, Optional, Set

import pytz
from asyncpg.connection import Connection
from asyncpg.exceptions import ForeignKeyViolationError

from maps_adv.billing_proxy.lib.db.enums import (
    CurrencyType,
    DbEnumConverter,
    PaymentType,
)
from smb.common.pgswim import PoolType

from . import sqls
from .base import BaseDataManager
from .exceptions import (
    AgencyDoesNotExist,
    ClientsDoNotExist,
    ClientDoesNotExist,
)


class AbstractClientsDataManager(BaseDataManager):
    @abstractmethod
    async def client_exists(self, client_id: int) -> bool:
        raise NotImplementedError()

    @abstractmethod
    async def agency_exists(self, agency_id: int) -> bool:
        raise NotImplementedError()

    @abstractmethod
    async def find_client_locally(self, client_id: int) -> Optional[Dict]:
        raise NotImplementedError()

    @abstractmethod
    async def upsert_client(
        self, client: dict, contracts: List[dict]
    ) -> Optional[Dict]:
        raise NotImplementedError()

    @abstractmethod
    async def sync_client_contracts(
        self, client_id: int, contracts: List[dict], con: Connection = None
    ) -> None:
        raise NotImplementedError()

    @abstractmethod
    async def insert_client(
        self,
        client_id: int,
        name: str,
        email: str,
        phone: str,
        account_manager_id: Optional[int],
        domain: str,
        partner_agency_id: Optional[int],
        has_accepted_offer: Optional[bool],
    ) -> Dict:
        raise NotImplementedError()

    @abstractmethod
    async def set_account_manager_for_client(
        self, client_id: int, account_manager_id: int
    ) -> None:
        raise NotImplementedError()

    @abstractmethod
    async def set_representatives_for_client(
        self, client_id: int, representatives: List[int]
    ) -> None:
        raise NotImplementedError()

    @abstractmethod
    async def list_agencies(self) -> List[Dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_agency_clients(self, agency_id: Optional[int]) -> List[Dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_account_manager_clients(self, account_manager_id: int) -> List[Dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_client_ids(self):
        raise NotImplementedError()

    @abstractmethod
    async def add_clients_to_agency(
        self, client_ids: List[int], agency_id: Optional[int]
    ) -> None:
        raise NotImplementedError()

    @abstractmethod
    async def remove_clients_from_agency(
        self, client_ids: List[int], agency_id: Optional[int]
    ) -> None:
        raise NotImplementedError()

    @abstractmethod
    async def client_is_in_agency(self, client_id: int, agency_id: int) -> bool:
        raise NotImplementedError()

    @abstractmethod
    async def find_contract(self, contract_id: int) -> Optional[dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_contacts_by_client(self, client_id: int) -> List[dict]:
        raise NotImplementedError()

    @abstractmethod
    async def set_client_has_accepted_offer(self, client_id: int):
        raise NotImplementedError()

    @abstractmethod
    async def list_clients(self) -> List[Dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_clients_with_agencies(self, client_ids: List[int]) -> List[int]:
        raise NotImplementedError()

    @abstractmethod
    async def list_clients_with_orders_with_agency(
        self,
        client_ids: List[int],
        agency_id: Optional[int],
    ) -> List[int]:
        raise NotImplementedError()


class ClientsDataManager(AbstractClientsDataManager):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._moscow_tz = pytz.timezone("Europe/Moscow")

    @staticmethod
    def _process_data(data):
        if not isinstance(data, dict):
            data = dict(data)

        data["currency"] = DbEnumConverter.to_enum(CurrencyType, data["currency"])
        data["payment_type"] = DbEnumConverter.to_enum(
            PaymentType, data["payment_type"]
        )

        return data

    async def _client_or_agency_exists(self, id_: int, is_agency: bool):
        async with self.connection() as con:
            return await con.fetchval(sqls.client_or_agency_exists, id_, is_agency)

    async def client_exists(self, client_id: int) -> bool:
        return await self._client_or_agency_exists(client_id, False)

    async def agency_exists(self, agency_id: int) -> bool:
        return await self._client_or_agency_exists(agency_id, True)

    async def find_client_locally(self, client_id: int) -> Optional[Dict]:
        async with self.connection() as con:
            client = await con.fetchrow(sqls.find_client_by_id, client_id)

        return dict(client) if client else None

    async def upsert_client(
        self,
        id: int,
        name: str,
        email: str,
        con: Connection = None,
        **kwargs,
    ) -> dict:
        kwargs.update({"id": id, "name": name, "email": email})
        async with self.connection(con) as con:
            keys = list(map(str, kwargs.keys()))
            client = await con.fetchrow(
                sqls.upsert_client(keys, **kwargs),
                *kwargs.values(),
            )
            return dict(client)

    async def list_client_ids(self):
        async with self.connection(type=PoolType.replica) as con:
            return await con.fetchval(sqls.list_client_ids)

    async def sync_client_contracts(
        self, client_id: int, contracts: List[dict], con: Connection = None
    ) -> None:
        existing_contracts = await self.list_contacts_by_client(client_id)
        existing_contract_ids = set(map(itemgetter("id"), existing_contracts))
        actual_contracts_by_id = {contract["id"]: contract for contract in contracts}

        for existing_contract in existing_contracts:
            if existing_contract["id"] in actual_contracts_by_id:
                actual_contract = actual_contracts_by_id[existing_contract["id"]]
                existing_contract["external_id"] = actual_contract["external_id"]
                existing_contract["currency"] = actual_contract["currency"]
                existing_contract["is_active"] = actual_contract["is_active"]
                existing_contract["date_start"] = actual_contract["date_start"]
                existing_contract["date_end"] = actual_contract["date_end"]
                existing_contract["payment_type"] = actual_contract["payment_type"]
            else:
                existing_contract["is_active"] = False

        for contract_id, contract_data in actual_contracts_by_id.items():
            if contract_id not in existing_contract_ids:
                existing_contracts.append(
                    {
                        "id": contract_data["id"],
                        "external_id": contract_data["external_id"],
                        "client_id": client_id,
                        "currency": contract_data["currency"],
                        "is_active": contract_data["is_active"],
                        "date_start": contract_data["date_start"],
                        "date_end": contract_data["date_end"],
                        "payment_type": contract_data["payment_type"],
                        "preferred": False,
                    }
                )

        async with self.connection(con) as con:
            await con.executemany(
                sqls.upsert_contract,
                list(
                    (
                        contract["id"],
                        contract["external_id"],
                        contract["client_id"],
                        DbEnumConverter.from_enum(contract["currency"]),
                        contract["is_active"],
                        contract["date_start"],
                        contract["date_end"],
                        DbEnumConverter.from_enum(contract["payment_type"]),
                        contract["preferred"],
                    )
                    for contract in existing_contracts
                ),
            )

    async def insert_client(
        self,
        client_id: int,
        name: str,
        email: str,
        phone: str,
        account_manager_id: Optional[int],
        domain: str,
        partner_agency_id: Optional[int],
        has_accepted_offer: Optional[bool],
        created_from_cabinet: bool = False,
    ) -> dict:
        async with self.connection() as con:
            client = await con.fetchrow(
                sqls.insert_client,
                client_id,
                name,
                email,
                phone,
                account_manager_id,
                domain,
                partner_agency_id,
                has_accepted_offer,
                created_from_cabinet,
            )

        return dict(client)

    async def set_account_manager_for_client(
        self, client_id: int, account_manager_id: int
    ) -> None:
        async with self.connection() as con:

            updated = await con.fetchrow(
                sqls.set_account_manager_for_client, client_id, account_manager_id
            )
            if not updated:
                raise ClientsDoNotExist(client_ids=[client_id])

    async def set_representatives_for_client(
        self, client_id: int, representatives: List[int]
    ) -> None:
        async with self.connection() as con:
            updated = await con.fetchrow(
                sqls.set_representatives_for_client, client_id, representatives
            )
            if not updated:
                raise ClientsDoNotExist(client_ids=[client_id])

    async def list_agencies(self) -> List[Dict]:
        async with self.connection() as con:
            results = await con.fetch(sqls.list_agencies)

        return list(map(dict, results))

    async def list_agency_clients(self, agency_id: Optional[int]) -> List[Dict]:
        async with self.connection() as con:
            results = await con.fetch(sqls.list_agency_clients, agency_id)

        return list(map(dict, results))

    async def list_account_manager_clients(self, account_manager_id: int) -> List[Dict]:
        async with self.connection() as con:
            results = await con.fetch(
                sqls.list_account_manager_clients, account_manager_id
            )

        return list(map(dict, results))

    async def add_clients_to_agency(
        self, client_ids: List[int], agency_id: Optional[int]
    ) -> None:
        if not client_ids:
            return

        try:
            async with self.connection() as con:
                async with con.transaction():
                    await con.executemany(
                        sqls.upsert_agency_client,
                        list((agency_id, client_id) for client_id in client_ids),
                    )
        except ForeignKeyViolationError as exc:
            if exc.constraint_name == "fk_agency_clients_agency_id_clients":
                raise AgencyDoesNotExist(agency_id=agency_id)
            elif exc.constraint_name == "fk_agency_clients_client_id_clients":
                async with self.connection() as con:
                    missing_ids = await self._list_missing_client_ids(
                        con, set(client_ids)
                    )
                raise ClientsDoNotExist(client_ids=sorted(missing_ids))
            else:
                raise

    async def remove_clients_from_agency(
        self, client_ids: List[int], agency_id: Optional[int]
    ) -> None:
        if agency_id is not None:
            sql = sqls.remove_clients_from_agency
            params = (client_ids, agency_id)
        else:
            sql = sqls.remove_clients_from_internal
            params = (client_ids,)
        async with self.connection() as con:
            await con.execute(sql, *params)

    @staticmethod
    async def _list_missing_client_ids(
        con: "Connection", client_ids: Set[int]
    ) -> Set[int]:
        existent = await con.fetchval(sqls.list_existing_client_ids, client_ids)

        return client_ids - set(existent)

    async def client_is_in_agency(self, client_id: int, agency_id: int) -> bool:
        if agency_id is not None:
            sql = sqls.client_is_in_agency
            params = (client_id, agency_id)
        else:
            sql = sqls.client_is_in_internal
            params = (client_id,)

        async with self.connection() as con:
            return await con.fetchval(sql, *params)

    async def find_contract(self, contract_id: int) -> Optional[dict]:
        async with self.connection() as con:
            contract = await con.fetchrow(sqls.find_contract_by_id, contract_id)

        return self._process_data(contract) if contract else None

    async def list_contacts_by_client(
        self, client_id: int, con: Optional[Connection] = None
    ) -> List[dict]:
        async with self.connection(con) as con:
            contracts = await con.fetch(sqls.list_contacts_by_client, client_id)

        return list(map(self._process_data, contracts))

    async def set_client_has_accepted_offer(self, client_id: int, is_agency: bool):
        async with self.connection() as con:
            updated = await con.fetchrow(
                sqls.set_client_has_accepted_offer, client_id, is_agency
            )
            if not updated:
                if is_agency:
                    raise AgencyDoesNotExist(agency_id=client_id)
                else:
                    raise ClientDoesNotExist(client_id=client_id)

    async def list_clients(self) -> List[Dict]:
        async with self.connection() as con:
            results = await con.fetch(sqls.list_clients)

        return list(map(dict, results))

    async def list_clients_with_agencies(self, client_ids: List[int]) -> List[int]:
        async with self.connection() as con:
            clients_with_agencies = await con.fetch(
                sqls.list_clients_with_agencies, client_ids
            )
        return list(map(itemgetter("client_id"), clients_with_agencies))

    async def list_clients_with_orders_with_agency(
        self, client_ids: List[int], agency_id: Optional[int]
    ) -> List[int]:
        async with self.connection() as con:
            clients_with_orders_with_agnecy = await con.fetch(
                sqls.list_clients_with_orders_with_agency, client_ids, agency_id
            )
        return list(map(itemgetter("client_id"), clients_with_orders_with_agnecy))
