import itertools as it
import json
import logging
import m3u8
import operator as op
import random
import requests
import threading
from datetime import datetime, timedelta, timezone
from time import sleep
from typing import Callable, List, Optional, Tuple
from urllib.parse import urljoin
try:
    from .utilities import desired_resolution
except ImportError:
    from utilities import desired_resolution

CompletionFn = Callable[[], None]
logger = logging.getLogger(__name__)
video_api_client_id = 'kimne78kx3ncx6brgo4mv6wki5h1ko'

# https://johannesbader.ch/2014/01/find-video-url-of-twitch-tv-live-streams-or-past-broadcasts/
LIVE_TOKEN_API_URL_FORMAT = 'http://api.twitch.tv/api/channels/{channel_login}/access_token'
LIVE_USHER_API_URL_FORMAT = ('http://usher.twitch.tv/api/channel/hls/{channel_login}.m3u8' +
                             '?allow_source=true' +
                             '&p={random_value}' +
                             '&player=twitchweb' +
                             '&sig={sig}' +
                             '&token={token}' +
                             '&type=any')
VOD_TOKEN_API_URL_FORMAT = ('https://api.twitch.tv/api/vods/{vod_id}/access_token' +
                            '?need_https=true' +
                            '&platform=web' +
                            '&player_backend=mediaplayer' +
                            '&player_type=site')
VOD_USHER_API_URL_FORMAT = ('https://usher.ttvnw.net/vod/{vod_id}.m3u8' +
                            '?allow_source=true' +
                            '&p={random_value}' +
                            '&player_backend=mediaplayer' +
                            '&playlist_include_framerate=true' +
                            '&reassignments_supported=true' +
                            '&sig={sig}' +
                            '&supported_codecs=vp09,avc1' +
                            '&token={token}')
maximum_try_count = 3

class Segment:
    def __init__(self, timestamp: datetime, duration: timedelta, url: str):
        self.__duration = duration
        self.__timestamp = timestamp
        self.__url = url

    def __str__(self) -> str:
        return f'{self.__timestamp} ({self.__duration}s)\n{self.__url}'

    @property
    def duration(self) -> timedelta:
        return self.__duration

    @property
    def timestamp(self) -> datetime:
        return self.__timestamp

    @property
    def url(self) -> str:
        return self.__url

ObserverFn = Callable[[Segment], bool]

class Stream:
    def __init__(self, playlist: m3u8.model.Playlist):
        self.__bandwidth = playlist.stream_info.bandwidth # bits per second
        self.__last_segment_timestamp = now() - timedelta(days=99999)
        self.__lock = threading.Lock()
        self.__quality = playlist.media[0].name
        self.__resolution = playlist.stream_info.resolution # (width, height) tuple
        self.__session = requests.Session()
        self.__thread = None
        self.__url = playlist.absolute_uri

    def __str__(self) -> str:
        text = f'{self.__bandwidth / 1024:.0f} Kb/s ({self.__quality}), resolution={self.__resolution if self.__resolution else "?"}'
        return f'{text}\n{len(text) * "-"}\n{self.__url}'

    @property
    def resolution(self) -> Tuple[int, int]:
        return self.__resolution

    def start(self, observer_fn: ObserverFn, completion_fn: CompletionFn) -> None:
        if self.__thread:
            raise RuntimeError('stream already started')
        def fn():
            delay = 0
            while delay is not None and not self.__lock.acquire(timeout=delay):
                delay = self.process(observer_fn) if self.__thread else None
            if delay is not None:
                self.__lock.release()
            completion_fn()
        self.__lock.acquire()
        self.__thread = threading.Thread(target=fn, daemon=True)
        self.__thread.start()

    def __compose_url(self, segment: m3u8.model.Segment) -> str:
        return urljoin(self.__url, segment.uri) if segment.base_uri is None else segment.absolute_uri

    def process(self, observer_fn: ObserverFn) -> Optional[float]:
        delay = 0
        start_time = now()

        # Fetch the list of segments.
        for try_count in range(1, maximum_try_count + 1):
            try:
                with self.__session.get(self.__url) as r:
                    m = m3u8.loads(r.text)
                break
            except Exception as ex:
                if try_count == maximum_try_count:
                    raise
                logger.warning('exception "%s" occurred; retrying after %dms', type(ex).__name__, 200 * try_count)
                sleep(0.2 * try_count)
        segments = [Segment(s.current_program_date_time, timedelta(seconds=s.duration), self.__compose_url(s)) for s in m.segments]
        if not segments:
            logger.debug('no segments for %s', self.__url)
            return

        # Fire an event for each new segment.
        g = (s for s in segments if s.timestamp is None or s.timestamp > self.__last_segment_timestamp)
        for segment in g:
            if observer_fn(segment):
                # The observer is requesting to end the stream.
                return
            self.__last_segment_timestamp = segment.timestamp
            duration = now() - start_time
            delay = max(0.0, (segment.duration - duration).total_seconds())

        # Return the delay before invoking this method again if this is not the
        # last list of segments.
        if m.is_endlist:
            logger.debug('is last segment list for %s', self.__url)
        else:
            return delay

    def stop(self) -> None:
        self.close()
        if self.__thread:
            self.__lock.release()
            self.__thread = None

    def close(self) -> None:
        pass

