import pickle
from enum import Enum, unique
from textwrap import dedent

import numpy as np
import pandas as pd

from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.yql.yql_helpers import execute_yql
from datacloud.features.locations.helpers import df_to_yt, FeaturesMapper
from datacloud.ml_utils.benchmark_v2.transformers import LocationsTransformer

logger = get_basic_logger(name=__name__, format='%(asctime)s %(message)s')


@unique
class LocationsFeaturesBuildSteps(Enum):
    step_1_gather_round_logs = 'step_1_gather_round_logs'
    step_2_calculate_user_statistics = 'step_2_calculate_user_statistics'
    step_3_calculate_bandits_statistics = 'step_3_calculate_bandits_statistics'
    step_4_calculate_homework_statistics = 'step_4_calculate_homework_statistics'
    step_5_merge_tables = 'step_5_merge_tables'


LOCATIONS_DEFAULT_FEATURES_BUILD_STEPS = (
    LocationsFeaturesBuildSteps.step_1_gather_round_logs,
    LocationsFeaturesBuildSteps.step_2_calculate_user_statistics,
    LocationsFeaturesBuildSteps.step_3_calculate_bandits_statistics,
    LocationsFeaturesBuildSteps.step_4_calculate_homework_statistics,
    LocationsFeaturesBuildSteps.step_5_merge_tables,
)


def step_1_gather_round_logs(yt_client, yql_client, build_config):
    yql_query = dedent('''
    $days_to_take = %(days_to_take)i;
    $precision = %(lat_lon_precision)i;
    $locations = '%(table_in)s';
    $out = '%(table_out)s';

    $look_back_seconds = $days_to_take * 24 * 60 * 60;

    INSERT INTO $out WITH TRUNCATE
    SELECT
        %(ext_id_key)s,
        Math::Round(lat, $precision) as lat,
        Math::Round(lon, $precision) as lon,
        -timestamp_of_log as log_ts,
    FROM $locations

    WHERE -timestamp_of_log >= original_timestamp - $look_back_seconds
    ORDER BY %(ext_id_key)s;
    ''')

    params = {
        'ext_id_key': build_config.ext_id_key,
        'days_to_take': build_config.days_to_take,
        'lat_lon_precision': build_config.lat_lon_precision,
        'table_in': build_config.input_table,
        'table_out': build_config.locations_round_table,
    }
    execute_yql(query=yql_query, yql_client=yql_client, params=params, set_owners=False, syntax_version=1)


