# -*- encoding: utf-8 -*-
from datetime import datetime
from typing import Optional, Dict

import pytz

from travel.avia.api_gateway.application.cache.cache_root import CacheRoot
from travel.avia.api_gateway.lib.landings.templater import get_point_title_nominative

MSK_TZ = pytz.timezone('Europe/Moscow')


class PersonalSearchMapper:
    def __init__(self, cache_root):
        # type: (CacheRoot) -> None
        self._cache_root = cache_root

    def map(self, geo_id, personal_search_data):
        # type: (Optional, Dict) -> dict
        if not personal_search_data or not personal_search_data.get('result'):
            return {'suggests': []}

        user_tz = self.pytz_by_geo_id(geo_id)
        now_aware = datetime.now(user_tz)
        today = now_aware.strftime('%Y-%m-%d')

        suggests = [self._build_suggest(s, today) for s in personal_search_data['result']]
        return {'suggests': suggests}

    def _build_suggest(self, suggest_data, today):
        start_date = suggest_data['when']
        end_date = suggest_data['returnDate'] or None
        if start_date < today:
            start_date = None
            end_date = None
        return {
            'from': self._map_geo_point(suggest_data['from']),
            'to': self._map_geo_point(suggest_data['to']),
            'startDate': start_date,
            'endDate': end_date,
            'travelers': suggest_data['travelers'],
            'aviaClass': suggest_data['aviaClass'],
        }

    def pytz_by_geo_id(self, geo_id, default=MSK_TZ):
        """
        Возвращает временную зону по указанному geo_id.
        В случае невозможности возвращает default таймзону
        :param geo_id: id геолокации пользователя
        :param default: дефеолтная таймзона для возврата, если не удалось вытянуть из geo_id
        :return:
        """
        if geo_id:
            settlement = self._cache_root.settlement_cache.get_settlement_by_geo_id(geo_id)
            if settlement:
                return self._get_timezone_by_settlement(settlement)
        return default

    def _get_timezone_by_settlement(self, settlement):
        tz = self._cache_root.timezone_cache.get_timezone_by_id(settlement.TimeZoneId)
        if tz:
            return pytz.timezone(tz.Code)
        region = self._cache_root.region_cache.get_region_by_id(settlement.RegionId)
        if region:
            tz = self._cache_root.timezone_cache.get_timezone_by_id(region.TimeZoneId)
            if tz:
                return pytz.timezone(tz.Code)
        return MSK_TZ

    @staticmethod
    def _map_point_key_to_type(point_key):
        type_char = point_key[0]
        if type_char == 'c':
            return PointType.SETTLEMENT
        if type_char == 's':
            return PointType.STATION
        if type_char == 'l':
            return PointType.COUNTRY

    def _map_geo_point(self, geo_point_data):
        point_key = geo_point_data.get('pointKey', geo_point_data.get('pointCode'))
        point_type = self._map_point_key_to_type(point_key)
        point_id = int(point_key[1:])
        geo_point = self._get_geo_point(point_type, point_id)
        if not geo_point:
            return {}

        return {
            'type': point_type,
            'title': get_point_title_nominative(geo_point),
            'pointCode': self._get_point_code(point_type, geo_point),
            'pointKey': point_key,
            'countryTitle': self._get_country_title(point_type, geo_point),
            'cityTitle': self._get_settlement_title(point_type, geo_point),
            'regionTitle': self._get_region_title(point_type, geo_point),
        }

    def _get_geo_point(self, point_type, point_id):
        if point_type == PointType.SETTLEMENT:
            return self._cache_root.settlement_cache.get_settlement_by_id(point_id)
        if point_type == PointType.STATION:
            return self._cache_root.station_cache.get_station_by_id(point_id)
        if point_type == PointType.COUNTRY:
            return self._cache_root.country_cache.get_country_by_id(point_id)
        return None

    def _get_country_title(self, point_type, geo_point):
        if point_type == PointType.COUNTRY:
            return get_point_title_nominative(geo_point)
        try:
            country_id = geo_point.CountryId
            country = self._cache_root.country_cache.get_country_by_id(country_id)
            return get_point_title_nominative(country)
        except:
            return ''

    def _get_point_code(self, point_type, geo_point):
        if point_type == PointType.COUNTRY:
            return geo_point.Code.lower()
        if point_type == PointType.SETTLEMENT:
            return geo_point.Iata or geo_point.SirenaId
        return self._cache_root.station_code_cache.get_station_code_by_id(geo_point.Id)

    def _get_settlement_title(self, point_type, geo_point):
        if point_type == PointType.SETTLEMENT:
            return get_point_title_nominative(geo_point)
        if point_type == PointType.STATION:
            settlement = self._get_settlement_by_station(geo_point)
            if not settlement:
                return ''
            return get_point_title_nominative(settlement)
        return ''

    def _get_settlement_by_station(self, station):
        if station.SettlementId:
            settlement = self._cache_root.settlement_cache.get_settlement_by_id(station.SettlementId)
            if settlement:
                return settlement

        settlement_id = self._cache_root.station_to_settlement_cache.get_settlement_id_by_station_id(station.Id)
        if settlement_id:
            return self._cache_root.settlement_cache.get_settlement_by_id(station.SettlementId)

        return None

    def _get_region_title(self, point_type, geo_point):
        if point_type == PointType.COUNTRY:
            return ''
        region = self._cache_root.region_cache.get_region_by_id(geo_point.RegionId)
        return get_point_title_nominative(region)


class PointType:
    SETTLEMENT = 0
    STATION = 1
    COUNTRY = 2
