import collections
import shutil
import logging
import os.path
import datetime as dt
import multiprocessing.dummy as mp

import pymongo

from sandbox import sdk2
from sandbox import common
from sandbox.sdk2.helpers import subprocess as sp

import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.projects.sandbox.resources as sb_resources
import sandbox.projects.sandbox.mongo_suite.consts as consts


DEFAULT_AWS_SECRET = "sec-01f7p8xd4afjqmhy3mbss20fy8"  # zomb-sandbox-mds
AWS_ACCESS_KEY_ID_NAME = "AWS_ACCESS_KEY_ID"
AWS_SECRET_ACCESS_KEY_NAME = "AWS_SECRET_ACCESS_KEY"

MONGO_BACKUP_BUCKET = "sandbox-backup"  # special bucket for backups


class Shard(collections.namedtuple("ShardInfo", ("shard_name", "host", "port"))):
    @property
    def host_port(self):
        return "{}:{}".format(self.host, self.port)

    @property
    def mongo_client(self):
        # older versions of mongodump can't resolve records which point to IPv6-only hosts,
        # but why resolve anything at all if the database is known to run on localhost?
        return pymongo.MongoClient("mongodb://localhost:{}".format(self.port))

    def __str__(self):
        return "{}/{}:{}".format(self.shard_name, self.host, self.port)

    @classmethod
    def from_str(cls, shard_string):
        shard_name, host_port = shard_string.split("/", 1)
        host, port = host_port.rsplit(":", 1)
        return cls(shard_name, host, port)

    def prepare_for_backup(self, fsync):
        logging.debug("Preparing shard %s for backup on %s", self.shard_name, self.host_port)
        self._check_is_not_primary()
        if fsync:
            logging.info("Calling fsync() with lock for %s...", self)
            self.mongo_client.fsync(lock=True)
        else:
            logging.info("Not locking/flushing anything")

    def backup(self, target_path, task):
        shard_backup_path = target_path.joinpath(self.shard_name)
        with sdk2.helpers.ProcessLog(task, logger="mongodump_{}".format(self.shard_name)) as pl:
            mongo_dump_process = sp.Popen(
                [
                    "/usr/bin/nice",
                    "mongodump",
                    "--oplog",
                    "--archive",
                    "--host", self.host,
                    "--port", self.port,
                    "-vvvv",
                ], stdout=sp.PIPE, stderr=pl.stdout
            )
            compress_process = sp.Popen(
                [
                    "pixz", "-p14",
                    "-o", "{}.pixz".format(str(shard_backup_path)),
                ], stdin=mongo_dump_process.stdout
            )
            rc_dump = mongo_dump_process.wait()
            rc_compress = compress_process.wait()
        if rc_dump or rc_compress:
            return self.shard_name, [rc_dump, rc_compress]
        return self.shard_name, []

    def unlock(self):
        logging.info("Unlocking shard %s...", self.shard_name)
        try:
            self.mongo_client.unlock()
            logging.info("%s unlocked", self.shard_name)
        except Exception:
            logging.exception("Exception during unlock/status restore")

    def _check_is_not_primary(self):
        shard_status = filter(
            lambda replicaset_member: replicaset_member["name"] == self.host_port,
            self.mongo_client.admin.command(consts.Command.GET_STATUS)["members"]
        )
        is_primary = next(iter(shard_status), {}).get("stateStr") == consts.ShardState.PRIMARY
        # this should not happen, unless there were re-elections during the task's enqueuing, which is unlikely
        assert not is_primary, "Shard {} on {} is primary! Dangerous situation.".format(self.shard_name, self.host_port)


