from __future__ import absolute_import

import os
import sys
import stat
import time
import uuid
import errno
import select
import shutil
import logging
import requests
import operator as op
import threading as th
import collections

import six

if six.PY2:
    import pathlib2 as pathlib
    import subprocess32 as subprocess
else:
    # noinspection PyUnresolvedReferences
    import pathlib
    import subprocess


from sandbox.common import log as common_log
from sandbox.common import config
from sandbox.common import format as common_format
from sandbox.common import itertools as common_itertools


from . import gdb as gdb_helpers
from . import misc as misc_helpers


__all__ = ("subprocess", "ProcessLog", "ProcessRegistry")


def _current_task():
    from sandbox import sdk2
    return sdk2.Task.current


class ProcessLog(object):
    class CalledProcessError(subprocess.CalledProcessError):
        def __init__(self, returncode, cmd, stdout_file=None, stderr_file=None, log_resource=None, vault_filter=None):
            self._stdout_file = stdout_file
            self._stderr_file = stderr_file
            self._log_resource = log_resource
            self.vault_filter = vault_filter
            super(ProcessLog.CalledProcessError, self).__init__(returncode, cmd)

        def get_task_info(self):
            ret = [str(self)]
            if self._log_resource:
                if self._stdout_file:
                    info = gdb_helpers.get_html_view_for_logs_file("stdout", self._stdout_file, self._log_resource)
                    if info:
                        ret.extend(["<br/>", info])
                if self._stderr_file:
                    info = gdb_helpers.get_html_view_for_logs_file("stderr", self._stderr_file, self._log_resource)
                    if info:
                        ret.extend(["<br/>", info])
                if not (self._stdout_file or self._stderr_file):
                    info = gdb_helpers.get_html_view_for_logs_file("task log", "common.log", self._log_resource)
                    if info:
                        ret.extend(["<br/>", info])
            result = "".join(ret + ["<hr>"])
            if self.vault_filter:
                result = self.vault_filter.filter_message(result, escape=True)
            return result

    class Pipe(object):
        buffer = ""

        def __init__(self, process_log, path, logger, log_level, formatter):
            self.__process_log = process_log
            self.__path = path and self.__unique_log(path)
            self.__log_file_handle = None
            if self.__path:
                open_flags = os.O_WRONLY | os.O_TRUNC | os.O_CREAT
                if os.name == "nt":
                    open_flags |= os.O_NOINHERIT | os.O_BINARY
                file_handle = os.open(str(self.__path), open_flags)
                self.__log_file_handle = logging.StreamHandler(os.fdopen(file_handle, "w"))
                if formatter:
                    self.__log_file_handle.setFormatter(formatter)
                self.__log_file_handle.addFilter(
                    type("Filter", (logging.Filter,), dict(filter=lambda _, rec: op.eq(rec.levelno, log_level)))()
                )
                logger.addHandler(self.__log_file_handle)
            self.__log_level = log_level
            self.__r, self.__w = map(self._set_cloexec_flag, os.pipe())
            self.__name = "<logger '{}'{}>".format(
                logger.name if isinstance(logger, logging.Logger) else logging.root.name,
                ", file {}".format(self.__path) if self.__path else ""
            )

        def __str__(self):
            return self.__name

        def __nonzero__(self):
            return bool(self.__path and self.__path.exists() and self.__path.stat().st_size)

        __bool__ = __nonzero__

        @staticmethod
        def __unique_log(path):
            i = 0
            while True:
                unique_path = path.with_suffix(".".join(common_itertools.chain("", str(i) if i else (), "log")))
                if not unique_path.exists():
                    return unique_path
                i += 1

        @staticmethod
        def _set_cloexec_flag(fd):
            import fcntl
            fcntl.fcntl(fd, fcntl.F_SETFD, fcntl.fcntl(fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC)
            return fd

        @property
        def process_log(self):
            return self.__process_log

        @property
        def path(self):
            return self.__path

        @property
        def read_fd(self):
            return self.__r

        @property
        def log_level(self):
            return self.__log_level

        @property
        def closed(self):
            return self.__w is None

        def fileno(self):
            return self.__w

        def flush(self):
            pass

        def close(self):
            if self.closed:
                return
            os.close(self.__w)
            self.__w = None
            if self.__log_file_handle is not None:
                self.__log_file_handle.flush()
                self.__log_file_handle.close()

    def __init__(
        self,
        task=None,
        logger=None,
        set_action=True,
        stdout_level=logging.INFO,
        stderr_level=logging.ERROR,
        formatter=None,
    ):
        self.__task = task or _current_task()
        self.__logger = logger or logging
        self.__formatter = formatter
        self.__set_action = set_action if task else False
        self.__stdout_level = stdout_level
        self.__stderr_level = stderr_level
        self.__stdout_pipe = None
        self.__stderr_pipe = None
        self.__action = None
        self.__prefix = None
        self._pipes = {}
        self.__bridge_thread = None
        if isinstance(self.__logger, six.string_types):
            self.__prefix = logger
            self.__logger = logging.getLogger("process.{}".format(self.__prefix))
            self.__logger.setLevel(logging.INFO)
            self.__logger.propagate = not logger
            # noinspection PyUnresolvedReferences
            self.__logger.manager.loggerDict.pop(self.__logger.name)
            vault_filter = common_log.VaultFilter.filter_from_logger(logging.getLogger())
            if vault_filter:
                self.__logger.addFilter(vault_filter)

    def __enter__(self):
        return self

    def __exit__(self, ex_type, ex_val, tb):
        self.close()
        if ex_type is subprocess.CalledProcessError and self.__task:
            logs = self.__task.log_path()
            raise self.CalledProcessError(
                ex_val.returncode, ex_val.cmd,
                self.__stdout_pipe and self.__stdout_pipe.path and self.__stdout_pipe.path.relative_to(logs),
                self.__stderr_pipe and self.__stderr_pipe.path and self.__stderr_pipe.path.relative_to(logs),
                self.__task.log_resource, common_log.VaultFilter.filter_from_logger(logging.getLogger())
            )

    def _join(self):
        if self.__bridge_thread:
            self.__bridge_thread.join()

    def close(self):
        for pipe in list(self._pipes.values()):
            pipe.close()
        self._join()

    def action(self, description):
        if self.__set_action:
            self.__action = misc_helpers.ProgressMeter(description)
            return self.__action

    def _pipe_flush(self, pipe):
        if pipe.buffer:
            self.__logger.log(pipe.log_level, pipe.buffer)

    def _pipe_push(self, pipe, data):
        pipe.buffer += six.ensure_str(data)
        if os.linesep in pipe.buffer:
            lines = pipe.buffer.split(os.linesep)
            pipe.buffer = lines[-1]
            for _ in lines[:-1]:
                self.__logger.log(pipe.log_level, _)

    def __bridge(self):
        registered_fds = set(six.viewkeys(self._pipes))
        # use select.select() on systems that does not support select.poll()
        poll = select.poll() if hasattr(select, "poll") else None

        def unregister_fd(fd_):
            if poll is not None:
                poll.unregister(fd_)
                registered_fds.discard(fd_)
            self._pipes.pop(fd_)
            os.close(fd_)

        if poll is not None:
            for fd in registered_fds:
                poll.register(fd, select.POLLIN)
        while self._pipes:
            if poll is not None:
                # register newly added fds
                for fd in six.viewkeys(self._pipes) - registered_fds:
                    registered_fds.add(fd)
                    poll.register(fd, select.POLLIN)
            try:
                events = (
                    [(f, select.POLLIN) for f in select.select(self._pipes, [], [], 1)[0]]
                    if poll is None else
                    poll.poll(1000)
                )
            except select.error as ex:
                if ex.args[0] == errno.EINTR:
                    continue
                raise
            for fd, event in events:
                pipe = self._pipes[fd]
                if event & select.POLLIN:
                    data = os.read(fd, 65536)
                    if not data:
                        unregister_fd(fd)
                        self._pipe_flush(pipe)
                        continue
                    self._pipe_push(pipe, data)
                elif event & (select.POLLHUP | select.POLLERR):
                    unregister_fd(fd)
                    self._pipe_flush(pipe)

            if not events:
                for fd, pipe in list(six.iteritems(self._pipes)):
                    if pipe.closed:
                        # poll() reported nothing, but this pipe is already closed
                        # remove it from read list and close read end
                        unregister_fd(fd)

    def _ensure_bridge_thread(self):
        if self.__bridge_thread is None or not self.__bridge_thread.is_alive():
            self.__bridge_thread = th.Thread(target=self.__bridge)
            self.__bridge_thread.start()

    @property
    def logger(self):
        return self.__logger

    def __log_path(self, filename):
        if self.__prefix:
            if self.__task:
                return self.__task.log_path(filename)
            else:
                path = pathlib.Path(self.__prefix) if os.path.abspath(self.__prefix) else pathlib.Path.cwd()
                return path / filename
        else:
            return None

    @property
    def stdout(self):
        if self.__stdout_pipe is not None:
            return self.__stdout_pipe
        p = self.__stdout_pipe = self.Pipe(
            self,
            self.__prefix and self.__log_path(".".join((self.__prefix, "out", "log"))),
            self.__logger,
            self.__stdout_level,
            self.__formatter,
        )
        self._pipes[p.read_fd] = p
        self._ensure_bridge_thread()
        return p

    @property
    def stderr(self):
        if self.__stderr_pipe is not None:
            return self.__stderr_pipe
        p = self.__stderr_pipe = self.Pipe(
            self,
            self.__prefix and self.__log_path(".".join((self.__prefix, "err", "log"))),
            self.__logger,
            self.__stderr_level,
            self.__formatter,
        )
        self._pipes[p.read_fd] = p
        self._ensure_bridge_thread()
        return p

    def raise_for_status(self, p):
        if p.wait():
            if self.__task and self.__prefix and (self.__stdout_pipe or self.__stderr_pipe):
                logs = self.__task.log_path()
                raise self.CalledProcessError(
                    p.returncode, p.args,
                    self.__stdout_pipe and self.__stdout_pipe.path.relative_to(logs),
                    self.__stderr_pipe and self.__stderr_pipe.path.relative_to(logs),
                    self.__task.log_resource
                )
            else:
                raise subprocess.CalledProcessError(p.returncode, p.args)


if sys.platform == "win32":
    ProcessLog_ = ProcessLog

    class ProcessLog(ProcessLog_):
        class Pipe(ProcessLog_.Pipe):
            @staticmethod
            def _set_cloexec_flag(fd):
                return fd

        def __init__(
            self, task=None, logger=None, set_action=True, stdout_level=logging.INFO, stderr_level=logging.ERROR
        ):
            ProcessLog_.__init__(
                self,
                task=task, logger=logger, set_action=set_action, stdout_level=stdout_level, stderr_level=stderr_level
            )
            self.__bridge_threads = {}

        def __fd_bridge(self, fd):
            pipe = self._pipes[fd]
            while True:
                data = os.read(fd, 65536)
                if not data:
                    self._pipe_flush(pipe)
                    break
                self._pipe_push(pipe, data)

        def _ensure_bridge_thread(self):
            for fd in six.viewkeys(self._pipes) - six.viewkeys(self.__bridge_threads):
                t = self.__bridge_threads[fd] = th.Thread(target=self.__fd_bridge, args=(fd,))
                t.start()

        def _join(self):
            for t in list(self.__bridge_threads.values()):
                t.join()


class ProcessRegistryMeta(type):
    __processes = []
    __local = th.local()
    __gdb_set_up = False
    Process = collections.namedtuple("Process", ("pid", "rpid", "cmd"))

    def register(cls, pid, cmd):
        rpid = _current_task().agentr.register_process(pid)
        cls.__processes.append(cls.Process(pid=pid, rpid=rpid, cmd=cmd))
        return rpid

    def __call__(cls, process, cmd):
        try:
            if cls.__local.enabled:
                cls.register(process.pid, cmd)
                return cls
        except AttributeError:
            pass

    def __iter__(cls):
        for process in cls.__processes:
            yield process

    def __enter__(cls):
        cls.__local.enabled = True
        return cls

    def __exit__(cls, *_):
        cls.__local.enabled = False

    @staticmethod
    def __send_trace_to_aggregator(
        aggregator_url,
        core_trace_text,
        binary_path,
        task_id,
        task_type,
    ):
        """
            Send core trace to stack trace aggregator.
            See https://wiki.yandex-team.ru/cores-aggregation for details

            :param aggregator_url: coredump aggregator url
            :param core_trace_text: trace text from gdb
            :param binary_path: coredumped binary full path
            :task_id: sandbox task id
            :task_type: sandbox task type
        """
        response = requests.post(
            url="{}/submit_core".format(aggregator_url),
            json={
                "parsed_traces": core_trace_text,
                "dump_json": {},
            },
            params={
                "time": int(time.time()),
                # treat binary basename as itype
                "service": os.path.split(binary_path)[-1],
                "ctype": "sandbox-task",
                "server": config.Registry().this.fqdn,
                "prj": task_type,
                "task_id": task_id,
                # TODO(mvel@): pass task tags
            },
            timeout=40,
        )
        response.raise_for_status()
        logging.info("Coredump traceback sent to %s\nResponse: %s", aggregator_url, response.text)

    def __save_gdb_traces(cls, binary_path, coredump_path):
        task = _current_task()
        cmd = gdb_helpers.gdb_trace_command(binary_path, coredump_path)
        core_name = os.path.split(coredump_path)[-1]

        with ProcessLog(task, logger="{}.gdb".format(core_name)) as gdb_pl:
            subprocess.call(cmd, shell=False, stdout=gdb_pl.stdout, stderr=gdb_pl.stdout)

        html_traceback_log = gdb_helpers.get_html_view_for_logs_file(
            "gdb_traceback",
            gdb_pl.stdout.path.name,
            task.log_resource.id
        )
        task.set_info(
            "GDB traceback is available in task logs:<br />{0}".format(html_traceback_log),
            do_escape=False,
        )
        # noinspection PyBroadException
        try:
            from sandbox import sdk2
            filtered_output_name = task.path("{0}.gdb_traceback.filtered.html".format(core_name))
            gdb_helpers.filter_traceback(
                task.id,
                str(task.path("coredump_filter")),
                str(gdb_pl.stdout.path),
                str(filtered_output_name),
            )
            if not filtered_output_name.exists():
                raise Exception("Traceback file {} does not exist".format(filtered_output_name))

            traceback_resource = sdk2.Resource["FILTERED_GDB_TRACEBACK"](
                task, "Filtered traceback", str(filtered_output_name), ttl=120
            )
            sdk2.ResourceData(traceback_resource).ready()

            html_filtered_traceback = gdb_helpers.get_html_view_for_logs_file(
                "filtered_traceback", filtered_output_name.name, traceback_resource.id, is_dir=False
            )
            task.set_info(
                "GDB filtered traceback is also available:<br />{0}".format(html_filtered_traceback),
                do_escape=False,
            )
        except Exception:
            logging.exception("Cannot filter traceback")
        else:
            try:
                with open(str(gdb_pl.stdout.path), "r") as f_core_trace:
                    core_text = f_core_trace.read()
                    cls.__send_trace_to_aggregator(
                        aggregator_url="https://coredumps.yandex-team.ru",
                        core_trace_text=core_text,
                        binary_path=binary_path,
                        task_id=task.id,
                        task_type=task.type,
                    )
                    cls.__send_trace_to_aggregator(
                        aggregator_url="https://coredumps-testing.yandex-team.ru",
                        core_trace_text=core_text,
                        binary_path=binary_path,
                        task_id=task.id,
                        task_type=task.type,
                    )
            except Exception:
                logging.exception("Cannot send traceback to aggregator")

    @staticmethod
    def __save_coredump(coredump_path):
        from sandbox import sdk2
        task = _current_task()
        coredump_path = os.path.abspath(coredump_path)
        logging.debug("Save coredump {0}".format(coredump_path))
        coredump_filename = os.path.basename(coredump_path)
        saved_coredump_path = str(task.path(coredump_filename))
        gzipped_coredump_path = saved_coredump_path + ".gz"
        try:
            if coredump_path != saved_coredump_path:
                shutil.move(coredump_path, saved_coredump_path)
            subprocess.call(["gzip", "-f", saved_coredump_path])
        except OSError:
            logging.exception("Error while moving coredump {}".format(coredump_filename))
            task.set_info("Cannot dump coredump {}".format(coredump_filename))
            return None
        mode = stat.S_IMODE(os.stat(gzipped_coredump_path).st_mode)
        mode |= stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH
        os.chmod(gzipped_coredump_path, mode)
        coredump_resource = sdk2.Resource["CORE_DUMP"](
            task, "{} coredump".format(coredump_filename), gzipped_coredump_path
        )
        sdk2.ResourceData(coredump_resource).ready()
        task.set_info("COREDUMP was saved as resource:{0}".format(coredump_resource.id))
        task.set_info("<hr/>", do_escape=False)
        return coredump_resource

    def __prepare_gdb(cls):
        if not cls.__gdb_set_up:
            from sandbox.sdk2 import environments
            gdb_version = config.Registry().client.tasks.coredumps_gdb_version
            environments.GDBEnvironment(gdb_version).prepare()
            cls.__gdb_set_up = True

    def __get_coredumps(cls, process_list):
        task = _current_task()
        # Next logging was added for UPS-445 debug
        logging.debug("Checking coredumps for pids: %s", process_list)
        coredumps_dir = config.Registry().client.tasks.coredumps_dir
        result = []
        coredumps = task.agentr.coredumps
        if os.path.exists(coredumps_dir):
            for p in process_list:
                binary_path = next(common_itertools.chain(p.cmd)).split()[0]
                core_file = coredumps.get(p.rpid or p.pid)
                if not core_file:
                    binary_name = os.path.split(binary_path)[-1]
                    core_file = gdb_helpers.get_core_path(coredumps_dir, binary_name, p.pid)
                    if core_file:
                        core_file_link = task.agentr.hard_link(core_file) or core_file
                        if core_file != core_file_link:
                            try:
                                os.unlink(core_file)
                            except OSError:
                                pass
                else:
                    core_file_link = core_file

                if core_file:
                    # noinspection PyUnboundLocalVariable
                    task.set_info("Coredump file {} ({}) for process {} was found".format(
                        core_file, common_format.size2str(os.path.getsize(core_file_link)), p.pid
                    ))
                    cls.__prepare_gdb()
                    cls.__save_gdb_traces(binary_path, core_file_link)
                    core_resource = cls.__save_coredump(core_file_link)
                    result.append({
                        "cmd": p.cmd,
                        "pid": p.pid,
                        "core_file": core_file,
                        "core_resource": core_resource.id if core_resource else None
                    })
        else:
            logging.warning("Coredumps folder %s does not exist.", coredumps_dir)
        return result

    def finish(cls):
        result = []
        # noinspection PyBroadException
        try:
            result = cls.__get_coredumps(cls.__processes)
        except Exception:
            logging.exception("Cannot check coredumps.")

        cls.__processes = []
        return result


class ProcessRegistry(six.with_metaclass(ProcessRegistryMeta, object)):
    """Context manager used to register task subprocesses.

    Usage:

    with sdk2.helpers.ProcessRegistry as reg:
        proc1 = sp.Popen(command_line, ...)  # Popen registers processes automatically
        reg.register(pid2, command_line2)  # use register method to register some other process
    """


class SubprocessMeta(type):
    # noinspection PyUnusedLocal
    # noinspection PyMethodParameters
    def __new__(mcs, _, bases, __):
        base = bases[0]
        # noinspection PyPep8Naming
        CalledProcessError = base.CalledProcessError
        six.moves.reload_module(base)
        popen_init = subprocess.Popen.__init__
        popen_wait = subprocess.Popen.wait
        logger = logging.getLogger(base.__name__)

        fake_meter = type("EmptyMeter", (object,), dict(__enter__=lambda _: None, __exit__=lambda *_, **__: None))()

        def init(self, args, **kws):
            self.__id = uuid.uuid4().hex[:8]
            logger.info("[%s] Running: %r", self.__id, args)
            stdout = kws.get("stdout")
            stderr = kws.get("stderr")
            process_log = None
            self.__pipes = []
            for std in filter(lambda _: isinstance(_, ProcessLog.Pipe), (stdout, stderr)):
                self.__pipes.append(std)
                if process_log is None:
                    process_log = std.process_log

            self.__process_log = process_log
            popen_init(self, args, **kws)

            # noinspection PyArgumentList
            ProcessRegistry(self, args)

            logger.debug("[%s] PID=%d OUT=%s ERR=%s", self.__id, self.pid, stdout, stderr)

        def wait(self, timeout=None, endtime=None):
            if endtime is not None:
                timeout = endtime - time.time()
            if timeout is not None:
                logger.debug("[%s] Wait process with PID %s for %s second(s)", self.__id, self.pid, timeout)

            progress_helper = fake_meter
            if self.__process_log:
                title = "[{}] Waiting process {!r}{}".format(
                    self.__id, self.args, "" if timeout is None else " ({} seconds timeout)".format(timeout)
                )
                # `or` is to support cases where we're outside of task context (.action() returns None)
                progress_helper = self.__process_log.action(title) or fake_meter

            with progress_helper:
                return popen_wait(self, timeout=timeout)

        subprocess.CalledProcessError = CalledProcessError
        subprocess.Popen.__init__ = init
        subprocess.Popen.wait = wait


# noinspection PyPep8Naming
class _subprocess(six.with_metaclass(SubprocessMeta, subprocess)):
    pass
