import json
import logging
from datetime import datetime, timedelta
from itertools import dropwhile
from threading import Condition
from time import sleep
from typing import Any, Callable, Iterable, List, Optional
try:
    from .messaging import MessageQueue, MessageType, QueueMessage, is_too_large
    from .utilities import create_connect_message, generate_session_id, get_segment, validate_data, validate_path
    from .web_socket import WebSocket
except ImportError:
    from messaging import MessageQueue, MessageType, QueueMessage, is_too_large
    from utilities import create_connect_message, generate_session_id, get_segment, validate_data, validate_path
    from web_socket import WebSocket

logger = logging.getLogger(__name__)

TokenRefreshFn = Callable[[str], None]
TokenExpiredFn = Callable[[TokenRefreshFn], bool]

maximum_connect_message_size = 100000
send_delay = timedelta(seconds=1)
url = 'wss://metadata.twitch.tv/api/ingest'

class DataSource:
    def __init__(self):
        self.__current_data = self.__web_socket = self.__session_id = self.__token = self.__create_connect_message_fn = None
        self.__connect_fn = self.__open_fn = self.__on_token_expired = self.__last_send_time = None
        self.__queue = MessageQueue()

    def connect(self, *, token: str, game_id: str, environment: str, on_token_expired: TokenExpiredFn, initial_data: dict, broadcaster_ids: Optional[Iterable[str]]=None, session_id: Optional[str]=None, timeout=timedelta(seconds=10), is_debug: Optional[bool]=None) -> str:
        if self.__current_data:
            raise AssertionError('already connected')

        # Validate the configuration.
        if not callable(on_token_expired):
            raise TypeError('on_token_expired is not callable')
        data = validate_data(initial_data)
        self.__on_token_expired = on_token_expired
        self.__session_id = generate_session_id() if session_id is None else str(session_id)
        self.__token = str(token)

        # Compose the connect function.
        broadcaster_ids = broadcaster_ids and [str(v) for v in broadcaster_ids]
        game_id = str(game_id)
        environment = str(environment)
        is_debug = is_debug or False
        def create_connect_message_fn(data: dict) -> dict:
            return create_connect_message(self.__session_id, self.__token, broadcaster_ids, game_id, environment, is_debug, data)
        self.__create_connect_message_fn = create_connect_message_fn
        connect_message_size = len(json.dumps(self.__create_connect_message_fn(data)))
        if connect_message_size > maximum_connect_message_size:
            raise RuntimeError('initial data object is too large')
        def connect_fn(data: dict) -> Optional[Exception]:
            try:
                # Send the "Connect" message.
                self.__send(self.__create_connect_message_fn(data))
            except Exception as ex:
                return ex
        self.__connect_fn = connect_fn

        # Compose the open function.
        def open_fn() -> WebSocket:
            def on_close() -> None:
                if self.__current_data is not None:
                    # Reconnect.
                    self.__queue.enqueue(QueueMessage.make_reconnect())
                    self.__send_message()
            def on_message(data: Any) -> None:
                try:
                    if type(data) is dict and len(data.keys()) == 1:
                        key = next(iter(data.keys()))
                        if key == 'connected':
                            if data['connected']:
                                return
                        elif key == 'error':
                            error = data['error']
                            if type(error) is dict:
                                code = error['code']
                                if code == 'invalid_connect_token':
                                    # The server has rejected the authorization token.
                                    self.__queue.enqueue(QueueMessage.make_reauthorize(clearing_token=True))
                                    return
                                elif code == 'connection_not_authed':
                                    # Ignore this since it means we're in the middle of getting connected.
                                    return
                                elif code == 'waiting_on_refresh_message':
                                    # Ignore this since it means we're in the middle of refreshing.
                                    return
                        elif key == 'reconnect':
                            reconnect_delay = data['reconnect']
                            if type(reconnect_delay) in [int, float]:
                                # The server requested a reconnection.  Close the socket.  The reconnection
                                # process will open a new one.
                                self.__queue.replace_with(QueueMessage.make_reconnect(timedelta(milliseconds=reconnect_delay)))
                                self.__close()
                                return
                    logger.error('[DataSource.on_message] unexpected response from server:  "%s"', data)
                except Exception:
                    logger.exception('[DataSource.on_message]')
            # Create the WebSocket and await a connection.
            web_socket = WebSocket(url)
            web_socket.connect(on_close, on_message, self.__send_message, poll=send_delay, timeout=timeout)
            return web_socket

        # Open the WebSocket.
        self.__web_socket = open_fn()
        self.__open_fn = open_fn

        # Enqueue a "Reauthorize" message without clearing the token.
        self.__queue.enqueue(QueueMessage.make_reauthorize(clearing_token=False))

        # Set the current data.
        self.__current_data = data

        return self.__session_id

    def __close(self) -> None:
        self.__web_socket.disconnect()

    def disconnect(self, is_aborting: bool=False) -> None:
        if not is_aborting and self.__last_send_time:
            while self.__queue:
                # Send remaining data after waiting based on last_send_time.
                next_send_time = self.__last_send_time + send_delay
                final_delay = (next_send_time - datetime.now()).total_seconds()
                if final_delay > 0:
                    sleep(final_delay)
                self.__send_message()
        self.__current_data = None
        self.__close()

    def __reauthorize(self) -> bool:
        if not self.__token:
            cv = Condition()
            def fn(token: str) -> None:
                self.__token = token or 'no-token'
                with cv:
                    cv.notify()
            if self.__on_token_expired(fn):
                with cv:
                    cv.wait_for(lambda: self.__token)
        if self.__token:
            error = self.__connect_fn(self.__current_data)
            if not error:
                return True
            logger.error('[DataSource.reauthorize] connection error "%s"', error)
        self.disconnect(True)
        return False

    def __send(self, message) -> None:
        self.__web_socket.send(json.dumps(message))
        self.__last_send_time = datetime.now()

    def __validate_connection(self) -> None:
        # Ensure connect has already been invoked.
        if self.__current_data is None:
            raise AssertionError('connection not established')

    def __send_message(self) -> None:
        try:
            unique_message = self.__queue.find_unique_message()
            if unique_message:
                # There is a unique message.  Clear the queue and process the message.
                self.__queue.clear()
                if unique_message.type == MessageType.Reauthorize:
                    if unique_message.value:
                        self.__token = ''
                    self.__reauthorize()
                elif unique_message.type == MessageType.Reconnect:
                    # Await the requested delay.
                    reconnect_delay = (unique_message.value - datetime.now()).total_seconds()
                    if reconnect_delay > 0:
                        sleep(reconnect_delay)

                    if self.__current_data:
                        # Open a new connection.
                        self.__web_socket = self.__open_fn()

                        # Perform the "Reauthorize" action above.
                        if self.__current_data:
                            self.__reauthorize()
                elif unique_message.type == MessageType.Refresh:
                    self.__send({ 'refresh': { 'data': self.__current_data } })
                else:
                    raise RuntimeError(f'unexpected message type "{unique_message.type}"')
            else:
                # Trim the queue based on repeated modifications.
                trimmed_queue = []
                while self.__queue:
                    message = self.__queue.dequeue()
                    if message.type == MessageType.Append:
                        current_message = next(dropwhile(lambda m: m.type != message.type or m.path != message.path, trimmed_queue), None)
                        if current_message:
                            current_message.value.extend(message.value)
                        else:
                            trimmed_queue.append(message)
                    elif message.type in [MessageType.Remove, MessageType.Update]:
                        trimmed_queue[:] = (m for m in trimmed_queue if m.path != message.path)
                        trimmed_queue.append(message)
                    else:
                        raise RuntimeError(f'unexpected message type "{message.type}"')
                self.__queue.replace_with(*trimmed_queue)

                # Take messages off of the queue until reaching the maximum payload size.
                deltas = []
                while self.__queue:
                    message = self.__queue.peek()
                    delta = message.as_delta()
                    if is_too_large([*deltas, delta]):
                        break
                    deltas.append(delta)
                    self.__queue.dequeue()
                if deltas:
                    self.__send({ 'delta': deltas })
        except Exception as ex:
            logger.exception('[DataSource.send_message]')

    def __check_connect_message_size(self, data: dict) -> None:
        connect_message = self.__create_connect_message_fn(data)
        if len(json.dumps(connect_message)) > maximum_connect_message_size:
            raise RuntimeError('data object is too large')

    def append_to_list_field(self, path: str, values: Iterable) -> None:
        # Validate the connection, path, and values.
        self.__validate_connection()
        path = validate_path(path)
        values = json.loads(json.dumps(values))
        if not isinstance(values, List):
            raise AssertionError('values is not a list')

        # Update the current data.  Enqueue the "Append" message only if the
        # length of the field will increase.
        if len(values):
            new_data = json.loads(json.dumps(self.__current_data))
            segment = get_segment(new_data, path)
            if not segment or segment['field'] not in segment['parent']:
                raise AssertionError(f'"{path}" does not specify a known field')
            field = segment['parent'][segment['field']]
            if type(field) is not list:
                raise AssertionError(f'"{path}" does not specify a list field')
            field.extend(values)
            self.__check_connect_message_size(new_data)
            self.__queue.enqueue(QueueMessage(MessageType.Append, path, values))
            self.__current_data = new_data
        else:
            logger.warning('[DataSource.append_to_list_field] values is empty; ignoring path "%s"', path)

    def remove_field(self, path: str) -> None:
        # Validate the connection and path.
        self.__validate_connection()
        path = validate_path(path)
        if path.endswith(']'):
            raise AssertionError(f'"{path}" does not specify a field')
        segment = get_segment(self.__current_data, path)
        if not segment or segment['field'] not in segment['parent']:
            logger.warning('[DataSource.remove_field] ignoring removal of an unknown field "%s"', path)
            return

        # Enqueue the "Remove Field" message.
        self.__queue.enqueue(QueueMessage(MessageType.Remove, path))

        # Update the current data.
        parent = segment['parent']
        field = segment['field']
        del parent[field]

    def update_field(self, path: str, value: Any) -> None:
        # Validate the connection, path, and value.
        self.__validate_connection()
        path = validate_path(path)
        value = json.loads(json.dumps(value))

        # If updating the _metadata field, ensure it is valid.
        if path == '_metadata':
            validate_data({ path: value })

        # Verify the path results in a value.
        new_data = json.loads(json.dumps(self.__current_data))
        segment = get_segment(new_data, path)
        if not segment:
            raise AssertionError(f'"{path}" does not specify a known field')
        elif type(segment['parent']) is list and segment['field'] >= len(segment['parent']):
            raise AssertionError(f'"{path}" is out of bounds')
        if segment['field'] in segment['parent'] and segment['parent'][segment['field']] == value:
            # Skip update if the value did not change.
            return
        segment['parent'][segment['field']] = value
        self.__check_connect_message_size(new_data)

        # Enqueue the "Update" message.
        self.__queue.enqueue(QueueMessage(MessageType.Update, path, value))

        # Update the current data.
        self.__current_data = new_data

if __name__ == '__main__':
    import sys
    import tests
    send_delay = timedelta(milliseconds=int(sys.argv[1]))
    url = sys.argv[2]
    tests._configure(send_delay, url)
    g = ((s, getattr(tests, s)) for s in dir(tests) if '_' in s and not s.startswith('_'))
    g = (t for t in g if callable(t[1]))
    for name, fn in g:
        print(name.replace('_', ' '))
        fn()
    print('done')
