import e2
import itertools as it
import json
import logging
import os
from datetime import datetime
from numpy import ndarray
from typing import Any, Callable, Optional, Tuple

ResultFn = Callable[[str, str, Any, float], None]
ProcessFn = Callable[[ndarray], Optional[Tuple[str, Any, float, str]]]
CloseFn = Callable[[], None]

game_id = '511224'
default_minimum_confidence = .9

logger = logging.getLogger(__name__)

class StateData:
    def __init__(self, channel_id, e2_access_token, legend_names, is_debugging=False):
        self.__data_source = e2.DataSource()
        self.__e2_access_token = e2_access_token
        self.__legend = None
        self.__was_champion = False
        self.__champion_counts = {k: 0 for k in legend_names}
        self.__legend_field = 'legend'
        self.__scene_field = 'scene'
        self.__champion_counts_field = 'champion_counts'

        # Connect the E2 DataSource object.
        environment = 'dev' if is_debugging else 'prod'
        e2_data = {
            self.__legend_field: self.__legend,
            self.__scene_field: 'unknown',
            self.__champion_counts_field: self.__champion_counts,
        }
        def on_token_expired(*_):
            logger.error('invalid token')
        if is_debugging:
            def on_debug(*args):
                logger.debug(' '.join(map(str, args)))
        self.__data_source.connect(e2_access_token, game_id, environment, on_token_expired, e2_data, [channel_id], debug_fn=is_debugging and on_debug)

    def close(self):
        # Disconnect the E2 DataSource object.
        if self.__data_source:
            self.__data_source.disconnect()
            self.__data_source = None

    @property
    def legend(self):
        pass

    @legend.setter
    def legend(self, value):
        self.__data_source.update_field(self.__scene_field, 'game')
        self.__data_source.update_field(self.__legend_field, value)
        self.__was_champion = False
        self.__legend = value

    @property
    def is_idle(self):
        pass

    @is_idle.setter
    def is_idle(self, value):
        self.__data_source.update_field(self.__scene_field, 'lobby' if value else 'game')
        if value:
            self.__data_source.update_field(self.__legend_field, None)
            self.__was_champion = False
            self.__legend = None

    @property
    def is_champion(self):
        pass

    @is_champion.setter
    def is_champion(self, value):
        if value and not self.__was_champion:
            self.__was_champion = True
            if self.__legend:
                self.__champion_counts[self.__legend] += 1
            self.__champion_counts['total'] += 1
            self.__data_source.update_field(self.__scene_field, 'game')
            self.__data_source.update_field(self.__champion_counts_field, self.__champion_counts)

    def add_timestamp(self):
        '''Add a timestamp to prevent E2 disconnection.'''
        self.__data_source.update_field('now', datetime.now().strftime('%F %R'))

class Strategy:
    def __init__(self, Detector, *, e2_access_token: Optional[str]=None, minimum_confidence: Optional[float]=None):
        # Read the strategy parameters file, if available.
        parameters_file_path = os.path.splitext(__file__)[0] + '.json'
        try:
            with open(parameters_file_path) as fin:
                parameters = json.load(fin)
            self.__e2_access_token = e2_access_token or parameters['e2_access_token']
            self.__is_debugging = parameters.get('is_debugging')
            if minimum_confidence is None:
                minimum_confidence = parameters.get('minimum_confidence')
        except FileNotFoundError:
            self.__e2_access_token = e2_access_token
            self.__is_debugging = None
        self.__minimum_confidence = default_minimum_confidence if minimum_confidence is None else minimum_confidence

        # Read the detector parameters files and create the models.
        data_directory_path = os.path.dirname(__file__)
        model_names = ('legends', 'idle', 'champion')
        g = ((s, os.path.join(data_directory_path, s)) for s in model_names)
        g = ((s, r, Detector.load_parameters(r + '.json')) for s, r in g)
        d = {s: (p, Detector.Model(p, r + '.model')) for s, r, p in g}

        # Get the Legend names for the Champion counts for E2.
        self.__legend_names = list(d['legends'][0]['outputs'].values())
        self.__legend_names.append('total')

        # Create the detectors.
        self.__detectors = {
            'legend': Detector(*d['legends']),
            'is_idle': Detector(*d['idle']),
            'is_champion': Detector(*d['champion']),
        }
        message = '\n'.join('{}: {}'.format(k, v) for k, v in self.__detectors.items())
        logger.debug(message)

    def open(self, channel_name: str, channel_id: str, is_debugging: Optional[bool]=None, on_result: ResultFn=None) -> Tuple[ProcessFn, CloseFn]:
        # Update the debug state.
        self.__is_debugging = self.__is_debugging or is_debugging or False
        logger.info('opening %s (%d)%s', channel_name, channel_id, ' with debugging' if self.__is_debugging else '')

        # Create the state data.
        state_data = StateData(channel_id, self.__e2_access_token, self.__legend_names, self.__is_debugging)

        # Create functions for each detector in a dictionary.
        def fn(detector_attribute_name):
            detect = self.__detectors[detector_attribute_name].detect
            def run_detector(frame):
                result, confidence = detect(frame)
                if on_result:
                    on_result(channel_name, detector_attribute_name, result, confidence)
                if confidence >= self.__minimum_confidence:
                    setattr(state_data, detector_attribute_name, result)
                    if result:
                        return (detector_attribute_name, result)
            return detector_attribute_name, run_detector
        fn_dict = dict(map(fn, self.__detectors))

        # Access them in a particular order.
        detector_names = sorted(self.__detectors, reverse=True)

        # Create a function to process a frame.
        def process_frame(frame):
            # Invoke the detectors on the frame.
            state_data.add_timestamp()
            g = (fn_dict[detector_name](frame) for detector_name in detector_names)
            g = it.dropwhile(lambda t: not t, g)
            detector_name, result = next(g, (None, None))
            if result:
                detector_names.sort(key=lambda s: (s == detector_name, s), reverse=True)
                return True
            detector_names.sort(reverse=True)

        # Return functions to process a frame and close the state data.
        return process_frame, state_data.close
