# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from common.utils.date import naive_to_timestamp, UTC_TZ

from common.dynamic_settings.default import conf
from travel.rasp.train_api.train_bandit_api.client import BanditContext, SegmentData, BanditClient, RequestContext
from travel.rasp.train_api.train_purchase.core.models import ClientContracts

log = logging.getLogger(__name__)


def fill_segment_fees(segments, train_query, yandex_uid=None):
    if conf.TRAIN_PURCHASE_BANDIT_CHARGING:
        try:
            return _fill_segments_fees_by_bandit(segments, train_query)
        except Exception:
            log.exception('Error in bandit charging')

    contract = ClientContracts.get_active_contract(train_query.partner)
    for segment in segments:
        for tariff_info in segment.tariffs['classes'].values():
            tariff_info.calculate_fee(contract, yandex_uid=yandex_uid)
    return segments


def fill_min_prices_fees(trains, partner):
    contract = ClientContracts.get_active_contract(partner)
    for segment in trains:
        for tariff_info in segment.tariffs['classes'].values():
            tariff_info.calculate_fee(contract)
            tariff_info.price = tariff_info.total_price


def in_suburban_search(segment):
    return (segment.thread.t_subtype.use_in_suburban_search
            if (segment and segment.thread and segment.thread.t_subtype) else False)


def _fill_segments_fees_by_bandit(segments, train_query):
    tariff_infos_by_index = {}
    segment_data_list = []
    tariff_info_index = 0
    log_info = RequestContext(
        req_id=train_query.original_query.get('req_id'),
        yandex_uid=train_query.original_query.get('yandex_uid'),
        user_device=train_query.original_query.get('user_device'),
    )
    for segment in segments:
        for class_name, tariff_info in segment.tariffs['classes'].items():
            context = BanditContext(
                icookie=train_query.icookie,
                point_from=segment.station_from.point_key,
                point_to=segment.station_to.point_key,
                departure=naive_to_timestamp(segment.departure.astimezone(UTC_TZ).replace(tzinfo=None)),
                arrival=naive_to_timestamp(segment.arrival.astimezone(UTC_TZ).replace(tzinfo=None)),
                train_type=segment.raw_train_name,
                car_type=class_name,
                in_suburban_search=in_suburban_search(segment),
            )
            data = SegmentData(
                context=context,
                amount=tariff_info.ticket_price,
                service_amount=tariff_info.service_price,
                index=tariff_info_index,
                service_class=tariff_info.service_class,
                log_info=log_info,
            )
            segment_data_list.append(data)
            tariff_infos_by_index[tariff_info_index] = tariff_info
            tariff_info_index += 1

    client = BanditClient(bandit_type=train_query.bandit_type)
    fees_by_segment_data = client.get_fee_for_segments(segment_data_list)

    for data, fee in fees_by_segment_data.items():
        tariff_info = tariff_infos_by_index[data.index]
        tariff_info.fee = fee.fee
        tariff_info.fee_percent = fee.fee_percent
        tariff_info.is_bandit_fee_applied = fee.is_bandit_fee_applied
        tariff_info.bandit_type = fee.bandit_type
        tariff_info.bandit_version = fee.bandit_version

    return segments
