from .publisher import Publisher
from classes import logger as log
from classes import device_factory
from classes.branches.directsound import DirectSound
from classes.branches.wasapi import Wasapi
from classes.branches.audio_volume import AudioVolume
from classes.branches.audio_volume_level import AudioVolumeLevel
from classes.error import PublisherDeviceNotFoundError


class AudioDevice(Publisher):

    def __init__(self, publisher_payload, session):
        super(AudioDevice, self).__init__(publisher_payload, session)
        self.volume_branch = None
        self.volume_raw_level_branch = None
        self.volume_post_filter_level_branch = None
        self.force_mono = publisher_payload.get("data", {}).get("mono", False)

    def get_running_sink_branch(self):
        return self.session.audio_mixer

    def get_running_sink_pad(self):
        return self.session.audio_mixer.get_sink_pad(self.publisher_id)

    def get_values(self, publisher_payload):
        values = super(AudioDevice, self).get_values(publisher_payload)
        values["volume"] = publisher_payload["volume"]
        values["visible"] = publisher_payload["visible"]
        values["force_mono"] = publisher_payload.get("data", {}).get("mono", False)
        return values

    def link_to_running_sink(self, branch):
        log.debug("linking %s to %s %s" % (branch, self.session.audio_mixer, self.publisher_id))
        self.session.audio_mixer.add_input_branch(branch, self.publisher_id)

    def update_running(self):
        super(AudioDevice, self).update_running()
        if self.volume_branch:
            self.volume_branch.update(self.volume, self.visible)

    def can_update(self, publisher_payload, **kwargs):
        super_can_update = super(AudioDevice, self).can_update(publisher_payload, **kwargs)
        if self.publisher_payload["device_id"] != publisher_payload["device_id"]:
            return False
        if self.publisher_payload.get("data", {}).get("mono", False) != publisher_payload.get("data", {}).get("mono", False):
            return False
        if self.audio_src.need_update():
            return False
        return super_can_update

    def can_pause(self):
        return False

    def make_src_branches(self):
        device = device_factory.get_audio_device(self.device_id)
        if not device:
            raise PublisherDeviceNotFoundError(self.publisher_payload['widget_id'], "Failed to find audio device. device_id: {}".format(self.device_id))

        # channels only works for type: WASAPI (mic) because DirectSound returns a gaint list of caps and we dont know which one is being used.
        # but, for WASAPI we're only getting one capability back from the audio device's GetMixFormat
        channels = 0
        samplerate = 0
        if len(device.caps):
            channels = device.caps[0].get('channels')
            samplerate = device.caps[0].get('sample_rate')

        is_mono = (channels == 1 or self.force_mono)

        log.debug('audio_device: %s, channels: %d, is_mono: %d', device, channels, is_mono)

        self.audio_src = None
        if device.type == 'wasapi':
            self.audio_src = Wasapi(self.session.bus, self.session.session_id, device.display_name, device.device_path, channels, samplerate, is_mono, self.session.refresh_session)
        elif device.type == 'directsound':
            self.audio_src = DirectSound(self.session.bus, self.session.session_id, device.display_name, device.device_path, self.widget_id, is_mono)
        return [self.audio_src]

    def make_running_branches(self):
        self.volume_branch = AudioVolume(self.session.bus, self.session.session_id, self.widget_id)
        self.volume_branch.update(self.volume, self.visible)

        self.volume_raw_level_branch = AudioVolumeLevel(self.session.bus, self.session.session_id, self.widget_id, 'audiolevel_raw')
        return [self.volume_branch, self.volume_raw_level_branch]

    def make_post_filter_branches(self):
        self.volume_post_filter_level_branch = AudioVolumeLevel(self.session.bus, self.session.session_id, self.widget_id, 'audiolevel_postprocessed')
        return [self.volume_post_filter_level_branch]