def step_2_calculate_user_statistics(yt_client, yql_client, build_config):
    yql_query = dedent('''
    $src_locations = '%(table_in)s';
    $out = '%(table_out)s';

    $locations_with_location =
        SELECT
            %(ext_id_key)s,
            lat,
            lon,
            log_ts,
            country.id as country_id,
            region.id as region_id,
            city.id as city_id,
            city.type as city_type
        FROM(
            SELECT
                %(ext_id_key)s,
                lat,
                lon,
                log_ts,
                Geo::RoundRegionByLocation(lat, lon, "country") as country,
                Geo::RoundRegionByLocation(lat, lon, "region") as region,
                Geo::RoundRegionByLocation(lat, lon, "city") as city,
            FROM $src_locations
        );

    INSERT INTO $out WITH TRUNCATE
    SELECT
        %(ext_id_key)s,
        COUNT(*) as cnt,

        COUNT(DISTINCT country_id) as cnt_country_id,
        MODE(country_id)[0].Value as mode_country_id,
        MODE(country_id)[0].Frequency as mode_country_cnt,
        AGGREGATE_LIST_DISTINCT(country_id) as country_ids,

        COUNT(DISTINCT region_id) as cnt_region_id,
        MODE(region_id)[0].Value as mode_region_id,
        MODE(region_id)[0].Frequency as mode_region_cnt,
        AGGREGATE_LIST_DISTINCT(region_id) as region_ids,

        COUNT(DISTINCT city_id) as cnt_city_id,
        MODE(city_id)[0].Value as mode_city_id,
        MODE(city_id)[0].Frequency as mode_city_cnt,
        MODE(city_type)[0].Value as mode_city_type,

        COUNT(DISTINCT log_ts) as log_ts,
    FROM $locations_with_location
    GROUP BY %(ext_id_key)s
    ORDER BY %(ext_id_key)s;
    ''')
    cat_features = ['mode_country_id', 'mode_region_id', 'mode_city_id', 'mode_city_type']

    params = {
        'ext_id_key': build_config.ext_id_key,
        'table_in': build_config.locations_round_table,
        'table_out': build_config.locations_stat_table,
    }
    with yt_client.Transaction() as transaction:
        execute_yql(query=yql_query, yql_client=yql_client,
                    params=params, set_owners=False, syntax_version=1,
                    transaction_id=transaction.transaction_id)

        df = pd.DataFrame.from_records(yt_client.read_table(build_config.locations_stat_table))
        if build_config.use_pretrain_transformer:
            stream = yt_client.read_file(build_config.custom_transformer)
            transformer = pickle.load(stream)
        else:
            transformer = LocationsTransformer(
                country_kmeans_top=build_config.country_kmeans_top,
                country_kmeans_clusters=build_config.country_kmeans_clusters,
                country_mlb_top=build_config.country_mlb_top,
                region_kmeans_top=build_config.region_kmeans_top,
                region_kmeans_clusters=build_config.region_kmeans_clusters,
                region_mlb_top=build_config.region_mlb_top,
                city_min_count=build_config.city_min_count,
                native_country_code=build_config.native_country_code,
            )
            transformer.fit(df)
        transformed = transformer.transform(df)
        yt_client.write_table(build_config.locations_stat_table_cat,
                              df_to_yt(transformed[[build_config.ext_id_key] + cat_features]))

        transformed.drop(columns=cat_features, inplace=True)
        schema = []
        for column in transformed.columns:
            if transformed[column].dtype == np.float:
                data_type = 'double'
            elif transformed[column].dtype == np.int:
                data_type = 'int64'
            elif transformed[column].dtype == object:
                data_type = 'string'
            else:
                data_type = 'any'
            schema.append({'name': column, 'type': data_type})
        yt_client.create('table', build_config.locations_stat_table_map, attributes={'schema': schema})
        yt_client.write_table(build_config.locations_stat_table_map, df_to_yt(transformed))


