# -*- coding: utf-8 -*-
from enum import Enum, unique

from datacloud.dev_utils.yt.yt_utils import get_yt_client, create_folders
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.features.geo.helpers import (
    fetch_geo_logs_reducer, filter_geo_logs, distance_reducer,
    single_reducer, DistancesFilterReducer, DistanceToFMapper, FeaturesCompactReducer,
    BinaryFeaturesReducer, PointsToGeohashReducerRetro
)

logger = get_basic_logger()


@unique
class GeoFeaturesBuildSteps(Enum):
    step_0_grep_logs = 'step_0_grep_logs'
    step_1_filter_logs = 'step_1_filter_logs'
    step_2_calculate_distances = 'step_2_calculate_distances'
    step_3_filter_distances = 'step_3_filter_distances'
    step_4_calc_features = 'step_4_calc_features',
    step_5_compact_features = 'step_5_compact_features',
    step_6_add_binary_features = 'step_6_add_binary_features'


GEO_FEATURES_BUILD_STEPS = (
    GeoFeaturesBuildSteps.step_0_grep_logs,
    GeoFeaturesBuildSteps.step_1_filter_logs,
    GeoFeaturesBuildSteps.step_2_calculate_distances,
    GeoFeaturesBuildSteps.step_3_filter_distances,
    GeoFeaturesBuildSteps.step_4_calc_features,
    GeoFeaturesBuildSteps.step_5_compact_features,
    GeoFeaturesBuildSteps.step_6_add_binary_features
)


