#!/usr/bin/env python2

from __future__ import print_function

import os
import py
import sys
import copy
import json
import getpass
import inspect
import logging
import argparse
import datetime as dt
import importlib

import six
from six.moves import urllib_parse as urlparse

if six.PY2:
    import pathlib2 as pathlib
else:
    import pathlib

from sandbox.common import auth as common_auth
from sandbox.common import hash as common_hash
from sandbox.common import rest as common_rest
from sandbox.common import errors as common_errors
from sandbox.common import format as common_format
from sandbox.common import upload as common_upload
from sandbox.common import console as common_console
from sandbox.common import patterns as common_patterns
from sandbox.common import itertools as common_itertools
from sandbox.common import projects_handler

import sandbox.common.types.misc as ctm
import sandbox.common.types.resource as ctr
import sandbox.common.types.task as ctt

from sandbox import sdk2
from sandbox.taskbox import age


service_logger = logging.getLogger(__name__)

DEFAULT_LOGS_DIR = os.path.expanduser("~/.sandbox/logs")
LOGS_DATEFMT = "%Y-%m-%d %H:%M:%S"

REVISION_FILE_NAME = ".revision"

TOKEN_ENV_VAR_NAME = "SANDBOX_TOKEN"
LOGS_DIR_VAR_NAME = "LOGS_DIR"


def svn_version():
    import library.python.svn_version as sv
    return sv.svn_version()


def svn_revision():
    import library.python.svn_version as sv
    return sv.svn_revision()


def dump_sources(dir_path):
    from library.python.sfx.extract_program import extract_program_sources
    return extract_program_sources(sys.argv[0], dir_path)


def dump_resources(dir_path, key_prefix="sandbox/"):
    from library.python import resource
    resources_inside = []
    for res_key, res_value in resource.iteritems():
        # assume that resource key is a full path in arcadia
        if res_key.startswith(key_prefix):
            py.path.local(dir_path).join(res_key).write_binary(res_value, ensure=True)
            resources_inside.append(res_key)
    return sorted(resources_inside)


def extract_resource(path):
    from library.python import resource
    return resource.find(path)


def extract_module(path):
    try:
        return inspect.getsource(importlib.import_module(path))
    except (ImportError, IOError):
        return None


def resource_attributes(binary_hash):
    return {
        ctr.BinaryAttributes.REVISION: svn_revision(),
        ctr.BinaryAttributes.BINARY_HASH: binary_hash,
        ctr.BinaryAttributes.BINARY_AGE: str(age.AGE),
    }


class LazyAuth(object):

    def __init__(self, token):
        self.__raw_token = token

    @common_patterns.singleton_classproperty
    def current_user(cls):
        return getpass.getuser()

    @common_patterns.singleton_classproperty
    def default_token_path(cls):
        return ctm.Upload.TOKEN_CACHE_FILENAME_TMPL.format(cls.current_user)

    @common_patterns.singleton_property
    def token(self):
        token_path = os.path.abspath(os.path.expanduser(self.__raw_token))
        env_token = os.environ.get(TOKEN_ENV_VAR_NAME)
        if env_token:
            service_logger.debug("Using token from environment variable {}".format(TOKEN_ENV_VAR_NAME))
            token = env_token
        elif os.path.isfile(token_path):
            with open(token_path, "r") as fh:
                service_logger.debug("Using token from file %s", token_path)
                token = fh.read().strip()
        elif self.__raw_token == self.default_token_path:  # No token as-is, trying to get it from ...
            service_logger.debug("Using token from ssh")
            token = common_console.Token(
                Sandbox.BASE_URL, self.current_user, False, open(os.devnull, "w")
            ).get_token_from_ssh(None)
        else:
            service_logger.debug("Using provided token")
            token = self.__raw_token
        if token and common_console.Token(
            Sandbox.BASE_URL, self.current_user, False, open(os.devnull, "w")
        ).check(token):
            return token
        service_logger.error(
            "Failed to get valid token, "
            "for more info see https://docs.yandex-team.ru/sandbox/dev/binary-task#cli-auth"
        )
        raise SystemExit(1)

    @property
    def auth(self):
        return common_auth.OAuth(self.token) if self.__raw_token is not None else common_auth.NoAuth()