def step_3_calculate_bandits_statistics(yt_client, yql_client, build_config):
    yql_query = dedent('''
    $src_logs = '%(table_in)s';
    $src_bandits = '%(table_bandits)s';
    $out = '%(table_out)s';

    $geo_lat_precision = %(hash_lat_precision)s;
    $geo_lon_precision = %(hash_lon_precision)s;
    $perc_max = %(percentile_max)s;
    $perc_min = %(percentile_min)s;

    $logs_single = (
        SELECT
            %(ext_id_key)s,
            Cast(Math::Round(lat / $geo_lat_precision) as String) || '_' ||
            Cast(Math::Round(lon / $geo_lon_precision) as String) as geohash_cluster,
        FROM $src_logs as src
    );

    INSERT INTO $out WITH TRUNCATE
    SELECT
        logs.%(ext_id_key)s as %(ext_id_key)s,
        COUNT(*) as cnt,

        AVG(cnt) as neighbor_cnt_mean,
        PERCENTILE(cnt, $perc_min) as neighbor_cnt_min,
        PERCENTILE(cnt, $perc_max) as neighbor_cnt_max,

        AVG(age) as neighbor_age_mean,
        PERCENTILE(age, $perc_min) as neighbor_age_min,
        PERCENTILE(age, $perc_max) as neighbor_age_max,

        AVG(income) as neighbor_income_mean,
        PERCENTILE(income, $perc_min) as neighbor_income_min,
        PERCENTILE(income, $perc_max) as neighbor_income_max,

        AVG(score) as neighbor_score_mean,
        PERCENTILE(score, $perc_min) as neighbor_score_min,
        PERCENTILE(score, $perc_max) as neighbor_score_max,

        AVG(gender_m) as neighbor_gender_m_mean,
        PERCENTILE(gender_m, $perc_min) as neighbor_gender_m_min,
        PERCENTILE(gender_m, $perc_max) as neighbor_gender_m_max,

        AVG(gender_anomaly) as neighbor_gender_anomaly_mean,
        PERCENTILE(gender_anomaly, $perc_min) as neighbor_gender_anomaly_min,
        PERCENTILE(gender_anomaly, $perc_max) as neighbor_gender_anomaly_max,

        AVG(gender_anomaly_log) as neighbor_gender_anomaly_log_mean,
        PERCENTILE(gender_anomaly_log, $perc_min) as neighbor_gender_anomaly_log_min,
        PERCENTILE(gender_anomaly_log, $perc_max) as neighbor_gender_anomaly_log_max,

        AVG(gender_anomaly_moivre_laplace) as neighbor_gender_anomaly_moivre_laplace_mean,
        PERCENTILE(gender_anomaly_moivre_laplace, $perc_min) as neighbor_gender_anomaly_moivre_laplace_min,
        PERCENTILE(gender_anomaly_moivre_laplace, $perc_max) as neighbor_gender_anomaly_moivre_laplace_max,
    FROM $logs_single as logs
    JOIN $src_bandits as bandits
        USING(geohash_cluster)
    GROUP BY logs.%(ext_id_key)s
    ORDER BY %(ext_id_key)s
    ''')

    params = {
        'ext_id_key': build_config.ext_id_key,
        'hash_lat_precision': build_config.hash_lat_precision,
        'hash_lon_precision': build_config.hash_lon_precision,
        'percentile_max': build_config.percentile_max,
        'percentile_min': build_config.percentile_min,
        'table_in': build_config.locations_round_table,
        'table_bandits': build_config.bandits_table,
        'table_out': build_config.locations_bandits_table,
    }
    execute_yql(query=yql_query, yql_client=yql_client, params=params, set_owners=False, syntax_version=1)


