import os
from collections import namedtuple
import pandas as pd

from yt.wrapper import ypath_join
from datacloud.dev_utils.yql.yql_helpers import execute_yql

from datacloud.ml_utils.dolphin.prepare_cse.helpers import (
    rm_complete_dups_q,
    clean_cse_q,
    get_eids_q,
    rows_sample,
    filter_eids_q,
    join_cids_q,
    filter_by_eids_q,
    list_suffixes,
    MarkCidsMapper,
    mark_eids_reducer,
    compact_reducer,
    make_file_name
)
from datacloud.ml_utils.dolphin.prepare_cse.path_config import PathConfig

CleanConfig = namedtuple('CleanConfig', [
    'experiment_name',
    'path_to_original_cse',
    'aggs_folder',
    'crypta_folder',
    'zeros_vs_ones',
    'min_retro_date',
    'no_go_partners',
    'n_folds',
    'val_size',
    'steps',
])


def step0_copy_inputs(yt_client, yql_client, path_config, cconfig, logger):
    yt_client.copy(
        ypath_join(cconfig.crypta_folder, 'id_value_to_cid'),
        path_config.id_value_to_cid,
        ignore_existing=True
    )
    yt_client.copy(
        ypath_join(cconfig.crypta_folder, 'phone_id_value_to_cid'),
        path_config.phone_id_value_to_cid,
        ignore_existing=True
    )
    yt_client.copy(
        ypath_join(cconfig.crypta_folder, 'email_id_value_to_cid'),
        path_config.email_id_value_to_cid,
        ignore_existing=True
    )
    yt_client.copy(
        ypath_join(cconfig.crypta_folder, 'cid_to_all'),
        path_config.cid_to_all,
        ignore_existing=True
    )
    yt_client.copy(
        cconfig.path_to_original_cse,
        path_config.path_to_cse,
        ignore_existing=True
    )


def step1_clean_cse_table(yt_client, yql_client, path_config, cconfig, logger):
    if not len(cconfig.no_go_partners):
        logger.warn('Attention! No go partners list is empty!')
    logger.info('Bad partners are\n{}'.format('\n'.join(cconfig.no_go_partners)))

    execute_yql(query=rm_complete_dups_q, yql_client=yql_client, params=dict(
        in_path=path_config.path_to_cse,
        out_path=path_config.clean_cse_path
    ), set_owners=False, syntax_version=1)

    execute_yql(query=clean_cse_q, yql_client=yql_client, params=dict(
        in_path=path_config.clean_cse_path,
        out_path=path_config.clean_cse_path,
        min_retro_date=cconfig.min_retro_date,
        no_go_partners='({})'.format(', '.join('"{}"'.format(p) for p in cconfig.no_go_partners))
    ), set_owners=False, syntax_version=1)


def step2_sample_eids(yt_client, yql_client, path_config, cconfig, logger):
    execute_yql(query=get_eids_q, yql_client=yql_client, params=dict(
        in_path=path_config.clean_cse_path,
        out_path=path_config.ones_table,
        target=1
    ), set_owners=False, syntax_version=1)
    ones_cnt = yt_client.row_count(path_config.ones_table)

    execute_yql(query=get_eids_q, yql_client=yql_client, params=dict(
        in_path=path_config.clean_cse_path,
        out_path=path_config.zeros_table,
        target=0
    ), set_owners=False, syntax_version=1)

    yt_client.write_table(
        path_config.zeros_table,
        rows_sample(
            yt_client=yt_client,
            table_path=path_config.zeros_table,
            sample_size=int(round(cconfig.zeros_vs_ones * ones_cnt))
        )
    )

    yt_client.run_merge(
        [path_config.zeros_table, path_config.ones_table],
        path_config.eids_table
    )

    execute_yql(query=filter_eids_q, yql_client=yql_client, params=dict(
        in_path=path_config.clean_cse_path,
        eids_table=path_config.eids_table,
        out_path=path_config.cse_interesting_eids
    ), set_owners=False, syntax_version=1)


def step3_mark_eids(yt_client, yql_client, path_config, cconfig, logger):
    execute_yql(query=join_cids_q, yql_client=yql_client, params=dict(
        cse_interesting_eids=path_config.cse_interesting_eids,
        out_path=path_config.eid2cid,
        id_value_to_cid=path_config.id_value_to_cid
    ), set_owners=False, syntax_version=1)

    cids_nrows = yt_client.row_count(path_config.eid2cid)

    yt_client.run_map_reduce(
        MarkCidsMapper(
            nfolds=cconfig.n_folds,
            val_sample_rate=float(cids_nrows) / cconfig.val_size
        ),
        mark_eids_reducer,
        path_config.eid2cid,
        path_config.eid2mark,
        reduce_by=['external_id'],
        spec={'title': 'Mark external ids'}
    )
    yt_client.run_sort(
        path_config.eid2mark,
        sort_by=['external_id'],
        spec={'title': 'Mark external ids / sort after'}
    )


def step4_split_cse(yt_client, yql_client, path_config, cconfig, logger):
    for suffix in list_suffixes(cconfig=cconfig):
        execute_yql(query=filter_by_eids_q, yql_client=yql_client, params=dict(
            cse=path_config.cse_interesting_eids,
            marked_eids=path_config.eid2mark,
            mark=suffix,
            out_path=path_config.make_cse_table(suffix)
        ), set_owners=False, syntax_version=1, yt_client=yt_client)

        yt_client.run_sort(path_config.make_cse_table(suffix), sort_by=[
            'external_id',
            'retro_date',
            'partner',
            'target'
        ])


def step5_compact_cse(yt_client, yql_client, path_config, cconfig, logger):
    configs = [(
        path_config.make_cse_table(suffix),
        path_config.make_cse_compacted_table(suffix),
        make_file_name(suffix)
    ) for suffix in list_suffixes(cconfig=cconfig)]

    for cse, cse_compacted, file_name in configs:
        yt_client.run_reduce(
            compact_reducer,
            cse,
            cse_compacted,
            reduce_by=['external_id', 'retro_date', 'partner', 'target']
        )
        yt_client.run_sort(cse_compacted, sort_by=['external_id', 'retro_date'])
        df = pd.DataFrame(yt_client.read_table(cse_compacted))
        df = df.fillna('')
        file_path = os.path.join(cconfig.experiment_name, file_name)
        df.to_csv(file_path, sep='\t', index=False)
        with open(file_path) as fd:
            yt_client.write_file(ypath_join(cconfig.aggs_folder, file_name), fd)


step_num2method = {
    0: step0_copy_inputs,
    1: step1_clean_cse_table,
    2: step2_sample_eids,
    3: step3_mark_eids,
    4: step4_split_cse,
    5: step5_compact_cse
}


def main(yt_client, yql_client, cconfig, logger):
    assert all(step in step_num2method for step in cconfig.steps)
    logger.info('Running prepare CSE pipeline')
    path_config = PathConfig(cconfig)

    for step in cconfig.steps:
        logger.info('Start of step {}'.format(step))
        step_num2method[step](yt_client, yql_client, path_config, cconfig, logger)
        logger.info('Step {} done'.format(step))
