# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from datetime import timedelta
from decimal import Decimal

from common.apps.train_order.enums import CoachType
from common.data_api.train_api.tariffs.utils import get_possible_numbers
from common.dynamic_settings.default import conf
from common.models.currency import Price
from common.models.schedule import RThreadType, DeLuxeTrain
from common.models.transport import TransportType, TransportSubtype
from common.models_utils.geo import Point
from common.utils.date import timedelta2minutes
from travel.rasp.library.python.common23.date.environment import now_aware
from travel.rasp.train_api.tariffs.train.base.availability_indication import AvailabilityIndication
from travel.rasp.train_api.tariffs.train.base.models import TrainSegment, TrainTariff, TrainMinPrices, PlaceReservationType
from travel.rasp.train_api.tariffs.train.base.segments import fill_min_prices_fees
from travel.rasp.train_api.tariffs.train.base.utils import (
    make_tariff_segment_key, build_train_order_url, fix_broken_classes
)
from travel.rasp.train_api.tariffs.train.base.worker import TrainTariffsResult
from travel.rasp.train_api.tariffs.train.segment_builder.helpers.black_list import (
    redirect_blacklisted_trains_to_ufs, get_ufs_black_list_train_numbers
)
from travel.rasp.train_api.wizard_api.client import train_wizard_api_client
from travel.rasp.train_api.wizard_api.helpers.dt import get_dt, get_dt_from_wizard_dict

log = logging.getLogger(__name__)


class WizardThreadInfo(object):
    type = None
    first_country_code = None
    last_country_code = None
    t_type_id = TransportType.TRAIN_ID
    number = None
    t_subtype = None
    deluxe_train = None

    @property
    def type_id(self):
        return self.type.id if self.type else None


def get_wizard_tariffs(train_query, expired_timeout=conf.TRAIN_PURCHASE_WIZARD_CONFIDENCE_MINUTES):
    try:
        segments = get_wizard_segments(train_query)
        segments = redirect_blacklisted_trains_to_ufs(segments)
        status = (
            TrainTariffsResult.STATUS_PENDING
            if need_to_update(segments, expired_timeout)
            else TrainTariffsResult.STATUS_SUCCESS
        )
        return TrainTariffsResult(train_query, segments=segments, status=status)
    except Exception:
        log.exception('Ошибка при получении поездов из train-wizard-api')
        return None


def get_wizard_segments(query):
    raw_segments = train_wizard_api_client.get_raw_segments_with_tariffs(
        departure_point_key=query.departure_point.point_key,
        arrival_point_key=query.arrival_point.point_key,
        departure_date=query.departure_date,
        safe=False
    )
    if not raw_segments:
        return []

    stations_by_key = _preload_stations(raw_segments['segments'])
    segments_iter = (build_segment(raw_segment, query, stations_by_key) for raw_segment in raw_segments['segments'])
    return [s for s in segments_iter if s]


def build_segment(raw_segment, query, stations_by_key):
    segment = TrainSegment()
    segment.raw_train_name = raw_segment['train'].get('raw_train_name')
    segment.original_number = raw_segment['train']['number']
    segment.number = raw_segment['train']['display_number'] or segment.original_number
    segment.train_number_to_get_route = segment.original_number
    segment.possible_numbers = get_possible_numbers(segment.number)
    segment.tariffs['electronic_ticket'] = raw_segment['places'].get('electronic_ticket', False)
    segment.can_supply_segments = True

    segment.ufs_title = raw_segment['train']['title']
    segment.title = raw_segment['train']['title']

    segment.coach_owners = raw_segment['train']['coach_owners'] or []
    segment.has_dynamic_pricing = raw_segment['train']['has_dynamic_pricing']
    segment.two_storey = raw_segment['train']['two_storey']
    segment.is_suburban = raw_segment['train']['is_suburban']
    segment.provider = raw_segment['train']['provider']

    departure = get_dt_from_wizard_dict(raw_segment['departure']['local_datetime'])
    arrival = get_dt_from_wizard_dict(raw_segment['arrival']['local_datetime'])

    duration = timedelta(minutes=int(raw_segment['duration']))
    if arrival - departure != duration:
        log.warning('Рассчитанное %s и указанное %s время прибытия отличаются', departure + duration, arrival)

    segment.station_from = stations_by_key[raw_segment['departure']['station']['key']]
    segment.station_to = stations_by_key[raw_segment['arrival']['station']['key']]

    segment.railway_departure = departure.astimezone(query.departure_railway_tz)
    segment.railway_arrival = arrival.astimezone(query.arrival_railway_tz)
    segment.departure = departure.astimezone(query.departure_point.pytz)
    segment.arrival = arrival.astimezone(query.arrival_point.pytz)
    segment.is_deluxe = bool(raw_segment['train']['brand'] and raw_segment['train']['brand']['is_deluxe'])

    segment.key = make_tariff_segment_key(segment)

    classes = _build_classes(raw_segment['places']['records'])
    broken_classes = fix_broken_classes(raw_segment.get('broken_classes'))
    if not classes and not broken_classes:
        return

    segment.tariffs['classes'] = classes
    segment.tariffs['broken_classes'] = broken_classes
    if raw_segment['places']['updated_at']:
        segment.updated_at = get_dt_from_wizard_dict(raw_segment['places']['updated_at'])

    first_country_code = raw_segment.get('first_country_code') or raw_segment['train'].get('first_country_code')
    last_country_code = raw_segment.get('last_country_code') or raw_segment['train'].get('last_country_code')
    segment.first_country_code = first_country_code
    segment.last_country_code = last_country_code

    thread = WizardThreadInfo()
    segment.thread = thread
    thread.first_country_code = first_country_code
    thread.last_country_code = last_country_code
    thread.type = RThreadType.get_by_code(raw_segment['train'].get('thread_type'))
    thread.number = raw_segment['train']['number']
    raw_brand = raw_segment['train']['brand']
    t_subtype_id = raw_segment['train'].get('t_subtype_id')
    if raw_brand:
        thread.deluxe_train = DeLuxeTrain.objects.get(id=raw_brand['id'])
    if t_subtype_id:
        thread.t_subtype = TransportSubtype.objects.get(id=t_subtype_id)

    return segment


