import enum

from django.db import models, transaction


class FixOrderItemTariffParams(models.Model):

    cost = models.DecimalField(max_digits=10, decimal_places=2)

    class Meta:
        db_table = 'order_item_tariff_fix_params'

    def __eq__(self, other):
        return self.cost == other.cost

    def __str__(self):
        return '<FixOrderItemTariffParams: cost={}>'.format(self.cost)


class PerMinuteOrderItemTariffParams(models.Model):

    cost_per_minute = models.DecimalField(max_digits=10, decimal_places=2)

    class Meta:
        db_table = 'order_item_tariff_per_minute_params'

    def __eq__(self, other):
        return self.cost_per_minute == other.cost_per_minute

    def __str__(self):
        return '<PerMinuteOrderItemTariffParams: cost_per_minute={}>'.format(self.cost_per_minute)


class OrderItemTariff(models.Model):

    class Type(enum.Enum):
        FIX = 'fix'
        PER_MINUTE = 'per_minute'

    type = models.CharField(max_length=16, choices=[(x.value, x.name) for x in Type])

    fix_params = models.OneToOneField(
        FixOrderItemTariffParams,
        null=True,
        on_delete=models.CASCADE,
        related_name='order_item',
    )
    per_minute_params = models.OneToOneField(
        PerMinuteOrderItemTariffParams,
        null=True,
        on_delete=models.CASCADE,
        related_name='order_item',
    )

    class Meta:
        db_table = 'order_item_tariff'
        db_constraints = {
            'exclusive_arc_chk': (
                '''
                CHECK (
                  (type = 'fix' AND fix_params_id IS NOT NULL)
                  OR
                  (type = 'per_minute' AND per_minute_params_id IS NOT NULL)
                  AND (
                    CASE WHEN fix_params_id IS NOT NULL THEN 1 ELSE 0 END
                    +
                    CASE WHEN per_minute_params_id IS NOT NULL THEN 1 ELSE 0 END
                  ) = 1
                )
                '''
            ),
        }

    def __eq__(self, other):
        return self.get_type() is other.get_type() and self.get_params() == other.get_params()

    def __ne__(self, other):
        return not self == other

    def __repr__(self):
        return '<OrderItemTariff: type={}, params={}>'.format(
            self.get_type().name,
            self.get_params(),
        )

    def get_type(self):
        return self.Type(self.type)

    def get_params(self):
        type_ = self.get_type()
        if type_ is self.Type.FIX:
            params = self.fix_params
        elif type_ is self.Type.PER_MINUTE:
            params = self.per_minute_params
        else:
            raise RuntimeError('unreachable: {}'.format(type_))
        return params

    def save(self, *args, **kwargs):
        with transaction.atomic():
            if self.fix_params_id is None and self.fix_params is not None:
                self.fix_params.save(*args, **kwargs)
                self.fix_params = self.fix_params
            if self.per_minute_params_id is None and self.per_minute_params is not None:
                self.per_minute_params.save(*args, **kwargs)
                self.per_minute_params = self.per_minute_params
            return super().save(*args, **kwargs)
