from __future__ import print_function, unicode_literals

import argparse
import logging
import signal
import sys
import time
import json

from google.protobuf import json_format

from sandbox.common import threading as sb_threading
from sandbox.taskbox import binary as sb_cli
from sandbox.common.types import misc as sb_misc

from tasklet import runtime
from tasklet.api import tasklet_pb2
from tasklet.basic.proto import remote_exec_tasklet
from tasklet.domain.sandbox import utils as sb_utils
from tasklet.runtime import context as rt_context
from tasklet.runtime import dispatch
from tasklet.runtime import utils
from tasklet.services import server
from tasklet.cli import proto_utils

service_logger = logging.getLogger()


def parse_args():
    parser = argparse.ArgumentParser()

    sb_opts_parser = argparse.ArgumentParser(add_help=False)
    sb_group = sb_opts_parser.add_argument_group("Sandbox arguments")
    sb_group.add_argument("--sb-url", metavar="URL", help="Sandbox base URL")
    sb_group.add_argument("--sb-owner", metavar="OWNER", help="Sandbox owner")
    sb_group.add_argument("--sb-use-skynet", action="store_true", help="Use skynet for binary upload")
    sb_group.add_argument("--sb-schema", action="store_true", help="Store schema in sandbox attributes")

    subparsers = parser.add_subparsers(title="subcommands", dest="action")
    subparsers.required = True

    run_cmd = subparsers.add_parser("run", help="execute tasklet", parents=[sb_opts_parser])
    run_cmd.add_argument("tasklet_impl_name", metavar="TASKLET", help="tasklet implementation name")

    run_cmd.add_argument(
        "--server", action="store_const", const=True, dest="server",
        help="start local server to display current status"
    )

    world = run_cmd.add_mutually_exclusive_group(required=True)
    world.add_argument(
        "--test", action="store_const", const=tasklet_pb2.Domain.TEST, dest="world",
        help="execute locally (test)"
    )
    world.add_argument(
        "--local", action="store_const", const=tasklet_pb2.Domain.LOCAL, dest="world",
        help="execute locally"
    )
    world.add_argument(
        "--sandbox", "--sandbox-tasklet", action="store_const", const=tasklet_pb2.Domain.SANDBOX, dest="world",
        help="execute in Sandbox using TASKLET_* task"
    )
    world.add_argument(
        "--yt", action="store_const", const=tasklet_pb2.Domain.YT, dest="world",
        help="execute in YT"
    )
    world.add_argument(
        "--sandbox-glycine", action="store_const", const="SANDBOX_GLYCINE", dest="world",  # Deprecated
        help="execute in Sandbox using GLYCINE_2 task"
    )

    run_cmd.add_argument("--input", help="json dump of input parameters", default="{}")

    misc_group = run_cmd.add_argument_group("miscellaneous arguments")
    misc_group.add_argument("-v", "--verbose", action="store_true", help="Increase verbosity")

    subparsers.add_parser("sb-upload", help="upload Sandbox task binary", parents=[sb_opts_parser])

    schema = subparsers.add_parser("schema", help="get tasklet input/output schema")
    schema.add_argument("tasklet_name", metavar="TASKLET", help="tasklet name")
    schema.add_argument(
        "-b", "--binary", action="store_true",
        help="serialize schema into base64 encoded binary format"
    )

    subparsers.add_parser("list", help="list tasklet names")
    return parser.parse_args()


def lookup_tasklet_impl(tasklet_impl_name, logger, find_any=False):
    impl_cls_path = dispatch.impl_class_path(tasklet_impl_name)
    if not impl_cls_path:
        possible_impl_names = dispatch.name_to_impl(tasklet_impl_name)
        if not possible_impl_names:
            logger.error("Unknown tasklet type %s", tasklet_impl_name)
            sys.exit(2)
        elif len(possible_impl_names) == 1 or find_any:
            tasklet_impl_name = possible_impl_names[0]
            impl_cls_path = dispatch.impl_class_path(tasklet_impl_name)
        else:
            logger.error("Uncertainty, found more than one implementation: %s", possible_impl_names)
            sys.exit(2)

    impl_cls = utils.import_symbol(impl_cls_path)
    return impl_cls, impl_cls_path, tasklet_impl_name


def list_tasklet_names(logger):
    dispatch.initialize_tasklet_registry()
    names = dispatch.list_tasklet_names() or []
    logger.info("There are %s tasklets", len(names))
    print(json.dumps(names))