def step_4_calculate_homework_statistics(yt_client, yql_client, build_config):
    yql_query = dedent('''
    PRAGMA yt.InferSchema = '1';
    $src_cse_yuid = '%(table_in)s';
    $src_homework = '%(table_homework)s';
    $src_bandits = '%(table_bandits)s';
    $out = '%(table_out)s';

    $precision = %(lat_lon_precision)i;
    $geo_lat_precision = %(hash_lat_precision)s;
    $geo_lon_precision = %(hash_lon_precision)s;
    $perc_max = %(percentile_max)s;
    $perc_min = %(percentile_min)s;

    $src_yuid_home =
        SELECT
            yuid,
            IF(predicted_home is NOT NULL,
                IF(predicted_home["latitude"] is NOT NULL,
                    AsList(predicted_home["latitude"], predicted_home["longitude"]),
                    NULL
                ),
            NULL) as home,
            IF(predicted_work is NOT NULL,
                IF(predicted_work["latitude"] is NOT NULL,
                    AsList(predicted_work["latitude"], predicted_work["longitude"]),
                    NULL
                ),
            NULL) as work,
        FROM (
            SELECT
            Yson::ConvertToString(Yson::Parse(yandexuid)) as yuid,
            Yson::ConvertToDoubleDict(Yson::Parse(WeakField(predicted_home, Yson))) as predicted_home,
            Yson::ConvertToDoubleDict(Yson::Parse(WeakField(predicted_work, Yson))) as predicted_work,
            FROM $src_homework
        );

    $middle = SELECT
        cse.%(ext_id_key)s as %(ext_id_key)s,
        AGGREGATE_LIST(home) as home,
        AGGREGATE_LIST(work) as work,
    FROM $src_cse_yuid as cse
        JOIN $src_yuid_home as hw
        ON cse.yuid = hw.yuid
    GROUP BY cse.%(ext_id_key)s;

    $middle_home = SELECT
            %(ext_id_key)s,
            Math::Round(Cast(home[0] as Double), $precision) as lat,
            Math::Round(Cast(home[1] as Double), $precision) as lon,
        FROM (SELECT %(ext_id_key)s, home FROM $middle)
        FLATTEN LIST BY home;

    $middle_home_top = SELECT
            %(ext_id_key)s,
            Cast(TOP_BY(lat, cnt, 1)[0] as Double) as lat,
            Cast(TOP_BY(lon, cnt, 1)[0] as Double) as lon,
        FROM (
            SELECT
                %(ext_id_key)s, lat, lon,
                COUNT(*) as cnt,
            FROM $middle_home
                GROUP BY %(ext_id_key)s, lat, lon
            )
        GROUP BY %(ext_id_key)s;

    $middle_work = SELECT
            %(ext_id_key)s,
            Math::Round(Cast(work[0] as Double), $precision) as lat,
            Math::Round(Cast(work[1] as Double), $precision) as lon,
        FROM (SELECT %(ext_id_key)s, work FROM $middle)
        FLATTEN LIST BY work;

    $middle_work_top = SELECT
            %(ext_id_key)s,
            Cast(TOP_BY(lat, cnt, 1)[0] as Double) as lat,
            Cast(TOP_BY(lon, cnt, 1)[0] as Double) as lon,
        FROM (
            SELECT
                %(ext_id_key)s, lat, lon,
                COUNT(*) as cnt,
            FROM $middle_work
                GROUP BY %(ext_id_key)s, lat, lon
            )
        GROUP BY %(ext_id_key)s;

    $distance =
        SELECT
            h.%(ext_id_key)s ?? w.%(ext_id_key)s as  %(ext_id_key)s,
            Geo::CalculatePointsDifference(h.lat, h.lon, w.lat, w.lon) as distance,
        FROM $middle_home_top as h
        INNER JOIN $middle_work_top as w
            ON h.%(ext_id_key)s = w.%(ext_id_key)s;

    $hw_single = (
        SELECT
            %(ext_id_key)s,
            Cast(Math::Round(lat / $geo_lat_precision) as String) || '_' ||  Cast(Math::Round(lon / $geo_lon_precision) as String) as geohash_cluster,
        FROM (
            SELECT %(ext_id_key)s, lat, lon
                FROM $middle_home
            UNION ALL
            SELECT %(ext_id_key)s, lat, lon
                FROM $middle_work
        ) as src
    );

    $hw_stat = SELECT
            hw.%(ext_id_key)s as %(ext_id_key)s,
            COUNT(*) as hw_cnt,

            AVG(cnt) as hw_cnt_mean,
            PERCENTILE(cnt, $perc_min) as hw_cnt_min,
            PERCENTILE(cnt, $perc_max) as hw_cnt_max,

            AVG(age) as hw_age_mean,
            PERCENTILE(age, $perc_min) as hw_age_min,
            PERCENTILE(age, $perc_max) as hw_age_max,

            AVG(income) as hw_income_mean,
            PERCENTILE(income, $perc_min) as hw_income_min,
            PERCENTILE(income, $perc_max) as hw_income_max,

            AVG(score) as hw_score_mean,
            PERCENTILE(score, $perc_min) as hw_score_min,
            PERCENTILE(score, $perc_max) as hw_score_max,

            AVG(gender_m) as hw_gender_m_mean,
            PERCENTILE(gender_m, $perc_min) as hw_gender_m_min,
            PERCENTILE(gender_m, $perc_max) as hw_gender_m_max,

            AVG(gender_anomaly) as hw_gender_anomaly_mean,
            PERCENTILE(gender_anomaly, $perc_min) as hw_gender_anomaly_min,
            PERCENTILE(gender_anomaly, $perc_max) as hw_gender_anomaly_max,

            AVG(gender_anomaly_log) as hw_gender_anomaly_log_mean,
            PERCENTILE(gender_anomaly_log, $perc_min) as hw_gender_anomaly_log_min,
            PERCENTILE(gender_anomaly_log, $perc_max) as hw_gender_anomaly_log_max,

            AVG(gender_anomaly_moivre_laplace) as hw_gender_anomaly_moivre_laplace_mean,
            PERCENTILE(gender_anomaly_moivre_laplace, $perc_min) as hw_gender_anomaly_moivre_laplace_min,
            PERCENTILE(gender_anomaly_moivre_laplace, $perc_max) as hw_gender_anomaly_moivre_laplace_max,
        FROM $hw_single as hw
        JOIN $src_bandits as bandits
            USING(geohash_cluster)
        GROUP BY hw.%(ext_id_key)s;

    INSERT INTO $out WITH TRUNCATE
    SELECT
        IF(distance.%(ext_id_key)s IS NOT NULL, distance.%(ext_id_key)s, hw.%(ext_id_key)s) as %(ext_id_key)s,
        distance.distance as distance,
        hw_cnt,

        hw_cnt_mean,
        hw_cnt_min,
        hw_cnt_max,
        hw_age_mean,
        hw_age_min,
        hw_age_max,
        hw_income_mean,
        hw_income_min,
        hw_income_max,
        hw_score_mean,
        hw_score_min,
        hw_score_max,
        hw_gender_m_mean,
        hw_gender_m_min,
        hw_gender_m_max,
        hw_gender_anomaly_mean,
        hw_gender_anomaly_min,
        hw_gender_anomaly_max,
        hw_gender_anomaly_log_mean,
        hw_gender_anomaly_log_min,
        hw_gender_anomaly_log_max,
        hw_gender_anomaly_moivre_laplace_mean,
        hw_gender_anomaly_moivre_laplace_min,
        hw_gender_anomaly_moivre_laplace_max,

    FROM $distance as distance
    FULL JOIN $hw_stat as hw
        USING(%(ext_id_key)s)
    ORDER BY %(ext_id_key)s;
    ''')

    params = {
        'ext_id_key': build_config.ext_id_key,
        'lat_lon_precision': build_config.lat_lon_precision,
        'hash_lat_precision': build_config.hash_lat_precision,
        'hash_lon_precision': build_config.hash_lon_precision,
        'percentile_max': build_config.percentile_max,
        'percentile_min': build_config.percentile_min,
        'table_in': build_config.input_yuid_table,
        'table_bandits': build_config.bandits_table,
        'table_homework': build_config.homework_table,
        'table_out': build_config.locations_homework_table,
    }
    execute_yql(query=yql_query, yql_client=yql_client, params=params, set_owners=False, syntax_version=1)