class Sandbox(object):

    BASE_URL = "https://sandbox.yandex-team.ru"
    BASE_PROXY_URL = "http://proxy.sandbox.yandex-team.ru"

    def __init__(self, base_url, base_proxy_url, auth):
        """
        :type auth: `LazyAuth`
        """
        self.base_url = base_url
        self.base_proxy_url = base_proxy_url
        self._auth = auth

    @property
    def auth(self):
        return self._auth.auth

    @property
    def base_api_url(self):
        return urlparse.urljoin(self.base_url, "/api/v1.0")

    @common_patterns.singleton_property
    def rest(self):
        return common_rest.Client(base_url=self.base_api_url, auth=self.auth)

    def task_url(self, task_id):
        return urlparse.urljoin(self.base_url, "task/{}".format(task_id))

    def resource_url(self, resource_id):
        return urlparse.urljoin(self.base_url, "resource/{}".format(resource_id))


class CommandLineUploader(object):

    def __init__(self, sandbox, logger):
        """
        :type sandbox: `Sandbox`
        """
        self.sb = sandbox
        self.logger = logger

    @staticmethod
    def _file_meta(path, meta_type=common_upload.HTTPHandle.FileMeta):
        sz = os.path.getsize(path)
        return meta_type(handle=open(path, "rb"), size=sz, name=os.path.basename(path))

    def _skynet_id(self, path):
        try:
            import api.copier
            import api.skycore
        except ImportError:
            self.logger.error("Skynet python API not found, option --skynet is not available")
            raise SystemExit(1)
        try:
            copier = api.copier.Copier()
        except (api.skycore.SkycoreError, RuntimeError) as error:
            self.logger.error("Skynet API error, option --skynet is not available: %s", error)
            raise SystemExit(1)
        self.logger.debug("Sharing %s", path)
        res_id = copier.create([os.path.basename(path)], cwd=os.path.dirname(path)).resid()
        self.logger.debug("Shared as %s", res_id)
        return res_id

    def _upload(self, path, resource_type, description, owner, use_skynet, use_mds, attributes):
        resource_meta = common_upload.HTTPHandle.ResourceMeta(
            type=resource_type,
            arch=ctm.OSFamily.from_system_name(),
            owner=owner,
            description=description,
            attributes=attributes,
            release_to_yd=False,
        )
        if use_skynet:
            uploader = common_upload.SkynetHandle(
                resource_meta, self.sb.auth, self._skynet_id(path), self.sb.base_url, 30, self._file_meta(path)
            )
        else:
            uploader_cls = common_upload.MDSHandle if use_mds else common_upload.HTTPHandle
            uploader = uploader_cls(
                resource_meta, self.sb.auth, self.sb.base_url, self.sb.base_proxy_url, 30, self._file_meta(path)
            )
        state, last_state, last_state_copy, last_share_state = (None,) * 4

        resource_id = None
        for state in uploader():
            if last_state != state:
                if isinstance(last_state, ctm.Upload.Check):
                    self.logger.info(
                        "%s file%s (%s) to upload",
                        last_state.amount, "s" if last_state.amount > 1 else "", common_format.size2str(last_state.size)
                    )
                if isinstance(state, ctm.Upload.Check):
                    self.logger.debug("Calculating total files size")
                if isinstance(state, ctm.Upload.Prepare):
                    self.logger.debug("Preparing upload task")
                if isinstance(state, ctm.Upload.Share):
                    self.logger.debug("Sharing uploaded data")
            else:
                if isinstance(state, ctm.Upload.Prepare):
                    if last_state_copy.task_id != state.task_id:
                        self.logger.info(
                            "Uploading task #%s created: %s", state.task_id, self.sb.task_url(state.task_id)
                        )
                    if last_state_copy.resource_id != state.resource_id:
                        resource_id = state.resource_id
                        url = self.sb.resource_url(state.resource_id)
                        self.logger.debug("Resource #%s registered: %s", state.resource_id, url)
                if isinstance(state, ctm.Upload.ShareResource):
                    if last_share_state != state.resource_state:
                        self.logger.info("Resource is in %s state", state.resource_state)
                        last_share_state = state.resource_state
                elif isinstance(state, ctm.Upload.Share):
                    if last_share_state != state.task_state:
                        self.logger.info("Task is in %s state", state.task_state)
                        last_share_state = state.task_state

            last_state = state
            last_state_copy = copy.deepcopy(state)

        if isinstance(state, ctm.Upload.Share):
            if isinstance(state, ctm.Upload.ShareResource):
                if last_share_state != state.resource_state:
                    self.logger.info("Resource state: '%s'", state.resource_state)
            else:
                if last_share_state != state.task_state:
                    self.logger.info("Task state: '%s'", state.task_state)
            if state.skynet_id:
                self.logger.debug("Skynet copier ID = %s", state.skynet_id)
            if state.md5sum:
                self.logger.debug("MD5 checksum = %s", state.md5sum)
        return resource_id

    def upload(
        self, path, resource_type, description, owner, binary_hash,
        use_skynet, use_mds, extra_attrs, force, enable_taskbox=None
    ):
        filtering_attrs = {ctr.BinaryAttributes.BINARY_HASH: binary_hash}
        if enable_taskbox is not None:
            filtering_attrs[ctr.BinaryAttributes.TASKBOX_ENABLED] = str(enable_taskbox)
        resources = self.sb.rest.resource.read(
            type=resource_type, attrs=filtering_attrs, state=ctr.State.READY, limit=1
        )["items"]
        if resources:
            res = resources[0]
            self.logger.info(
                "%sse already uploaded resource: %s", "Don't u" if force else "U", self.sb.resource_url(res["id"])
            )
            if not force:
                return res
        attributes = resource_attributes(binary_hash)
        if enable_taskbox is not None:
            attributes[ctr.BinaryAttributes.TASKBOX_ENABLED] = enable_taskbox
        if extra_attrs:
            attributes.update(extra_attrs)
        resource_id = self._upload(path, resource_type, description, owner, use_skynet, use_mds, attributes)
        return self.sb.rest.resource[resource_id].read()


