from __future__ import absolute_import

import glob
import logging
import os
import time
from collections import namedtuple

from sandbox import sdk2
from sandbox import common
from sandbox.projects import resource_types
from sandbox.projects.common import utils
from sandbox.projects.common import network
from sandbox.projects.common.gnupg import GpgKey2
from sandbox.projects.common.nanny import client
from sandbox.sdk2.helpers import subprocess
from sandbox.common.utils import singleton_property


def deep_get(_dict, keys, default=None):
    for key in keys:
        if isinstance(_dict, dict):
            _dict = _dict.get(key, default)
        else:
            return default

    return _dict


def get_vault(owner='YASAP', name=''):
    return sdk2.Vault.data(owner, name)


def _wait_all(waitables):
    for waitable in waitables:
        ret = waitable.wait()
        if ret:
            raise Exception("Process %s failed with code %s", waitable, ret)


def yt_path_to_yasm(path):
    if len(path) >= 128:
        logging.warning('We will cut path in signal name')
        path = str(path)[:128]
    if path.startswith('//'):
        path = path[2:]
    return path.replace('/', '_').replace('-', '_')


def _remove_finished(promises):
    undone = []
    for promise in promises:
        ret = promise.poll()
        if ret is None:
            undone.append(promise)
        if ret:
            raise Exception("Process %s failed with code %s", promise, ret)
    return undone


class MongoDBMixin:
    START_MONGODB_MAX_RETRIES = 300

    def _start_mongo(self, mongo_path, db_path=None, log_path=None, cache_size=None):
        os.environ["LC_ALL"] = "C"
        if not db_path:
            db_path = "mongo_data"
        if not os.path.exists(db_path):
            os.makedirs(db_path)
        if not log_path:
            log_path = os.path.join(os.getcwd(), "mongo_log")
        if not cache_size:
            cache_size = self._recommended_cache_size()
        port = network.get_free_port()
        with sdk2.helpers.ProcessLog(
                self, logging.getLogger('mongod')
        ) as pl:
            subprocess.Popen(
                [
                    os.path.join(mongo_path, "bin/mongod"),
                    "--port", str(port),
                    "--dbpath", db_path,
                    "--logpath", log_path,
                    "--wiredTigerCacheSizeGB", str(cache_size),
                    "--journal"
                ],
                stdout=pl.stdout,
                stderr=pl.stderr,
            )
            for _ in xrange(self.START_MONGODB_MAX_RETRIES):
                try:
                    subprocess.check_output(
                        [
                            os.path.join(mongo_path, "bin/mongo"),
                            "--eval", "db.version()",
                            "localhost:{port}".format(port=port)
                        ],
                        stderr=pl.stderr,
                    )
                except subprocess.CalledProcessError:
                    time.sleep(0.1)
                else:
                    break
            else:
                raise common.errors.TaskFailure('mongo_wait failed')
        return port

    @classmethod
    def _restore_backup(cls, mongo_path, port):
        args = [
            os.path.join(mongo_path, "bin/mongorestore"),
            "--port", str(port),
            "--noIndexRestore",
        ]
        dump_path = cls._get_db_dump_path()
        if os.path.exists(os.path.join(dump_path, 'oplog.bson')):
            args.append("--oplogReplay")
        args.append(dump_path)

        with sdk2.helpers.ProcessLog(
            cls, logging.getLogger('mongorestore')
        ) as pl:
            process = subprocess.Popen(
                args,
                stdout=pl.stdout,
                stderr=pl.stderr,
            )
            process.wait()

    @classmethod
    def _get_db_dump_path(cls, root='mongo-backup'):  # preserve legacy path
        db_dump_path = os.path.join(root, 'backup')
        files = [file_name for file_name in os.listdir(db_dump_path) if not file_name.endswith('components')]
        assert len(files) == 1
        return os.path.join(db_dump_path, files[0])

    def _recommended_cache_size(self):
        return ((self.Requirements.ram >> 10) - 1) / 2.


GpgSettings = namedtuple('GpgSettings', ['key_owner', 'secret_key_name', 'public_key_name', 'recipient'])


class ExtractDumpMixin:
    @property
    def _decrypt_settings(self):
        return GpgSettings(
            key_owner='YASAP',
            secret_key_name='robot_pdb_builder_private',
            public_key_name='robot_pdb_builder_public',
            recipient=self.Parameters.gpg_key_owner,
        )

    def _extract_dump(self, dump_folder_path, settings=None):
        """
        Decrypts mongodump from `dump_folder_path` into current working directory.
        :type dump_folder_path: iterable or string
        :type settings: GpgSettings
        """
        if os.path.isdir(dump_folder_path):
            paths = glob.glob(os.path.join(dump_folder_path, '*/*'))
            # workaround for race conditions between gpg-zip processes
            if dump_folder_path.endswith('_components'):
                dbs = set()
                for path in paths:
                    dbs.add(path.split('/')[-2])
                decoded_dump_root = dump_folder_path.rstrip('_components')
                for db in dbs:
                    os.makedirs(os.path.join(decoded_dump_root, db))
        else:  # file
            paths = [dump_folder_path]

        with sdk2.helpers.ProcessLog(
            self, logging.getLogger('extract')
        ) as pl:
            promises = []

            to_untar, to_decrypt = [], []
            for path in paths:
                if path.endswith('_encrypted'):  # convention over configuration
                    to_decrypt.append(path)
                else:
                    to_untar.append(path)

            if to_decrypt:
                if not settings:
                    settings = self._decrypt_settings
                with GpgKey2(settings.key_owner, settings.secret_key_name, settings.public_key_name):
                    for path in to_decrypt:
                        promises.append(
                            subprocess.Popen(
                                [
                                    "gpg-zip",
                                    "-d",
                                    "--recipient", settings.recipient,
                                    "--gpg-args", "--no-tty",
                                    "--gpg-args", "--batch",
                                    "--gpg-args", "--trust-model",
                                    "--gpg-args", "always",
                                    path
                                ],
                                stdout=pl.stdout,
                                stderr=pl.stderr,
                            )
                        )
                    _wait_all(promises)

            if to_untar:
                for path in to_untar:
                    promises.append(
                        subprocess.Popen(
                            [
                                "tar",
                                "-xzvf",  # assuming gzip
                                path,
                            ],
                            stdout=pl.stdout,
                            stderr=pl.stderr,
                        )
                    )
                _wait_all(promises)


