from concurrent.futures import (
    Future,
    TimeoutError,
)
import json
import logging
from queue import (
    Empty,
    Queue,
)
from threading import Thread
import urllib.parse as url_parse
import uuid

from dataclasses import dataclass
from passport.backend.qa.autotests.base.builders.base import (
    AutotestsBuilderUnexpectedResponseError,
    AutotestsJsonBuilder,
)
from passport.backend.qa.autotests.base.settings.push_api import (
    PUSH_API_SERVICE,
    PUSH_API_TVM_DST_ALIAS,
    PUSH_API_URL,
    PUSH_API_WS_CLIENT_NAME,
    PUSH_API_WS_PUSH_READ_TIMEOUT,
)
from websocket import WebSocketApp


log = logging.getLogger(__name__)


@dataclass
class Sign:
    sign: str
    ts: int


@dataclass
class CloseResult:
    code: int
    msg: str


class WsTimeoutError(Exception):
    """ WS timeout """


class WsSession:
    ws: WebSocketApp

    def __init__(
        self, ws_url: str, uid: int, service: str, client: str, sign: Sign,
    ):
        self.session = str(uuid.uuid4())
        subscribe_base_url = f'{ws_url}/v2/subscribe/websocket'
        subscribe_query = url_parse.urlencode(dict(
            service=service,
            user=uid,
            client=client,
            session=self.session,
            sign=sign.sign,
            ts=sign.ts,
        ))
        subscribe_url = f'{subscribe_base_url}?{subscribe_query}'

        self.thread = Thread(target=self._thread_wrapper, daemon=True)
        self.thread_died = False
        self.thread_exception = None
        self.msg_queue = Queue(maxsize=1024)
        self.open_future = Future()
        self.close_future = Future()

        # Это - колбеки. Но альтернатива, кажется - только asyncio loop,
        # не факт, что лучше =^._.^=
        self.ws = WebSocketApp(
            url=subscribe_url,
            on_open=self._on_open,
            on_close=self._on_close,
            on_message=self._on_message,
            on_error=self._on_error,
        )

    @staticmethod
    def _parse_data(raw_data) -> dict:
        try:
            data = json.loads(raw_data)
            if data.get('message'):
                data['message'] = json.loads(data['message'])
        except json.JSONDecodeError as err:
            raise RuntimeError('Wrong WS data {}: {}'.format(raw_data, err))

        return data

    def _thread_body(self):
        self.ws.run_forever()

    def _thread_wrapper(self):
        try:
            self._thread_body()
        except Exception as err:
            self.thread_exception = err
        finally:
            self.thread_died = True

    def _check_exception(self):
        if self.thread_exception:
            raise self.thread_exception

    def _check_running(self):
        self._check_exception()
        if self.thread_died:
            raise RuntimeError('Websocket died unexpectedly')

    @staticmethod
    def _is_push(message) -> bool:
        return bool(message.get('message'))

    def _on_open(self, *_):
        self.open_future.set_result(None)
        log.debug('Opened WS connection, session {}'.format(self.session))

    def _on_message(self, _, raw_message):
        log.debug('Received push api WS message: {}'.format(raw_message))
        message = self._parse_data(raw_message)
        if self._is_push(message):
            self.msg_queue.put(message)

    def _on_close(self, _, code, msg):
        log.debug('WS connection closed with {} {}'.format(code, msg))
        self.close_future.set_result(CloseResult(code=code, msg=msg))

    def _on_error(self, _, error):
        if isinstance(error, Exception):
            raise error
        else:
            raise RuntimeError('Websocket error: {}'.format(error))

    def recv(self, timeout=PUSH_API_WS_PUSH_READ_TIMEOUT):
        self._check_running()
        try:
            return self.msg_queue.get(block=True, timeout=timeout)
        except Empty:
            self._check_running()
            raise WsTimeoutError('Push message has not been received in {} sec'.format(timeout))

    def start(self):
        self.thread.start()
        try:
            self.open_future.result(timeout=10)
        except TimeoutError as err:
            self._check_exception()
            raise WsTimeoutError('Timeout opening websocket') from err

    def close(self):
        self._check_exception()
        if self.thread_died:
            return
        self.ws.close()

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


class PushApi(AutotestsJsonBuilder):
    def __init__(self):
        super().__init__(
            base_url=PUSH_API_URL,
            tvm_dst_alias=PUSH_API_TVM_DST_ALIAS,
        )
        self.service = PUSH_API_SERVICE
        parsed = url_parse.urlparse(PUSH_API_URL)
        ws_scheme = 'wss' if parsed.scheme == 'https' else 'ws'
        self.ws_url = f'{ws_scheme}://{parsed.hostname}'

    def request(self, method, path, form_params=None, query_params=None, headers=None,
                expected_http_status=200, **kwargs):
        final_query_params = dict(service=self.service)
        final_query_params.update(query_params or {})
        return super().request(
            method=method,
            path=path,
            form_params=form_params,
            query_params=final_query_params,
            headers=headers,
            expected_http_status=expected_http_status,
            **kwargs,
        )

    def get_secret_sign(self, uid: int) -> Sign:
        rv = self.get(
            path='/v2/secret_sign',
            query_params=dict(
                user=uid
            ),
        )
        try:
            sign = Sign(sign=rv['sign'], ts=rv['ts'])
        except KeyError:
            raise AutotestsBuilderUnexpectedResponseError(rv)

        return sign

    def open_ws(self, uid: int) -> WsSession:
        sign = self.get_secret_sign(uid)

        ws = WsSession(
            ws_url=self.ws_url,
            service=self.service,
            uid=uid,
            client=PUSH_API_WS_CLIENT_NAME,
            sign=sign,
        )
        return ws

    def list(self, uid: int) -> list[dict]:
        rv = self.get(
            path='/v2/list',
            query_params=dict(user=uid),
        )
        if not isinstance(rv, list):
            raise self.parser_error_class('Response {!r} is not a list'.format(rv))
        return rv