class CommandHandler(object):

    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        self.results = {}

    @common_patterns.singleton_property
    def bin_path(self):
        return py.path.local(sys.argv[0])

    @common_patterns.singleton_property
    def bin_hash(self):
        return common_hash.md5sum(str(self.bin_path))

    @common_patterns.singleton_property
    def sb(self):
        return Sandbox(
            self.args.sandbox_url, self.args.proxy_url, LazyAuth(None if self.args.no_auth else self.args.token)
        )

    def _upload(self):
        uploader = CommandLineUploader(self.sb, self.logger)
        attrs = getattr(self.args, "attr", [])
        attrs = dict(attr.split("=", 1) for attr in attrs) if attrs else {}
        use_skynet = self.args.skynet
        use_mds = self.args.mds
        resource = uploader.upload(
            sys.argv[0],
            sdk2.service_resources.SandboxTasksBinary.name,
            "Tasks binary.",
            self.args.owner,
            self.bin_hash,
            use_skynet=use_skynet,
            use_mds=use_mds,
            extra_attrs=attrs,
            force=getattr(self.args, "force", False),
            enable_taskbox=getattr(self.args, "enable_taskbox", False),
        )
        self.results["resource"] = {_: resource[_] for _ in ("id", "type", "task", "attributes")}
        return resource["id"]

    def upload(self):
        self._upload()
        print(json.dumps(self.results))

    def content(self):
        if self.args.task_type:
            types = projects_handler.load_project_types()
            loc = types.get(self.args.task_type)
            if loc:
                print(json.dumps(
                    {
                        "type": self.args.task_type,
                        "class": loc.cls.__name__,
                        "module": loc.cls.__module__,
                    }
                ))
            if not loc:
                self.logger.error("Task type %r is not defined in binary.", self.args.task_type)
            sys.exit(not loc)
        elif self.args.list_types:
            print(json.dumps(
                {"types": sorted(projects_handler.load_project_types())}
            ))
        elif self.args.dump_to:
            dump_sources(self.args.dump_to)
            dump_resources(self.args.dump_to)
        elif self.args.extract_resource or self.args.extract_module:
            content = (
                extract_resource(self.args.extract_resource)
                if self.args.extract_resource else
                extract_module(self.args.extract_module)
            )
            if content is None:
                template = "{} with name '{}' doesn't exist in binary"
                print(
                    template.format("Resource", self.args.extract_resource)
                    if self.args.extract_resource else
                    template.format("Python module", self.args.extract_module),
                    file=sys.stderr,
                )
                sys.exit(1)
            else:
                print(content)
        elif self.args.self_info:
            print(json.dumps({
                "resource_attributes": resource_attributes(self.bin_hash),
                "vcs": {
                    "revision": svn_revision(),
                    "version": svn_version(),
                },
            }))

    def ipython(self):
        # This import is rather slow (~0.5s), don't want it on global level
        import IPython
        IPython.start_ipython([])

    def interpret(self):
        path = self.args.script_path
        script_file = open(path) if path and path != "-" else sys.stdin
        sys.argv = sys.argv[:1] + self.args.remaining_args
        exec(script_file.read(), globals())

    def run(self):
        if getattr(self.args, "enable_taskbox", False) and ctm.OSFamily.from_system_name() != ctm.OSFamily.LINUX:
            self.logger.error(
                "It's forbidden to run binaries built with OS other than LINUX and flag '--enable-taskbox'."
                " For more info see https://docs.yandex-team.ru/sandbox/dev/binary-task#server-side"
            )
            raise SystemExit(1)
        if self.args.payload == "-":
            kwargs = json.load(sys.stdin)
        elif self.args.payload:
            kwargs = json.loads(self.args.payload)
        else:
            kwargs = {}
        if self.args.owner:
            kwargs["owner"] = self.args.owner
        task_type = type_source = ""
        if self.args.tid:
            kwargs["source"] = self.args.tid
            task_type = self.sb.rest.task[self.args.tid].read()["type"]
            type_source = " of task #{}".format(self.args.tid)
        elif self.args.template:
            kwargs["scheduler_id"] = self.args.template
            task_type = self.sb.rest.scheduler[self.args.template].read()["task"]["type"]
            type_source = " of template #{}".format(self.args.template)
        elif self.args.type:
            kwargs["type"] = task_type = self.args.type

        project_types = projects_handler.load_project_types()
        if not task_type and len(project_types) == 1:
            kwargs["type"] = task_type = list(project_types)[0]

        if not task_type:
            self.logger.error(
                "The binary has multiple task types, one of the arguments --tid --template --type is required"
            )
            sys.exit(1)

        try:
            task_cls = sdk2.Task[task_type]
        except common_errors.UnknownTaskType:
            self.logger.error("Task type %r%s is not defined in binary.", task_type, type_source)
            sys.exit(1)
        if issubclass(task_cls, sdk2.OldTaskWrapper):
            self.logger.error(
                "Task type %r%s is based on SDK1. Only SDK2 tasks can be run as binary task.", task_type, type_source
            )
            sys.exit(1)

        requirements = kwargs.setdefault("requirements", {})
        requirements["tasks_resource"] = self._upload()

        task = self.sb.rest.task(**kwargs)
        self.results["task"] = {_: task[_] for _ in ("id", "type", "owner")}
        self.logger.info("Task of type %s created: %s", task["type"], self.sb.task_url(task["id"]))
        if not self.args.create_only:
            resp = self.sb.rest.batch.tasks.start.update([task["id"]])
            self.logger.info(resp[0]["message"])

        if self.args.wait:
            break_statuses = tuple(ctt.Status.Group.FINISH) + tuple(ctt.Status.Group.BREAK)
            self.logger.info("Waiting for the task to finish")

            current_status = ""

            for _, _ in common_itertools.progressive_yielder(0.5, 3, float("inf")):
                sb_task = self.sb.rest.task.read(
                    {"limit": 10, "id": task["id"], "fields": ["status"]}
                )["items"][0]

                if sb_task["status"] != current_status:
                    current_status = sb_task["status"]
                    self.logger.info("Task status is %s", current_status)

                    if current_status in break_statuses:
                        self.results["task"]["execution_status"] = current_status
                        break

        print(json.dumps(self.results))


