# -*- coding: utf-8 -*-
import calendar
import operator
from itertools import islice

from yt.wrapper import TablePath

from travel.avia.avia_api.avia.v1.email_dispenser.helpers.qkey_utils import qkey_to_params_tuple

NEEDED_COLUMNS = [
    'unixtime',
    'partner_id',
    'national_price', 'national_currency_id', 'national_version_id',
    'forward_segments', 'backward_segments',
    'with_baggage',
]

SOURCE_COLUMNS = NEEDED_COLUMNS + ['query_id']

TOP_N_POPULAR = 20
TOP_N_MINPRICE = 20


# Pipeline primitives

class VariantsLogSubscriptionPipeline(object):
    def __init__(self, yt, qkeys, left_date_time, right_date_time, logger):
        """
        :param yt: yt wrapper
        :param set[basestring] qkeys: set of qkeys for variants search
        :param datettime.datetime left_date_time:
        :param datetime.datetime right_date_time:
        :param logging.Logger logger:
        """
        self.yt = yt
        self.qkeys = qkeys
        self.left_date_time = left_date_time
        self.right_date_time = right_date_time
        self.logger = logger

    def run_pipeline(
        self, source_tables, minprice_variants_table, minprice_qkey_table,
        top_popular_table, popularity_map, stations_cache
    ):
        """
        :param list[basestring] source_tables: list of tables with variants
        :param basestring minprice_variants_table: table path to save minprice for each variant
        :param basestring minprice_qkey_table: table path to save minprice for each qkey
        :param basestring top_popular_table: table path to save top popular variants
        :param dict[tuple,int] popularity_map: flight-popularity mapping
        :param dict[basestring,dict] stations_cache: station
        """
        self.logger.info('Getting minimal price for each variant dimensions = (qkey X variant)')
        get_minimal_prices_for_each_variant(
            self.yt, source_tables, minprice_variants_table,
            popularity_map, self.qkeys, self.left_date_time, self.right_date_time,
            stations_cache
        )
        self.logger.info('Getting minimal price for each qkey (qkey)')
        get_min_price_variants(self.yt, minprice_variants_table, minprice_qkey_table)
        self.logger.info('Leaving top %s popular variants qkey (qkey X variant)', TOP_N_POPULAR)
        get_popular_variants(self.yt, minprice_variants_table, top_popular_table)
        self.logger.info('Done with YT pipeline')


def get_minimal_prices_for_each_variant(
    yt, source_tables, out_table,
    popularity, qkeys, left_date_time, right_date_time,
    stations_cache
):
    """
    :param yt: yt wrapper
    :param list[basestring] source_tables: tables with variants
    :param basestring out_table: table to save minimal prices to
    :param dict[tuple,int] popularity: flight - popularity mapping
    :param set[basestring] qkeys: set of qkeys to get prices for
    :param datetime.datetime left_date_time:
    :param datetime.datetime right_date_time:
    :param dict[basestring, dict] stations_cache:
    """
    left_timestamp = calendar.timegm(left_date_time.utctimetuple())
    right_timestamp = calendar.timegm(right_date_time.utctimetuple())

    def map_qkey_and_popularity(record):
        qkey = record.get('query_id').split('.')[3]
        time = record.get('unixtime')
        if left_timestamp <= time <= right_timestamp and qkey in qkeys:
            forward_segments = record.get('forward_segments') or []
            backward_segments = record.get('backward_segments') or []
            fwd_seg_str = ';'.join(map(operator.itemgetter('route'), forward_segments))
            bwd_seg_str = ';'.join(map(operator.itemgetter('route'), backward_segments))

            forward_param, backward_param = qkey_to_params_tuple(qkey, stations_cache)

            forward_popularity = popularity.get((forward_param + (fwd_seg_str,)), 0)

            backward_popularity = popularity.get((backward_param + (bwd_seg_str,)), 0) if backward_param else 0

            route_popularity = forward_popularity + backward_popularity

            out = {k: record[k] for k in NEEDED_COLUMNS}
            out['qkey'] = qkey
            out['fwd_seg_str'] = fwd_seg_str
            out['bwd_seg_str'] = bwd_seg_str
            out['popularity'] = route_popularity
            out['_popularity'] = -route_popularity
            yield out

    def reduce_latest_partner_variants(key, values):
        latest_variant_from_partner = {}
        for value in values:
            if (
                value['partner_id'] not in latest_variant_from_partner
                or value['unixtime'] > latest_variant_from_partner[value['partner_id']]['unixtime']
            ):
                latest_variant_from_partner[value['partner_id']] = value

        for var in reduce_min_price(key, latest_variant_from_partner.values()):
            yield var

    source_tables_restricted_columns = [
        TablePath(table_name, columns=SOURCE_COLUMNS)
        for table_name in source_tables
    ]

    yt.run_map_reduce(
        source_table=source_tables_restricted_columns,
        destination_table=out_table,
        mapper=map_qkey_and_popularity,
        reducer=reduce_latest_partner_variants,
        reduce_by=['qkey', 'fwd_seg_str', 'bwd_seg_str'],
    )


def get_min_price_variants(yt, compressed_table, min_prices_table):
    yt.run_map_reduce(
        source_table=compressed_table,
        destination_table=min_prices_table,
        mapper=None,
        reducer=reduce_top_n(TOP_N_MINPRICE),
        reduce_by=['qkey'],
        sort_by=['qkey', 'national_price']
    )
    yt.run_sort(min_prices_table, sort_by=['qkey', 'national_price'])


def get_popular_variants(yt, source_table, out_table):
    yt.run_map_reduce(
        source_table=source_table,
        destination_table=out_table,
        mapper=None,
        reducer=reduce_top_n(TOP_N_POPULAR),
        sort_by=['qkey', '_popularity', 'national_price'],
        reduce_by=['qkey'],
    )
    yt.run_sort(out_table, sort_by=['qkey', '_popularity'])


# Common reducers

def reduce_top_n(n):
    def reduce_top(key, values):
        for value in islice(values, n):
            yield value

    return reduce_top


def reduce_min_price(key, values):
    min_val = min(values, key=lambda v: v['national_price'])
    yield min_val