class ServiceMongoBackuper(sdk2.ServiceTask):
    """ Backup Sandbox MongoDB replica set """

    class Requirements(sdk2.Task.Requirements):
        dns = ctm.DnsType.DNS64

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 8 * 3600
        shards = sdk2.parameters.List(
            "Shards to backup", description="Expected format is shardname/host:port", required=True
        )
        do_fsync = sdk2.parameters.Bool("Call db.fsync() before backup", default=True)
        use_mongodump = sdk2.parameters.Bool("Use mongodump for backup", default=True)
        with use_mongodump.value[False]:
            mongodb_path = sdk2.parameters.String("Mongo DB storage path", default_value="/ssd/mongodb")
        backup_ttl = sdk2.parameters.Integer("For how long should a backup be accessible", default=14)
        aws_credentials = sdk2.parameters.YavSecret(
            "Yav secret with aws credentials", required=True, default_value=DEFAULT_AWS_SECRET,
            description="Should contain keys {} and {}".format(AWS_ACCESS_KEY_ID_NAME, AWS_SECRET_ACCESS_KEY_NAME)
        )

    def _backup_shards_by_mongodump(self, target_path, shards):
        thread_pool = mp.Pool(4)
        shards_backup_results = thread_pool.imap(
            lambda shard: shard.backup(target_path, self),
            shards
        )
        for shard_name, return_codes in shards_backup_results:
            if return_codes:
                return shard_name, return_codes
        return None, None

    def _backup_shards_by_disk_copy(self, target_path, shards):
        mongodb_path = sdk2.path.Path(self.Parameters.mongodb_path)
        if len(os.listdir(str(mongodb_path))) != len(shards):
            raise Exception(
                "Number of folders in '{}' ({}) not equal to number of backup shards ({})".format(
                    mongodb_path, len(os.listdir(str(mongodb_path))), len(shards)
                )
            )
        if not self.Parameters.do_fsync:
            raise Exception("Can't run backup by disk copy without fsync")
        for instance_name in os.listdir(str(mongodb_path)):
            shard_backup_path = target_path.joinpath(instance_name)
            rc_tar = sp.Popen(
                [
                    "tar",
                    "-I", "pixz -p14",
                    "-cf",
                    "{}.tar.pixz".format(str(shard_backup_path)),
                    str(mongodb_path.joinpath(instance_name)),
                ]
            ).wait()
            if rc_tar:
                return instance_name, rc_tar
        return None, None

    def on_execute(self):
        shards = [Shard.from_str(shard) for shard in self.Parameters.shards]
        for shard in shards:
            shard.prepare_for_backup(self.Parameters.do_fsync)
        self._set_dayly_iteration_context()
        backup_resource = self._prepare_backup_resource(shards)
        # noinspection PyBroadException
        try:
            if self.Parameters.use_mongodump:
                failed_shard, error_codes = self._backup_shards_by_mongodump(backup_resource.path, shards)
            else:
                failed_shard, error_codes = self._backup_shards_by_disk_copy(backup_resource.path, shards)
        except Exception:
            logging.exception("Exception during backup processes")
            raise common.errors.TaskFailure("Failed to backup a shard, please look for exceptions in debug.log")
        else:
            if failed_shard:
                raise common.errors.TaskFailure(
                    "Failed to backup {} (return codes: {})".format(failed_shard, error_codes)
                )
        finally:
            if self.Parameters.do_fsync:
                for shard in shards:
                    shard.unlock()
        self._upload_resource_to_mds(backup_resource)
        sdk2.ResourceData(backup_resource).ready()

    def _set_dayly_iteration_context(self):
        self.Context.date = dt.datetime.today().strftime("%Y%m%d")
        previous_run = sdk2.Task.find(type=self.type, status=ctt.Status.SUCCESS).order(-sdk2.Task.id).first()
        if (
            not previous_run or
            previous_run.Context.date is ctm.NotExists or
            previous_run.Context.iteration is ctm.NotExists or
            previous_run.Context.date != self.Context.date
        ):
            self.Context.iteration = 0
        else:
            self.Context.iteration = previous_run.Context.iteration + 1

    def _prepare_backup_resource(self, shards):
        backup_resource = sb_resources.SandboxShardGroupBackup(
            self, "Backup of shards {}".format([shard.shard_name for shard in shards]),
            "mongo_backup-{}-{}".format(self.Context.date, self.Context.iteration),
            ttl=self.Parameters.backup_ttl
        )
        if backup_resource.path.exists():
            shutil.rmtree(str(backup_resource.path))
        backup_resource.path.mkdir(parents=True)
        return backup_resource

    def _upload_resource_to_mds(self, backup_resource):
        logging.info("Start uploading mongo backup to MDS")
        os.environ["AWS_ACCESS_KEY_ID"] = self.Parameters.aws_credentials.data()[AWS_ACCESS_KEY_ID_NAME]
        os.environ["AWS_SECRET_ACCESS_KEY"] = self.Parameters.aws_credentials.data()[AWS_SECRET_ACCESS_KEY_NAME]
        mds_key, metadata = common.mds.S3().upload_directory(
            path_or_obj=str(backup_resource.path),
            namespace=MONGO_BACKUP_BUCKET,
            resource_id=backup_resource.path.name,
        )
        logging.debug("Updating mongo backup resource info")
        if mds_key is not None:
            backup_resource.mds = dict(
                key=mds_key,
                namespace=MONGO_BACKUP_BUCKET,
            )
