import json
import logging
import lomond
from datetime import datetime, timedelta
from threading import Thread
from typing import Any, Callable, Optional, Union

logger = logging.getLogger(__name__)

class WebSocket:
    def __init__(self, url: str):
        self.url = url
        self.web_socket = self.thread = None

    def connect(self, on_close: Callable[[], None], on_message: Callable[[Any], None], on_poll: Callable[[], None], poll: Optional[timedelta], timeout: Optional[timedelta]=None) -> None:
        if self.web_socket:
            raise AssertionError('already connected')
        if not callable(on_close):
            raise AssertionError('on_close not a function')
        if not callable(on_message):
            raise AssertionError('on_message not a function')
        if not callable(on_poll):
            raise AssertionError('on_poll not a function')

        # Create the socket and await a connection.
        timeout = timeout or timedelta(seconds=10)
        web_socket = lomond.WebSocket(self.url)
        g = web_socket.connect(poll=poll and poll.total_seconds(), ping_rate=0, close_timeout=1.0)
        start_time = datetime.now()
        while datetime.now() - start_time < timeout:
            event = next(g, None)
            if event is None:
                break
            elif event.name == 'disconnected':
                raise RuntimeError(event.reason)
            elif event.name == 'ready':
                def run() -> None:
                    for event in g:
                        if event.name == 'poll':
                            on_poll()
                        elif event.name == 'binary':
                            try:
                                on_message(event.data)
                            except:
                                logger.warning('ignoring on_message exception')
                        elif event.name == 'text':
                            try:
                                data = event.json
                            except:
                                data = event.text
                            try:
                                on_message(data)
                            except:
                                logger.warning('ignoring on_message exception')
                    self.web_socket = None
                    on_close()
                thread = Thread(target=run)
                thread.start()
                self.web_socket, self.thread = web_socket, thread
                return
            elif event.name == 'connect_fail':
                raise RuntimeError(event.reason)
        raise RuntimeError('timeout expired')

    def disconnect(self) -> None:
        if not self.web_socket:
            raise AssertionError('not connected')
        self.web_socket.close()
        self.thread = None

    def send(self, data: Union[bytes, str]) -> None:
        if not self.web_socket:
            raise AssertionError('not connected')
        if type(data) is str:
            self.web_socket.send_text(data)
        else:
            self.web_socket.send_binary(data)

if __name__ == '__main__':
    import sys
    send_delay = timedelta(milliseconds=int(sys.argv[1]))
    url = sys.argv[2]
    timeout = timedelta(seconds=30)
    web_socket = WebSocket(url)
    n = 0
    def on_close():
        print('close')
    def on_message(data):
        print('', type(data), data, sep='\n')
    def on_poll():
        global n
        print('\rpoll', n, end=' ')
        n += 1
    web_socket.connect(on_close, on_message, on_poll, poll=send_delay, timeout=timeout)
    input('\t\tsend? ')
    web_socket.send('hello')
    input('\t\tdisconnect? ')
    web_socket.disconnect()
