import glob
import logging
import os.path
from datetime import datetime, timezone
from operator import itemgetter
from typing import List, Optional, Tuple
from urllib.parse import urlparse

from email_validator import EmailNotValidError, validate_email
from marshmallow import Schema, ValidationError, fields, validate, validates

from maps_adv.common.email_sender import (
    Client as EmailClient,
    EmailSenderError,
    MailingListSource,
)
from maps_adv.geosmb.clients.facade import CouponStatus, FacadeIntClient

from .data_manager import BaseDataManager
from .enums import ScenarioName, SubscriptionStatus
from .exceptions import CurrentCouponDuplicate, CurrentStatusDuplicate


class MessageGroupRequiredFieldsSchema(Schema):
    subject = fields.String(validate=validate.Length(min=1))


class MessageRequiredFieldsSchema(Schema):
    recipient = fields.String(validate=validate.Length(min=1))
    template_vars = fields.Dict(keys=fields.String())

    def __init__(
        self, *args, allowed_button_url_domains: Optional[Tuple[str]], **kwargs
    ):
        super().__init__(*args, **kwargs)
        self._allowed_button_url_domains = allowed_button_url_domains

    @validates("recipient")
    def validate_recipient_is_email(self, value: str):
        try:
            validate_email(value)
        except EmailNotValidError:
            raise ValidationError("Invalid email")

    @validates("template_vars")
    def validate_button_url_template_arg(self, value: dict):
        if "button_url" not in value or self._allowed_button_url_domains is None:
            return

        try:
            parsed_button_url = urlparse(value["button_url"])
        except ValueError:
            raise ValidationError("button_url is invalid")

        if parsed_button_url.hostname not in self._allowed_button_url_domains:
            raise ValidationError("domain not allowed for button_url")


class EmailTemplatesManager:
    TEMPLATE_EXT = ".tpl.html"

    def __init__(self, templates_dir: str):
        self._templates_dir = templates_dir
        self._templates = None

    def load_templates(self) -> "EmailTemplatesManager":
        self._templates = {}
        for fpath in glob.glob(
            os.path.join(self._templates_dir, f"*{self.TEMPLATE_EXT}")
        ):
            with open(fpath, "rt", encoding="utf-8") as f:
                template_key = os.path.basename(fpath).rsplit(
                    self.TEMPLATE_EXT, maxsplit=1
                )[0]
                self._templates[template_key] = f.read()

        return self

    def __getitem__(self, item: str) -> str:
        if self._templates is None:
            raise Exception("Template files not loaded")

        return self._templates[item]


class MessagesGroupInvalid(Exception):
    pass


