# -*- coding: utf-8 -*-
import os
import logging
import datetime
import uuid

import sandbox.sdk2 as sdk2

import api.copier.errors as copier_errors
from sandbox.projects.common import error_handlers as eh
from sandbox.projects.common import file_utils as fu
from sandbox.projects.common import task_env
from sandbox.common.types import task as ctt
from sandbox.common.types import notification as ctn
from sandbox.common.errors import TaskFailure
from sandbox.common.share import skynet_get
from sandbox.projects.websearch.begemot.common import add_realtime_version_file
from sandbox.projects.websearch.begemot.resources import BEGEMOT_REALTIME_PACKAGE
from sandbox.sdk2.helpers import subprocess
import sandbox.projects.ads.emily.storage.client.binary as mls_client
from sandbox.sandboxsdk import environments


WORKING_DIR = "realtime_data"


class StorageSources(object):
    SANDBOX = "sandbox"
    ML_STORAGE = "ml_storage"
    YT = "YT"


class YTResource(object):
    def __init__(self, spec):
        import yt.wrapper as yt
        self.yt = yt
        self.id = str(uuid.uuid4())
        self.type = StorageSources.YT
        self.spec = spec
        self.created = datetime.datetime.now()
        self.attributes = dict()

    def __iter__(self):
        return iter(self.attributes.items())

    def __getattr__(self, name):
        return self.spec.get(name, None)

    def download_to(self, local_dir):
        self.yt.config["proxy"]["url"] = self.spec['resource_proxy_url']
        self.yt.config["token"] = sdk2.Vault.data(self.spec['owner'], self.spec['yt_token_secret_name'])
        for resource_spec in self.spec['components']:
            self.download_resource(resource_spec, local_dir)

    def download_resource(self, resource_spec, local_dir):
        found, not_found = self.find_resource_children(resource_spec)
        if not_found:
            logging.error('Some of resource dependencies not found: %s', ' '.join(not_found))  # TODO
            raise TaskFailure()
        for child_path, path_suffix in found:
            logging.info('Downloading resource from YT {}'.format(child_path))
            self.download_yt_node(child_path, local_dir, path_suffix)
        if found:
            date_str = self.yt.get_attribute(found[0][0], 'creation_time')
            try:
                self.original_created_time = datetime.datetime.strptime(date_str, '%Y-%m-%dT%H:%M:%S.%fZ').isoformat()
            except ValueError:
                logging.error('Can not parse creation_time string %s', date_str)

    def get_versions_in_order(self, dir_path, version_type, format=None, reverse=True):
        version_names = []
        for entry in self.yt.list(dir_path):
            if version_type is datetime.datetime:
                try:
                    datetime.datetime.strptime(entry, format or '%Y-%m-%dT%H:%M:%S.%f')
                except:
                    continue
            elif version_type is int:
                try:
                    int(entry)
                except:
                    continue
            version_names.append(entry)
        version_names.sort(reverse=reverse)
        return version_names

    def find_resource_children(self, resource_spec):
        def _find_resource_children(resource_spec, found, not_found, path_prefix=None):
            path = resource_spec.get('path')
            if not path:
                raise TaskFailure()

            if path_prefix:
                path = os.path.join(path_prefix, path)
            logging.info("Looking for path {}".format(path))

            if not self.yt.exists(path):
                logging.info("YT path {} does not exist".format(path))
                not_found.append(path)
                return False

            self.read_attributes(path, resource_spec.get('attributes', []))

            node_type = self.yt.get_attribute(path, 'type')
            if node_type == 'table' or node_type == 'file' or \
                    (node_type == 'map_node' and
                        'version_type' not in resource_spec and
                            not resource_spec.get('children')):
                logging.info('Found leaf node {}'.format(path))
                found.append((path, resource_spec.get('path_suffix')))
                return True
            elif node_type != 'map_node':
                raise TaskFailure()

            if 'version_type' in resource_spec:
                logging.info('Directory with version_type')
                initial_found_len = len(found)
                initial_attr_keys = set(self.attributes.keys())
                for version in self.get_versions_in_order(
                    path,
                    resource_spec['version_type'],
                    resource_spec.get('format'),
                    reverse=True
                ):
                    logging.info('Trying version {}'.format(version))
                    version_spec = {
                        'path': os.path.join(path, version),
                        'local_path': resource_spec.get('local_path'),
                        'children': resource_spec.get('children')
                    }
                    if _find_resource_children(version_spec, found, not_found):
                        del not_found[:]
                        return True
                    del found[initial_found_len:]
                    self.attributes = {k: v for k, v in self.attributes.items() if k in initial_attr_keys}
                return False

            logging.info('Directory without versions')
            for child_spec in resource_spec['children']:
                if not _find_resource_children(child_spec, found, not_found, path):
                    return False
            return True

        found, not_found = [], []
        _find_resource_children(resource_spec, found, not_found)
        return found, not_found

    def download_yt_node(self, ypath, local_dir, path_suffix=None):
        local_path = os.path.join(local_dir, path_suffix or os.path.basename(ypath))
        ypath_type = self.yt.get_attribute(ypath, 'type')
        if ypath_type == "table":
            torrent_id = self.yt.sky_share(ypath)
            subprocess.check_call(['sky', 'get', '-wud', local_dir, torrent_id])
        elif ypath_type == "file":
            with open(local_path, 'wb') as f:
                f.write(self.yt.read_file(ypath).read())
        elif ypath_type == "map_node":
            self.recursively_download_directory(ypath, local_path)
        else:
            raise TaskFailure("Cannot download YT node with type {}".format(ypath_type))
        return local_path

    def recursively_download_directory(self, yt_dir, local_dir):
        if not os.path.exists(local_dir):
            os.makedirs(local_dir)
        for node_name in self.yt.list(str(yt_dir)):
            ypath = os.path.join(yt_dir, node_name)
            self.download_yt_node(ypath, local_dir)

    def read_attributes(self, ypath, attr_list):
        for attr in attr_list:
            self.attributes[attr] = self.yt.get_attribute(ypath, attr)