def step_5_merge_tables(yt_client, yql_client, build_config):
    yql_query = dedent('''
    $src_input = '%(table_input)s';
    $src_stat = '%(table_logs)s';
    $src_bandits = '%(table_bandits)s';
    $src_hw = '%(table_homework)s';

    $out = '%(table_out)s';

    INSERT INTO $out WITH TRUNCATE
        SELECT
            *
        FROM (
            SELECT
                src.*,
                hw.*,
                Cast((neighbor_cnt is NULL) as Uint8) as neighbor_null,
                Cast((hw_cnt_mean is NULL) as Uint8) as hw_null,
                Cast((distance is NULL ) as Uint8) as home_or_work_null,
            FROM (
                SELECT
                    stat.*,
                    bandits.cnt as neighbor_cnt,
                    neighbor_cnt_mean,neighbor_cnt_min,neighbor_cnt_max,
                    neighbor_age_mean,neighbor_age_min,neighbor_age_max,
                    neighbor_income_mean,neighbor_income_min,neighbor_income_max,
                    neighbor_score_mean,neighbor_score_min,neighbor_score_max,
                    neighbor_gender_m_mean,neighbor_gender_m_max,neighbor_gender_m_min,
                    neighbor_gender_anomaly_mean,neighbor_gender_anomaly_min,neighbor_gender_anomaly_max,
                    neighbor_gender_anomaly_log_mean,neighbor_gender_anomaly_log_min,neighbor_gender_anomaly_log_max,
                    neighbor_gender_anomaly_moivre_laplace_mean,neighbor_gender_anomaly_moivre_laplace_min,neighbor_gender_anomaly_moivre_laplace_max,
                FROM $src_stat as stat
                    LEFT JOIN $src_bandits as bandits ON stat.%(ext_id_key)s = bandits.%(ext_id_key)s
            ) as src
            LEFT JOIN $src_hw as hw on src.%(ext_id_key)s = hw.%(ext_id_key)s
        ) as src
        FULL JOIN (
            SELECT
                %(ext_id_key)s
            FROM $src_input
            GROUP BY %(ext_id_key)s
        ) as inp USING(%(ext_id_key)s)
    ORDER BY %(ext_id_key)s''')

    params = {
        'ext_id_key': build_config.ext_id_key,
        'table_input': build_config.input_yuid_table,
        'table_logs': build_config.locations_stat_table_map,
        'table_bandits': build_config.locations_bandits_table,
        'table_homework': build_config.locations_homework_table,
        'table_out': build_config.locations_out_merged,
    }
    with yt_client.Transaction() as transaction:
        execute_yql(query=yql_query, yql_client=yql_client,
                    params=params, set_owners=False, syntax_version=1,
                    transaction_id=transaction.transaction_id)

        columns = next(yt_client.read_table(build_config.locations_out_merged)).keys()
        columns = sorted(list(set(columns) - {build_config.ext_id_key}))
        logger.info('Column order in string: ' + ' '.join(columns))
        yt_client.run_map(FeaturesMapper(build_config.ext_id_key, columns),
                          build_config.locations_out_merged, build_config.out_table)