class YasmReportable:
    def _yasm_report(self, args=None, stdout=None):
        if not args:
            args = []
        monitoring_client_tool_id = utils.get_and_check_last_released_resource_id(
            resource_types.MONITORING_CLIENT_EXECUTABLE
        )
        monitoring_client_tool = str(sdk2.ResourceData(sdk2.Resource[monitoring_client_tool_id]).path)
        monitoring_args = [
            monitoring_client_tool,
            "report-metric",
            "--monitoring-server",
            self.Parameters.monitoring_server_host,
        ] + args

        with sdk2.helpers.ProcessLog(self, logging.getLogger('monitoring_client')) as pl:
            subprocess.Popen(
                monitoring_args,
                stdout=stdout,
                stderr=pl.stderr,
            )

    def _report_lag(self, metric_id):
        logging.info('Send report lag by %s', metric_id)
        self._yasm_report(
            args=[
                "--id", metric_id,
                "--value", "start",
                "--transform", "minutes",
                "--policy", "timelag"
            ]
        )


class NannyMixin(object):
    @singleton_property
    def nanny_client(self):
        return client.NannyClient(
            api_url='http://nanny.yandex-team.ru/',
            oauth_token=get_vault(name='nanny_oauth_token'),
        )

    def get_service_value(self, service, properties_chain, default=None):
        return deep_get(service, properties_chain, default)

    def __wait_services_for_status(self, service_ids, status="ONLINE", timeout=300):
        waitfor = time.time() + timeout
        while time.time() < waitfor:
            active = True
            for service_id in service_ids:
                service_state = self.nanny_client.get_service_current_state(service_id)
                active &= self.get_service_value(service_state, ['content', 'summary', 'value'], '') == status
            if not active:
                time.sleep(30)
            else:
                return True
        raise common.errors.TaskFailure("Failed to set status {status} to services {services}".format(
            services=service_ids,
            status=status,
        ))

    def get_service(self, service_id):
        try:
            return self.nanny_client.get_service(service_id)
        except client.NannyApiException as e:
            logging.debug('Can not find service: {}'.format(e))

    def get_exist_services(self, category):
        resp = self.nanny_client.list_services_by_category(category)

        return resp.get('result', [])

    def copy_service(self, **kwargs):
        return self.nanny_client.copy_service(**kwargs)

    def get_service_instances(self, service_id):
        return self.nanny_client.get_service_instances(service_id)

    def update_service_instances(self, service_id, data):
        return self.nanny_client.update_service_instances(service_id, data)

    def get_service_resources(self, service_id):
        return self.nanny_client.get_service_resources(service_id)

    def get_snapshot_id(self, service_id):
        return self.get_service_resources(service_id).get('snapshot_id', None)

    def update_service_resources(self, service_id, data):
        return self.nanny_client.update_service_resources(service_id, data)

    def update_service_engine(self, service_id, engine):
        self.nanny_client.update_service_engines(service_id, engine)

    def create_event(self, service_id, type='', content={}):
        payload = {
            'type': type,
            'content': content,
        }

        return self.nanny_client.self.create_event(service_id, payload)

    def update_service_sandbox_file(self, **kwargs):
        return self.nanny_client.update_service_sandbox_file(**kwargs)

    def remove_snapshots(self, service_ids):
        for service_id in service_ids:
            self.nanny_client.delete_all_snapshots(service_id)

        self.__wait_services_for_status(service_ids, status="OFFLINE")

    def remove_services(self, service_ids):
        for service_id in service_ids:
            self.nanny_client.delete_service(service_id)


class FetchDumpMixin(object):
    def _fetch_dump(self):
        if self.Parameters.components:
            self._sky_get(
                # file dumps are not supported so far
                self._filter_dumps(self._dbs_with_collections())
            )
            dump_folder_path = glob.glob('mongo-backup/backup/*_components')[0]
        else:
            dump_folder_path = str(
                sdk2.ResourceData(
                    sdk2.Resource[self.Parameters.mongo_dump_id]
                ).path
            )
            dump_folder_path = self._get_db_dump_path(dump_folder_path)
        return dump_folder_path

    def _sky_get(self, torrents, max_concurrency=3):
        with sdk2.helpers.ProcessLog(
            self, logging.getLogger('sky_log')
        ) as pl:
            promises = []
            for torrent in torrents:
                promises.append(
                    subprocess.Popen(
                        [
                            'sky', 'get', torrent
                        ],
                        stdout=pl.stdout,
                        stderr=pl.stderr,
                    )
                )
                while len(promises) >= max_concurrency:
                    time.sleep(1)
                    promises = _remove_finished(promises)
            _wait_all(promises)

    def _filter_dumps(self, dbs_with_collections):
        keys = [
            '{}.{}'.format(db, collection)
            for db, collections in dbs_with_collections.iteritems()
            for collection in collections
        ]
        return [
            torrent
            for key, torrent in self.Parameters.components.iteritems()
            if key in keys
        ]
