""" Native logbroker runner """
import json
import logging
import math
import multiprocessing
from multiprocessing import Process
import os
import queue
import signal
import socket
import stat
from threading import Thread
from time import (
    sleep,
    time,
)
from typing import Generator

from passport.backend.logbroker_client.core.consumers.simple.native_worker import NativeLogbrokerWorker
from passport.backend.logbroker_client.core.native_runner.config import (
    Config,
    parse_config,
)


log = logging.getLogger('logbroker')


DEATH_MESSAGES_THRESHOLD = 3
SIGKILL_THRESHOLD = 30
HEARTBEAT_TIMEOUT = 30.0


class HeartbeatThread(Thread):
    def __init__(self, processes, socket_path):
        super().__init__()
        self.processes = processes
        self.socket_path = socket_path
        self.sock = None

    def _return_error(self, connection, message):
        data = json.dumps(dict(error=message))
        connection.sendall(data.encode())

    def _return_result(self, connection, message):
        data = json.dumps(dict(result=message))
        connection.sendall(data.encode())

    def process_connection(self, connection):
        data = connection.recv(len(b'HEARTBEAT'))
        if data == b'HEARTBEAT':
            log.debug('Heartbeat request to master')
            for process in self.processes:
                process.heartbeat_req_queue.put('1')
            heartbeats = dict()
            start_time = time()
            for process in self.processes:
                timeout_left = HEARTBEAT_TIMEOUT - (time() - start_time)
                try:
                    heartbeats[process.pid] = process.heartbeat_resp_queue.get(timeout_left)
                except queue.Empty:
                    self._return_error(
                        connection,
                        'Heartbeat request from timed out on {}'.format(process.pid),
                    )
            self._return_result(connection, heartbeats)
        else:
            log.warning('Wrong data received from heartbeat socket: {}'.format(data))

    def bind_socket(self):
        if (
            os.path.exists(self.socket_path) and
            stat.S_ISSOCK(os.stat(self.socket_path).st_mode)
        ):
            os.unlink(self.socket_path)
        self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self.sock.bind(self.socket_path)
        self.sock.settimeout(1)
        self.sock.listen()
        while True:
            try:
                connection, _ = self.sock.accept()
            except socket.timeout:
                continue
            try:
                self.process_connection(connection)
            except Exception:
                log.exception('Exception in heartbeat connection')

    def run(self):
        while True:
            try:
                self.bind_socket()
            except Exception:
                log.exception('Exception in heartbeat server')
                sleep(5)
            finally:
                try:
                    self.sock.close()
                    os.unlink(self.socket_path)
                except Exception:
                    pass


class WorkerProcess(Process):
    def __init__(self, *args, **kwargs):
        super(WorkerProcess, self).__init__(*args, **kwargs)
        self.heartbeat_req_queue = multiprocessing.Queue()
        self.heartbeat_resp_queue = multiprocessing.Queue()
        self._kwargs.update(
            heartbeat_req_queue=self.heartbeat_req_queue,
            heartbeat_resp_queue=self.heartbeat_resp_queue,
        )


