# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
from collections import defaultdict
from datetime import datetime, timedelta, time
from functools import partial
from itertools import groupby, chain

import pytz
from django.conf import settings

from common.apps.info_center.push import get_dir_pair_key
from common.apps.suburban_events.api import EventStateType, get_cancelled_path_for_segments, get_states_for_segments
from common.data_api.platforms.helpers import SegmentPlatformsBatch
from common.data_api.platforms.instance import platforms as platforms_client
from common.models.geo import ExternalDirectionMarker
from common.models.schedule import RThread, TrainSchedulePlan, RTStation
from common.models.transport import TransportType, TransportSubtype
from common.models_utils import fetch_related
from common.models_utils.i18n import RouteLTitle
from travel.rasp.library.python.common23.date import environment
from travel.rasp.library.python.common23.logging import log_run_time

from route_search.transfers import transfers
from route_search.transfers.variant import Variant
from route_search.facilities import fill_suburban_facilities as fill_segments_suburban_facilities
from route_search.shortcuts import search_routes

from travel.rasp.export.export.v3.core.helpers import (
    log_search, get_transport_type, get_days_and_except_texts, clean_number, find_suburban_and_train, cut_excess_str,
    get_thread_type, set_key, get_facilities_list, fill_thread_local_start_dt
)
from travel.rasp.export.export.v3.core.settings import get_settings
from travel.rasp.export.export.v3.core.suburban_events import get_segment_state_dict, STATE_NAME_BY_TYPE
from travel.rasp.export.export.v3.core.tariffs.tariffs import get_suburban_tariffs, get_segment_tariff
from travel.rasp.export.export.v3.core.tariffs.train_tariffs import add_train_tariffs
from travel.rasp.export.export.v3.core.teasers import get_search_teasers
from travel.rasp.export.export.v3.selling.suburban import add_suburban_selling_tariffs, SELLING_V2
from travel.rasp.export.export.v3.selling.train_api import set_segments_train_selling_info, initialize_train_tariffs
from travel.rasp.export.export.v3.views.utils import get_server_time, get_day_start_end_utc_in_iso


log = logging.getLogger(__name__)
log_run_time = partial(log_run_time, logger=log, log_level=logging.DEBUG)


SUBURBAN_COMPANIES_BLACKLIST = [59181]  # РЖД/ДОСС


class TransfersMode(object):
    NONE = 'none'
    ALL = 'all'
    NO_DIRECT = 'no_direct'
    AUTO = 'auto'


def _get_transfers(point_from, point_to, date):
    try:
        transfers_t_types = [TransportType.objects.get(id=TransportType.SUBURBAN_ID).code]
        transfer_variants = list(transfers.get_transfer_variants(point_from, point_to, date, transfers_t_types))
        return transfer_variants
    except Exception as ex:
        log.exception(repr(ex))
        return []


def _get_segments_states(segments, disable_cancels=True):
    if settings.ENABLE_SUBURBAN_STATES:
        with log_run_time('get states for {} segments'.format(len(segments))):
            try:
                segments_states = get_states_for_segments(
                    segments, all_keys=True, cancels_as_possible_delay=disable_cancels
                )
            except Exception as ex:
                log.exception(repr(ex))
                segments_states = {}
    else:
        segments_states = {}

    return segments_states


def _get_thread_cancels(segments, segment_states, disable_cancels=True):
    if settings.ENABLE_SUBURBAN_STATES and not disable_cancels:
        segments_with_cancels = []
        for segment in segments:
            segment_state = segment_states.get(segment)
            if segment_state and (
                getattr(segment_state.arrival, 'type', None) == EventStateType.CANCELLED or
                getattr(segment_state.departure, 'type', None) == EventStateType.CANCELLED
            ):
                segments_with_cancels.append(segment)

        with log_run_time('get thread cancels for {} segments'.format(len(segments_with_cancels))):
            try:
                thread_cancels = get_cancelled_path_for_segments(segments_with_cancels)
            except Exception as ex:
                log.exception(repr(ex))
                thread_cancels = {}
    else:
        thread_cancels = {}

    return thread_cancels


