import datetime
import json
import itertools
from collections import Counter

from travel.rasp.bus.scripts.cache_research.emulator import Cache
from travel.rasp.bus.scripts.cache_research.log_parser import LogEntry, Request


class Warmer:
    K_PREVIOUS_DIRECTION_COUNTER = 0.8

    def __init__(self, config_file, start_time, segments):
        with open(config_file, "r") as f:
            config = json.load(f)
        self.analyse_time = datetime.timedelta(minutes=config["warmer_analyse_time"])
        self.k_days_before_ride = config["k_days_before_ride"]
        self.k_cache_age = config["k_cache_age"]
        self.rps_restrictions = config["rps_restrictions"]
        self.top_size = config["warmer_top_size"]
        self.calendar_depth = config["warmer_calendar_depth"]
        self.current_time = start_time
        self.last_time = self.current_time
        self.segments = segments

        self.direction_top = dict()
        self.direction_counter = Counter()
        self.total_direction_counter = Counter()
        self.current_supplier_rps = Counter()
        self.evaluated_requests = []
        self.warmer_queue = []
        self.warmer_requests = set()

    def is_supplier_rps_limit_exceeded(self, supplier):
        if supplier in self.rps_restrictions:
            return self.rps_restrictions[supplier] <= self.current_supplier_rps[supplier]
        return True

    def are_suppliers_available(self):
        unavailable_suppliers_counter = 0
        for supplier in self.rps_restrictions.keys():
            if self.is_supplier_rps_limit_exceeded(supplier):
                unavailable_suppliers_counter += 1
        return unavailable_suppliers_counter < len(self.rps_restrictions.keys())

    def _warm(self, delta_seconds):
        for seconds_shift in range(delta_seconds):
            sent_requests = []
            if not self.are_suppliers_available():
                continue
            for i, value in enumerate(self.evaluated_requests):
                request, suppliers = value
                filtered_suppliers = set()
                for supplier in suppliers:
                    if not self.is_supplier_rps_limit_exceeded(supplier):
                        filtered_suppliers.add(supplier)
                        self.current_supplier_rps[supplier] += 1
                if filtered_suppliers:
                    self.send_request(request, filtered_suppliers, seconds_shift)
                    suppliers -= filtered_suppliers
                    self.warmer_requests.add(request)
                if not suppliers:
                    sent_requests.append(i)
            for index in sorted(sent_requests, reverse=True):
                self.evaluated_requests.pop(index)
            self.current_supplier_rps.clear()

    def send_request(self, request, available_suppliers, seconds_shift):
        request_time = self.current_time + datetime.timedelta(seconds=seconds_shift)
        self.warmer_queue.append((LogEntry(request=request, datetime=request_time), available_suppliers))

    def evaluate_request(self, request):
        days_before_ride = (request.date - self.current_time.date()).days
        age = ((self.current_time - Cache.DATE[(request, Cache.ALL_CACHE_KEY)]).total_seconds()
               if (request, Cache.ALL_CACHE_KEY) in Cache.DATE else float('inf'))
        if age == 0:
            value = self.direction_top[request.direction] + self.k_days_before_ride * days_before_ride
            return value if self.k_cache_age == 0 else float('inf')
        return (self.direction_top[request.direction] +
                self.k_days_before_ride * days_before_ride + self.k_cache_age / age)

    def update_top(self):
        direction_top = map(lambda x: (x[1][0], x[0]), enumerate(sorted(self.direction_counter.items(),
                                                                        key=lambda x: x[1], reverse=True)))
        self.direction_top = dict(itertools.islice(direction_top, self.top_size))

    def evaluate_requests(self):
        requests = set()
        for day_shift in range(1, 1 + self.calendar_depth):
            for direction in self.direction_top.keys():
                new_request_date = self.current_time.date() + datetime.timedelta(days=day_shift)
                new_request = Request(direction=direction, date=new_request_date)
                requests.add((new_request, self.evaluate_request(new_request)))
        self.evaluated_requests = list(map(lambda x:
                                           (x[0], self.find_suppliers_for_direction(x[0].direction)),
                                           sorted(requests, key=lambda x: x[1])))

    def find_suppliers_for_direction(self, direction):
        return set(filter(lambda s: (direction.from_id in self.segments[s] and
                                     direction.to_id in self.segments[s][direction.from_id]), self.segments.keys()))

    def process_remaining_requests(self):
        for request, suppliers in self.evaluated_requests:
            self.warmer_requests.add(request)

    def reset(self):
        self.process_remaining_requests()
        self.evaluated_requests.clear()
        self.last_time += self.analyse_time
        self.update_top()
        self.evaluate_requests()
        for direction in self.total_direction_counter:
            self.total_direction_counter[direction] *= self.K_PREVIOUS_DIRECTION_COUNTER
        self.total_direction_counter.update(self.direction_counter)
        self.direction_counter.clear()
        self.current_time = self.last_time

    def is_time_over(self, time):
        return time - self.last_time >= self.analyse_time

    def run(self, entry):
        next_warm_cycle_seconds = 0
        self.warmer_queue.clear()
        if (entry.datetime - self.current_time).total_seconds() > 0:
            delta_seconds = int(min((entry.datetime - self.current_time).total_seconds(),
                                    (self.last_time + self.analyse_time - self.current_time).total_seconds()))
            next_warm_cycle_seconds = int((entry.datetime - self.current_time).total_seconds() -
                                          (self.last_time + self.analyse_time - self.current_time).total_seconds())
            self._warm(delta_seconds)
        if self.is_time_over(entry.datetime):
            self.reset()
        if next_warm_cycle_seconds > 0:
            self._warm(next_warm_cycle_seconds)
        self.current_time = entry.datetime
        self.direction_counter[entry.request.direction] += 1
        for supplier in self.find_suppliers_for_direction(entry.request.direction):
            self.current_supplier_rps[supplier] += 1
        return self.warmer_queue
