from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Set, Type
import logging
import json

from google.protobuf.json_format import MessageToJson

# noinspection PyUnresolvedReferences
from yt.wrapper import YPath, write_table

from travel.library.python.tools import replace_args_from_env
from travel.hotels.lib.python3.cli.cli import create_progress_bar, auto_progress_reporter
from travel.hotels.lib.python3.yt import ytlib
from travel.hotels.lib.python3.yt.ytlib import ensure_table_exists, schema_from_dict, link
from travel.hotels.lib.python3.yt.versioned_path import VersionedPath
# noinspection PyUnresolvedReferences
from travel.hotels.proto.region_pages.region_pages_pb2 import EPageType
from travel.hotels.tools.region_pages_builder.common.tanker_data import (
    ConfigRegionFilter, RegionNameData, read_templates
)
from travel.hotels.tools.region_pages_builder.renderer.renderer.exceptions import RenderExceptionContainer
from travel.hotels.tools.region_pages_builder.renderer.renderer.templater import (
    Hotel, Station, Region, RegionKey, RegionData, Templater
)
from travel.hotels.tools.region_pages_builder.renderer.renderer.templating_mapper import TemplatingMapper


# touch here 0


FORMAT = '%(asctime)-15s | %(levelname)-4.4s | %(name)-12.12s | %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logging.getLogger('transitions').setLevel(logging.WARNING)


@dataclass
class PageConfig:
    page_name: str
    page_type: EPageType
    mapper: TemplatingMapper
    cross_links: Dict[RegionKey, List[Region]]
    filters: Dict[str, ConfigRegionFilter]


RawRegion = Dict[str, Any]


RegionDataType = Type[RegionData]