_ARGUMENTS = {
    "owner": (
        ["-o", "--owner"],
        dict(
            default=LazyAuth.current_user,
            help="Task's owner group (default: {}).".format(LazyAuth.current_user),
        )
    ),
    "token": (
        ["--token"],
        dict(
            default=LazyAuth.default_token_path,
            help="Sandbox token or path to file with token (default: {}). Ignore this parameter if environmental "
                 "variable {} is set.".format(LazyAuth.default_token_path, TOKEN_ENV_VAR_NAME)
        )
    ),
    "no_auth": (["--no-auth", "--no_auth"], dict(action="store_true", help="Disable authentication.")),
    "verbose": (["-v", "--verbose"], dict(action="store_true", help="Increase verbosity.")),
    "sandbox_url": (
        ["--url"],
        dict(
            default=Sandbox.BASE_URL,
            dest="sandbox_url",
            help="Base url to sandbox API. Default: {}".format(Sandbox.BASE_URL),
        )
    ),
    "proxy_url": (
        ["--proxy-url"],
        dict(
            default=Sandbox.BASE_PROXY_URL,
            dest="proxy_url",
            help="Base url to sandbox proxy API. Default: {}".format(Sandbox.BASE_PROXY_URL),
        )
    ),
}


def _add_parser_arguments(parser, *args):
    for arg_name in args:
        arg_args, arg_kwargs = _ARGUMENTS[arg_name]
        parser.add_argument(*arg_args, **arg_kwargs)