def is_subscription_allowed(suburban_segments):
    # Отключаем возможность подписок
    # https://st.yandex-team.ru/SUBURBAN-3214
    return False

    # https://st.yandex-team.ru/RASPFRONT-6492
    # if not suburban_segments:
    #     return False
    # if all(segment.thread.company and segment.thread.company.id in SUBURBAN_COMPANIES_BLACKLIST
    #        for segment in suburban_segments):
    #     return False
    #
    # return True


def search_on_date(
    request, point_from, point_to, point_from_reduced, point_to_reduced, timezone, departure_date,
    days_ahead, tomorrow_upto, selling_version=None, selling_flows=None, selling_barcode_presets=None,
    national_version=None, lang=None, transfers_mode=TransfersMode.NONE, disable_cancels=True
):
    today = environment.today()

    direct_segments = find_segments_on_date(point_from, point_to, departure_date, days_ahead, tomorrow_upto)
    suburban_segments, train_segments = get_suburban_train_segments(direct_segments)
    direct_segments = suburban_segments + train_segments

    all_segments = list(direct_segments)

    transfer_variants, transfer_segments = [], []
    if (transfers_mode == TransfersMode.ALL or
       (transfers_mode in [TransfersMode.NO_DIRECT, TransfersMode.AUTO] and not all_segments)):

        transfer_variants = _get_transfers(point_from, point_to, departure_date)
        transfer_segments = list(chain.from_iterable(v.segments for v in transfer_variants))
        all_segments += transfer_segments

    current_plan, next_plan = TrainSchedulePlan.add_to_threads([s.thread for s in all_segments if s.thread], today)
    fill_segments_suburban_facilities(all_segments)
    day_start_utc_iso, day_end_utc_iso = get_day_start_end_utc_in_iso(departure_date, point_from.pytz)

    subscription_allowed = is_subscription_allowed(suburban_segments)

    search_result = common_search(
        request, all_segments, point_from, point_to, point_from_reduced, point_to_reduced,
        subscription_allowed, departure_date
    )

    search_result['date_time'].update({
        'date': departure_date.strftime("%Y-%m-%d"),
        'days_ahead': days_ahead,
        'day_start_utc': day_start_utc_iso,
        'day_end_utc': day_end_utc_iso
    })

    if selling_version:
        add_segments_train_tariffs(point_from, point_to, departure_date, search_result, all_segments, train_segments)

        if selling_version >= SELLING_V2:
            add_suburban_selling_tariffs(
                direct_segments, all_segments, search_result, selling_version, selling_flows, selling_barcode_presets
            )

    dynamic_platforms = SegmentPlatformsBatch()
    dynamic_platforms.try_load(platforms_client, all_segments)

    segment_states = _get_segments_states(all_segments, disable_cancels)
    thread_cancels = _get_thread_cancels(all_segments, segment_states, disable_cancels)

    search_result['days'] = build_segments_on_date(
        direct_segments + transfer_variants,
        departure_date, today, next_plan, timezone, point_from,
        segments_states=segment_states,
        dynamic_platforms=dynamic_platforms,
        thread_cancels=thread_cancels
    )

    teasers = get_search_teasers(direct_segments, point_from, point_to, national_version=national_version, lang=lang)
    set_key(search_result, 'teasers', teasers)

    search_result['sup_tags'] = build_segments_subscription(direct_segments)

    return search_result


def search_on_all_days(request, point_from, point_to, point_from_reduced, point_to_reduced, timezone,
                       national_version=None, lang=None):
    today = environment.today()
    segments = find_segments_on_all_days(point_from, point_to)
    suburban_segments, train_segments = get_suburban_train_segments(segments)
    segments = suburban_segments + train_segments

    current_plan, next_plan = TrainSchedulePlan.add_to_threads([s.thread for s in segments if s.thread], today)
    subscription_allowed = is_subscription_allowed(suburban_segments)

    search_result = common_search(
        request, segments, point_from, point_to, point_from_reduced, point_to_reduced, subscription_allowed
    )
    search_result['segments'] = build_segments_on_all_days(segments, today, next_plan, timezone)

    teasers = get_search_teasers(segments, point_from, point_to, national_version=national_version, lang=lang)
    set_key(search_result, 'teasers', teasers)

    return search_result