class Domain:
    __slots__ = [
        "_dm",
        "_facade_client",
        "_email",
        "_email_templates",
        "_schedule_promo_campaign_params",
        "_allowed_button_url_domains",
    ]

    _dm: BaseDataManager
    _facade_client: FacadeIntClient
    _email: EmailClient
    _email_templates: EmailTemplatesManager
    _schedule_promo_campaign_params: dict
    _allowed_button_url_domains: Optional[Tuple[str]]

    def __init__(
        self,
        *,
        dm: BaseDataManager,
        facade_client: FacadeIntClient,
        email_client: EmailClient,
        schedule_promo_campaign_params: dict,
        allowed_button_url_domains: Optional[Tuple[str]],
    ):
        self._dm = dm

        self._facade_client = facade_client

        self._email = email_client
        self._email_templates = EmailTemplatesManager(
            os.path.join(os.getcwd(), "templates")
        ).load_templates()

        self._schedule_promo_campaign_params = schedule_promo_campaign_params
        self._allowed_button_url_domains = allowed_button_url_domains

    async def list_scenarios(self, *, biz_id: int) -> List[dict]:
        return await self._dm.list_scenarios(biz_id=biz_id)

    async def create_subscription(
        self,
        *,
        biz_id: int,
        scenario_name: ScenarioName,
        coupon_id: Optional[int] = None,
    ) -> dict:
        return await self._dm.create_subscription(
            biz_id=biz_id, scenario_name=scenario_name, coupon_id=coupon_id
        )

    async def retrieve_subscription(self, *, subscription_id: int, biz_id: int) -> dict:
        return await self._dm.retrieve_subscription(
            subscription_id=subscription_id, biz_id=biz_id
        )

    async def update_subscription_status(
        self, *, subscription_id: int, biz_id: int, status: SubscriptionStatus
    ):
        subscription = await self._dm.retrieve_subscription_current_state(
            subscription_id=subscription_id, biz_id=biz_id
        )

        if subscription["status"] == status:
            raise CurrentStatusDuplicate()

        await self._dm.update_subscription_status(
            subscription_id=subscription_id, biz_id=biz_id, status=status
        )

    async def replace_subscription_coupon(
        self, *, subscription_id: int, biz_id: int, coupon_id: Optional[int] = None
    ):
        subscription = await self._dm.retrieve_subscription_current_state(
            subscription_id=subscription_id, biz_id=biz_id
        )

        if subscription["coupon_id"] == coupon_id:
            raise CurrentCouponDuplicate()

        new_status = (
            SubscriptionStatus.ACTIVE if coupon_id else SubscriptionStatus.COMPLETED
        )

        await self._dm.replace_subscription_coupon(
            subscription_id=subscription_id,
            biz_id=biz_id,
            coupon_id=coupon_id,
            status=new_status,
        )

        return new_status

    async def iter_subscriptions_for_export(self, chunk_size: int):
        async for subscriptions in self._dm.iter_subscriptions_for_export(chunk_size):
            coupon_ids = [subscription["coupon_id"] for subscription in subscriptions]
            coupons = await self._facade_client.list_coupons_statuses(
                coupon_ids=coupon_ids
            )

            unknown_coupon_ids = set(coupon_ids) - {
                coupon["coupon_id"] for coupon in coupons
            }
            if unknown_coupon_ids:
                bad_subscriptions = [
                    dict(
                        subscription_id=sub["subscription_id"],
                        coupon_id=sub["coupon_id"],
                    )
                    for sub in subscriptions
                    if sub["coupon_id"] in unknown_coupon_ids
                ]
                bad_subscriptions = sorted(
                    list(bad_subscriptions), key=itemgetter("subscription_id")
                )

                logging.getLogger(__name__).warning(
                    f"Unknown coupons for facade service: {bad_subscriptions}"
                )

            active_coupon_ids = {
                coupon["coupon_id"]
                for coupon in coupons
                if coupon["status"] == CouponStatus.RUNNING
            }
            completed_coupon_ids = {
                coupon["coupon_id"]
                for coupon in coupons
                if coupon["status"] == CouponStatus.FINISHED
            }

            subscription_ids_for_close = [
                sub["subscription_id"]
                for sub in subscriptions
                if sub["coupon_id"] in completed_coupon_ids
            ]
            await self._dm.update_subscriptions_statuses(
                subscription_ids=subscription_ids_for_close,
                status=SubscriptionStatus.COMPLETED,
            )

            subscriptions_for_export = [
                sub for sub in subscriptions if sub["coupon_id"] in active_coupon_ids
            ]

            yield subscriptions_for_export

    async def process_unsent_emails(self):
        now = datetime.now(tz=timezone.utc)
        message_groups = await self._dm.list_unprocessed_email_messages()
        async with self._email as email_client:
            for message_group in message_groups:
                # Check common group params
                try:
                    (
                        time_to_send,
                        subject,
                        template_content,
                    ) = self._parse_email_messages_group(message_group)
                except MessagesGroupInvalid as e:
                    await self._dm.mark_messages_processed(
                        dict.fromkeys(
                            map(itemgetter("id"), message_group["messages"]), e.args[0]
                        ),
                        now,
                        None,
                    )
                    continue

                valid_messages, invalid_messages = self._validate_email_messages(
                    message_group["messages"]
                )

                send_result = None
                process_results = invalid_messages
                if valid_messages:
                    valid_message_ids = list(map(itemgetter("id"), valid_messages))
                    try:
                        send_result = await email_client.schedule_promo_campaign(
                            subject=subject,
                            body=template_content,
                            mailing_list_source=MailingListSource.IN_PLACE,
                            mailing_list_params=list(
                                {
                                    "email": message["recipient"],
                                    "params": message["template_vars"],
                                }
                                for message in valid_messages
                            ),
                            tags=tuple(
                                map(itemgetter("message_anchor"), valid_messages)
                            ),
                            schedule_dt=time_to_send,
                            **self._schedule_promo_campaign_params,
                        )
                    except EmailSenderError as e:
                        process_results.update(
                            dict.fromkeys(
                                valid_message_ids, f"EmailSender error: {str(e)}"
                            )
                        )
                    except Exception as e:
                        logging.getLogger("geosmb.scenarist.mailing_emails").error(
                            f"Failed to send email messages %s because email_client raised {e.__class__.__name__}",  # noqa
                            ",".join(map(str, valid_message_ids)),
                        )
                    else:
                        process_results.update(dict.fromkeys(valid_message_ids))

                # Mark messages as processed
                if process_results:
                    await self._dm.mark_messages_processed(
                        process_results, now, send_result
                    )

    def _parse_email_messages_group(
        self, message_group: dict
    ) -> Tuple[datetime, str, str]:
        if message_group["time_to_send"] < datetime.now(tz=timezone.utc):
            logging.getLogger("geosmb.scenarist.mailing_emails").error(
                "Too late to send email messages %s",
                ",".join(map(lambda m: str(m["id"]), message_group["messages"])),
            )
            raise MessagesGroupInvalid("Too late to send")

        validation_errors = MessageGroupRequiredFieldsSchema().validate(message_group)

        if validation_errors:
            raise MessagesGroupInvalid(
                ";".join(
                    f"{field_name}: {', '.join(field_errors)}"
                    for field_name, field_errors in validation_errors.items()
                )
            )

        try:
            template_content = self._email_templates[message_group["template_name"]]
        except KeyError:
            raise MessagesGroupInvalid("Failed to get template content")

        return (
            message_group["time_to_send"],
            message_group["subject"],
            template_content,
        )

    def _validate_email_messages(self, messages: List[dict]) -> Tuple[list, dict]:
        valid_messages, invalid_messages = [], {}

        for message in messages:
            validation_errors = MessageRequiredFieldsSchema(
                allowed_button_url_domains=self._allowed_button_url_domains
            ).validate(message)
            if validation_errors:
                invalid_messages[message["id"]] = ";".join(
                    f"{field_name}: {', '.join(field_errors)}"
                    for field_name, field_errors in validation_errors.items()
                )
            else:
                valid_messages.append(message)

        return valid_messages, invalid_messages
