import aioboto3
import ujson as json

from mail.shiva.stages.api.settings.s3api import S3ApiSettings
from mail.shiva.stages.api.props.shard.helpers import chunks

FOLDER_SEPARATOR = '/'
MID_COUNT_SEPARATOR = '_'
DELETE_CHUNK_SIZE = 100


class Metrics(object):
    def __init__(self, stats):
        self._stats = stats

    def s3_list_objects_v2_ok(self):
        self.increase('s3_list_objects_v2_ok')

    def s3_list_objects_v2_error(self):
        self.increase('s3_list_objects_v2_error')

    def s3_delete_objects_ok(self):
        self.increase('s3_delete_objects_ok')

    def s3_delete_objects_error(self):
        self.increase('s3_delete_objects_error')

    def s3_put_object_ok(self):
        self.increase('s3_put_object_ok')

    def s3_put_object_error(self):
        self.increase('s3_put_object_error')

    def s3_get_object_ok(self):
        self.increase('s3_get_object_ok')

    def s3_get_object_error(self):
        self.increase('s3_get_object_error')

    def increase(self, name):
        self._stats.increase_task_meter(f"{name}")


def get_mid(key):
    mid_and_count = key.split(FOLDER_SEPARATOR)[1]
    mid = mid_and_count.split(MID_COUNT_SEPARATOR)[0]
    return int(mid)


class ArchiveStorage:
    def __init__(self, s3api_settings: S3ApiSettings, s3_id, tvm, stats, auto_refresh_tvm_ticket=True):
        self.endpoint_url = s3api_settings.location
        self.bucket_name = s3api_settings.bucket_name
        self.ca_cert_path = s3api_settings.ca_cert_path
        self.access_key = f'TVM_V2_{tvm.client_id}'
        self.s3_id = s3_id
        self.tvm = tvm
        self.tvm_ticket = None
        self.auto_refresh_tvm_ticket = auto_refresh_tvm_ticket
        self.metrics = Metrics(stats)

    async def refresh_ticket(self):
        self.tvm_ticket = await self.tvm.get(self.s3_id)

    async def _s3_client(self):
        if self.auto_refresh_tvm_ticket:
            await self.refresh_ticket()
        session_token = f'TVM2 {self.tvm_ticket}'
        return aioboto3.Session().client(
            service_name='s3',
            endpoint_url=self.endpoint_url,
            verify=self.ca_cert_path,
            aws_access_key_id=self.access_key,
            aws_secret_access_key='unused',
            aws_session_token=session_token,
        )

    async def list_user_objects(self, uid):
        keys = []
        ctoken = None

        async with await self._s3_client() as s3Client:
            while True:
                args = {'ContinuationToken': ctoken} if ctoken is not None else {}
                try:
                    resp = await s3Client.list_objects_v2(
                        Bucket=self.bucket_name,
                        Prefix=f"{uid}{FOLDER_SEPARATOR}",
                        **args,
                    )
                    if resp['ResponseMetadata']['HTTPStatusCode'] != 200:
                        raise RuntimeError(f"Can't list objects in S3: bad response: {resp}")
                    self.metrics.s3_list_objects_v2_ok()
                except:
                    self.metrics.s3_list_objects_v2_error()
                    raise
                for item in resp.get('Contents', []):
                    keys.append(item['Key'])
                ctoken = resp.get('NextContinuationToken', None)
                if not ctoken:
                    break
        return keys

    async def delete_objects(self, keys):
        if not keys:
            return
        for keys_chunk in chunks(keys, DELETE_CHUNK_SIZE):
            delete_list = [{'Key': key} for key in keys_chunk]
            async with await self._s3_client() as s3Client:
                try:
                    resp = await s3Client.delete_objects(
                        Bucket=self.bucket_name,
                        Delete={'Objects': delete_list},
                    )
                    if resp['ResponseMetadata']['HTTPStatusCode'] != 200:
                        raise RuntimeError(f"Can't delete objects in S3: bad response: {resp}")
                    if resp.get('Errors'):
                        raise RuntimeError(f"Can't delete some objects in S3: {resp}")
                    self.metrics.s3_delete_objects_ok()
                except:
                    self.metrics.s3_delete_objects_error()
                    raise

    async def get_last_saved_mid(self, uid):
        objects_keys = await self.list_user_objects(uid)
        key = max(
            objects_keys,
            key=lambda item: get_mid(item),
            default=None,
        )
        return get_mid(key) if key is not None else 0

    async def save_messages(self, uid, messages):
        last_mid = messages[-1].mid
        data = [mess.saved_data() for mess in messages]
        rec_count = len(data)
        body = json.dumps(data).encode('utf-8')
        obj_key = f'{uid}{FOLDER_SEPARATOR}{last_mid}{MID_COUNT_SEPARATOR}{rec_count}'

        async with await self._s3_client() as s3Client:
            try:
                resp = await s3Client.put_object(
                    Bucket=self.bucket_name,
                    Key=obj_key,
                    Body=body,
                    ContentType='application/json',
                )
                if resp['ResponseMetadata']['HTTPStatusCode'] != 200:
                    raise RuntimeError(f"Can't save messages in S3: bad response: {resp}")
                self.metrics.s3_put_object_ok()
            except:
                self.metrics.s3_put_object_error()
                raise

    async def get_messages(self, key):
        async with await self._s3_client() as s3Client:
            try:
                resp = await s3Client.get_object(
                    Bucket=self.bucket_name,
                    Key=key,
                )
                if resp['ResponseMetadata']['HTTPStatusCode'] != 200:
                    raise RuntimeError(f"Can't get object from S3: bad response: {resp}")
                async with resp['Body'] as stream:
                    data = (await stream.read()).decode('utf-8')
                    return json.loads(data)
                self.metrics.s3_get_object_ok()
            except:
                self.metrics.s3_get_object_error()
                raise
