# -*- coding: utf-8 -*-
from enum import Enum, unique
from textwrap import dedent
from datacloud.dev_utils.yt.yt_utils import get_yt_client, create_folders
from datacloud.dev_utils.yql.yql_helpers import execute_yql, create_yql_client
from datacloud.dev_utils.data.data_utils import array_tostring
from datacloud.dev_utils.time.patterns import FMT_DATE
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.features.time_hist.helpers import ActivityHistReducer


@unique
class TimeHistFeaturesBuildSteps(Enum):
    step_1_calulate_timezone = 'step_1_calulate_timezone'
    step_2_aggregates_by_days_category = 'step_2_aggregates_by_days_category'
    step_3_prepare_vectors = 'step_3_prepare_vectors'


TIMEHIST_DEFAULT_FEATURES_BUILD_STEPS = (
    TimeHistFeaturesBuildSteps.step_1_calulate_timezone,
    TimeHistFeaturesBuildSteps.step_2_aggregates_by_days_category,
    TimeHistFeaturesBuildSteps.step_3_prepare_vectors
)


def step_1_calulate_timezone(build_config, yt_client=None, yql_client=None):
    yt_client = yt_client or get_yt_client()
    yql_client = yql_client or create_yql_client(yt_client=yt_client)

    create_folders(build_config.data_dir, yt_client)

    yql_query = dedent("""
        %(custom_pragmas)s
        $region_dir = '%(region_dir)s';
        $input_yuid = '%(input_yuid)s';
        $min_log_date = '%(min_log_date)s';
        $max_log_date = '%(max_log_date)s';

        $output_table = '%(output_table)s';

        $days_to_take = %(days_to_take)s;

        $retro_date = ($id)->{
            RETURN String::Reverse(String::Substring(String::Reverse($id), 0, 10))
        };

        $days_in_window = ($date, $retro_date, $window)->{
            RETURN $date <= $retro_date
               AND CAST($date AS DATE) >= CAST($retro_date AS DATE) - DateTime::IntervalFromDays($window)
        };

        $get_timezone = ($id)->{
            RETURN Geo::RegionById(CAST($id AS Int32)).timezone_name
        };

        $mode_region = (
            SELECT
                %(ext_id_key)s,
                MODE(r.user_region)[0].Value AS user_region
            FROM $input_yuid AS i
            LEFT JOIN RANGE($region_dir, $min_log_date, $max_log_date) AS r
            ON i.yuid = r.yuid
            WHERE $days_in_window(r.log_date, %(retro_date)s, $days_to_take)
            GROUP BY i.%(ext_id_key)s AS %(ext_id_key)s
        );

        INSERT INTO $output_table WITH TRUNCATE
        SELECT
            %(ext_id_key)s,
            user_region,
            $get_timezone(user_region) AS timezone_name
        FROM $mode_region
        ORDER BY %(ext_id_key)s
    """)

    cloud_nodes_pragmas = dedent("""
        PRAGMA yt.PoolTrees = "physical";
        PRAGMA yt.TentativePoolTrees = "cloud";
    """)

    execute_yql(query=yql_query, yql_client=yql_client, params=dict(
        custom_pragmas=cloud_nodes_pragmas if build_config.use_cloud_nodes else '',
        region_dir=build_config.region_dir,
        input_yuid=build_config.input_yuid,
        output_table=build_config.timezones_table,
        days_to_take=build_config.days_to_take,
        ext_id_key=build_config.ext_id_key,
        retro_date="$retro_date(i.external_id)" if build_config.is_retro else "'{}'".format(build_config.snapshot_date),
        min_log_date=build_config.min_log_date,
        max_log_date=build_config.max_log_date,
    ), set_owners=False, syntax_version=1)


