import socket
import logging
import datetime
import collections

import pymongo
import pymongo.uri_parser

from sandbox import sdk2
from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.client as ctc
import sandbox.common.types.notification as ctn

import sandbox.projects.sandbox.mongo_suite.consts as consts
import sandbox.projects.sandbox.mongo_suite.service_mongo_restore as mongo_restore
import sandbox.projects.sandbox.mongo_suite.service_mongo_backuper as mongo_backuper
import sandbox.projects.sandbox.mongo_suite.service_mongo_shard_backuper as shard_backuper


PREPROD_API_URL = "https://www-sandbox1.n.yandex-team.ru/api/v1.0"
DEFAULT_TASK_WATCHER = "ilia-vashurov"


class ServiceMongoBackup(sdk2.ServiceTask):
    """ Backup Sandbox's MongoDB replica set and configuration server's database """

    """
    Algorithm:
    - Run on a server to get cluster's configuration from mongos
    - Ensure balancer is disabled
    - Create one or several children tasks to backup all replica sets including database config
    - After all tasks are completed, wake up and check their state
    - If necessary, run ServiceMongoRestore task on a fresh backup
    - Enable balancer and sleep until 7AM
    - Wake up and disable balancer
    - Wait for backup test to complete

    Replica set backup task has an assertion that the shard's current replica is not primary;
    If shards located on several different hosts, use Zookeeper barrier to synchronize backup start between all replicas
    """

    class Requirements(sdk2.Requirements):
        disk_space = 100
        client_tags = ctc.Tag.SERVER
        dns = ctm.DnsType.DNS64

    class Parameters(sdk2.Parameters):
        fail_on_any_error = True
        notifications = (
            sdk2.Notification(
                tuple(ctt.Status.Group.BREAK) + (ctt.Status.FAILURE,),
                (DEFAULT_TASK_WATCHER, "sandbox-errors"),
                ctn.Transport.EMAIL
            ),
            sdk2.Notification(
                ctt.Status.SUCCESS,
                (DEFAULT_TASK_WATCHER,),
                ctn.Transport.EMAIL
            )
        )

        with sdk2.parameters.Group("Replica set backup options") as whatever:
            do_backup_replica = sdk2.parameters.Bool("Backup replica set", default_value=True)
            with do_backup_replica.value[True]:
                do_fsync = sdk2.parameters.Bool("Call db.fsync() before backup", default_value=True)
                use_mongodump = sdk2.parameters.Bool("Use mongodump for backup", default_value=True)
                run_on_single_host = sdk2.parameters.Bool("Run on single host with all shards required", default=True)
                use_hidden_shards = sdk2.parameters.Bool("Use only hidden replicas for backup", default=True)
                blacklist = sdk2.parameters.String(
                    "Blacklisted servers (will not enqueue backup tasks on these)", multiline=True
                )
                shard_size = sdk2.parameters.Integer("Single shard's approximate size, in GBs", default_value=25)
                backup_ttl = sdk2.parameters.Integer("For how long should backups be accessible", default_value=14)

        with sdk2.parameters.Group("Miscellanous") as blah:
            mongos_port = sdk2.parameters.Integer("mongos port on localhost", default_value=22222)
            run_balancer = sdk2.parameters.Bool("Run balancer after backup", default=True)
            test_backup = sdk2.parameters.Bool("Test backup right after it's made", default=True)

    class Context(sdk2.Context):
        children = []
        iteration = None
        backup_test_task_id = None
        total_wait = 0

    @property
    def zk_barrier(self):
        return "/sandbox-tasks/{}/start_barrier".format(self.id)

    # because servers may go under multiple aliases
    @staticmethod
    def _resolve_hostname(host):
        res = socket.gethostbyaddr(host)[0]
        logging.debug("Resolved %s to %s", host, res)
        return res

    @staticmethod
    def _fqdn_to_hostname(netloc):
        return netloc.split(".", 1)[0]

    @property
    def previous_iteration(self):
        prev = sdk2.Task.find(type=self.type, status=ctt.Status.SUCCESS).order(-sdk2.Task.id).first()
        return None if (not prev or prev.Context.iteration is ctm.NotExists) else prev.Context.iteration

    @property
    def mongo_connection(self):
        return pymongo.MongoClient("mongodb://localhost:{}".format(self.Parameters.mongos_port))

    @property
    def wait_timeout(self):
        return 3600 / 2  # half an hour

    @property
    def wait_limit(self):
        return 10 * 3600  # ServiceMongoRestore's default time_to_kill requirement

    def _ensure_balancer_off(self):
        conn = self.mongo_connection
        res = next(conn.config.settings.find({"_id": "balancer"}))
        stopped = res["stopped"]
        if not stopped:
            logging.debug("Balancer was running! Stopping")
            conn.config.settings.update({"_id": "balancer"}, {"$set": {"stopped": True}}, upsert=True)
        else:
            logging.debug("Balancer is already stopped")

    def _balancer_turnon(self):
        logging.debug("Starting balancer")
        conn = self.mongo_connection
        conn.config.settings.update({"_id": "balancer"}, {"$set": {"stopped": False}}, upsert=True)

    def _balancer_turnoff(self):
        logging.debug("Stopping balancer")
        conn = self.mongo_connection
        conn.config.settings.update({"_id": "balancer"}, {"$set": {"stopped": True}}, upsert=True)

    def _run_backupers(self):
        prev_iteration = self.previous_iteration
        self.Context.iteration = 0 if prev_iteration is None else prev_iteration + 1
        logging.debug("Previous task iteration: %r, current: %r", prev_iteration, self.Context.iteration)
        self._backup_replicas()
        self.set_info("Waiting for children to finish backup")
        raise sdk2.WaitTask(
            self.Context.children, common.itertools.chain(ctt.Status.Group.FINISH, ctt.Status.Group.BREAK)
        )

    def _get_shards_on_appropriate_hosts(self, hosts_blacklist):
        shards_info = common.rest.Client().service.status.database.shards.read()
        shards_on_hosts = collections.defaultdict(list)  # { host -> [ (replica_name, port), ... ] }
        for replicaset_info in shards_info:  # See sandbox/web/api/v1/schemas/service.py#DatabaseShardReplicaInstance
            replicaset_name = replicaset_info["id"]
            for shard in replicaset_info["replicaset"]:
                host, port = shard["id"].rsplit(":", 1)
                if host in hosts_blacklist:
                    continue
                if shard["state"] != consts.ShardState.SECONDARY:
                    continue
                if self.Parameters.use_hidden_shards and not shard["hidden"]:
                    continue
                shards_on_hosts[host].append((replicaset_name, port))
        return shards_on_hosts, len(shards_info)

    def _backup_replicas(self):
        blacklist = set(filter(None, map(str.strip, self.Parameters.blacklist.splitlines())))
        logging.info("Servers blacklist: %s", ", ".join(blacklist) or "empty")
        logging.debug("Searching replicas for backup")
        shards_on_hosts, shards_number = self._get_shards_on_appropriate_hosts(blacklist)
        fully_replicated_hosts = [
            host for host, shard_info in shards_on_hosts.items()
            if len(shard_info) == shards_number
        ]
        if fully_replicated_hosts:
            logging.debug("Using single host with all shards available")
            host = fully_replicated_hosts[self.Context.iteration % len(fully_replicated_hosts)]
            self._run_backup_on_single_host(host, shards_on_hosts)
        elif not self.Parameters.run_on_single_host:
            logging.debug("Using several hosts to backup all shards")
            self._run_backup_on_several_hosts(shards_on_hosts)
        else:
            raise common.errors.TaskFailure("Failed to find a single host with all shards available")

    def _run_backup_on_single_host(self, host, shards_on_hosts):
        logging.info("Host to run full backup on: %s", host)
        self.set_info("Host to run full backup on: " + host)
        shards = [
            "{}/{}:{}".format(replicaset_name, host, port)
            for replicaset_name, port in shards_on_hosts[host]
        ]
        subtask = mongo_backuper.ServiceMongoBackuper(
            self,
            description="Full mongo backup from host {}".format(host),
            shards=shards,
            do_fsync=self.Parameters.do_fsync,
            use_mongodump=self.Parameters.use_mongodump,
            backup_ttl=self.Parameters.backup_ttl,
            __requirements__={
                "host": self._fqdn_to_hostname(host),
                "disk_space": self.Parameters.shard_size * len(shards) * (1 << 10)
            },
        )
        self.Context.children.append(subtask.save().enqueue().id)

    def _run_backup_on_several_hosts(self, shards_on_hosts):
        # Group hosts with same replica sets (consider each group has unique set of replicas)
        replica_groups = collections.defaultdict(list)  # {replicas group} -> {hosts group}
        for host, data in shards_on_hosts.items():
            replica_groups[frozenset(replica for replica, _ in data)].append(host)
        backup_hosts = [
            hosts[self.Context.iteration % len(hosts)]
            for hosts in replica_groups.values()
        ]
        logging.info("Hosts to run backup on: %s", ", ".join(backup_hosts))
        self.set_info("Hosts to run backup on:\n" + "\n".join(backup_hosts))
        logging.info("Using ZooKeeper barrier path %s", self.zk_barrier)
        for i, host in enumerate(backup_hosts):
            shards = [
                "{}/{}:{}".format(replicaset_name, host, port)
                for replicaset_name, port in shards_on_hosts[host]
            ]
            subtask = shard_backuper.ServiceMongoShardBackuper(
                self,
                description="Mongo backup of group {} on host {}".format(i, host),
                shards=shards,
                zk_barrier_path=self.zk_barrier,
                zk_barrier_num=len(backup_hosts),
                fsync=self.Parameters.do_fsync,
                backup_ttl=self.Parameters.backup_ttl,
                __requirements__={
                    "host": self._fqdn_to_hostname(host),
                    "disk_space": self.Parameters.shard_size * len(shards) * (1 << 10)
                },
            )
            self.Context.children.append(subtask.save().enqueue().id)

    def _check_children_state(self):
        res = self.server.task.read(id=",".join(map(str, self.Context.children)), children=True, limit=100)["items"]
        failures = filter(lambda _: _["status"] not in ctt.Status.Group.SUCCEED, res)
        if failures:
            msg = "{} of {} tasks have failed".format(len(failures), len(res))
            rich = ", ".join([
                "<a href='{}'>{}</a> (status: {})".format(common.utils.get_task_link(ch["id"]), ch["id"], ch["status"])
                for ch in failures
            ])
            self.set_info(": ".join([msg, rich]), do_escape=False)
            raise common.errors.TaskFailure(msg)

    def _run_backup_test(self):
        backup_test_task = mongo_restore.ServiceMongoRestore(
            self,
            description="Test latest backup, created on {} by task <a href='{}'>#{}</a>".format(
                datetime.datetime.today(), common.utils.get_task_link(self.id), self.id
            ),
            control_task_id=int(self.id)
        )
        self.Context.backup_test_task_id = backup_test_task.save().enqueue().id
        self.Context.save()

    def _sleep_with_balancer_on(self):
        self._balancer_turnon()

        on = datetime.datetime.now()
        off = datetime.datetime(on.year, on.month, on.day, hour=7)
        if on.hour >= 23:  # wakeup should be in the future
            off += datetime.timedelta(days=1)

        if on < off:  # it's daytime otherwise
            delta = (off - on).seconds
            msg = "Sleeping from {} to {} ({} seconds)".format(on, off, delta)
            logging.debug(msg)
            self.set_info(msg)
            raise sdk2.WaitTime(delta)

    def _check_test_status(self):
        if self.Context.total_wait >= self.wait_limit:
            raise common.errors.TaskFailure(
                "Waited for backup test task #{} for too long, consider it failed".format(
                    self.Context.backup_test_task_id
                )
            )
        status = self.server.task[self.Context.backup_test_task_id][:]["status"]
        if status not in (
            tuple(ctt.Status.Group.DRAFT) + tuple(ctt.Status.Group.FINISH) + tuple(ctt.Status.Group.BREAK)
        ):
            self.Context.total_wait += self.wait_timeout
            raise sdk2.WaitTime(self.wait_timeout)

        if status != ctt.Status.SUCCESS:
            raise common.errors.TaskFailure("Backup test task #{} has ended with status {}".format(
                self.Context.backup_test_task_id, status
            ))

    def on_execute(self):
        with self.memoize_stage.prepare:
            self._ensure_balancer_off()

        if self.Parameters.do_backup_replica:
            with self.memoize_stage.backup:
                self._run_backupers()
            with self.memoize_stage.check_backupers:
                self._check_children_state()

        if self.Parameters.test_backup:
            with self.memoize_stage.test_backup:
                self._run_backup_test()

        if self.Parameters.run_balancer:
            with self.memoize_stage.balance:
                self._sleep_with_balancer_on()
            with self.memoize_stage.stop_balancing:
                self._balancer_turnoff()

        if self.Parameters.test_backup:
            self._check_test_status()
