from dataclasses import dataclass
from typing import List, Optional

from sandbox.projects.Travel.resources import dicts
from sandbox.common.rest import Client as SandboxClient
from sandbox.common.proxy import OAuth

from travel.avia.country_restrictions.lib.types import PointType
from travel.avia.country_restrictions.lib.geo_format_manager import repositories
from travel.avia.library.python.sandbox.resource_fetcher import ResourceFetcher
from travel.library.python.dicts.base_repository import BaseRepository
from travel.library.python.dicts.settlement_repository import SettlementRepository as PKSettlementRepo
from travel.library.python.dicts.region_repository import RegionRepository as PKRegionRepo
from travel.library.python.dicts.country_repository import CountryRepository as PKCountryRepo


@dataclass
class PointTypeRepositories:
    point_key_repo: BaseRepository
    geo_id_repo: BaseRepository


class GeoFormatManager:
    SETTLEMENT_PREFIX = 'c'
    REGION_PREFIX = 'r'
    COUNTRY_PREFIX = 'l'

    def __init__(self, token: str):
        self.geo_id_settlement_repository = repositories.GeoIdSettlementRepository()
        self.point_key_settlement_repository = PKSettlementRepo()

        self.geo_id_region_repository = repositories.GeoIdRegionRepository()
        self.point_key_region_repository = PKRegionRepo()

        self.geo_id_country_repository = repositories.GeoIdCountryRepository()
        self.point_key_country_repository = PKCountryRepo()

        if token is None:
            raise Exception('No sandbox token provided for GeoFormatManager')
        else:
            sandbox_oauth = OAuth(token)
            sandbox_client = SandboxClient(auth=sandbox_oauth)
            resource_fetcher = ResourceFetcher(sandbox_client, sandbox_oauth)

            settlement_resource_data = resource_fetcher.fetch_latest_ready(dicts.TRAVEL_DICT_RASP_SETTLEMENT_PROD)
            self.geo_id_settlement_repository.load_from_string(settlement_resource_data)
            self.point_key_settlement_repository.load_from_string(settlement_resource_data)

            region_resource_data = resource_fetcher.fetch_latest_ready(dicts.TRAVEL_DICT_RASP_REGION_PROD)
            self.geo_id_region_repository.load_from_string(region_resource_data)
            self.point_key_region_repository.load_from_string(region_resource_data)

            country_resource_data = resource_fetcher.fetch_latest_ready(dicts.TRAVEL_DICT_RASP_COUNTRY_PROD)
            self.geo_id_country_repository.load_from_string(country_resource_data)
            self.point_key_country_repository.load_from_string(country_resource_data)

        self.routes = {
            self.SETTLEMENT_PREFIX: PointTypeRepositories(
                point_key_repo=self.point_key_settlement_repository,
                geo_id_repo=self.geo_id_settlement_repository,
            ),
            self.REGION_PREFIX: PointTypeRepositories(
                point_key_repo=self.point_key_region_repository,
                geo_id_repo=self.geo_id_region_repository,
            ),
            self.COUNTRY_PREFIX: PointTypeRepositories(
                point_key_repo=self.point_key_country_repository,
                geo_id_repo=self.geo_id_country_repository,
            ),
        }

    def get_point_key_by_geo_id(self, geo_id: int) -> Optional[str]:
        for prefix, converters in self.routes.items():
            obj = converters.geo_id_repo.get(geo_id)
            if obj is not None:
                return prefix + str(obj.Id)

        return None

    def get_geo_id_by_point_key(self, point_key: str) -> Optional[int]:
        point_type = point_key[0]
        obj_id = int(point_key[1:])
        if point_type in self.routes:
            obj = self.routes[point_type].point_key_repo.get(obj_id)
            return obj.GeoId if obj is not None else None
        else:
            return None

    def get_point_key_parents(self, point_key: str) -> List[str]:
        point_type = PointType.from_str(point_key[0])
        obj_id = int(point_key[1:])

        if point_type == PointType.SETTLEMENT:
            obj = self.point_key_settlement_repository.get(obj_id)
            # some objects can not be in avia database, but we can try to serve them
            if obj is None:
                return []
            else:
                return [f'r{obj.RegionId}', f'l{obj.CountryId}']
        elif point_type == PointType.REGION:
            obj = self.point_key_region_repository.get(obj_id)
            # some objects can not be in avia database, but we can try to serve them
            if obj is None:
                return []
            else:
                return [f'l{obj.CountryId}']
        elif point_type == PointType.COUNTRY:
            return []

        return []

    def has_point_key_in_hierarchy(self, point_key: str, search_point_key: str) -> bool:
        if point_key == search_point_key:
            return True

        parents = self.get_point_key_parents(point_key)
        for parent in parents:
            if parent == search_point_key:
                return True

        return False

    def get_dictionary_obj_by_point_key(self, point_key: str):
        point_type = PointType.from_str(point_key[0])
        obj_id = int(point_key[1:])

        if point_type == PointType.SETTLEMENT:
            return self.point_key_settlement_repository.get(obj_id)
        elif point_type == PointType.REGION:
            return self.point_key_region_repository.get(obj_id)
        elif point_type == PointType.COUNTRY:
            return self.point_key_country_repository.get(obj_id)
        else:
            return None

    def get_point_ru_name_by_point_key(self, point_key: str) -> Optional[str]:
        obj = self.get_dictionary_obj_by_point_key(point_key)
        if obj is not None:
            return obj.TitleDefault
        else:
            return None