def step_0_grep_logs(build_config, yt_client=None):
    yt_client = yt_client or get_yt_client()
    create_folders((build_config.logs_dir,), yt_client)

    if yt_client.exists(build_config.fetched_logs):
        logger.info('Geo logs already greped. Fast forwarding...')
        return

    min_date, max_date = build_config.min_date, build_config.max_date

    logs_tables = yt_client.list(build_config.LOCAL_LOGS_DIR, absolute=True)

    if min_date is not None:
        logs_tables = filter(lambda t: t.split('/')[-1] >= min_date, logs_tables)
    if max_date is not None:
        logs_tables = filter(lambda t: t.split('/')[-1] <= max_date, logs_tables)

    if len(logs_tables) < 6:
        logger.info(' Using log tables:\n{}'.format('\n'.join(logs_tables)))
    else:
        logger.info(' Using log tables:\n{}\n...\n{}'.format(logs_tables[0], logs_tables[-1]))

    with yt_client.Transaction():
        yt_client.run_sort(
            build_config.input_yuid_table,
            sort_by=['yuid'],
            spec=dict(
                title='[{}] Sort input yuid'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )

        yt_client.run_reduce(
            fetch_geo_logs_reducer,
            [build_config.input_yuid_table] + logs_tables,
            build_config.fetched_logs,
            reduce_by=['yuid'],
            spec=dict(
                title='[{}] Join logs with input_yuid'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )

        yt_client.run_sort(
            build_config.fetched_logs,
            sort_by=['external_id', 'timestamp_of_log'],
            spec=dict(
                title='[{}] Join logs / sort after'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )


def step_1_filter_logs(build_config, yt_client):
    with yt_client.Transaction():
        yt_client.run_reduce(
            filter_geo_logs,
            build_config.fetched_logs,
            build_config.filtered_logs1,
            reduce_by=['external_id'],
            sort_by=['external_id', 'timestamp_of_log'],
            spec=dict(
                title='[{}] Filter logs by last timestamp'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )
        yt_client.run_sort(
            build_config.filtered_logs1,
            sort_by=['external_id'],
            spec=dict(
                title='[{}] Filter logs by last timestamp / sort after'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )

        yt_client.run_reduce(
            PointsToGeohashReducerRetro(),
            build_config.filtered_logs1,
            build_config.filtered_logs2,
            reduce_by=['external_id'],
            spec=dict(
                title='[{}] Filter logs by clustering'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )
        yt_client.run_sort(
            build_config.filtered_logs2,
            sort_by=['external_id'],
            spec=dict(
                title='[{}] Filter logs by clustering / sort after'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )


def step_2_calculate_distances(build_config, yt_client):
    with yt_client.Transaction():
        yt_client.run_sort(
            build_config.resolved_addrs_table,
            sort_by=['external_id'],
            spec=dict(
                title='[{}] Sort resolved addrs table'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )

        yt_client.run_reduce(
            distance_reducer,
            [
                build_config.filtered_logs2,
                yt_client.TablePath(
                    build_config.resolved_addrs_table,
                    attributes={'foreign': True}
                )
            ],
            build_config.distances_table,
            reduce_by=['external_id'],
            join_by=['external_id'],
            spec=dict(
                title='[{}] Calculate distances'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )
        yt_client.run_sort(
            build_config.distances_table,
            sort_by=['external_id', 'type', 'distance'],
            spec=dict(
                title='[{}] Calculate distances / sort after'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )


def step_3_filter_distances(build_config, yt_client):
    with yt_client.Transaction():
        yt_client.run_reduce(
            single_reducer,
            build_config.distances_table,
            build_config.distances_filtered,
            reduce_by=['external_id', 'type', 'distance'],
            spec=dict(
                title='[{}] Reduce duplicates in distances'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )
        yt_client.run_sort(
            build_config.distances_filtered,
            sort_by=['external_id', 'type', 'distance'],
            spec=dict(
                title='[{}] Reduce duplicates in distances / sort after'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )

        yt_client.run_reduce(
            DistancesFilterReducer(build_config.max_distances_in_category),
            build_config.distances_filtered,
            build_config.distances_filtered,
            reduce_by=['external_id', 'type'],
            sort_by=['external_id', 'type', 'distance'],
            spec=dict(
                title='[{}] Filter distances'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )
        yt_client.run_sort(
            build_config.distances_filtered,
            sort_by=['external_id', 'type', 'distance'],
            spec=dict(
                title='[{}] Filter distances / sort after'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )


def step_4_calc_features(build_config, yt_client):
    with yt_client.Transaction():
        yt_client.run_map(
            DistanceToFMapper(build_config.distance_thresh),
            build_config.distances_filtered,
            build_config.features_flatten,
            spec=dict(
                title='[{}] Calc features'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )
        yt_client.run_sort(
            build_config.features_flatten,
            sort_by=['external_id', 'type', 'feature'],
            spec=dict(
                title='[{}] Calc features / sort after'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )


def step_5_compact_features(build_config, yt_client):
    with yt_client.Transaction():
        yt_client.run_reduce(
            FeaturesCompactReducer(
                build_config.addrs_types,
                build_config.max_distances_in_category,
                build_config.features_fillna,
                build_config.features_sort_order
            ),
            build_config.features_flatten,
            build_config.features_table,
            reduce_by=['external_id'],
            sort_by=['external_id', 'type', 'feature'],
            spec=dict(
                title='[{}] Compact features'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )
        yt_client.run_sort(
            build_config.features_table,
            sort_by=['external_id'],
            spec=dict(
                title='[{}] Compact features / sort after'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )


def step_6_add_binary_features(build_config, yt_client):
    with yt_client.Transaction():
        yt_client.run_reduce(
            BinaryFeaturesReducer(
                build_config.addrs_types,
                build_config.max_distances_in_category,
                build_config.features_fillna
            ),
            [
                build_config.features_table,
                build_config.addresses_table
            ],
            build_config.features_with_binary,
            reduce_by=['external_id'],
            spec=dict(
                title='[{}] Add binary features'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )
        yt_client.run_sort(
            build_config.features_with_binary,
            sort_by=['external_id'],
            spec=dict(
                title='[{}] Add binary features / sort after'.format(build_config.tag),
                **build_config.cloud_nodes_spec
            )
        )


def build_vectors(build_config, yt_client=None, build_steps=GEO_FEATURES_BUILD_STEPS):
    yt_client = yt_client or get_yt_client()
    create_folders((build_config.data_dir,), yt_client)

    if GeoFeaturesBuildSteps.step_0_grep_logs in build_steps:
        logger.info(' Started step 0 / grep logs')
        step_0_grep_logs(build_config, yt_client)

    if GeoFeaturesBuildSteps.step_1_filter_logs in build_steps:
        logger.info(' Started step 1 / filter logs')
        step_1_filter_logs(build_config, yt_client)

    if GeoFeaturesBuildSteps.step_2_calculate_distances in build_steps:
        logger.info(' Started step 2 / calculate distances')
        step_2_calculate_distances(build_config, yt_client)

    if GeoFeaturesBuildSteps.step_3_filter_distances in build_steps:
        logger.info(' Started step 3 / filter distances')
        step_3_filter_distances(build_config, yt_client)

    if GeoFeaturesBuildSteps.step_4_calc_features in build_steps:
        logger.info(' Started step 4 / calc features')
        step_4_calc_features(build_config, yt_client)

    if GeoFeaturesBuildSteps.step_5_compact_features in build_steps:
        logger.info(' Started step 5 / compact features')
        step_5_compact_features(build_config, yt_client)

    if GeoFeaturesBuildSteps.step_6_add_binary_features in build_steps:
        logger.info(' Started step 6 / add binary features')
        step_6_add_binary_features(build_config, yt_client)

    logger.info(' Geo vectors built!')
