import io
import logging
from enum import Enum
from typing import Optional

from aioboto3 import Session
from library.python.awssdk_async_extensions.lib.core import tvm2_session
from botocore.exceptions import ClientError

from intranet.trip.src.api.auth import get_tvm_service_ticket
from intranet.trip.src.config import settings

logger = logging.getLogger(__name__)


class TVMHandler(object):
    def __init__(self, tvm_service_name: str):
        self.dst_name = tvm_service_name
        self.self_id = settings.tvm_services.get('trip')

    async def get_service_ticket(self):
        return await get_tvm_service_ticket(self.dst_name)


class S3Response:
    class S3ResponseCode(Enum):
        OK = 200
        NOT_FOUND = 404
        ERROR = 500

    def __init__(self, code=S3ResponseCode.OK, data=None, error=None):
        self.code = code
        self.data = data
        self.error = error


class AsyncS3Client(object):
    def __init__(
        self,
        tvm_handler: TVMHandler,
        bucket_name: str,
        retry_options: Optional[dict[str, int]] = None,
    ):
        self._inited = False
        self.tvm_handler = tvm_handler
        self.bucket_name = bucket_name

        self.s3_session: Optional[Session] = None

        self.retry_options = retry_options or {}

    async def initialize(self):
        logger.debug('initing S3 client')
        self.s3_session = await tvm2_session(
            self.tvm_handler.get_service_ticket,
            self.tvm_handler.self_id,
        )
        self._inited = True

    async def _retry(self, func):
        attempts = self.retry_options.get('attempts', 3)
        last_exception = None
        while attempts > 0:
            attempts -= 1
            try:
                return await func()
            except Exception as exception:
                last_exception = exception

            logger.debug("s3 request failed with exception '{}'".format(str(last_exception)))

        return S3Response(
            error=last_exception,
            code=S3Response.S3ResponseCode.ERROR,
        )

    async def _try_get_object(self, key):
        with io.BytesIO() as fl:
            try:
                async with self.s3_session.client('s3', endpoint_url=settings.MDS_S3_ENDPOINT_URL) as s3:
                    await s3.download_fileobj(self.bucket_name, key, fl)
            except ClientError as e:
                if e.response['Error']['Code'] == '404':
                    return S3Response(
                        code=S3Response.S3ResponseCode.NOT_FOUND
                    )
                else:
                    raise

            return S3Response(
                data=fl.getvalue(),
                code=S3Response.S3ResponseCode.OK,
            )

    async def _try_put_object(self, key, data):
        async with self.s3_session.client('s3', endpoint_url=settings.MDS_S3_ENDPOINT_URL) as s3:
            await s3.upload_fileobj(data, self.bucket_name, key)

    async def get(self, key):
        if not self._inited:
            await self.initialize()

        async def func():
            return await self._try_get_object(key)

        return await self._retry(func)

    async def put(self, key, data):
        if not self._inited:
            await self.initialize()

        async def func():
            return await self._try_put_object(key, data)

        return await self._retry(func)
