# -*- coding: utf-8 -*-
from __future__ import absolute_import

from geoindex import GeoGridIndex, GeoPoint

from travel.avia.library.python.common.models.geo import (
    Settlement, Station, StationType, Station2Settlement, CityMajority
)


class SubstitutionCache(object):
    def __init__(self):
        self._index = GeoGridIndex(precision=2)
        self._region_centers = {}
        self._capitals = {}
        self._avia_settlement_ids = set()

    def precache(self):
        airports = tuple(Station.objects.filter(
            station_type_id=StationType.AIRPORT_ID,
            hidden=False,
        ))

        avia_settlements = tuple(
            Settlement.objects.filter(
                id__in=[
                    airport.settlement_id for airport in airports
                ] + [
                    relation.settlement_id
                    for relation in Station2Settlement.objects.filter(
                        station__in=airports
                    )
                ],
                hidden=False,
            ).order_by(
                'id'
            )
        )

        self._index = GeoGridIndex(precision=2)  # precision for 200km queries
        for s in avia_settlements:
            if s.latitude and s.longitude:
                self._index.add_point(
                    GeoPoint(latitude=s.latitude, longitude=s.longitude, ref=s)
                )

        self._region_centers = {
            s.region_id: s
            for s in avia_settlements
            if s.region_id and s.majority_id == CityMajority.REGION_CAPITAL_ID
        }

        self._capitals = {
            s.country_id: s
            for s in avia_settlements
            if s.country_id and s.majority_id == CityMajority.CAPITAL_ID
        }

        self._avia_settlement_ids = {s.id for s in avia_settlements}

    def _nearest(self, latitude, longitude):
        if latitude and longitude:
            points_with_dists = tuple(self._index.get_nearest_points(
                GeoPoint(latitude=latitude, longitude=longitude),
                radius=200
            ))
            if points_with_dists:
                return min(points_with_dists, key=lambda pair: pair[1])[0].ref

        return None

    def avia_settlement(self, settlement):
        if settlement.id in self._avia_settlement_ids:
            return settlement
        return (
            self._nearest(settlement.latitude, settlement.longitude) or
            self._region_centers.get(settlement.region_id) or
            self._capitals.get(settlement.country_id)
        )


substitution_cache = SubstitutionCache()