def get_suburban_train_segments(segments):
    fetch_related([segment.thread for segment in segments], 't_subtype', 'express_lite', model=RThread)
    suburban_segments, train_segments = [], []
    for segment in segments:
        if segment.t_type.id == TransportType.TRAIN_ID:
            if segment.thread.t_subtype and segment.thread.t_subtype.code in TransportSubtype.get_train_search_codes():
                train_segments.append(segment)
        else:
            suburban_segments.append(segment)

    return suburban_segments, train_segments


def add_segments_train_tariffs(point_from, point_to, departure_date, search_result, all_segments, train_segments):
    try:
        pseudo_train_segments = [
            segment for segment in all_segments
            if segment.thread.t_subtype and segment.thread.t_subtype.has_train_tariffs
        ]
        if pseudo_train_segments or train_segments:
            tariffs_by_key, polling_status = initialize_train_tariffs({
                'point_from': point_from.point_key,
                'point_to': point_to.point_key,
                'date': departure_date
            })
            search_result['train_tariffs_polling'] = polling_status
            set_segments_train_selling_info(tariffs_by_key, pseudo_train_segments, train_segments)

            add_train_tariffs(pseudo_train_segments, train_segments, point_from, point_to, departure_date)
    except Exception as ex:
        log.exception('Не смогли получить цены от поездов: {}'.format(repr(ex)))


def common_search(request, segments, point_from, point_to, point_from_reduced, point_to_reduced, subscription_allowed,
                  date_=None):
    fill_thread_local_start_dt(segments)
    tariffs = get_suburban_tariffs(segments)
    RouteLTitle.fetch([segment.thread.L_title for segment in segments])

    rtstations, threads = [], []
    for segment in segments:
        rtstations.append(segment.rtstation_from)

        # Предустанавливаем известные объекты, чтобы исключить лишние хождения в базу
        segment.rtstation_from.thread = segment.thread
        threads.append(segment.thread)

    fetch_related(rtstations, 'station', model=RTStation)
    fetch_related(threads, 't_subtype', model=RThread)

    if settings.USERS_SEARCH_LOG:
        log_search(log, request, point_from, point_to, date_, segments)

    search_result = {
        'narrowed_from': None if not point_from_reduced else point_from.get_code('esr'),
        'narrowed_to': None if not point_to_reduced else point_to.get_code('esr'),
        'date_time': {'server_time': get_server_time()},
        'settings': get_settings(),
        'subscription_allowed': subscription_allowed,
        'tariffs': tariffs
    }

    return search_result


def find_segments_on_date(point_from, point_to, date_, days_ahead, tomorrow_upto):
    start = point_from.localize(loc=datetime.combine(date_, time(0, 0)))
    end = start + timedelta(days=days_ahead, hours=tomorrow_upto)

    segments = []
    for segment in find_suburban_and_train(point_from, point_to, start.date()):
        if segment.departure < start:
            continue

        if segment.departure > end:
            break

        segments.append(segment)

    # Костыль для RASP-8365
    border = datetime.combine(date_, time(0, 0))
    segments = [s for s in segments if datetime.replace(s.departure, tzinfo=None) >= border
                and datetime.replace(s.arrival, tzinfo=None) >= border]
    return segments


def find_segments_on_all_days(point_from, point_to):
    segments, _, _ = search_routes(
        point_from=point_from,
        point_to=point_to,
        departure_date=None,
        transport_types=[TransportType.objects.get(code='suburban')],
        add_train_subtypes=TransportSubtype.get_train_search_codes()
    )

    return segments


def build_segments_on_date(segments, date_, today, next_plan, timezone, point_from,
                           segments_states=None, dynamic_platforms=None, thread_cancels=None):
    segments_states = segments_states or {}

    segments.sort(key=lambda seg: (seg.departure, seg.thread.uid if getattr(seg, 'thread', None) else None))

    days_list = []
    for day, segments in groupby(segments, lambda _s: _s.departure.date()):
        day_start, day_end = get_day_start_end_utc_in_iso(day, point_from.pytz)
        day_dict = {
            'date': day.strftime("%Y-%m-%d"),
            'day_start_utc': day_start,
            'day_end_utc': day_end,
            'segments': []
        }
        for segment in segments:
            if isinstance(segment, Variant):
                segment_el = build_transfer_variant_data(
                    segment,
                    today=today,
                    next_plan=next_plan,
                    timezone=timezone,
                    request_date=date_,
                    segments_states=segments_states,
                    thread_cancels=thread_cancels
                )
            else:
                segment_el = build_segment_data(
                    segment,
                    today=today,
                    next_plan=next_plan,
                    timezone=timezone,
                    request_date=date_,
                    segment_state=segments_states.get(segment),
                    dynamic_platforms=dynamic_platforms,
                    thread_cancel=thread_cancels.get(segment)
                )

            day_dict['segments'].append(segment_el)

        days_list.append(day_dict)

    return days_list


