import logging
import os

import boto3
import enum

from sandbox import sdk2
from sandbox.projects.music.deployment.helpers.MusicBaseTask import MusicBaseTask


class MusicUpdateQaSamples(MusicBaseTask):
    bucket_name = 'music-samples'

    qa_endpoint_url = 'https://s3.mdst.yandex.net'
    prod_endpoint_url = 'https://s3.mds.yandex.net'

    class Env(enum.Enum):
        QA = 1,
        PRODUCTION = 2

    class Parameters(sdk2.task.Parameters):
        qa_s3_secret = sdk2.parameters.YavSecret(
            'Yav qa secret with S3 access key',
            default='sec-01dx0v6rb3wavq387m5tes207e'
        )

        prod_s3_secret = sdk2.parameters.YavSecret(
            'Yav prod secret with S3 access key',
            default='sec-01dx0vbyda03xta8vzf3gaanp1'
        )

    def on_execute(self):
        prod_sample_keys = self.load_production_samples()
        self.upload_sample(prod_sample_keys)
        self.delete_local_files(prod_sample_keys)
        self.set_info('All done')

    @staticmethod
    def delete_local_files(prod_sample_keys):
        logging.info('Removing all local files')
        for file_name in prod_sample_keys:
            os.remove(file_name)

    def load_production_samples(self):
        logging.info('Loading samples from production')
        s3, bucket = self.create_s3_bucket(self.Env.PRODUCTION)

        sample_keys = []
        for s3_file in bucket.objects.all():
            with open(s3_file.key, 'wb') as local_file:
                bucket.download_fileobj(s3_file.key, local_file)
                sample_keys.append(s3_file.key)

        logging.info('Downloaded {count} samples'.format(count=len(sample_keys)))
        return sample_keys

    def upload_sample(self, prod_sample_keys):
        s3, bucket = self.create_s3_bucket(self.Env.QA)

        logging.info('Removing all samples from qa')
        for s3_file in bucket.objects.all():
            obj = s3.Object(s3_file.bucket_name, s3_file.key)
            obj.delete()

        logging.info('Uploading samples to qa')
        for sample_key in prod_sample_keys:
            with open(sample_key, 'rb') as data:
                bucket.upload_fileobj(data, sample_key)

    def create_s3_bucket(self, env):
        s3_secret = self.Parameters.prod_s3_secret.data() if env == self.Env.PRODUCTION else self.Parameters.qa_s3_secret.data()
        endpoint_url = self.prod_endpoint_url if env == self.Env.PRODUCTION else self.qa_endpoint_url
        session = boto3.Session(aws_access_key_id=s3_secret['aws.accessKeyId'],
                                aws_secret_access_key=s3_secret['aws.secretKey'])
        s3 = session.resource('s3', endpoint_url=endpoint_url)
        bucket = s3.Bucket(self.bucket_name)
        return s3, bucket