def step_2_aggregates_by_days_category(build_config, yt_client=None, yql_client=None):
    yt_client = yt_client or get_yt_client()
    yql_client = yql_client or create_yql_client(yt_client=yt_client)

    holidays = {row['date'] for row in yt_client.read_table(build_config.holidays_table)}

    with yt_client.Transaction() as transaction:
        if not build_config.is_retro:
            yql_query = """
            $input_yuid = '%(input_yuid)s';
            $grep_logs = (%(grep_logs)s);
            $output_data = '%(tmp_grep_table)s';

            INSERT INTO $output_data WITH TRUNCATE
            SELECT
                %(ext_id_key)s,
                `timestamp`
            FROM $input_yuid AS i
            INNER JOIN $grep_logs AS l
            ON i.yuid = l.yuid
            ORDER BY %(ext_id_key)s
            """
            execute_yql(query=yql_query, yql_client=yql_client, params=dict(
                tmp_grep_table=build_config.tmp_grep_table,
                input_yuid=build_config.input_yuid,
                grep_logs=' UNION ALL '.join(
                    "SELECT yuid, `timestamp` FROM RANGE('{}/{}', '{}', '{}')".format(
                        build_config.grep_root, log, build_config.min_log_date, build_config.max_log_date
                    )
                    for log in build_config.log_folders
                ),
                ext_id_key=build_config.ext_id_key,
            ), set_owners=False, syntax_version=1, transaction_id=transaction.transaction_id)

            log_tables = [build_config.tmp_grep_table]
        else:
            log_tables = build_config.get_log_tables(yt_client)

        yt_client.run_reduce(
            ActivityHistReducer(
                holidays=holidays,
                days_to_take=build_config.days_to_take,
                date_format=FMT_DATE,
                ext_id_key=build_config.ext_id_key,
                retro_date=None if build_config.is_retro else build_config.snapshot_date,
            ),
            log_tables + [build_config.timezones_table],
            build_config.histogram_table,
            reduce_by=[build_config.ext_id_key],
            spec=dict(
                title='[{}] run histogram'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            ),
        )

        yt_client.run_merge(
            build_config.histogram_table,
            build_config.histogram_table,
            spec={"combine_chunks": True}
        )

        if not build_config.is_retro:
            yt_client.remove(build_config.tmp_grep_table)


def step_3_prepare_vectors(build_config, yt_client=None):
    yt_client = yt_client or get_yt_client()

    def mapper(rec):
        if rec['hist_activity_rate']:
            rates = rec['hist_activity_rate'].get('working', {})
            total = rec['total_activity_days'].get('working', 0)
            result = [rates.get(str(i), 0.) for i in range(24)] + [total]
            features = array_tostring(result)
            yield {build_config.ext_id_key: rec[build_config.ext_id_key], 'features': features}

    with yt_client.Transaction():
        yt_client.run_map(
            mapper,
            build_config.histogram_table,
            build_config.features_table,
            spec=dict(
                title='[{}] select working histogram and total working days'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            ),
        )

        yt_client.run_sort(
            build_config.features_table,
            sort_by=build_config.ext_id_key,
            spec=dict(
                title='[{}] sort features'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            ),
        )


def build_time_hist_vectors(yt_client, yql_client, build_config, logger=None,
                            steps_to_run=TIMEHIST_DEFAULT_FEATURES_BUILD_STEPS):
    logger = logger or get_basic_logger(name=__name__, format='%(asctime)s %(message)s')

    logger.info('Start calculate time_hist features')
    if TimeHistFeaturesBuildSteps.step_1_calulate_timezone in steps_to_run:
        logger.info('Start calculate timezone')
        step_1_calulate_timezone(build_config, yt_client, yql_client)
        logger.info('Finish calculate timezone')
    if TimeHistFeaturesBuildSteps.step_2_aggregates_by_days_category in steps_to_run:
        logger.info('Start calculate histogram')
        step_2_aggregates_by_days_category(build_config, yt_client, yql_client)
        logger.info('Finish calculate histogram')
    if TimeHistFeaturesBuildSteps.step_3_prepare_vectors in steps_to_run:
        logger.info('Start calculate working days features')
        step_3_prepare_vectors(build_config, yt_client)
        logger.info('Finish calculate working days features')
    logger.info('Finish calculate time_hist features')
