import enum
from typing import List

from marshmallow import fields
from marshmallow_enum import EnumField

from maps_adv.stat_controller.client.lib.base import (
    BaseClient,
    FindTaskSchema,
    NoTasksFound,
    StrictSchema,
    StructOfPrimitives,
    TaskSchema,
    UnknownResponse,
    UpdateTaskSchema,
    with_schemas,
)

from .base.client import async_shield

__all__ = ["Client", "TaskStatus", "UnknownResponse", "NoChargerTaskFound"]


class TaskStatus(enum.Enum):
    accepted = "accepted"
    context_received = "context_received"
    calculation_completed = "calculation_completed"
    billing_notified = "billing_notified"
    charged_data_sent = "charged_data_sent"
    completed = "completed"

    def slice(self) -> List["TaskStatus"]:
        statuses = list(self.__class__)
        current_index = statuses.index(self) + 1
        return statuses[current_index:]


class CampaignPayloadSchema(StrictSchema):
    campaign_id = fields.Integer(required=True)
    cpm = fields.Decimal(required=True, as_string=True)
    budget = fields.Decimal(required=True, as_string=True, allow_nan=True)
    daily_budget = fields.Decimal(required=True, as_string=True, allow_nan=True)
    charged = fields.Decimal(required=True, as_string=True)
    charged_daily = fields.Decimal(required=True, as_string=True)
    events_count = fields.Integer(required=True)
    cost_per_event = fields.Decimal(required=False, as_string=True, allow_none=True)
    cost_per_last_event = fields.Decimal(
        required=False, as_string=True, allow_none=True
    )
    events_to_charge = fields.Integer(required=False)
    tz_name = fields.String(requred=True)


class OrderPayloadSchema(StrictSchema):
    order_id = fields.Integer(required=True, allow_none=True)
    budget_balance = fields.Decimal(required=True, as_string=True, allow_nan=True)
    amount_to_bill = fields.Decimal(required=False, as_string=True, allow_none=True)
    billing_success = fields.Bool(required=False, allow_none=True)
    campaigns = fields.List(fields.Nested(CampaignPayloadSchema), required=True)


class UpdateChargerTaskSchema(UpdateTaskSchema):
    status = EnumField(TaskStatus)
    execution_state = StructOfPrimitives(required=False, allow_none=True)


class ChargerTaskSchema(TaskSchema):
    status = EnumField(TaskStatus, required=True)
    execution_state = fields.List(
        fields.Nested(OrderPayloadSchema), required=False, allow_none=True
    )


class NoChargerTaskFound(NoTasksFound):
    pass


class Client(BaseClient):
    @with_schemas(FindTaskSchema, ChargerTaskSchema)
    async def find_new_task(self, executor_id: str) -> dict:
        json = {"executor_id": executor_id}

        try:
            return await self._request("POST", "/tasks/charger/", 201, json)
        except UnknownResponse as exc:
            if exc.status_code == 200 and exc.payload == b"{}":
                raise NoChargerTaskFound
            raise

    @async_shield
    @with_schemas(UpdateChargerTaskSchema, TaskSchema)
    async def update_task(
        self, task_id: int, status: str, executor_id: str, execution_state: List[dict]
    ) -> dict:
        json = {
            "status": status,
            "executor_id": executor_id,
            "execution_state": execution_state,
        }

        return await self._request("PUT", f"/tasks/charger/{task_id}/", 200, json)
