from marshmallow import ValidationError, fields, validate, validates_schema
from marshmallow.validate import Length
from marshmallow_enum import EnumField

from mail.payments.payments.api.schemas.base import (
    CURRENCY_RUB, BasePaginatedRequestSchema, BaseSchema, SuccessResponseSchema
)
from mail.payments.payments.api.schemas.moderation import ModerationSchema
from mail.payments.payments.core.entities.enums import NDS, MerchantOAuthMode, PeriodUnit


class SubscriptionPriceSchema(BaseSchema):
    price = fields.Decimal(required=True)
    currency = fields.String(validate=validate.Equal(CURRENCY_RUB), required=True)
    region_id = fields.Integer(required=True)


class SubscriptionSchema(BaseSchema):
    uid = fields.Integer(dump_only=True)
    subscription_id = fields.Integer(dump_only=True)

    title = fields.String(required=True)
    fiscal_title = fields.String(required=True)
    nds = EnumField(NDS, required=True, by_value=True)
    period_amount = fields.Integer(required=True)
    period_units = EnumField(PeriodUnit, required=True, by_value=True)
    trial_period_amount = fields.Integer()
    trial_period_units = EnumField(PeriodUnit, by_value=True)
    prices = fields.List(fields.Nested(SubscriptionPriceSchema), validate=(Length(min=1),), required=True)
    merchant_oauth_mode = EnumField(MerchantOAuthMode, by_value=True, load_from='mode', dump_to='mode')
    fast_moderation = fields.Boolean(missing=False)

    moderation = fields.Nested(ModerationSchema, dump_only=True)

    created = fields.DateTime(dump_only=True)
    updated = fields.DateTime(dump_only=True)

    enabled_customer_subscriptions = fields.Integer(dump_only=True)
    deleted = fields.Boolean(dump_only=True)

    @validates_schema
    def validate_schema(self, data):
        region_ids = [price['region_id'] for price in (data.get('prices') or [])]
        if len(set(region_ids)) != len(region_ids):
            raise ValidationError('region_id in prices must be unique')

        trial_period_amount, trial_period_units = data.get('trial_period_amount'), data.get('trial_period_units')
        if (trial_period_amount is None) is not (trial_period_units is None):
            raise ValidationError(
                'trial_period_amount and trial_period_units must be not empty or empty simultaneously'
            )


class GetSubscriptionRequestSchema(BasePaginatedRequestSchema):
    pass


class GetSubscriptionResponseSchema(SuccessResponseSchema):
    data = fields.Nested(SubscriptionSchema, many=True)


class SubscriptionResponseSchema(SuccessResponseSchema):
    data = fields.Nested(SubscriptionSchema)


get_subscription_request_schema = GetSubscriptionRequestSchema()
get_subscription_response_schema = GetSubscriptionResponseSchema()

post_subscription_request_schema = SubscriptionSchema()
subscription_response_schema = SubscriptionResponseSchema()