def build_segments_on_all_days(segments, today, next_plan, timezone):
    segments.sort(key=lambda _s: (_s.departure.time(), _s.thread.uid))

    segments_list = []
    for segment in segments:
        segments_list.append(
            build_segment_data(
                segment,
                today=today,
                next_plan=next_plan,
                timezone=timezone,
                time_format='%H:%M',
            )
        )
    return segments_list


def build_segment_data(
        segment, today, next_plan,
        time_format=None, timezone=None, request_date=None, segment_state=None,
        dynamic_platforms=None, thread_cancel=None
):

    if timezone:
        departure = segment.departure.astimezone(timezone)
        arrival = segment.arrival.astimezone(timezone)
    else:
        departure = segment.departure
        arrival = segment.arrival

    segment_data = {
        'departure': {
            'time': departure.isoformat() if not time_format else departure.strftime(time_format),
            'station': segment.station_from.get_code('esr'),
        },
        'arrival': {
            'time': arrival.isoformat() if not time_format else arrival.strftime(time_format),
            'station': segment.station_to.get_code('esr')
        },
        'thread': {
            'uid': segment.thread.uid,
            'canonical_uid': segment.thread.canonical_uid,
            'number': cut_excess_str(clean_number(segment.thread)),
            'title': segment.thread.L_title(),
            'title_short': segment.thread.L_title(short=True)
        },
        'duration': int((segment.arrival - segment.departure).total_seconds() / 60),
    }

    if getattr(segment, 'selling_info', None):
        segment_data['selling_info'] = segment.selling_info

    if getattr(segment, 'selling_tariffs_ids', None):
        segment_data['selling_tariffs_ids'] = segment.selling_tariffs_ids

    if getattr(segment, 'train_keys', None):
        segment_data['train_keys'] = segment.train_keys

    if hasattr(segment, 'train_tariffs'):
        segment_data['train_tariffs'] = segment.train_tariffs

    shift = segment.rtstation_from.calc_days_shift(
        event='departure',
        start_date=segment.start_date,
        event_tz=timezone)

    segment_data['stops'] = cut_excess_str(segment.L_stops())

    days_text, except_text = get_days_and_except_texts(today, segment.thread, shift, next_plan)

    if hasattr(segment, 'tariffs_ids'):
        segment_data['tariffs_ids'] = segment.tariffs_ids
    set_key(segment_data, 'tariff', get_segment_tariff(segment))

    set_key(segment_data, 'days', cut_excess_str(days_text))
    set_key(segment_data, 'except', cut_excess_str(except_text))
    set_key(segment_data['arrival'], 'platform', SegmentPlatformsBatch.get_arrival_safe(
            dynamic_platforms, segment, cut_excess_str(segment.rtstation_to.L_platform())))
    set_key(segment_data['departure'], 'platform', SegmentPlatformsBatch.get_departure_safe(
            dynamic_platforms, segment, cut_excess_str(segment.rtstation_from.L_platform())))
    set_key(segment_data['thread'], 'type', get_thread_type(segment.thread))
    set_key(segment_data['thread'], 'transport', get_transport_type(segment.thread))
    set_key(segment_data['thread'], 'facilities', get_facilities_list(segment.suburban_facilities))

    if request_date:
        segment_data['arrival']['time_utc'] = arrival.astimezone(pytz.utc).isoformat()
        segment_data['departure']['time_utc'] = departure.astimezone(pytz.utc).isoformat()
        set_key(segment_data['thread'], 'start_time', segment.thread_local_start_dt.isoformat())

        if segment_state:
            state_dict = get_segment_state_dict(segment_state)
            for state_type, state in state_dict.items():
                event_type, state_name = STATE_NAME_BY_TYPE.get(state_type)
                segment_data[event_type][state_name] = state
        if thread_cancel:
            segment_data['thread']['cancelled'] = thread_cancel.is_fully_cancelled
            segment_data['thread']['cancelled_segments'] = [
                {
                    'from_title_genitive': _get_cancelled_title(cancelled_segment.rtstation_from.station),
                    'to_title_genitive': _get_cancelled_title(cancelled_segment.rtstation_to.station)
                }
                for cancelled_segment in thread_cancel.cancelled_segments
            ]

    return segment_data


