import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.functions import col

from itertools import chain

from enum import Enum

from sprav.protos import signal_pb2


HOTEL_RUBRICS = [
    # Rubric-Id   Rubric Permalink    По-русски   Есть на портале
    30785,       # 184106414          Гостиница        +
    30655,       # 184106316          Общежитие        -
    31632,       # 20699506347        Хостел           +
    30788,       # 184106420          Кемпинг          +
    3501492236,  # 150049871970       Апартаменты      -
    31309,       # 255921949          Отдых на ферме   +
    30781,       # 184106404          Санаторий        +
    30779,       # 184106400          Дом Отдыха       +
    30791,       # 184106426          Турбаза          +
    3501708107,  # 197061821387       Жилье посуточно  -
]


class InputTables(Enum):
    ALTAY_PROVIDER = '//home/altay/db/export/current-state/snapshot/provider'
    ALTAY_COMPANY_TO_PROVIDER = '//home/altay/db/export/current-state/snapshot/company_to_provider'
    ALTAY_COMPANY = '//home/altay/db/export/current-state/snapshot/company'
    ALTAY_COMPANY_TO_DUPLICATE = '//home/altay/db/export/current-state/snapshot/company_to_duplicate'
    ALTAY_COMPANY_TO_NAME = '//home/altay/db/export/current-state/snapshot/company_to_name'
    ALTAY_COMPANY_TO_RUBRIC = '//home/altay/db/export/current-state/snapshot/company_to_rubric'
    ALTAY_COMPANY_TO_FEATURE = '//home/altay/db/export/current-state/snapshot/company_to_feature'
    ALTAY_SIGNALS_FOR_MERGE = '//home/altay/db/permalink-clusterization/signals-for-merge'
    GEOBASE_REGIONS = '//home/geotargeting/public/geobase/regions'
    SPRAV_POPULARITY = '//home/sprav/assay/common/Popularity'
    TRAVEL_PARTNERS = '//home/travel/prod/config/partners'


class ResultTables(Enum):
    ALTAY_HOTEL_MAPPINGS = 0
    HOTELS_PERMALINKS_ALL = 1
    HOTELS_PERMALINKS_PUBLISHED = 2
    PERMALINK_TO_CLUSTER_PERMALINK = 3
    PERMALINK_TO_CLUSTER_PERMALINK_FILTERED = 4
    PERMALINK_TO_PARTNERID_ORIGINALID = 5
    PARTNERID_ORIGINALID_TO_CLUSTER_PERMALINK = 6
    PERMALINK_TO_HOTEL_INFO = 7