def _add_upload_arguments(parser):
    group = parser.add_argument_group("tasks resource uploading")
    upload_via = group.add_mutually_exclusive_group()
    upload_via.add_argument("--skynet", action="store_true", help="Upload binary via skybone (skynet copier) service.")
    upload_via.add_argument("--mds", action="store_true", default=True, help="Upload binary to MDS via proxy.")
    group.add_argument(
        "--force", action="store_true", help="Force upload: ignore another resources with same content (binary)."
    )
    group.add_argument(
        "--attr", type=str, action="append", help="Resource attribute in <name>=<value> form. Multiple argument."
    )
    group.add_argument(
        "--enable-taskbox", action="store_true",
        help="Enable server-side hooks executing for tasks created using this binary."
    )


def _parse_args():
    parser = argparse.ArgumentParser(description="Tasks binary.")
    subparsers = parser.add_subparsers(
        help="<description>", dest="command", metavar="<command>", title="subcommands",
    )

    # run
    run_cmd = subparsers.add_parser(
        CommandHandler.run.__name__,
        description="Run task from binary.",
        formatter_class=argparse.RawTextHelpFormatter,
        help="Subcommand to run task.",
    )

    _add_parser_arguments(run_cmd, "owner", "token", "no_auth")

    group = run_cmd.add_argument_group("task creation alternatives").add_mutually_exclusive_group(required=False)
    group.add_argument("--tid", metavar="<id>", type=int, help="Task ID to copy.")
    group.add_argument("--template", metavar="<id>", type=int, help="Template (scheduler) ID to use.")
    group.add_argument("--type", help="Task type.")

    run_cmd.add_argument("--create-only", action="store_true", help="Do not start execution of created task.")
    run_cmd.add_argument(
        "payload", nargs="?", type=str,
        help="Json with arbitrary parameters of task creation API-request. Pass '-' to read payload from stdin. "
             "Scheme: https://sandbox.yandex-team.ru/media/swagger-ui/index.html#/task/task_list_post"
    )
    run_cmd.add_argument("--wait", action="store_true", help="Wait until task is executed")
    _add_upload_arguments(run_cmd)
    _add_parser_arguments(run_cmd, "sandbox_url", "proxy_url", "verbose")

    # upload
    upload_cmd = subparsers.add_parser(CommandHandler.upload.__name__, help="Binary uploading subcommand.")
    _add_upload_arguments(upload_cmd)
    _add_parser_arguments(upload_cmd, "owner", "token", "no_auth", "sandbox_url", "proxy_url", "verbose")

    # content
    content_cmd = subparsers.add_parser(CommandHandler.content.__name__, help="Binary content analysis.")
    group = content_cmd.add_argument_group().add_mutually_exclusive_group(required=True)
    group.add_argument("--task-type", metavar="<task type>", type=str.upper, help="Check existence of task type.")
    group.add_argument("--list-types", action="store_true", help="List defined task types.")
    group.add_argument("--dump-to", metavar="<path>", type=str, help="Dump all inner code do specified directory.")
    group.add_argument(
        "--extract-resource", metavar="<path>", type=str, help="Extract content of resource file by it key."
    )
    group.add_argument(
        "--extract-module", metavar="<path>", type=str, help="Extract content of python module by it import path."
    )
    group.add_argument("--self-info", action="store_true", help="Dump info about binary itself.")
    _add_parser_arguments(content_cmd, "verbose")

    # ipython
    subparsers.add_parser(CommandHandler.ipython.__name__, help="Run ipython using built tasks code.")

    # interpret
    interpret_cmd = subparsers.add_parser(
        CommandHandler.interpret.__name__, help="Interpret given script using built environment."
    )
    interpret_cmd.add_argument(
        "script_path", nargs="?", metavar="<path>",
        help="Path to script to interpret or '-' to read script from stdin. If absent then stdin is used also."
    )
    interpret_cmd.add_argument("remaining_args", nargs=argparse.REMAINDER, help="Other arguments.")
    if len(sys.argv) == 1:
        parser.print_help(sys.stderr)
        parser.exit()
    return parser.parse_args()


