# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging

import requests
from django.conf import settings
from django.utils.encoding import force_text
from pybreaker import CircuitBreaker

from common.apps.train.models import COACH_TYPE_CHOICES
from common.data_api.train_api.tariffs.utils import get_possible_numbers
from common.models.schedule import TrainPurchaseNumber
from common.settings.configuration import Configuration
from common.settings.utils import define_setting
from travel.library.python.tracing.instrumentation import traced_function


COACH_TYPE_NAMES = {coach[0]: force_text(coach[1]) for coach in COACH_TYPE_CHOICES}

log = logging.getLogger(__name__)


define_setting(
    'TRAIN_WIZARD_TARIFF_URL',
    {Configuration.PRODUCTION: 'https://production.search-api.trains.internal.yandex.ru/searcher/public-api/open_direction/'},
    default='https://testing.search-api.trains.internal.yandex.ru/searcher/public-api/open_direction/'
)

TRAIN_WIZARD_TIMEOUT = 1

define_setting('TRAIN_WIZARD_BREAKER_PARAMS', default={'fail_max': 5, 'reset_timeout': 20})

train_wizard_breaker = CircuitBreaker(**settings.TRAIN_WIZARD_BREAKER_PARAMS)


def round_train_price(price):
    # https://st.yandex-team.ru/RASPFRONT-6300#5bf58f55c96c51001c4bb84f
    return round(price, 0) if price > 100 else round(price, 2)


@train_wizard_breaker
@traced_function
def get_train_wizard_data(query_params):
    try:
        data = requests.get(
            url=settings.TRAIN_WIZARD_TARIFF_URL,
            params=query_params,
            timeout=TRAIN_WIZARD_TIMEOUT,
        ).json()
    except Exception as ex:
        log.error('failed {}'.format(repr(ex)))
        data = {}

    return data


def set_tariffs(segment, train_number, tariff_by_key):
    for number in get_possible_numbers(train_number):
        key = '{}_{}_{}_{}_{}'.format(
            segment.station_from.point_key,
            segment.station_to.point_key,
            number,
            segment.departure.isoformat(),
            segment.arrival.isoformat()
        )

        tariffs = tariff_by_key.get(key)
        if tariffs:
            segment.train_tariffs = [{
                'seats': t['count'],
                'class': t['coach_type'],
                'currency': t['price']['currency'],
                'value': round_train_price(float(t['price']['value'])),
                'class_name': COACH_TYPE_NAMES.get(t['coach_type'])
            } for t in tariffs]


def add_train_tariffs(pseudo_train_segments, train_segments, point_from, point_to, departure_date):
    train_numbers_by_id = TrainPurchaseNumber.get_train_purchase_numbers([segment.thread
                                                                          for segment in pseudo_train_segments])

    if not (train_numbers_by_id or train_segments):
        return

    query_params = {
        'departure_point_key': point_from.point_key,
        'arrival_point_key': point_to.point_key,
        'departure_date': departure_date,
        'order_by': 'departure',
    }
    segments = get_train_wizard_data(query_params).get('segments', [])

    tariff_by_key = {}
    for segment in segments:
        records = segment['places'].get('records')
        if records:
            key = '{}_{}_{}_{}_{}'.format(
                segment['departure']['station']['key'],
                segment['arrival']['station']['key'],
                segment['train']['number'],
                segment['departure']['local_datetime']['value'],
                segment['arrival']['local_datetime']['value']
            )
            tariff_by_key[key] = records

    for segment in pseudo_train_segments:
        for train_number in train_numbers_by_id.get(segment.thread.id, []):
            for number in get_possible_numbers(train_number):
                set_tariffs(segment, number, tariff_by_key)

    for segment in train_segments:
        set_tariffs(segment, segment.number, tariff_by_key)
