# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
import os

import requests
from django.conf import settings
from pybreaker import CircuitBreaker

from common.data_api.train_api.tariffs.utils import _make_segment_datetime_keys, get_possible_numbers
from common.models.schedule import TrainPurchaseNumber
from common.settings.configuration import Configuration
from common.settings.utils import define_setting
from travel.library.python.tracing.instrumentation import traced_function

from travel.rasp.export.export.v3.core.tariffs.train_tariffs import COACH_TYPE_NAMES, round_train_price
from travel.rasp.export.export.v3.core.tracing import get_request_id


log = logging.getLogger(__name__)


define_setting(
    'TRAIN_API_TARIFF_URL',
    {Configuration.PRODUCTION: 'https://production.train-api.rasp.internal.yandex.net/ru/api/segments/train-tariffs/'},
    default='https://testing.train-api.rasp.internal.yandex.net/ru/api/segments/train-tariffs/'
)

define_setting(
    'TRAIN_API_TARIFF_POLL_URL',
    {Configuration.PRODUCTION: 'https://production.train-api.rasp.internal.yandex.net/ru/api/segments/train-tariffs/poll/'},
    default='https://testing.train-api.rasp.internal.yandex.net/ru/api/segments/train-tariffs/poll/'
)

define_setting('TRAIN_ORDER_DOMAIN', default=os.environ.get('TRAIN_ORDER_DOMAIN'))

TRAIN_API_TIMEOUT = 5

define_setting('TRAIN_API_BREAKER_PARAMS', default={'fail_max': 3, 'reset_timeout': 60})

train_api_breaker = CircuitBreaker(**settings.TRAIN_API_BREAKER_PARAMS)


def prepare_train_selling_info(tariffs):
    return {
        'type': 'train',
        'tariffs': [{
            'seats': tariff['seats'],
            'class': class_name,
            'class_name': COACH_TYPE_NAMES.get(class_name),
            'currency': tariff['price']['currency'],
            'value': round_train_price(tariff['price']['value']),
            'order_url': '{}{}&{}'.format(settings.TRAIN_ORDER_DOMAIN, tariff['trainOrderUrl'], 'utm_source=suburbans')
        } for class_name, tariff in tariffs['classes'].items()]
    }


def set_segment_tariff(segment, train_number, tariffs_by_key):
    keys = _make_segment_datetime_keys(segment, get_possible_numbers(train_number))
    segment.train_keys = keys
    for key in keys:
        tariffs = tariffs_by_key.get(key)
        if tariffs:
            selling_info = prepare_train_selling_info(tariffs)
            if selling_info['tariffs']:
                segment.selling_info = selling_info


def set_segments_train_selling_info(tariffs_by_key, pseudo_train_segments, train_segments):
    train_numbers_by_id = TrainPurchaseNumber.get_train_purchase_numbers([segment.thread
                                                                          for segment in pseudo_train_segments])
    for segment in pseudo_train_segments:
        for train_number in train_numbers_by_id.get(segment.thread.id, []):
            set_segment_tariff(segment, train_number, tariffs_by_key)

    for segment in train_segments:
        set_segment_tariff(segment, segment.number, tariffs_by_key)


def build_train_api_params(query):
    return {
        'pointFrom': query['point_from'],
        'pointTo': query['point_to'],
        'date': query['date'],
        'national_version': 'ru',
        'partner': 'im',
        'includePriceFee': 1,
        '_rid': get_request_id()
    }


@train_api_breaker
@traced_function
def get_train_data(query, url):
    try:
        data = requests.get(
            url=url,
            params=build_train_api_params(query),
            timeout=TRAIN_API_TIMEOUT,
        ).json()
    except Exception as ex:
        log.error('failed {}'.format(repr(ex)))
        data = {}

    return data


def get_segments_tariffs(query, url):
    data = get_train_data(query, url)

    polling_status = data.get('querying', False)
    tariffs_by_key = {}
    for segment in data.get('segments', []):
        tariffs_by_key[segment['key']] = segment['tariffs']

    return tariffs_by_key, polling_status


def initialize_train_tariffs(query):
    return get_segments_tariffs(query, settings.TRAIN_API_TARIFF_URL)


def poll_train_tariffs(query):
    return get_segments_tariffs(query, settings.TRAIN_API_TARIFF_POLL_URL)