def _preload_stations(raw_segments):
    station_keys = []
    for raw_segment in raw_segments:
        station_keys.append(raw_segment['departure']['station']['key'])
        station_keys.append(raw_segment['arrival']['station']['key'])

    return Point.in_bulk(station_keys)


def _build_classes(raw_coach_tariffs):
    tariffs = {}
    for raw_tariff in raw_coach_tariffs or []:
        currency = raw_tariff['price']['currency']
        total_price = raw_tariff['price']['value']
        price_details = raw_tariff['price_details'] or {}
        service_price = Decimal(price_details.get('service_price', 0))
        fee = Decimal(price_details.get('fee', 0))
        ticket_price = Decimal(price_details.get('ticket_price', total_price - fee))

        tariff = TrainTariff(
            coach_type=CoachType(raw_tariff['coach_type']),
            ticket_price=Price(ticket_price, currency),
            service_price=Price(service_price, currency),
            seats=raw_tariff['count'],
            lower_seats=raw_tariff.get('lower_count'),
            upper_seats=raw_tariff.get('upper_count'),
            lower_side_seats=raw_tariff.get('lower_side_count'),
            upper_side_seats=raw_tariff.get('upper_side_count'),
            max_seats_in_the_same_car=raw_tariff['max_seats_in_the_same_car'],
            several_prices=price_details.get('several_prices', False),
            place_reservation_type=PlaceReservationType(
                price_details.get('place_reservation_type', PlaceReservationType.USUAL)
            ),
            is_transit_document_required=price_details.get('is_transit_document_required', False),
            availability_indication=AvailabilityIndication(
                price_details.get('availability_indication', AvailabilityIndication.AVAILABLE)
            ),
            service_class=raw_tariff.get('service_class'),
            has_non_refundable_tariff=raw_tariff.get('has_non_refundable_tariff', False),
        )
        tariff.fee = Price(fee, currency)

        valid, errors = tariff.validate()
        if valid:
            tariffs[tariff.coach_type] = tariff

    return tariffs


def need_to_update(segments, expired_timeout=conf.TRAIN_PURCHASE_WIZARD_CONFIDENCE_MINUTES):
    updated_at_values = [segment.updated_at for segment in segments if segment.updated_at]
    if not updated_at_values:
        return True
    return timedelta2minutes(now_aware() - min(updated_at_values)) > expired_timeout


def get_wizard_prices(query):
    directions = [(query['point_from'].point_key, query['point_to'].point_key)]

    departure_date_from = now_aware().astimezone(query['point_from'].pytz)
    raw_prices = train_wizard_api_client.get_raw_prices_by_direction(
        directions,
        departure_date_from=departure_date_from,
        departure_date_to=departure_date_from + timedelta(days=90),
        tld=query['national_version']
    )
    if not raw_prices:
        return []

    black_list = get_ufs_black_list_train_numbers()
    trains = filter(None, (
        build_min_train_price(raw_price, query['point_from'], query['point_to'], black_list)
        for raw_price in raw_prices
    ))
    fill_min_prices_fees(trains, query['partner'])

    return trains


def build_min_train_price(raw_price, point_from, point_to, black_list):
    train = TrainMinPrices()
    train.original_number = raw_price['number']
    train.display_number = raw_price['display_number'] or train.original_number
    train.tariffs = {'electronic_ticket': raw_price['electronic_ticket']}

    train.coach_owners = raw_price['coach_owners'] or []
    train.has_dynamic_pricing = raw_price['has_dynamic_pricing']
    train.two_storey = raw_price['two_storey']
    train.is_suburban = raw_price['is_suburban']

    train.departure = get_dt(raw_price['departure_dt'], point_from.pytz)
    train.arrival = get_dt(raw_price['arrival_dt'], point_to.pytz)

    train.departure_station_id = raw_price['departure_station_id']
    train.arrival_station_id = raw_price['arrival_station_id']

    classes = _build_classes(raw_price['places'])
    if not classes:
        return

    for tariff in classes.values():
        tariff.train_order_url, tariff.train_order_url_owner = build_train_order_url(
            departure=train.departure, number=train.original_number,
            segment_number=train.original_number,
            old_ufs_order=train.original_number in black_list,
            first_country_code=raw_price.get('first_country_code'),
            last_country_code=raw_price.get('last_country_code'),
            coach_type=tariff.coach_type,
            point_from=point_from,
            point_to=point_to,
            train_title=None,  # можно вычислить из title_dict.title_parts, но кажется не нужно
        )

    train.tariffs['classes'] = classes
    return train