def run(args, logger):
    dispatch.initialize_tasklet_registry()
    impl_cls, impl_cls_path, tasklet_impl_name = lookup_tasklet_impl(args.tasklet_impl_name, logger)
    logger.debug("Tasklet implementation %s has found for %s", impl_cls_path, args.tasklet_impl_name)
    holder = impl_cls.__holder_cls__

    if args.world == tasklet_pb2.Domain.SANDBOX and sb_misc.OSFamily.from_system_name() != sb_misc.OSFamily.LINUX:
        logger.error(
            "It's forbidden to execute in Sandbox tasklet built with OS other than LINUX."
            " For more info see https://docs.yandex-team.ru/ya-make/manual/project_specific/tasklet#binary"
        )
        raise SystemExit(1)

    ctx = rt_context.setup(args.world, holder.Context.DESCRIPTOR)

    if args.server:
        ctx.entries.add().CopyFrom(server.cons_server_cli())

    input_msg = holder.Input()
    json_format.Parse(args.input, input_msg, ignore_unknown_fields=True)

    req = tasklet_pb2.JobStatement()
    req.input.Pack(input_msg)
    req.name = tasklet_impl_name
    req.ctx.CopyFrom(ctx)

    remote_executor = tasklet_pb2.JobInstance()
    remote_executor.statement.CopyFrom(remote_exec_tasklet.RemoteExec(ctx, req).statement)
    run_result = runtime.launch(remote_executor)

    if not run_result.success:
        if run_result.is_python_error:
            logger.error("Traceback:\n%s", run_result.python_error.tb)
        raise Exception("Internal error occurred: {}".format(run_result.error))
    job_result = tasklet_pb2.JobResult()
    run_result.output.Unpack(job_result)
    if not job_result.success:
        if job_result.is_python_error:
            logger.error("Traceback:\n%s", job_result.python_error.tb)
        raise Exception("Tasklet failed with error: {}".format(job_result.error))

    print(json_format.MessageToJson(job_result))

    if args.server:
        logger.info("Tasklet execution ended. Waiting KeyboardInterrupt to shutdown http server.")
        while True:
            try:
                time.sleep(1)
            except KeyboardInterrupt:
                logger.info("KeyboardInterrupt caught, shutdown server")
                break


def schema(args, logger):
    dispatch.initialize_tasklet_registry()
    print(get_schema(args.tasklet_name, args.binary, logger))


def get_schema(tasklet_name, binary, logger):
    implementations = dispatch.name_to_impl(tasklet_name)
    if not implementations:
        logger.error("Tasklet %s is not found. You can check available tasklets using 'list' command", tasklet_name)
        sys.exit(2)

    impl_cls, _, _ = lookup_tasklet_impl(tasklet_name, logger, find_any=True)
    return proto_utils.serialize_description(implementations, impl_cls, binary)


def sb_upload(args, logger):
    dispatch.initialize_tasklet_registry()
    names = dispatch.list_tasklet_names()

    schemas = {}
    if args.sb_schema:
        for name in names:
            implementations = dispatch.name_to_impl(name)
            impl_cls, impl_cls_path, tasklet_impl_name = lookup_tasklet_impl(name, logger, find_any=True)
            schemas["schema_" + name] = proto_utils.serialize_description(implementations, impl_cls, binary_meta=True)

    print(sb_utils.tasks_binary_resource(
        owner=args.sb_owner,
        enable_taskbox=True,
        sandbox_url=args.sb_url,
        use_skynet=args.sb_use_skynet,
        extra_attrs=schemas,
    ))


def _ensure_py3(logger):
    if sys.version_info.major == 2:
        logger.warning(
            "You are using Python 2 program. Tasklet core libraries allow to use both major versions of python. "
            "Please migrate your binaries to Python 3."
        )


def main():
    args = parse_args()

    logger, log_path = sb_cli.setup_logging(
        "tasklet_run",
        logging.DEBUG if getattr(args, "verbose", False) else logging.INFO,
        base_logger="tasklet",
        fmt="%(asctime)s.%(msecs)d %(levelname)-6s %(threadName)s (%(pathname)s:%(lineno)d) %(message)s",
    )
    service_logger.debug(str(sys.argv))

    signal.signal(signal.SIGUSR2, lambda *_: sb_threading.dump_threads(logger))

    _ensure_py3(logger)

    try:
        if args.action == "run":
            run(args, logger)
        elif args.action == "sb-upload":
            sb_upload(args, logger)
        elif args.action == "schema":
            schema(args, logger)
        elif args.action == "list":
            list_tasklet_names(logger)
    except KeyboardInterrupt:
        logger.error("Interrupted. More logs here: %s", log_path)
        service_logger.debug("Last trace", exc_info=True)
        sys.exit(1)
    except BaseException:
        service_logger.exception("Unhandled exception detected")
        logger.error("Unhandled exception detected. More logs here: %s", log_path)
        raise


if __name__ == "__main__":
    main()
