from yt.wrapper import ypath_join, TablePath

from datacloud.features.locations.constants import DAYS_TO_TAKE, LAT_LON_PRECISION, HASH_LAT_PRECISION, \
    HASH_LON_PRECISION, PERCENTILE_MAX, PERCENTILE_MIN, COUNTRY_KMEANS_TOP, COUNTRY_KMEANS_CLUSTERS, COUNTRY_MLB_TOP, \
    REGION_KMEANS_TOP, REGION_KMEANS_CLUSTERS, REGION_MLB_TOP, CITY_MIN_COUNT, NATIVE_COUNTRY_CODE


class LocationsBuildConfig():
    def __init__(
        self,
        root,
        tag='LOCATIONS',
        is_retro=True,
        snapshot_date=None,

        days_to_take=None,
        lat_lon_precision=None,
        hash_lat_precision=None,
        hash_lon_precision=None,
        percentile_max=None,
        percentile_min=None,
        bandits_table=None,
        homework_table=None,

        country_kmeans_top=None,
        country_kmeans_clusters=None,
        country_mlb_top=None,
        region_kmeans_top=None,
        region_kmeans_clusters=None,
        region_mlb_top=None,
        city_min_count=None,
        native_country_code=None,
        use_pretrain_transformer=True,
        custom_transformer=None,

        input_yuid_table=None,
        input_table=None,
        locations_round_table=None,
        locations_stat_table=None,
        locations_stat_table_map=None,
        locations_stat_table_cat=None,
        locations_bandits_table=None,
        locations_homework_table=None,
        locations_out_merged=None,
        out_table=None,
    ):
        bandits_table_original = r'//home/x-products/production/datacloud/static-data/geo_bandits/{lat}_{lon}_latest'
        custom_transformer_original = r'//home/x-products/production/datacloud/static-data/locations_transformer/latest'

        self.root = root
        self.tag = tag
        self.days_to_take = days_to_take or DAYS_TO_TAKE
        self.lat_lon_precision = lat_lon_precision or LAT_LON_PRECISION
        self.hash_lat_precision = hash_lat_precision or HASH_LAT_PRECISION
        self.hash_lon_precision = hash_lon_precision or HASH_LON_PRECISION
        self.percentile_max = percentile_max or PERCENTILE_MAX
        self.percentile_min = percentile_min or PERCENTILE_MIN
        self.bandits_table = bandits_table or bandits_table_original.format(lat=self.hash_lat_precision,
                                                                            lon=self.hash_lon_precision)
        self.homework_table = homework_table or r'//home/user_identification/homework/prod/homework_yuid'

        self.country_kmeans_top = country_kmeans_top or COUNTRY_KMEANS_TOP
        self.country_kmeans_clusters = country_kmeans_clusters or COUNTRY_KMEANS_CLUSTERS
        self.country_mlb_top = country_mlb_top or COUNTRY_MLB_TOP
        self.region_kmeans_top = region_kmeans_top or REGION_KMEANS_TOP
        self.region_kmeans_clusters = region_kmeans_clusters or REGION_KMEANS_CLUSTERS
        self.region_mlb_top = region_mlb_top or REGION_MLB_TOP
        self.city_min_count = city_min_count or CITY_MIN_COUNT
        self.native_country_code = native_country_code or NATIVE_COUNTRY_CODE
        self.use_pretrain_transformer = use_pretrain_transformer
        self.custom_transformer = custom_transformer or custom_transformer_original

        if is_retro:
            self.ext_id_key = 'external_id'
            self.features_dir = ypath_join('datacloud/aggregates/locations/weekly')
        else:
            self.ext_id_key = 'cid'
            self.features_dir = ypath_join('datacloud/aggregates/locations/weekly', snapshot_date)

        self.input_yuid_table = input_yuid_table or 'input_yuid'
        self.input_table = input_table or 'datacloud/grep/geo/geo'
        self.locations_round_table = locations_round_table or 'datacloud/grep/locations/round'
        self.locations_stat_table = locations_stat_table or ypath_join(self.features_dir, 'geo_stat')
        self.locations_stat_table_map = locations_stat_table_map or ypath_join(self.features_dir, 'geo_stat_map')
        self.locations_stat_table_cat = locations_stat_table_cat or ypath_join(self.features_dir, 'geo_stat_cat')
        self.locations_bandits_table = locations_bandits_table or ypath_join(self.features_dir, 'bandits_stat')
        self.locations_homework_table = locations_homework_table or ypath_join(self.features_dir, 'homework_stat')
        self.locations_out_merged = locations_out_merged or ypath_join(self.features_dir, 'out_merged')
        self.out_table = out_table or ypath_join(self.features_dir, 'features')

        self.input_yuid_table = TablePath(ypath_join(root, self.input_yuid_table))
        self.input_table = TablePath(ypath_join(root, self.input_table))
        self.locations_round_table = TablePath(ypath_join(root, self.locations_round_table))
        self.locations_stat_table = TablePath(ypath_join(root, self.locations_stat_table))
        self.locations_stat_table_map = TablePath(ypath_join(root, self.locations_stat_table_map))
        self.locations_stat_table_cat = TablePath(ypath_join(root, self.locations_stat_table_cat), schema=[
            {'name': 'external_id', 'type': 'string'},
            {'name': 'mode_country_id', 'type': 'int64'},
            {'name': 'mode_region_id', 'type': 'int64'},
            {'name': 'mode_city_id', 'type': 'int64'},
            {'name': 'mode_city_type', 'type': 'int64'},
        ])
        self.locations_bandits_table = TablePath(ypath_join(root, self.locations_bandits_table))
        self.locations_homework_table = TablePath(ypath_join(root, self.locations_homework_table))
        self.locations_out_merged = TablePath(ypath_join(root, self.locations_out_merged))
        self.out_table = TablePath(ypath_join(root, self.out_table))
