from logging import getLogger

from travel.avia.admin.lib.yt_helpers import yt_client_fabric


class FlightFrequencyCollector(object):
    def __init__(self, yt_fabric, logger):
        self._yt_fabric = yt_fabric
        self._logger = logger

    def collect(self, sources, output):
        def _extract_flights_numbers(record):
            raw_backward_numbers = record.get('backward_numbers', '')
            raw_forward_numbers = record.get('forward_numbers', '')
            national_version = record.get('national_version', '')

            for raw_number in (raw_backward_numbers + ';' + raw_forward_numbers).split(';'):
                number = raw_number.strip()
                if number and national_version:
                    yield {
                        'number': number,
                        'national_version': national_version
                    }

        def _calculate_frequency_for_flight_number(key, items):
            answer = dict(key)
            for item in items:
                nv = item['national_version']
                if nv not in answer:
                    answer[nv] = 0
                answer[nv] += 1

            yield answer

        self._yt_fabric.create().run_map_reduce(
            source_table=sources,
            destination_table=output,

            mapper=_extract_flights_numbers,
            reducer=_calculate_frequency_for_flight_number,

            sort_by=['number'],
            reduce_by=['number'],
        )

    def collect_to_list(self, sources):
        with self._yt_fabric.create().TempTable() as temp_table:
            self.collect(sources, temp_table)
            return list(self._yt_fabric.create().read_table(temp_table))

flight_frequency_collector = FlightFrequencyCollector(
    yt_fabric=yt_client_fabric,
    logger=getLogger(__name__)
)