def build_locations_vectors(yt_client, yql_client, build_config, steps_to_run=LOCATIONS_DEFAULT_FEATURES_BUILD_STEPS):
    logger.info('Start calculate geo logs features')

    if LocationsFeaturesBuildSteps.step_1_gather_round_logs in steps_to_run:
        logger.info('Start calculate round geo logs')
        step_1_gather_round_logs(yt_client, yql_client, build_config)
        logger.info('Finish calculate round geo logs')

    if LocationsFeaturesBuildSteps.step_2_calculate_user_statistics in steps_to_run:
        logger.info('Start calculate geo logs statistics')
        step_2_calculate_user_statistics(yt_client, yql_client, build_config)
        logger.info('Finish calculate geo logs statistics')

    if LocationsFeaturesBuildSteps.step_3_calculate_bandits_statistics in steps_to_run:
        logger.info('Start calculate bandits statistics')
        step_3_calculate_bandits_statistics(yt_client, yql_client, build_config)
        logger.info('Finish calculate bandits statistics')

    if LocationsFeaturesBuildSteps.step_4_calculate_homework_statistics in steps_to_run:
        logger.info('Start calculate homework statistics')
        step_4_calculate_homework_statistics(yt_client, yql_client, build_config)
        logger.info('Finish calculate homework statistics')

    if LocationsFeaturesBuildSteps.step_5_merge_tables in steps_to_run:
        logger.info('Start merge table')
        step_5_merge_tables(yt_client, yql_client, build_config)
        logger.info('Finish merge table')

    logger.info('Finish calculate geo logs features')
