import threading
from classes import logger as log
from collections import deque


BWE_DOWN_FACTOR = 0.5
BWE_INTERVAL = 1.0
BWE_TIME_INTERVAL = 0.5
BWE_UP_FACTOR = 1.05
BWE_UP_SLOW_FACTOR = 1.0125
BWE_UP_CREEP_FACTOR = 1.001
BWE_IGNORE_START_SECONDS = 5 / BWE_TIME_INTERVAL
BWE_UP_WAIT = 3 / BWE_TIME_INTERVAL
BWE_UP_WAIT_SLOW = 7 / BWE_TIME_INTERVAL
BWE_FAST_UP_STEPS = 3
BWE_X_FACTOR = 4000000
BWE_WATERMARK_MS = 600
BWE_CUM_WATERMARK_MS = 1500
BWE_HIST_LEN = 6


UP = 1
DOWN = -1
GROWING = 1
SHRINKING = -1
LEVEL = 0


class BitRateControl():

    def __init__(self, *args, **kwargs):

        super(BitRateControl, self).__init__(*args, **kwargs)
        self.bwe_start = 0
        self.reset_bwe()
        self.bandwidthcontroller_timer = threading.Timer(BWE_TIME_INTERVAL, self.run_latency_bwe)
        self.bandwidthcontroller_timer.start()
        self.rtmp_queue = None
        self.bitrate = None
        self.target_fps = None

    def reset_bwe(self):
        self.bwe_change = ""
        self.bwe_data = []
        self._buf_hist = deque([0 for i in range(BWE_HIST_LEN)])
        self.bwe_last_down_bitrate = 0
        self.bwe_start = 0

    def get_slow_factor(self, recalc=False):
        if self.is_creep_up():
            return BWE_UP_CREEP_FACTOR
        return BWE_UP_SLOW_FACTOR

    def get_fast_factor(self, recalc=False):
        bound = self.get_estimated_upper_bound()
        t = self.target_bitrate
        self.up_factor_fast = (t+(bound-t)/BWE_FAST_UP_STEPS)/t
        self.up_factor_fast = max(BWE_UP_SLOW_FACTOR, self.up_factor_fast)
        # if not recalc:
        #     return self.up_factor_fast
        if (self.up_factor_fast > 1.2):
            log.debug("get_fast_factor: %s bound: %s, target: %s",
                      self.up_factor_fast, bound, self.target_bitrate)
        return self.up_factor_fast

    def run_latency_bwe(self):
        try:
            if not self.rtmp_queue:
                return
            self.rtmp_queue_ms = self.rtmp_queue.get_property("current-level-time") // 1000000
            if self.bwe_start < BWE_IGNORE_START_SECONDS:
                self.bwe_start += 1
                self.bandwidthcontroller_timer = threading.Timer(BWE_TIME_INTERVAL, self.run_latency_bwe)
                self.bandwidthcontroller_timer.start()
                return
            self._buf_hist.appendleft(self.rtmp_queue_ms)
            self._buf_hist.pop()

            d = self.decide_up_down()
            if d is UP:
                self.increase_bitrate()
            elif d is DOWN:
                self.decrease_bitrate()

        except Exception:
            log.error("Unhandled exception in run_latency_bwe", exc_info=True)
        self.bandwidthcontroller_timer = threading.Timer(BWE_TIME_INTERVAL, self.run_latency_bwe)
        self.bandwidthcontroller_timer.start()

    def is_slow_up(self):
        bound = min(self.bwe_last_down_bitrate,
                    self.get_estimated_upper_bound())
        return self.target_bitrate >= bound * 0.87

    def get_estimated_upper_bound(self):
        sample_count = len(self.bwe_data)
        if not sample_count:
            return self.target_bitrate
        i_to_median = int(sample_count/2)
        return self.bwe_data[i_to_median]

    def is_creep_up(self):
        bound = self.get_estimated_upper_bound() * 0.95
        return len(self.bwe_data) > 5 and self.target_bitrate >= bound

    def add_upper_bound(self, bitrate):
        self.bwe_data.append(bitrate)
        self.bwe_data = sorted(self.bwe_data, reverse=True)
        if len(self.bwe_data) > 8:
            self.bwe_data = self.bwe_data[1:-1]

    def is_bw_restricted(self):
        return self.target_bitrate != self.bitrate

    def decide_up_down(self):
        ancient = 0
        current = 0
        for i in range(BWE_HIST_LEN-1, int(BWE_HIST_LEN/2)-1, -1):
            ancient += self._buf_hist[i]

        for i in range(0, int(BWE_HIST_LEN/2), 1):
            current += self._buf_hist[i]
        s = ancient + current
        if s is 0:
            if self.is_bw_restricted():
                return UP
            return LEVEL

        is_current_growing = False
        if self._buf_hist[0] > self._buf_hist[1]:
            is_current_growing = True

        is_longterm_growing = False
        if (current > ancient):
            is_longterm_growing = True

        bad_cnt = 0
        for i in self._buf_hist:
            if i > BWE_WATERMARK_MS:
                bad_cnt += 1
            else:
                break

        if bad_cnt >= 3 and s > BWE_CUM_WATERMARK_MS and is_longterm_growing and is_current_growing:
            return DOWN
        return LEVEL

    def increase_bitrate(self):
        is_slow = self.is_slow_up()
        self.bwe_up_monitor_interval_cnt += 1
        wait = BWE_UP_WAIT
        if is_slow:
            wait = BWE_UP_WAIT_SLOW
        if self.bwe_up_monitor_interval_cnt < wait:
            return
        self.bwe_up_monitor_interval_cnt = 0
        recalc = self.bwe_change != "up"
        self.bwe_change = "up"

        if is_slow:
            factor = self.get_slow_factor(recalc)
        else:
            factor = self.get_fast_factor(recalc)

        bitrate = int(min(self.bitrate, self.target_bitrate * factor))
        log.info("increase_bitrate %d -> %d  (queue: %s, factor: %s, slow: %s)",
                 self.target_bitrate, bitrate,
                 self._buf_hist, factor, is_slow)
        self.target_bitrate = bitrate
        self.set_bitrate()

    def decrease_bitrate(self):
        self.bwe_up_monitor_interval_cnt = 0
        overreact = False
        if self.bwe_change != "down":
            self.bwe_last_down_bitrate = self.target_bitrate
            self.add_upper_bound(self.bwe_last_down_bitrate)
            overreact = True
            self.bwe_change = "down"
        delta = max(0, self.rtmp_queue_ms - (2 * 1000 / self.target_fps))
        down_factor = 1.0 - (delta * delta) / BWE_X_FACTOR
        down_factor = max(0.5, down_factor)
        down_factor = min(1, down_factor)
        if overreact:
            down_factor = min(0.7, down_factor)
        bitrate = int(max(self.min_bitrate, self.target_bitrate * down_factor))
        log.info("decrease_bitrate %d -> %d (factor = %s, queue: %s)",
                 self.target_bitrate, bitrate, down_factor, self._buf_hist)
        self.target_bitrate = bitrate
        self.set_bitrate()

    def bwe_destroy(self):
        self.bwe_data = []
        self._buf_hist = None