class NativeRunner:
    targets_config: list[Config]
    processes: list[Process]

    def __init__(self, config: dict):
        self._general_config = config
        self.targets_config = parse_config(config['lbc'])
        self.processes = []
        self._interrupted = False

    def create_worker(self, config: Config, task: dict, worker_args: dict):
        def run_worker(heartbeat_req_queue, heartbeat_resp_queue):
            try:
                NativeLogbrokerWorker.current_task = task
                NativeLogbrokerWorker.current_config = config
                worker = NativeLogbrokerWorker(bind_signals=True, **worker_args)
                worker.run_task(
                    task,
                    heartbeat_req_queue=heartbeat_req_queue,
                    heartbeat_resp_queue=heartbeat_resp_queue,
                )
            except Exception:
                log.exception('Unhandled worker exception')
                raise

        process = WorkerProcess(target=run_worker, daemon=False)
        self.processes.append(process)
        process.start()
        log.info('Started worker pid={}'.format(process.pid))

    def generate_worker_tasks(self) -> Generator[tuple[Config, dict], None, None]:
        # TODO: научить воркеры работать с датаклассами после переезда на py3
        for n_target, config in enumerate(self.targets_config, 1):
            log.info('Starting {} workers for target {}'.format(
                config.partitions, config,
            ))
            for n_worker in range(1, config.partitions + 1):
                worker_task = dict(
                    host=config.host,
                    port=config.data_port,
                    ca_cert=config.ca_cert,
                    client_id=config.client_id,
                    credentials_config=config.credentials_config,
                    topic=config.topic,
                    decompress=config.decompress,
                    use_client_locks=config.fixed_partitions,
                    connect_timeout=config.connect_timeout,
                    max_count=config.max_count,
                )
                if config.fixed_partitions:
                    worker_task.update(partition_group=n_worker)
                yield config, worker_task

    def spawn_workers(self):
        for config, task in self.generate_worker_tasks():
            self.create_worker(
                config,
                task,
                dict(
                    handler={
                        'class': config.handler_class,
                        'args': config.handler_args,
                    },
                    config=self._general_config,
                ),
            )

    def spawn_heartbeat_thread(self):
        stat_socket = self._general_config.get('stat_socket')
        if stat_socket:
            thread = HeartbeatThread(self.processes, self._general_config['stat_socket'])
            thread.daemon = True
            thread.start()

    def run(self):
        log.critical('Starting lbc')
        self.spawn_workers()
        try:
            self.spawn_heartbeat_thread()
        except Exception:
            log.exception('Exception starting heartbeat thread')
        self.bind_signals()
        try:
            self.supervise_processes()
        finally:
            sleep(0.5)
            self.stop_processes()

    def bind_signals(self):
        log.info('Setup signal handlers')
        for signum in [signal.SIGTERM, signal.SIGINT]:
            signal.signal(signum, self.handle_signal)

    def handle_signal(self, signum, _frame):
        if signum == signal.SIGTERM or signum == signal.SIGINT:
            log.critical('{} ({}). Terminating.'.format(
                signal.strsignal(signum), signum,
            ))
            self._interrupted = True

    def _pids(self, dead=False) -> list[int]:
        return [p.pid for p in self.processes if p.is_alive() != dead]

    def start_processes(self):
        for process in self.processes:
            process.start()

    def supervise_processes(self):
        while True:
            if self._interrupted:
                log.critical('Interrupted by signal. Exiting')
                return
            dead = self._pids(dead=True)
            if dead:
                log.critical(
                    'Workers died unexpectedly: {} '
                    'Terminating.'.format(dead),
                )
                return
            sleep(1)

    @staticmethod
    def send_signal(pids: list[int], signum: int):
        for pid in pids:
            try:
                os.kill(pid, signum)
            except OSError:
                pass

    def wait_for_sigterm_death(self) -> bool:
        alive = self._pids()
        if not alive:
            log.info('All workers are stopped')
            return True
        log.info('Sending SIGINT and waiting for workers to die: {}'.format(
            alive,
        ))
        self.send_signal(alive, signal.SIGINT)
        term_sent = last_death_message = time()

        while True:
            alive = self._pids()
            if not alive:
                log.info('All workers have stopped')
                return True
            until_kill = SIGKILL_THRESHOLD - (time() - term_sent)
            if until_kill <= 0:
                return False
            if time() - last_death_message >= DEATH_MESSAGES_THRESHOLD:
                log.info('Waiting for pids (TERM, time left: {}): {}'.format(
                    math.floor(until_kill), alive,
                ))
                last_death_message = time()

    def wait_for_sigkill_death(self):
        alive = self._pids()
        if not alive:
            log.info('All workers have stopped')
            return
        log.info('Sending SIGKILL to {}'.format(alive))
        self.send_signal(alive, signal.SIGTERM)
        last_death_message = time()

        while True:
            alive = self._pids()
            if not alive:
                log.info('All workers have been killed')
                return
            if time() - last_death_message >= DEATH_MESSAGES_THRESHOLD:
                log.info('Waiting for pids (KILL): {}'.format(alive))
                last_death_message = time()

    def stop_processes(self):
        if not self.wait_for_sigterm_death():
            self.wait_for_sigkill_death()

        log.critical('Terminated')