def setup_logging(log_file_name, level, base_logger=None, fmt="%(asctime)s %(levelname)s %(message)s"):
    now = dt.datetime.now()
    if LOGS_DIR_VAR_NAME in os.environ:
        logs_dir = os.environ[LOGS_DIR_VAR_NAME]
    else:
        logs_dir = DEFAULT_LOGS_DIR
    log_path = os.path.join(
        logs_dir, now.date().isoformat(), "{}.{}.log".format(log_file_name, now.strftime("%H-%M-%S"))
    )

    exc = None
    try:
        if not os.path.isdir(os.path.dirname(log_path)):
            os.makedirs(os.path.dirname(log_path))
        open(log_path, "a").close()
    except EnvironmentError as exc:
        if exc.errno == pathlib.EACCES:
            pass  # Error on try to create log file. Ignore it but redirect logging to stderr.
        else:
            raise

    logging.root.setLevel(logging.DEBUG)
    logging.root.handlers = []
    fh = logging.FileHandler(log_path) if exc is None else logging.StreamHandler(sys.stderr)
    fh.setFormatter(logging.Formatter(
        "%(asctime)s\t%(levelname)s\t%(threadName)s\t(%(pathname)s:%(lineno)d)\t%(message)s",
        LOGS_DATEFMT
    ))
    logging.root.addHandler(fh)

    logger = logging.getLogger(base_logger or log_file_name)
    logger.setLevel(logging.NOTSET)  # propagate all records to root logger

    if exc is not None:
        logging.warning("Logging to file is disabled because of error: %s", exc)
        return logger, None

    handler = logging.StreamHandler(sys.stderr)
    handler.setFormatter(logging.Formatter(fmt, LOGS_DATEFMT))
    handler.setLevel(level)
    logger.addHandler(handler)

    if hasattr(os, "symlink"):
        last_log_path = os.path.join(logs_dir, "last_{}.log".format(log_file_name))
        if os.path.lexists(last_log_path):
            os.unlink(last_log_path)
        os.symlink(log_path, last_log_path)
    else:
        # os.symlink not supported on Windows
        last_log_path = log_path
    return logger, last_log_path


def main():
    args = _parse_args()

    logger = logging.getLogger()
    log_path = None
    if args.command != CommandHandler.interpret.__name__:
        logger, log_path = setup_logging("cli_run", logging.DEBUG if getattr(args, "verbose", False) else logging.INFO)

    logging.debug(str(sys.argv))

    try:
        getattr(CommandHandler(args, logger), args.command)()
    except KeyboardInterrupt:
        logger.error("Interrupted. More logs here: %s", log_path)
        service_logger.debug("Last trace", exc_info=True)
        sys.exit(1)
    except BaseException as exc:
        if isinstance(exc, SystemExit):
            service_logger.info("Exiting with code %s", exc.code)
            if exc.code != 0:
                logger.error("Exited with non-zero code. More logs here: %s", log_path)
        else:
            service_logger.exception("Unhandled exception detected")
            logger.error("Unhandled exception detected. More logs here: %s", log_path)
        raise


if __name__ == "__main__":
    main()
