from enum import Enum, unique
from textwrap import dedent
from datacloud.dev_utils.yql.yql_helpers import execute_yql, create_yql_client
from datacloud.dev_utils.yt.yt_utils import get_yt_client, create_folders
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.features.phone_range.helpers import PhoneRangeReducer


CLOUD_NODES_PRAGMAS = dedent("""
    PRAGMA yt.PoolTrees = "physical";
    PRAGMA yt.TentativePoolTrees = "cloud";
""")


@unique
class PhoneRangeFeaturesBuildSteps(Enum):
    step_1_compute_phone_range = 'step_1_compute_phone_range'
    step_2_compute_change_region = 'step_2_compute_change_region'


PHONERANGE_DEFAULT_FEATURES_BUILD_STEPS = (
    PhoneRangeFeaturesBuildSteps.step_1_compute_phone_range,
    PhoneRangeFeaturesBuildSteps.step_2_compute_change_region,
)


def step_1_compute_phone_range(yt_client, yql_client, build_config):
    yt_client = yt_client or get_yt_client()
    yql_client = yql_client or create_yql_client()

    create_folders(build_config.data_dir, yt_client)

    phone_ranges_info = list(yt_client.read_table(build_config.phone_ranges_info))
    with yt_client.Transaction() as transaction,\
         yt_client.TempTable(prefix=build_config.root) as temp_table:
        if build_config.pure_external_id:
            yql_query = dedent("""
                %(custom_pragmas)s
                $rainbow_table = '%(rainbow_table)s';
                $input_table = '%(input_table)s';
                $raw_data_table = '{}';
                $output_table = '%(output_table)s';

                INSERT INTO $output_table WITH TRUNCATE
                SELECT
                    t.%(ext_id_key)s || '_' || t.retro_date AS %(ext_id_key)s,
                    COALESCE(r.phone, '') AS phone
                FROM $input_table AS i
                LEFT JOIN $raw_data_table AS t
                ON t.%(ext_id_key)s = i.%(ext_id_key)s
                LEFT JOIN $rainbow_table AS r
                ON r.phone_md5 = i.id_value
                WHERE i.id_type = 'phone_md5'
                ORDER BY %(ext_id_key)s
            """.format(build_config.raw_data_table))
        else:
            yql_query = dedent("""
                %(custom_pragmas)s
                $rainbow_table = '%(rainbow_table)s';
                $input_table = '%(input_table)s';
                $output_table = '%(output_table)s';

                INSERT INTO $output_table WITH TRUNCATE
                SELECT
                    %(ext_id_key)s,
                    COALESCE(r.phone, '') AS phone
                FROM $input_table AS i
                LEFT JOIN $rainbow_table AS r
                ON r.phone_md5 = i.id_value
                ORDER BY %(ext_id_key)s
            """)

        execute_yql(query=yql_query, yql_client=yql_client, params=dict(
            custom_pragmas=CLOUD_NODES_PRAGMAS if build_config.use_cloud_nodes else '',
            rainbow_table=build_config.rainbow_table,
            output_table=temp_table,
            input_table=build_config.input_table,
            ext_id_key=build_config.ext_id_key,
        ), set_owners=False, syntax_version=1, transaction_id=transaction.transaction_id)

        yt_client.run_reduce(
            PhoneRangeReducer(phone_ranges_info=phone_ranges_info, ext_id_key=build_config.ext_id_key),
            temp_table,
            build_config.features_table,
            reduce_by=[build_config.ext_id_key],
            spec=dict(
                title='[{}] calculate phone range features'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            ),
        )

        yt_client.run_sort(
            build_config.features_table,
            sort_by=[build_config.ext_id_key],
            spec=dict(
                title='[{}] sort features'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            ),
        )


def step_2_compute_change_region(yql_client, build_config):
    yql_query = dedent("""
        %(custom_pragmas)s
        $features_table = '%(features_table)s';
        $clickstream_region_table = '%(clickstream_region_table)s';

        $is_regions_includes = ($r1, $r2)->{
            RETURN CASE
                WHEN $r1 IS NOT NULL
                AND $r2 IS NOT NULL
                THEN Geo::IsRegionInRegion(UNWRAP(CAST($r1 AS Int32)), UNWRAP(CAST($r2 AS Int32)))
                OR Geo::IsRegionInRegion(UNWRAP(CAST($r2 AS Int32)), UNWRAP(CAST($r1 AS Int32)))
                ELSE NULL
            END
        };

        INSERT INTO $features_table WITH TRUNCATE
        SELECT
            f.%(ext_id_key)s AS %(ext_id_key)s,
            operator,
            region,
            region_id,
            NOT $is_regions_includes(region_id, user_region) AS change_region
        FROM $features_table AS f
        LEFT JOIN $clickstream_region_table AS r
        ON r.%(ext_id_key)s = f.%(ext_id_key)s
        ORDER BY %(ext_id_key)s
    """)

    execute_yql(query=yql_query, yql_client=yql_client, params=dict(
        custom_pragmas=CLOUD_NODES_PRAGMAS if build_config.use_cloud_nodes else '',
        features_table=build_config.features_table_path,
        clickstream_region_table=build_config.clickstream_region_table,
        input_table=build_config.input_table,
        ext_id_key=build_config.ext_id_key,
    ), set_owners=False, syntax_version=1)


def build_phone_range_vectors(yt_client, yql_client, build_config, logger=None,
                              steps_to_run=PHONERANGE_DEFAULT_FEATURES_BUILD_STEPS):
    logger = logger or get_basic_logger(name=__name__, format='%(asctime)s %(message)s')
    logger.info('Start calculate phone range features')

    if PhoneRangeFeaturesBuildSteps.step_1_compute_phone_range in steps_to_run:
        logger.info('Start calculate phone range')
        step_1_compute_phone_range(yt_client, yql_client, build_config)
        logger.info('Finish calculate phone range')

    if PhoneRangeFeaturesBuildSteps.step_2_compute_change_region in steps_to_run:
        if not build_config.compute_region_changes:
            logger.info('Compute region in config equels False')
        elif not yt_client.exists(build_config.clickstream_region_table):
            logger.info('Table with clickstream region does not exist')
        else:
            logger.info('Start calculate change region feature')
            step_2_compute_change_region(yql_client, build_config)
            logger.info('Finish calculate change region feature')

    logger.info('Finish calculate phone range features')