class AltayMappingBuilder:
    def __init__(self, spark, results_path):
        self.spark = spark
        self.results_path = results_path

    def build(self):
        self._prepare()
        self._altay_hotel_mappings()
        self._hotels_permalinks()
        self._permalink_to_cluster_permalink()
        self._permalink_to_partnerid_originalid()
        self._permalink_to_hotel_info()

    def _get_table_path(self, table):
        if type(table) == ResultTables:
            return {
                ResultTables.ALTAY_HOTEL_MAPPINGS: f'{self.results_path}/altay_hotel_mappings',
                ResultTables.HOTELS_PERMALINKS_ALL: f'{self.results_path}/extras/hotels_permalinks_all',
                ResultTables.HOTELS_PERMALINKS_PUBLISHED: f'{self.results_path}/hotels_permalinks_published',
                ResultTables.PERMALINK_TO_CLUSTER_PERMALINK: f'{self.results_path}/permalink_to_cluster_permalink',
                ResultTables.PERMALINK_TO_CLUSTER_PERMALINK_FILTERED: f'{self.results_path}/extras/permalink_to_cluster_permalink_filtered',
                ResultTables.PERMALINK_TO_PARTNERID_ORIGINALID: f'{self.results_path}/permalink_to_partnerid_originalid',
                ResultTables.PARTNERID_ORIGINALID_TO_CLUSTER_PERMALINK: f'{self.results_path}/partnerid_originalid_to_cluster_permalink',
                ResultTables.PERMALINK_TO_HOTEL_INFO: f'{self.results_path}/permalink_to_hotel_info',
            }[table]
        elif type(table) == InputTables:
            return table.value
        else:
            raise Exception(f'Unknown table identificator: {table}')

    def _prepare(self):
        providers_df = (self.spark.read
                        .yt(self._get_table_path(InputTables.ALTAY_PROVIDER))
                        .select("permalink")
                        .where("permalink LIKE 'ytravel%'")
                        .collect()
                        )

        self.providers = [x.permalink for x in providers_df]

    def _altay_hotel_mappings(self):
        df = (self.spark.read
              .yt(self._get_table_path(InputTables.ALTAY_COMPANY_TO_PROVIDER))
              .select(
                  col('company_permalink').alias('permalink'),
                  'provider_permalink',
                  'original_id'
              )
              .where(col('provider_permalink').isin(self.providers))
              .groupby('permalink')
              .pivot('provider_permalink', self.providers)
              .agg(
                  F.when(F.count('original_id') > 0, F.collect_list('original_id'))
              )
              )

        df.write.mode("overwrite").optimize_for("scan").yt(self._get_table_path(ResultTables.ALTAY_HOTEL_MAPPINGS))

    def _hotels_permalinks(self):
        hotels_permalinks_all = (self.spark.read
                                 .yt(self._get_table_path(InputTables.ALTAY_COMPANY_TO_PROVIDER))
                                 .where(col('provider_permalink').isin(self.providers))
                                 .select(col('company_permalink').alias('permalink'))
                                 .distinct()
                                 .sort('permalink'))

        hotels_permalinks_all.write.mode("overwrite").optimize_for("scan").yt(self._get_table_path(ResultTables.HOTELS_PERMALINKS_ALL))

        hotels_permalinks_published = (hotels_permalinks_all
                                       .join(self.spark.read.yt(self._get_table_path(InputTables.ALTAY_COMPANY)), ['permalink'])
                                       .where('is_exported AND publishing_status == "publish"')
                                       .select('permalink')
                                       .distinct())

        hotels_permalinks_published.write.mode("overwrite").optimize_for("scan").yt(self._get_table_path(ResultTables.HOTELS_PERMALINKS_PUBLISHED))

    def _permalink_to_cluster_permalink(self):
        company_to_duplicate = (
            self.spark.read
            .yt(self._get_table_path(InputTables.ALTAY_COMPANY_TO_DUPLICATE))
            .select(
                col('duplicate_permalink').alias('permalink'),
                col('company_permalink').alias('cluster_permalink'),
            )
            .distinct()
        )

        permalink_to_cluster_permalink = (
            self.spark.read
            .yt(self._get_table_path(ResultTables.HOTELS_PERMALINKS_ALL))
            .join(company_to_duplicate, ['permalink'], 'full')
            .select(
                'permalink',
                F.coalesce(F.col('cluster_permalink'), F.col('permalink')).alias('cluster_permalink'),
            )
            .distinct()
            .sort('cluster_permalink')
        )

        permalink_to_cluster_permalink.write.mode("overwrite").optimize_for("scan").yt(self._get_table_path(ResultTables.PERMALINK_TO_CLUSTER_PERMALINK))

        hotels_permalinks_published = self.spark.read.yt(self._get_table_path(ResultTables.HOTELS_PERMALINKS_PUBLISHED))

        permalink_to_cluster_permalink_filtered = (
            permalink_to_cluster_permalink
            .join(hotels_permalinks_published, permalink_to_cluster_permalink.cluster_permalink == hotels_permalinks_published.permalink)
            .select(
                permalink_to_cluster_permalink['permalink'].alias('permalink'),
                permalink_to_cluster_permalink['cluster_permalink'].alias('cluster_permalink'),
            )
            .sort('cluster_permalink')
        )

        permalink_to_cluster_permalink_filtered.write.mode("overwrite").optimize_for("scan").yt(self._get_table_path(ResultTables.PERMALINK_TO_CLUSTER_PERMALINK_FILTERED))

    def _permalink_to_partnerid_originalid(self):
        permalink_to_partnername_originalid = (
            self.spark.read.yt(self._get_table_path(InputTables.ALTAY_COMPANY_TO_PROVIDER))
                .where('hide == False')
                .select(
                col('company_permalink').alias('permalink'),
                col('provider_permalink').alias('partner_name'),
                col('original_id').alias('originalid'),
            )
            .distinct()
        )

        partners = self.spark.read.yt(self._get_table_path(InputTables.TRAVEL_PARTNERS))

        permalink_to_partnerid_originalid = (
            permalink_to_partnername_originalid
            .join(
                partners,
                permalink_to_partnername_originalid.partner_name == partners.Code,
                'inner'
            )
            .select(
                'permalink',
                col('PartnerIdInt').alias('partnerid'),
                'partner_name',
                'originalid'
            )
            .sort('permalink')
        )

        permalink_to_partnerid_originalid.write.mode("overwrite").optimize_for("scan").yt(self._get_table_path(ResultTables.PERMALINK_TO_PARTNERID_ORIGINALID))

        partnerid_originalid_to_cluster_permalink = (
            permalink_to_partnerid_originalid
            .join(self.spark.read.yt(self._get_table_path(ResultTables.PERMALINK_TO_CLUSTER_PERMALINK)), 'permalink')
            .select('partnerid', 'partner_name', 'originalid', 'cluster_permalink')
            .distinct()
            .sort('permalink')
        )

        partnerid_originalid_to_cluster_permalink.write.mode("overwrite").optimize_for("scan").yt(self._get_table_path(ResultTables.PARTNERID_ORIGINALID_TO_CLUSTER_PERMALINK))

    def _permalink_to_hotel_info(self):
        hotels_cluster_permalinks = (
            self.spark.read.yt(self._get_table_path(ResultTables.HOTELS_PERMALINKS_ALL))
                .join(self.spark.read.yt(self._get_table_path(ResultTables.PERMALINK_TO_CLUSTER_PERMALINK)), ['permalink'])
                .select(col('cluster_permalink').alias('permalink'))
                .distinct()
        )

        company_to_name = (
            self.spark.read
                .schema_hint({"value": T.MapType(T.StringType(), T.StringType())})
                .yt(self._get_table_path(InputTables.ALTAY_COMPANY_TO_NAME))
        )

        hotels_names_prep = (
            hotels_cluster_permalinks
            .join(company_to_name.where('type == "main"'), 'permalink')
            .select(
                col('permalink').alias('cluster_permalink'),
                col('value.value').alias('name'),
                col('value.locale').alias('lang'),
            )
        )

        hotels_names = (
            hotels_names_prep
            .groupby('cluster_permalink')
            .agg(
                F.coalesce(
                    F.first(F.when(F.col('lang') == 'ru', F.col('name')), ignorenulls=True),
                    F.first('name')
                ).alias('name'),
                F.coalesce(
                    F.first(F.when(F.col('lang') == 'en', F.col('name')), ignorenulls=True),
                    F.first('name')
                ).alias('name_en')
            )
        )

        def get_rubrics(company_to_rubric_df, hotels_cluster_permalinks_df, only_main):
            agg_func = F.collect_set
            if only_main:
                company_to_rubric_df = company_to_rubric_df.where(col('is_main'))
                agg_func = F.first
            return (
                hotels_cluster_permalinks_df
                .join(
                    company_to_rubric_df,
                    hotels_cluster_permalinks_df.permalink == company_to_rubric_df.company_permalink
                )
                .groupby('permalink')
                .agg(
                    agg_func('rubric_id').alias('rubric_id' if only_main else 'rubric_ids'),
                    agg_func('rubric_permalink').alias('rubric_permalink' if only_main else 'rubric_permalinks')
                )
                .withColumnRenamed('permalink', 'cluster_permalink')
            )

        company_to_rubric = self.spark.read.yt(self._get_table_path(InputTables.ALTAY_COMPANY_TO_RUBRIC))

        main_hotels_rubrics = get_rubrics(company_to_rubric, hotels_cluster_permalinks, True)
        all_hotels_rubrics = get_rubrics(company_to_rubric, hotels_cluster_permalinks, False)

        company = (
            self.spark.read
            .schema_hint({
                "address": {
                    "geo_id": T.LongType(),
                    "region_code": T.StringType(),
                    "pos": T.StructType().add("coordinates", T.ArrayType(T.DoubleType())),
                    "components": T.ArrayType(
                        T.StructType()
                        .add("kind", T.StringType())
                        .add("name", T.StructType().add("value", T.StringType()).add("locale", T.StringType()))
                    ),
                    "formatted": T.StructType().add("value", T.StringType()),
                },
            })
            .yt(self._get_table_path(InputTables.ALTAY_COMPANY))
        )

        def multiline_to_singleline(s):
            return ''.join([x.strip() for x in s.split('\n') if x.strip() != ''])

        def build_address_component_expr(kind, column_name):
            return multiline_to_singleline(f'''
                coalesce(
                    element_at(filter(raw_address.components, x -> x.kind == "{kind}" AND x.name.locale == "ru"), 1),
                    element_at(filter(raw_address.components, x -> x.kind == "{kind}"), 1)
                ).name.value AS {column_name}
            ''')

        coordinates = (
            hotels_cluster_permalinks
            .join(company, 'permalink')
            .select(
                col('permalink').alias('cluster_permalink'),
                col('address.pos.coordinates').getItem(0).alias('lon'),
                col('address.pos.coordinates').getItem(1).alias('lat'),
                col('address.geo_id').alias('geoid'),
                col('address.region_code').alias('country_code'),
                col('address.formatted.value').alias('original_address'),
                col('address').alias('raw_address')
            )
            .selectExpr(
                '*',
                build_address_component_expr('locality', 'original_city'),
                build_address_component_expr('country', 'original_country'),
                build_address_component_expr('street', 'original_street'),
                build_address_component_expr('house', 'original_house'),
            )
            .drop('raw_address')
        )

        regions = (
            self.spark.read
            .schema_hint({
                "parents_ids": T.ArrayType(T.LongType()),
            })
            .yt(self._get_table_path(InputTables.GEOBASE_REGIONS))
        )

        geo_type_name_by_id = {
            -1: 'hidden',
            0: 'earth',
            1: 'continent',
            2: 'world_part',
            3: 'country',
            4: 'district',
            5: 'region',
            6: 'city',
            7: 'town',
        }

        geo_type_mapping_expr = F.create_map([F.lit(x) for x in chain(*geo_type_name_by_id.items())])

        regions_with_types = (
            regions
            .select(
                'reg_id',
                geo_type_mapping_expr[col('type')].alias('type')
            )
        )

        region_rounds = (
            regions
            .select(
                'reg_id',
                'parents_ids',
                F.sequence(F.lit(0), F.size(col('parents_ids')) - 1).alias('parent_inds')
            )
            .select(
                'reg_id',
                F.explode(F.arrays_zip('parent_inds', 'parents_ids')).alias('parent')
            )
            .select(
                'reg_id',
                col('parent.parent_inds').alias('ind'),
                col('parent.parents_ids').alias('parent'),
            )
            .alias('a')
            .join(
                regions_with_types.alias('b'),
                col('a.parent') == col('b.reg_id')
            )
            .select(
                col('a.reg_id').alias('reg_id'),
                F.array('ind', 'parent').alias('parent'),
                'type',
            )
            .groupby('reg_id')
            .pivot('type', geo_type_name_by_id.values())
            .agg(
                F.array_min(F.collect_list('parent')).getItem(1)
            )
        )

        renamed_region_rounds = (region_rounds.select(
            col('reg_id').alias('geoid'),
            col('town').alias('geosearch_town'),
            col('city').alias('geosearch_city'),
            col('region').alias('geosearch_region'),
            col('district').alias('geosearch_district'),
            col('country').alias('geosearch_country'),
        ))

        hotels_addresses = (
            coordinates
            .join(renamed_region_rounds, 'geoid')
        )

        platinum_permalinks = (
            self.spark.read.yt(self._get_table_path(InputTables.ALTAY_COMPANY_TO_PROVIDER))
                .where(col('provider_permalink') == F.lit('platinum'))
                .select('company_permalink')
                .join(self.spark.read.yt(self._get_table_path(ResultTables.PERMALINK_TO_CLUSTER_PERMALINK)), col('company_permalink') == col('permalink'))
                .select(
                col('cluster_permalink').alias('permalink')
            )
            .distinct()
        )

        platinum_status = (
            hotels_cluster_permalinks
            .join(platinum_permalinks, 'permalink')
            .select(
                col('permalink').alias('cluster_permalink'),
                F.lit(True).alias('is_platinum')
            )
        )

        star_value_mapping = {
            289: "5",
            285: "4",
            668: "3",
            284: "2",
            283: "1",
            3481839596: "unrated",
        }

        company_to_feature = (
            self.spark.read
                .schema_hint({
                    "enum_values": T.ArrayType(T.LongType()),
                })
            .yt(self._get_table_path(InputTables.ALTAY_COMPANY_TO_FEATURE))
        )

        star_value_mapping_expr = F.create_map([F.lit(x) for x in chain(*star_value_mapping.items())])

        hotels_stars = (
            hotels_cluster_permalinks
            .join(
                company_to_feature.where(col('feature_permalink') == F.lit('star')),
                hotels_cluster_permalinks.permalink == company_to_feature.company_permalink
            )
            .select(
                col('permalink').alias('cluster_permalink'),
                star_value_mapping_expr[col('enum_values').getItem(0)].alias('stars')
            )
            .groupby('cluster_permalink')
            .agg(F.first('stars').alias('stars'))
        )

        hotel_type_value_mapping = {
            3501500015: "beach_hotel",
            3501500012: "urban_hotel",
            3501500014: "country_hotel",
            3501500013: "mountain_hotel",
        }

        hotel_type_value_mapping_expr = F.create_map([F.lit(x) for x in chain(*hotel_type_value_mapping.items())])

        hotels_types = (
            hotels_cluster_permalinks
            .join(
                company_to_feature.where(col('feature_permalink') == F.lit('hotel_type_tech')),
                hotels_cluster_permalinks.permalink == company_to_feature.company_permalink
            )
            .select(
                col('permalink').alias('cluster_permalink'),
                F.udf(lambda values: [hotel_type_value_mapping[x] or str(x) for x in values], T.ArrayType(T.StringType()))(col('enum_values')).alias('hotel_types'),
            )
            .groupby('cluster_permalink')
            .agg(F.first('hotel_types').alias('hotel_types'))
        )

        popularity = (
            hotels_cluster_permalinks
            .join(self.spark.read.yt(self._get_table_path(InputTables.SPRAV_POPULARITY)), 'permalink')
            .select(
                col('permalink').alias('cluster_permalink'),
                'popularity'
            )
        )

        permalink_to_partnerid_originalid = self.spark.read.yt(self._get_table_path(ResultTables.PERMALINK_TO_PARTNERID_ORIGINALID))

        permalinks_with_partners = (
            permalink_to_partnerid_originalid
            .join(self.spark.read.yt(self._get_table_path(InputTables.TRAVEL_PARTNERS)), permalink_to_partnerid_originalid.partnerid == col('PartnerIdInt'), 'inner')
            .select(
                'permalink',
                col('Code').alias('partner_code'),
            )
        )

        hotels_providers = (
            hotels_cluster_permalinks
            .join(permalinks_with_partners, 'permalink')
            .groupby('permalink')
            .agg(
                F.collect_list('partner_code').alias('partners'),
                F.count('*').alias('partners_count')
            )
            .withColumnRenamed('permalink', 'cluster_permalink')
        )

        main_urls = (
            self.spark.read
                .schema_hint({
                    "urls": T.ArrayType(T.StructType()
                                        .add("value", T.StringType())
                                        .add("type", T.StringType())),
                })
            .yt(self._get_table_path(InputTables.ALTAY_COMPANY))
            .select(
                    'permalink',
                    F.explode('urls').alias('url_info')
            )
            .where(col('url_info.type') == F.lit("main"))
            .groupby('permalink')
            .agg(F.first('url_info.value').alias('main_url'))
            .withColumnRenamed('permalink', 'cluster_permalink')
        )

        def decode_signal_proto(data):
            message = signal_pb2.Signal()
            message.ParseFromString(data)
            return message

        decode_signal_udf = F.udf(decode_signal_proto, T.StructType()
                                  .add("provider_id", T.LongType())
                                  .add("company", T.StructType()
                                       .add("urls", T.ArrayType(T.StructType()
                                                                .add("value", T.StringType())
                                                                )
                                            )
                                       )
                                  )

        provider_urls_agg = (
            self.spark.read
            .schema_hint({
                "data": T.BinaryType(),
            })
            .yt(self._get_table_path(InputTables.ALTAY_SIGNALS_FOR_MERGE))
            .select(
                'permalink',
                decode_signal_udf('data').alias("data")
            )
            .alias('clusters')
            .join(
                self.spark.read.yt(self._get_table_path(InputTables.ALTAY_PROVIDER)).where("permalink LIKE 'ytravel%'").alias('providers'),
                col('clusters.data.provider_id') == col('providers.id')
            )
            .select(
                col('clusters.permalink').alias('permalink'),
                F.explode('clusters.data.company.urls').alias('url_info'),
                col('providers.permalink').alias('provider'),
            )
            .select(
                'permalink',
                'provider',
                col('url_info.value').alias('url')
            )
            .where('provider IS NOT NULL AND url IS NOT NULL')
            .groupby('permalink')
            .agg(
                F.map_from_entries(F.collect_list(F.struct('provider', 'url'))).alias('other_urls')
            )
            .withColumnRenamed('permalink', 'cluster_permalink')
        )

        hotel_info = (
            hotels_names
            .join(main_hotels_rubrics, 'cluster_permalink', 'full')
            .join(all_hotels_rubrics, 'cluster_permalink', 'full')
            .join(hotels_addresses, 'cluster_permalink', 'full')
            .join(platinum_status, 'cluster_permalink', 'full')
            .join(hotels_stars, 'cluster_permalink', 'full')
            .join(hotels_types, 'cluster_permalink', 'full')
            .join(popularity, 'cluster_permalink', 'full')
            .join(hotels_providers, 'cluster_permalink', 'full')
            .join(main_urls, 'cluster_permalink', 'left')
            .join(provider_urls_agg, 'cluster_permalink', 'left')
        )

        hotel_info = (
            hotel_info
            .selectExpr(
                '*',
                f'array_intersect(rubric_ids, array({", ".join(map(str, HOTEL_RUBRICS))})) AS hotel_rubric_ids'
            )
            .withColumn('hotel_rubric', F.when(col('rubric_id').isin(HOTEL_RUBRICS), col('rubric_id')).otherwise(col('hotel_rubric_ids').getItem(0)))
            .drop('hotel_rubric_ids')
        )

        hotel_info = (
            hotel_info
            .join(self.spark.read.yt(self._get_table_path(ResultTables.PERMALINK_TO_CLUSTER_PERMALINK)), 'cluster_permalink')
        )

        hotel_info = (
            hotel_info
            .join(self.spark.read.yt(self._get_table_path(InputTables.ALTAY_COMPANY)).select('permalink', 'publishing_status', 'is_exported'), 'permalink')
        )

        permalink_to_hotel_info = (
            hotel_info
            .withColumn('is_platinum', F.coalesce(col('is_platinum'), F.lit(False)))
            .sort('permalink')
        )

        permalink_to_hotel_info.write.mode("overwrite").optimize_for("scan").yt(self._get_table_path(ResultTables.PERMALINK_TO_HOTEL_INFO))
