# coding: utf-8

import datetime
import logging
import os
import shutil

from sandbox.common import errors
import sandbox.common.types.client as ctc
import sandbox.common.types.task as ctt
import sandbox.common.types.misc as ctm
import sandbox.sdk2 as sdk2
from sandbox.projects.common.environments import MongodbEnvironment
from sandbox.projects.common import decorators as deco
from sandbox.projects.Afisha.base import AfishaSandboxBaseTaskBinary


class AfishaMongoDump(AfishaSandboxBaseTaskBinary):
    BINARY_TASK_ATTR_TARGET = 'Afisha/infra/AfishaMongoDump'
    MONGO_PORT = '27018'

    class Requirements(sdk2.Requirements):
        dns = ctm.DnsType.DNS64
        disk_space = 300 * 1024  # 300 GB
        environments = (
            MongodbEnvironment(),
        )
        client_tags = ctc.Tag.LINUX_XENIAL

    class Parameters(AfishaSandboxBaseTaskBinary.Parameters):
        kill_timeout = 300
        mdb_folder_id = sdk2.parameters.String('MDB folder ID', required=True)
        mdb_cluster = sdk2.parameters.String('MDB cluster filter', required=True)
        user = sdk2.parameters.String('User for mongo_dump', required=True)
        secret_mdb = sdk2.parameters.YavSecretWithKey('Secret with key for MDB token', required=True)
        secret_mongo = sdk2.parameters.YavSecretWithKey('Secret with key for MongoDB user password', required=True)
        secret_s3_key_id = sdk2.parameters.YavSecretWithKey('Secret with key for S3 Access key ID', required=True)
        secret_s3_secret_key = sdk2.parameters.YavSecretWithKey('Secret with key for S3 Secret access key', required=True)
        db_collections_dict = sdk2.parameters.Dict("DBs and collections to dump: key = 'db_name', value = 'collection1;collection2...'", required=True)
        s3_endpoint = sdk2.parameters.String("S3: endpoint", required=True, default="https://s3.mds.yandex.net")
        s3_bucket = sdk2.parameters.String("S3: bucket", required=True)
        s3_prefix = sdk2.parameters.String("S3: prefix", required=True)
        dump_prefix = sdk2.parameters.String("Prefix for dump naming", required=False, default="mongo-dump")

    def on_enqueue(self):
        component_flow_lock = ctt.Semaphores.Acquire(
            name='{}_{}_{}_{}'.format(
                'afisha_mongo_dump',
                self.Parameters.mdb_folder_id,
                self.Parameters.mdb_cluster,
                self.Parameters.s3_bucket,
            ),
            weight=1,
            capacity=1
        )
        release = (ctt.Status.Group.BREAK, ctt.Status.Group.FINISH)
        self.Requirements.semaphores = ctt.Semaphores(acquires=[component_flow_lock], release=release)

    @deco.retries(max_tries=3, delay=60)
    def _get_certificate(self):
        logging.info('Getting CA certificate')
        curl_command = 'curl https://crls.yandex.net/allCAs.pem --output mdb.pem'
        with sdk2.helpers.ProcessLog(self, logger='curl') as pl:
            return_code = sdk2.helpers.subprocess.Popen(
                curl_command.split(),
                stdout=pl.stdout,
                stderr=pl.stderr
            ).wait()
        if return_code != 0:
            raise Exception('Get certificate process failed!')

    def on_execute(self):
        super(AfishaMongoDump, self).on_execute()
        import boto3
        from afisha.infra.libs.mdb import MdbClientMongo

        self._get_certificate()
        mdb_token = self.Parameters.secret_mdb.value()
        mdb = MdbClientMongo(token=mdb_token)

        s3_access_key_id = self.Parameters.secret_s3_key_id.value()
        s3_access_secret_key = self.Parameters.secret_s3_secret_key.value()

        session = boto3.session.Session(
            aws_access_key_id=s3_access_key_id,
            aws_secret_access_key=s3_access_secret_key,
        )
        s3_client = session.client(
            's3',
            endpoint_url=self.Parameters.s3_endpoint,
        )

        self.s3_client = s3_client
        self.mdb = mdb

        logging.info('Getting clusters from MDB folder: {}'.format(self.Parameters.mdb_folder_id))
        clusters = mdb.cluster_list(folder_id=self.Parameters.mdb_folder_id,
                                    filter=self.Parameters.mdb_cluster)

        if not clusters:
            raise errors.TaskFailure('No clusters matched')
        if len(clusters) > 1:
            clusters_names = [cluster.name for cluster in clusters]
            raise errors.TaskFailure('Found more then 1 cluster beginning with {}: {}'.format(
                self.Parameters.mdb_cluster,
                ",".join(clusters_names)
            ))

        try:
            cluster = clusters[0]
            hosts = mdb.cluster_hosts_list(cluster_id=cluster.id)
            shards = mdb.cluster_shards_list(cluster_id=cluster.id)
            mdb_databases = mdb.database_list(cluster_id=cluster.id)
            databases = [db for db in self.Parameters.db_collections_dict.keys() if db in mdb_databases]
            if len(shards) > 1:
                raise errors.TaskFailure(
                    "Trying to dump databases in sharded cluster {}".format(cluster)
                )
            elif shards:
                shard = shards[0]
                hosts_list = shard + '/' + ','.join(list(map(lambda x: x + ':' + self.MONGO_PORT, hosts)))
            else:
                hosts_list = ','.join(list(map(lambda x: x + ':' + self.MONGO_PORT, hosts)))

            for database in databases:
                if self.Parameters.db_collections_dict[database]:
                    collections = self.Parameters.db_collections_dict[database].split(';')
                    dump_name = self.run_mongodump(database, hosts_list, collections=collections)
                else:
                    dump_name = self.run_mongodump(database, hosts_list)
                self.upload_to_s3("{}.tar.gz".format(dump_name))
                os.remove("{}.tar.gz".format(dump_name))
                shutil.rmtree(dump_name)
        except Exception as e:
            logging.error('Something went wrong: {}'.format(e))
            raise errors.TaskFailure('Something went wrong. Details:\n{}'.format(e))

    def run_mongodump(self, database, hosts_list, collections=None):
        now = datetime.datetime.now()
        date_string = now.strftime('%d.%m')

        if collections:
            dump_name = "{}-{}-{}-{}".format(self.Parameters.dump_prefix, database, "_".join(collections), date_string)
        else:
            dump_name = "{}-{}-{}".format(self.Parameters.dump_prefix, database, date_string)

        cmd = [
            "mongodump",
            "-vvvvvv",
            "--host", hosts_list,
            "--username", self.Parameters.user,
            "--password", self.Parameters.secret_mongo.value(),
            "--db", database,
            "--out", dump_name,
            "--ssl",
            "--sslCAFile=mdb.pem",
            "--readPreference=secondary"
        ]

        if collections:
            for collection in collections:
                logging.info("dumping %s:%s database:collection", database, collection)
                with sdk2.helpers.ProcessLog(self, logger='mongodump') as pl:
                    return_code = sdk2.helpers.subprocess.Popen(
                        cmd + ["--collection"] + [collection],
                        stdout=pl.stdout,
                        stderr=pl.stderr
                    ).wait()

                if return_code != 0:
                    raise errors.TaskFailure(
                        "mongodump process failed while dumping {}:{}. Details in mongodump logs".format(database, collection)
                    )
        else:
            logging.info("dumping %s database", database)
            with sdk2.helpers.ProcessLog(self, logger='mongodump') as pl:
                return_code = sdk2.helpers.subprocess.Popen(
                    cmd,
                    stdout=pl.stdout,
                    stderr=pl.stderr
                ).wait()

            if return_code != 0:
                raise errors.TaskFailure(
                    "mongodump process failed while dumping {}. Details in mongodump logs".format(database)
                )

        cmd = [
            "tar", "-zcvf",
            "{}.tar.gz".format(dump_name),
            dump_name,
        ]

        with sdk2.helpers.ProcessLog(self, logger='tar_and_gzip') as pl:
            return_code = sdk2.helpers.subprocess.Popen(
                cmd,
                stdout=pl.stdout,
                stderr=pl.stderr
            ).wait()

        if return_code != 0:
            raise errors.TaskFailure("tar + gzip process failed while dumping {}. Details in tar_and_gzip logs".format(dump_name))

        return dump_name

    def upload_to_s3(self, dump_name):
        try:
            self.s3_client.upload_file(dump_name, self.Parameters.s3_bucket, self.Parameters.s3_prefix + dump_name)
        except Exception as e:
            logging.error("Error occured when uploading %s to s3. Details: %s" % (dump_name, e))
            raise errors.TaskFailure("Error occured when uploading %s to s3" % dump_name)