def fetch_live_stream(client_id: str, channel_login: str, vertical_resolution: int, Stream=Stream, wants_exact=False) -> Tuple[str, Optional[Stream]]:
    channel_id, streams = fetch_live_streams(client_id, channel_login, Stream=Stream)
    is_acceptable = op.eq if wants_exact else op.ge
    g = (s for s in streams if is_acceptable(s.resolution[1], vertical_resolution))
    return channel_id, min(g, key=lambda s: s.resolution[1], default=None)

def fetch_live_streams(client_id: str, channel_login: str, Stream=Stream) -> Tuple[str, List[Stream]]:
    # Get the signature and token for the channel.
    url = LIVE_TOKEN_API_URL_FORMAT.format(channel_login=channel_login)
    headers = {'client-id': client_id}
    with requests.get(url, headers=headers) as r:
        data = r.json()
    error = data.get('error')
    if error is not None:
        raise ValueError(f'{error}:  {data.get("message", "")}')
    sig = data['sig']
    token = data['token']

    # Get the channel ID from the token.
    channel_id = json.loads(token)['channel_id']

    # Get the streams for the channel.
    random_value = random.randrange(1e5, 1e7)
    url = LIVE_USHER_API_URL_FORMAT.format(channel_login=channel_login, sig=sig, token=token, random_value=random_value)
    with requests.get(url) as r:
        m = m3u8.loads(r.text)
    return channel_id, [Stream(p) for p in m.playlists]

def fetch_vod(client_id: str, vod_id: str, vertical_resolution: int, Stream=Stream) -> Optional[Stream]:
    streams = fetch_vods(client_id, vod_id, Stream=Stream)
    g = (s for s in streams if s.resolution[1] >= vertical_resolution)
    return min(g, key=lambda s: s.resolution[1], default=None)

def fetch_vods(client_id: str, vod_id: str, Stream=Stream) -> List[Stream]:
    # Get the signature and token for the channel.
    url = VOD_TOKEN_API_URL_FORMAT.format(vod_id=vod_id)
    headers = {'client-id': client_id}
    with requests.get(url, headers=headers) as r:
        data = r.json()
    error = data.get('error')
    if error is not None:
        raise ValueError(f'{error}:  {data.get("message", "")}')
    sig = data['sig']
    token = data['token']

    # Get the streams for the channel.
    random_value = random.randrange(1e5, 1e7)
    url = VOD_USHER_API_URL_FORMAT.format(vod_id=vod_id, random_value=random_value, sig=sig, token=token)
    with requests.get(url, headers=headers) as r:
        m = m3u8.loads(r.text)
    return [Stream(p) for p in m.playlists]

def now() -> datetime:
    return datetime.now(tz=timezone(timedelta(0)))

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser('Get stream parameters of a twitch channel.')
    parser.add_argument('-t', '--thread', help='run stream on a different thread')
    parser.add_argument('client_id')
    parser.add_argument('channel_login')
    args = parser.parse_args()
    channel_id, streams = fetch_live_streams(args.client_id, args.channel_login)
    if streams:
        print(f'Streams for "{args.channel_login}" ({channel_id}) ({len(streams)} sorted by descending quality):')
        print(*streams, sep='\n\n')
        print()
        stream = next(it.dropwhile(lambda s: s.resolution[1] > desired_resolution, streams))
        s = 'Details of the {}x{} resolution stream'.format(*stream.resolution)
        print(s)
        print('-' * len(s))
        if args.thread:
            print('Press Enter to exit.')
            stream.start(print, lambda: print('stream complete'))
            input()
            stream.stop()
        else:
            print('Press ^C to exit.')
            delay = 0
            try:
                while delay is not None:
                    if delay > 0:
                        sleep(delay)
                    delay = stream.process()
            except KeyboardInterrupt:
                print('stopping stream')
            stream.close()
            print('stream complete')
    else:
        print(f'No streams for "{args.channel_login}" ({channel_id})')