class Runner:
    NAME = 'CityPagesRenderer'
    LANG = 'ru'

    def __init__(self):
        parser = ArgumentParser()
        parser.add_argument('--yt-proxy', default='hahn')
        parser.add_argument('--yt-token')
        parser.add_argument('--yt-token-path')
        parser.add_argument('--yt-data-miner-path', default=ytlib.get_default_user_path('region_pages/data_miner/latest'))
        parser.add_argument('--yt-output-path', default=ytlib.get_default_user_path('region_pages/rendered'))
        parser.add_argument('--tanker-url', default='https://tanker-api.yandex-team.ru')
        parser.add_argument('--tanker-project', default='travel-backend')
        parser.add_argument('--config-keyset', default=None)
        parser.add_argument('--config-yaml-file', default=None, type=Path)
        parser.add_argument('--config-yt-table', default=None, type=YPath)
        parser.add_argument('--dictionaries-keyset', default=None)
        parser.add_argument('--dictionaries-yaml-file', default=None, type=Path)
        parser.add_argument('--dictionaries-yt-table', default=None, type=YPath)
        parser.add_argument('--city-page-keyset', default=None)
        parser.add_argument('--city-page-yaml-file', default=None, type=Path)
        parser.add_argument('--city-page-yt-table', default=None, type=YPath)
        parser.add_argument('--state-page-keyset', default=None)
        parser.add_argument('--state-page-yaml-file', default=None, type=Path)
        parser.add_argument('--state-page-yt-table', default=None, type=YPath)
        parser.add_argument('--render-filtered', action='store_true')
        parser.add_argument('--only-geo-id', type=int)  # For debug purposes

        args = parser.parse_args(replace_args_from_env())

        self.args = args
        yt_config = {
            'token': self.args.yt_token,
            'token_path': self.args.yt_token_path,
        }
        self.yt_client = ytlib.create_client(proxy=self.args.yt_proxy, config=yt_config)

        self.args.yt_data_miner_path = ytlib.abspath(self.args.yt_data_miner_path, self.yt_client)
        self.work_dir = self.args.yt_output_path
        self.temp_dir = ytlib.join(self.work_dir, 'temp')
        self.exception_container = RenderExceptionContainer()

    def run(self):
        logging.info("Program started")
        ytlib.ensure_dir(self.yt_client, self.temp_dir)
        with VersionedPath(ytlib.join(self.work_dir), yt_client=self.yt_client) as p:
            latest_data_miner_result = self.args.yt_data_miner_path
            templates_dump_path = ytlib.join(p, 'debug', 'templates')

            raw_regions = self.get_raw_regions(ytlib.join(latest_data_miner_result, 'regions'))
            hotels = self.get_hotels(ytlib.join(latest_data_miner_result, 'hotels'))
            stations = self.get_stations(ytlib.join(latest_data_miner_result, 'stations'))

            rendered = self._get_rendered_pages(
                raw_regions=raw_regions,
                templates_dump_path=templates_dump_path,
                hotels=hotels,
                stations=stations,
            )

            self.exception_container.check_and_raise()

            result_table = ytlib.join(p, 'result')
            ensure_table_exists(result_table, yt_client=self.yt_client, schema=schema_from_dict({
                'geo_id': 'int32',
                'filter_slug': 'string',
                'slug': 'string',
                'page_type': 'string',
                'proto': 'string',
                'json': 'any',
            }))
            write_table(result_table, rendered, client=self.yt_client)
            logging.info(f"Static pages rendered and saved to {link(result_table)}")

    def _read_yt_table(self, table_path, name):
        for row in auto_progress_reporter(self.yt_client.read_table(table_path), name=name,
                                          total=self.yt_client.row_count(table_path)):
            yield row

    def get_raw_regions(self, regions_table: YPath) -> List[RawRegion]:
        regions = list()
        loaded_count = 0
        skipped_count = 0

        for row in self._read_yt_table(regions_table, 'regions'):
            if row['top_permalinks']:
                regions.append(row)
                loaded_count += 1
            else:
                skipped_count += 1
        logging.info(f'Regions loaded {loaded_count}, skipped due to zero hotels {skipped_count}')
        return regions

    def get_regions(
        self,
        raw_regions: List[RawRegion],
        filter_config: Iterable[ConfigRegionFilter],
        region_name_data: RegionNameData,
    ) -> Dict[RegionKey, Region]:
        regions = dict()
        filter_categories = {item.filterSlug: item.category for item in filter_config}
        for raw_region in raw_regions:
            filter_slug = raw_region.get('filter_slug')
            is_filtered = filter_slug is not None
            if is_filtered and not self.args.render_filtered:
                continue
            category = filter_categories.get(filter_slug)
            if category is None and raw_region['hotel_count'] <= 10:
                category = 'not-much-hotels'
            region_name_declension = region_name_data.get_name(raw_region['slug'])
            region = Region(raw_region, category, self.LANG, region_name_declension)
            region_key = RegionKey(region.geo_id, region.filter_slug)
            regions[region_key] = region
        return regions

    @staticmethod
    def get_available_filters(region_keys: List[RegionKey]) -> Dict[int, Set[str]]:
        available_filters = dict()
        for key in region_keys:
            region_available_filters = available_filters.setdefault(key.geo_id, set())
            region_available_filters.add(key.filter_slug)
        return available_filters

    def get_hotels(self, hotels_table: YPath) -> Dict[int, Hotel]:
        hotels = {}
        for row in self._read_yt_table(hotels_table, 'hotels'):
            hotels[row['permalink']] = Hotel(row)
        logging.info(f'Got {len(hotels)} hotels')
        return hotels

    def get_stations(self, stations_table: YPath) -> Dict[int, Station]:
        stations = {}
        for row in self._read_yt_table(stations_table, 'stations'):
            stations[row['id']] = Station(row)
        logging.info(f'Got {len(stations)} stations')
        return stations

    def get_cross_links_dict(
        self,
        cross_links_table: YPath,
        geo_id_to_region: Dict[RegionKey, Region],
    ) -> Dict[RegionKey, List[Region]]:
        result: Dict[RegionKey, List[Region]] = {}
        for row in self._read_yt_table(cross_links_table, 'cross-links'):
            geo_id = row['geo_id']
            filter_slug = row.get('filter_slug')
            links = row['links']

            regions = list(map(lambda c_id: geo_id_to_region[RegionKey(c_id, filter_slug)], links))
            result[RegionKey(geo_id, filter_slug)] = regions

        return result

    def _get_rendered_pages(
        self,
        raw_regions: List[RawRegion],
        templates_dump_path: YPath,
        hotels: Dict[int, Hotel],
        stations: Dict[int, Station],
    ) -> List[Dict[str, Any]]:
        config_map = dict()

        data_type = 'config'
        config = read_templates(
            page_name=data_type,
            yt_client=self.yt_client,
            save_state_table=ytlib.join(templates_dump_path, data_type),
            tanker_url=self.args.tanker_url,
            tanker_project=self.args.tanker_project,
            templates_keyset=self.args.config_keyset,
            templates_yaml_file=self.args.config_yaml_file,
            templates_yt_table=self.args.config_yt_table,
        )
        filters = config.get_config_region_filters()

        data_type = 'dictionaries'
        dictionaries = read_templates(
            page_name=data_type,
            yt_client=self.yt_client,
            save_state_table=ytlib.join(templates_dump_path, data_type),
            tanker_url=self.args.tanker_url,
            tanker_project=self.args.tanker_project,
            templates_keyset=self.args.dictionaries_keyset,
            templates_yaml_file=self.args.dictionaries_yaml_file,
            templates_yt_table=self.args.dictionaries_yt_table,
        )

        region_name_data = RegionNameData()
        if dictionaries:
            region_name_data = dictionaries.get_region_name_data()

        regions = self.get_regions(raw_regions, filters.values(), region_name_data)
        available_filters = self.get_available_filters(regions.keys())

        region_type = 'state'
        state_page_templates = read_templates(
            page_name=region_type,
            yt_client=self.yt_client,
            save_state_table=ytlib.join(templates_dump_path, region_type),
            tanker_url=self.args.tanker_url,
            tanker_project=self.args.tanker_project,
            templates_keyset=self.args.state_page_keyset,
            templates_yaml_file=self.args.state_page_yaml_file,
            templates_yt_table=self.args.state_page_yt_table,
        )

        if state_page_templates:
            links_table_path = ytlib.join(self.args.yt_data_miner_path, 'state_to_city_links')

            page_config = PageConfig(
                page_name=region_type,
                page_type=EPageType.PT_City,
                mapper=TemplatingMapper(Templater(), state_page_templates, config, available_filters),
                cross_links=self.get_cross_links_dict(links_table_path, regions),
                filters=filters,
            )
            config_map[region_type] = page_config

        region_type = 'city'
        city_page_templates = read_templates(
            page_name=region_type,
            yt_client=self.yt_client,
            save_state_table=ytlib.join(templates_dump_path, region_type),
            tanker_url=self.args.tanker_url,
            tanker_project=self.args.tanker_project,
            templates_keyset=self.args.city_page_keyset,
            templates_yaml_file=self.args.city_page_yaml_file,
            templates_yt_table=self.args.city_page_yt_table,
        )

        if city_page_templates:
            cross_links_table = str(ytlib.join(self.args.yt_data_miner_path, 'city_cross_links'))
            page_config = PageConfig(
                page_name=region_type,
                page_type=EPageType.PT_City,
                mapper=TemplatingMapper(Templater(), city_page_templates, config, available_filters),
                cross_links=self.get_cross_links_dict(cross_links_table, regions),
                filters=filters,
            )
            config_map[region_type] = page_config
            config_map["other"] = page_config  # use city page config for "other" region type

        return (
            self._get_rendered_pages_by_config(
                regions=regions,
                hotels=hotels,
                stations=stations,
                config_map=config_map,
            )
        )

    def _get_rendered_pages_by_config(
        self,
        regions: Dict[RegionKey, Region],
        hotels: Dict[int, Hotel],
        stations: Dict[int, Station],
        config_map: Dict[str, PageConfig],

    ) -> List[Dict[str, Any]]:
        rendered = list()

        # TODO maybe turn it into YT map?
        with create_progress_bar('regions processed', len(regions)) as bar:
            processed_count = 0
            for region_key, region in regions.items():
                if self.args.only_geo_id and region_key.geo_id != self.args.only_geo_id:
                    continue
                page_config = config_map.get(region.region_type)
                if page_config is None:
                    continue
                try:
                    filter_config = page_config.filters.get(region.filter_slug)
                    cross_links = page_config.cross_links.get(region_key, list())
                    if not cross_links:
                        logging.warning(f'No cross links for {region_key}')
                    region_data = RegionData(region, hotels, stations, cross_links, filter_config)
                    page = page_config.mapper.render_page(region_data, page_config.page_type)
                    rendered.append({
                        'geo_id': region_key.geo_id,
                        'filter_slug': region_key.filter_slug,
                        'slug': region.slug,  # For sitemap_builder, and for convenience
                        'page_type': region.region_type,
                        'proto': page.SerializeToString(),
                        'json': json.loads(MessageToJson(page)),
                    })
                except Exception as e:
                    message = f'{region.nominative}({region_key=})'
                    self.exception_container.add_exception(e, exception_metadata=message)

                processed_count += 1
                bar.update(processed_count)

        return rendered