def _get_cancelled_title(station):
    return station.L_popular_title(case='genitive', fallback=True) or station.L_title()


def build_transfer_variant_data(
        variant, today, next_plan,
        time_format=None, timezone=None, request_date=None, segments_states=None, thread_cancels=None):

    segment_elements = []
    total_tariff = {'value': 0}
    prices_currencies = set()
    for segment in variant.segments:
        segment_el = build_segment_data(
            segment,
            today, next_plan,
            time_format=time_format,
            timezone=timezone,
            request_date=request_date,
            segment_state=segments_states.get(segment),
            thread_cancel=thread_cancels.get(segment)
        )
        segment_elements.append(segment_el)

        if segment_el.get('tariff'):
            tariff = segment_el['tariff']
            total_tariff['value'] += tariff['value']
            prices_currencies.add(tariff['currency'])

    # Не пытаемся вычислять цену, если валюты разные (либо тарифов нет)
    if len(prices_currencies) == 1:
        total_tariff['currency'] = prices_currencies.pop()
        total_tariff['value'] = round(total_tariff['value'], 2)
    else:
        total_tariff = None

    result = {
        'is_transfer': True,
        'duration': variant.duration.total_seconds() / 60,
        'transfer_points': [
            {
                'title': point.L_title()
            }
            for point in variant.transfers
        ],
        'segments': segment_elements,
    }

    set_key(result, 'tariff', total_tariff)

    return result


def _get_station_directions(station):
    direction_markers = ExternalDirectionMarker.objects.filter(station=station).select_related(
        'external_direction', 'external_direction__suburban_zone'
    )

    dirs = set()
    for dir_marker in direction_markers:
        dirs.add(dir_marker.external_direction)

    return dirs


def get_segments_subscription_objects(segments):
    """ https://st.yandex-team.ru/RASPEXPORT-283 """

    station_pairs = set()
    for segment in segments:
        station_pairs.add((segment.station_from, segment.station_to))

    stations_subscr, settlements_subscr, dirs_subscr, dirs_pairs_subscr = set(), set(), set(), set()
    for st_from, st_to in station_pairs:
        stations_subscr.add(st_from)
        stations_subscr.add(st_to)

        if st_from.settlement:
            settlements_subscr.add(st_from.settlement)

        if st_to.settlement:
            settlements_subscr.add(st_to.settlement)

        st_from_dirs = _get_station_directions(st_from)
        st_to_dirs = _get_station_directions(st_to)
        dirs_subscr |= st_from_dirs & st_to_dirs

        # если не нашли одинаковых направлений, формируем их пары из общих пригородных зон
        if not dirs_subscr:
            by_zone_from = defaultdict(list)
            for st_from_dir in st_from_dirs:
                if st_from_dir.suburban_zone:
                    by_zone_from[st_from_dir.suburban_zone].append(st_from_dir)

            for dir_to in st_to_dirs:
                for dir_from in by_zone_from.get(dir_to.suburban_zone, []):
                    dirs_ordered = tuple(sorted((dir_from, dir_to), key=lambda d: d.id))
                    dirs_pairs_subscr.add(dirs_ordered)

    return stations_subscr, settlements_subscr, dirs_subscr, dirs_pairs_subscr


def build_segments_subscription(segments):
    stations_subscr, settlements_subscr, dirs_subscr, dirs_pairs_subscr = get_segments_subscription_objects(segments)

    dirs = [str(ex_dir.id) for ex_dir in dirs_subscr]

    for dir1, dir2 in dirs_pairs_subscr:
        dirs.append(get_dir_pair_key(dir1.id, dir2.id))

    return {
        'suburban_station': [st.get_code('esr') for st in stations_subscr],
        'suburban_city': [str(sett.id) for sett in settlements_subscr],
        'suburban_direction': dirs,
    }