BEGEMOT_SHARDS = {
    'ShardNone',
    'YabsHitModels',
}


def get_build_models_package_parameters(rules_to_deploy_default, bg_shard_default, bg_shards=set(), pack_to_tar_default=False):
    class _BuildModelsPackageParameters(sdk2.Parameters):
        begemot_shard = sdk2.parameters.RadioGroup(
            'Environment',
            choices=[(shard, shard) for shard in (BEGEMOT_SHARDS | bg_shards | {bg_shard_default})],
            default=bg_shard_default,
        )
        rules_to_deploy = sdk2.parameters.List("List of rules to deploy", default=rules_to_deploy_default)
        pack_to_tar = sdk2.parameters.Bool("whether to pack output resource to .tar", default=pack_to_tar_default)

    return _BuildModelsPackageParameters()


class BuildBegemotModelsPackageBase(sdk2.Task):

    VAULT_VARS = {
        "SANDBOX_TOKEN": {
            "owner": "robot-itditp",
            "name": "robot-itditp-sandbox-token",
        },
    }

    class Requirements(task_env.TinyRequirements):
        """See SEARCH-11128."""
        disk_space = 256 * 1024  # 256 Gb
        ram = 48 * 1024  # 48 Gb
        environments = [
            environments.PipEnvironment("yandex-yt"),
        ]

    class Parameters(sdk2.Parameters):
        _bmp = get_build_models_package_parameters(
            rules_to_deploy_default=[],
            bg_shard_default='YabsCaesarModelsTest',
        )
        notifications = [
            sdk2.Notification(
                [ctt.Status.EXCEPTION],
                ["host=rt_models_packages&service=build_package_base_status"],
                ctn.Transport.JUGGLER,
                check_status=ctn.JugglerStatus.CRIT
            ),
            sdk2.Notification(
                [ctt.Status.SUCCESS],
                ["host=rt_models_packages&service=build_package_base_status"],
                ctn.Transport.JUGGLER,
                check_status=ctn.JugglerStatus.OK
            )
        ]

    def _get_vault_token(self, name):
        eh.ensure(name in self.VAULT_VARS, "No vault '{}' in VAULT_VARS".format(name))
        data = self.VAULT_VARS[name]
        return sdk2.Vault.data(data["owner"], data["name"])

    def _sky_get(self, rbtorrent, dir="."):
        try:
            self.set_info(
                'RBTORRENT: rbtorrent = {} dir = {}'.format(
                    rbtorrent,
                    dir
                )
            )
            skynet_get(skynet_id=rbtorrent, data_dir=dir, timeout=datetime.timedelta())
        except copier_errors.CopierError as e:
            logging.exception("Cannot download resource with id %s. Exception %s", rbtorrent, e)
            raise

    def _make_archive(self, archive_path, dirs):
        self.set_info('Packing to tar: {}'.format(archive_path))
        tar_command = ['tar', '-C', WORKING_DIR, '-cf', archive_path] + dirs
        subprocess.check_call(tar_command)

    def _get_resource_publish_time(self, resource):
        # type: (sdk2.Resource) -> str
        try:
            publish_time = resource.original_created_time
        except AttributeError:
            publish_time = resource.created.isoformat()
        return publish_time

    def _get_resource_from_proper_storage(self, spec):
        source = spec.get('source', StorageSources.SANDBOX)
        if source is StorageSources.YT:
            return YTResource(spec)
        constraints = spec['constraints']
        resource_type = spec['resource_type']
        if source is StorageSources.SANDBOX:
            constraints.setdefault("state", "READY")
            logging.debug("Search resource: {} by {} on sandbox".format(resource_type, constraints))
            return sdk2.Resource[resource_type].find(**constraints).order(-sdk2.Resource.id).first()
        if source is StorageSources.ML_STORAGE:
            logging.debug("Search resource by {} on ml storage".format(constraints))
            return self._get_ml_storage_resource(constraints)
        eh.fail("Unknown source storage: {}".format(source))

    def _human_readable_size(self, size):
        if size is None:
            return "not supported for YT"
        for unit in ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB']:
            if size < 1024.0 or unit == 'PiB':
                break
            size /= 1024.0
        return "{:.2f} {}".format(size, unit)

    def _get_resource(self, model_dir, spec):
        last_released_resource = self._get_resource_from_proper_storage(spec)
        self.set_info(
            'Found last released: id = {:13} creation_time = {:30} type = {:40} dir = {:60} size = {}'.format(
                last_released_resource.id,
                self._get_resource_publish_time(last_released_resource),
                spec['resource_type'],
                model_dir,
                self._human_readable_size(last_released_resource.size)
            )
        )
        return last_released_resource

    def _download_resource(self, model_dir, resource):
        if isinstance(resource, YTResource):
            resource.download_to(model_dir)
            return
        self._sky_get(rbtorrent=resource.skynet_id, dir=model_dir)

    def _add_multi_resource_model(self, model_dir, resource_specs):
        return [self._get_resource(model_dir, spec) for spec in resource_specs]

    def _get_ml_storage_resource(self, constraints):
        attrs = constraints["attrs"]
        required_attrs = ["key", "dump_key"]

        for key in required_attrs:
            eh.ensure(key in attrs, "ML Storage | No required argument '{}' in constraints['attrs']".format(key))

        key_ = attrs["key"]
        dump_key_ = attrs["dump_key"]
        version_ = attrs.get("version")

        client = mls_client.MlStorageBinaryClient(
            token=self._get_vault_token("SANDBOX_TOKEN")
        )
        model = client.info(
            key=key_,
            version=version_,
            latest=not version_,
            prod=True,
        )
        storage_id = None
        for key, part in model["parts"].items():
            if key == dump_key_:
                storage_id = part["storage_id"]
        eh.ensure(storage_id, "ML Storage | No dumps found with dump_key = '{}' for model '{}'".format(dump_key_, key_))
        last_released_model = sdk2.Resource[storage_id]
        eh.ensure(last_released_model, "ML Storage | Not found resource with id = '{}' for model '{}'".format(storage_id, key_))
        return last_released_model

    def _prepare_caesar_models_data(self):
        added_models = 0

        rules = self.rule2models.get(self.Parameters._bmp.begemot_shard, None)
        eh.ensure(rules, "no rules for shard " + self.Parameters._bmp.begemot_shard)

        rules_to_deploy = self.Parameters._bmp.rules_to_deploy or list(rules.keys())

        for rule_name in rules_to_deploy:
            models = rules.get(rule_name, None)
            eh.ensure(models, "no models for " + rule_name)
            rule_dir = os.path.join(WORKING_DIR, rule_name)
            if not os.path.exists(rule_dir):
                os.makedirs(rule_dir)
            for model in models:
                model_name = model['model_name']
                update_period = model.get('update_period')
                model_dir = os.path.join(rule_dir, model_name)

                added_models += 1
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)

                if 'resources' in model:
                    resources = self._add_multi_resource_model(model_dir, model['resources'])
                else:
                    resources = self._add_multi_resource_model(model_dir, [model])

                self.set_info('Downloading resource for model {}'.format(model_name))
                for resource in resources:
                    self._download_resource(model_dir=model_dir, resource=resource)

                pkginfo_path = os.path.join(model_dir, 'PKGINFO')
                pkginfo = self._build_pkginfo(resources=resources, model_name=model_name, update_period=update_period)
                fu.json_dump(file_name=pkginfo_path, data=pkginfo, indent=4, sort_keys=True)

        eh.ensure(added_models, "no models have been added")

        return rules_to_deploy

    def _build_pkginfo(self, resources, model_name, update_period=None):
        resource = resources[0]
        combined_id = resource.id if len(resources) == 1 else hash(tuple([r.id for r in resources]))
        created_time = min([self._get_resource_publish_time(r) for r in resources])
        result = {
            'attributes': {
                'id': combined_id,
                'type': str(resource.type),
                'task_id': getattr(resource, 'task_id', None),
                'owner': resource.owner,
                'attributes': {
                    k: v for k, v in resource
                    if k != "ttl"  # fix for serialization inf as Infinity
                },
                'time': {
                    'created': created_time,
                },
                'model_name': model_name,
            }
        }
        if update_period is not None:
            eh.ensure(isinstance(update_period, datetime.timedelta), 'unknown update_period format')
            result['attributes']['update_period'] = update_period.total_seconds()
        return result

    def on_execute(self):
        rules_to_deploy = self._prepare_caesar_models_data()
        version_file = add_realtime_version_file(WORKING_DIR, self.id)
        resource_path = WORKING_DIR
        if self.Parameters._bmp.pack_to_tar:
            resource_path = os.path.abspath('realtime_wizard_data.tar')
            self._make_archive(resource_path, rules_to_deploy + [version_file])
        _ = BEGEMOT_REALTIME_PACKAGE(
            self,
            BEGEMOT_REALTIME_PACKAGE.name + ' for {}'.format(self.Parameters._bmp.begemot_shard),
            resource_path,
            released='stable',
            shard=self.Parameters._bmp.begemot_shard,
            ttl=3,
            version=self.id,
        )

    @property
    def rule2models(self):
        raise NotImplementedError("rule2models must be implemented in derived class")
