# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging

from django.conf import settings
from django.core.cache import cache
from django.utils.translation import get_language
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.response import Response

from common.data_api.baris.instance import baris
from common.data_api.baris.helpers import get_plane_stations_ids
from common.models.geo import Settlement, Station, Station2Settlement
from common.models.transport import TransportType
from travel.rasp.library.python.common23.logging import log_run_time
from route_search.models import ZNodeRoute2
from travel.rasp.morda_backend.morda_backend.search.canonicals.serialization import (
    CanonicalsQuerySchema, CanonicalsResponseSchema
)


log = logging.getLogger(__name__)


def get_railway_points(point_from, point_to, station_pairs):
    stations_from = {pair[0] for pair in station_pairs}
    stations_to = {pair[1] for pair in station_pairs}

    return (Station.objects.get(id=stations_from.pop()) if len(stations_from) == 1 else point_from,
            Station.objects.get(id=stations_to.pop()) if len(stations_to) == 1 else point_to)


def get_znoderoute_canonicals(point_from, point_to, t_type):
    with log_run_time('get_znoderoute_canonicals {} {} {}'.format(point_from.point_key, point_to.point_key, t_type), logger=log):
        data = []

        exclude_t_types = [TransportType.PLANE_ID]
        if t_type:
            exclude_t_types += [TransportType.objects.get(code=t_type)]

        znoderoute_objects = list(
            ZNodeRoute2.objects.filter(settlement_from_id=point_from, settlement_to_id=point_to, good_for_start=True, good_for_finish=True)
                .exclude(t_type_id__in=exclude_t_types)
                .values_list('t_type_id', 'station_from_id', 'station_to_id')
                .distinct()
        )

        t_types_ids = set()
        suburban_stations, train_stations = [], []
        for zn in znoderoute_objects:
            t_types_ids.add(zn[0])
            if zn[0] == TransportType.SUBURBAN_ID:
                suburban_stations.append((zn[1], zn[2]))
            elif zn[0] == TransportType.TRAIN_ID:
                train_stations.append((zn[1], zn[2]))

        for t_type_id in t_types_ids:
            if t_type_id == TransportType.TRAIN_ID:
                canonical_point_from, canonical_point_to = get_railway_points(point_from, point_to, train_stations)
            elif t_type_id == TransportType.SUBURBAN_ID:
                canonical_point_from, canonical_point_to = get_railway_points(point_from, point_to, suburban_stations)
            else:
                canonical_point_from, canonical_point_to = point_from, point_to

            data.append({
                'point_from': canonical_point_from,
                'point_to': canonical_point_to,
                'transport_type': TransportType.objects.get(id=t_type_id).code
            })

    return data


def get_baris_canonicals(point_from, point_to, t_type):
    with log_run_time('get_baris_canonicals {} {} {}'.format(point_from.point_key, point_to.point_key, t_type), logger=log):
        if t_type == TransportType.get_plane_type().code:
            return []

        station_from_ids = get_plane_stations_ids(point_from)
        station_to_ids = get_plane_stations_ids(point_to)

        data = []
        baris_data = baris.get_p2p_all_days_search(
            get_language(), station_from_ids, station_to_ids,
            point_from.time_zone, point_to.time_zone
        )
        if baris_data.flights:
            data.append({
                'point_from': point_from,
                'point_to': point_to,
                'transport_type': TransportType.get_plane_type().code
            })

    return data


def get_canonicals(point_from, point_to, t_type):
    znoderoute_data = get_znoderoute_canonicals(point_from, point_to, t_type)
    baris_data = get_baris_canonicals(point_from, point_to, t_type)
    data = znoderoute_data + baris_data

    if len(data) > 1 and t_type:
        data.append({
            'point_from': point_from,
            'point_to': point_to,
            'transport_type': None
        })

    return data


@api_view(['GET'])
def canonicals(request):
    """
    https://st.yandex-team.ru/RASPFRONT-7974
    Список данных для канонических страниц другими видами транспорта
    /ru/search/canonicals/?pointFrom=c65&pointTo=c197&transportType=bus
    """
    context, errors = CanonicalsQuerySchema().load(request.GET)
    if errors:
        return Response({'result': {}, 'errors': errors}, status=status.HTTP_400_BAD_REQUEST)
    point_from, point_to = None, None

    station_2_settlement = {
        s_2_s.station_id: s_2_s.settlement_id
        for s_2_s in Station2Settlement.objects.filter(
            station__in=[point for point in [point_from, point_to] if isinstance(point, Station)]
        )
    }

    if isinstance(context.point_from, Settlement):
        point_from = context.point_from
    elif context.point_from.settlement is not None:
        point_from = context.point_from.settlement
    elif station_2_settlement.get(context.point_from.id):
        point_from = station_2_settlement[context.point_from.id]

    if isinstance(context.point_to, Settlement):
        point_to = context.point_to
    elif context.point_to.settlement is not None:
        point_to = context.point_to.settlement
    elif station_2_settlement.get(context.point_to.id):
        point_to = station_2_settlement[context.point_to.id]

    if point_from and point_to:
        if point_from == point_to:
            data = []
        else:
            cache_key = settings.CACHEROOT + '/search/canonicals' + point_from.point_key + point_to.point_key + str(context.transport_type)
            cached_response = cache.get(cache_key)

            if cached_response:
                data = cached_response
            else:
                try:
                    data = get_canonicals(point_from, point_to, context.transport_type)
                    cache.set(cache_key, data, settings.CACHES['default']['LONG_TIMEOUT'])
                except Exception as ex:
                    return Response({'result': {}, 'errors': [ex.message]}, status=400)
    else:
        data = []

    data = [
        {
            'canonical': canonical,
            'point_from': canonical['point_from'],
            'point_to': canonical['point_to']
        }
        for canonical in data
    ]
    response_data, errors = CanonicalsResponseSchema().dump({'canonicals': data})

    return Response({'result': response_data, 'errors': errors}, status=status.HTTP_200_OK)
