from functools import partial

import luigi
import yt.wrapper as yt

from lib import graphite_sender
from lib.luigi import yt_luigi
from rtcconf import config
from utils import mr_utils as mr
from utils import utils

SOURCES_TO_IGNORE = ['auto', 'md5']
NUMBER_OF_INTERVALS = 5.0

    
def get_tables_names(id_type, val_tables, passport_tables, dir_name):
    """
    fills passport_tables and val_tables tables
    """
    files = yt.list(dir_name[:-1])
    for file_name in files:
        if file_name.startswith('yuid_with_id_' + id_type + '_'):
            suffix = file_name.replace('yuid_with_id_' + id_type + '_', '')
            if suffix in SOURCES_TO_IGNORE:  # ignore empty tables
                continue
            elif yt.row_count(dir_name + file_name) == 0:  # ignore auto_ru and md5
                continue
            elif suffix.startswith('passport'):
                passport_tables.append(suffix)
            else:
                val_tables.append(suffix)

def map_tables_to_short_table(rec):
    """
    simplify input and filter null values (mapper)
    """
    yuid = rec.get('key')  # yuid in record
    value = rec.get('id_value')  # phone/email in record
    if value and yuid: 
        yield {'key': yuid, 'value': value, 'to_table': 0}
        yield {'key': value, 'value': yuid, 'to_table': 0}
        yield {'key': yuid, 'to_table': 1}
        yield {'key': value, 'to_table': 1}

def reduce_tables_to_short_table(keys, recs, source=None):
    """
    take all edges and vertices and put in to tables (reducer)
    """
    if keys['to_table'] == 1:
        yield {'key':keys['key'], '@table_index':1}
        return
    values = []
    for rec in recs:
        if len(values) > 1000:
            return
        values.append(rec['value'])
    yield {'key':keys['key'], 'values':values, 'source':source, '@table_index':0}
    
def start_mr_operations(self, id_type, val_tables, passport_tables):
    """
    creates MR operations to get all edges and vertices from sources
    """
    id_type_folder = id_type + 's'
    output_folder = id_type + 's_out'

    yuid_with_id_template = self.in_f('dict') + 'yuid_with_id_' + id_type + '_'
    pairs_table_template = self.out_f(id_type_folder) + 'pairs_'
    values_table_template = self.out_f(id_type_folder) + 'values_'

    operations = []

    mr.mkdir(self.out_f(id_type_folder))

    for suffix in val_tables + (['passport'] if passport_tables else []):
        mr.create_table_with_schema(
            pairs_table_template + suffix,
            {'key':'string', 'source':'string', 'values':'any'},
            True,
        )
        mr.create_table_with_schema(
            values_table_template + suffix,
            {'key':'string'},
            True,
        )
    if passport_tables:
        operations.append(yt.run_map_reduce(
            map_tables_to_short_table,
            partial(
                reduce_tables_to_short_table,
                source='passport'
            ),
            [
                yuid_with_id_template + suffix
                for suffix in passport_tables
            ],
            [
                pairs_table_template + 'passport',
                values_table_template + 'passport'
            ],
            reduce_by=['key', 'to_table'],
            sync=False
        ))
    operations += [
        yt.run_map_reduce(
            map_tables_to_short_table,
            partial(
                reduce_tables_to_short_table,
                source=suffix
            ),
            yuid_with_id_template + suffix,
            [
                pairs_table_template + suffix,
                values_table_template + suffix
            ],
            reduce_by=['key', 'to_table'],
            sync=False
        )
        for suffix in val_tables
    ]

    mr.mkdir(self.out_f(output_folder))
    
    return operations


class YuidPairsPreprocBySourceStatsTask(yt_luigi.BaseYtTask):

    date = luigi.Parameter()

    def output_folders(self):
        return {
            'phones':
                config.GRAPH_FOLDER + self.date + '/precision_of_phones_by_sources/',
            'emails':
                config.GRAPH_FOLDER + self.date + '/precision_of_emails_by_sources/',
            'phones_out':
                config.GRAPH_FOLDER + self.date + '/precision_of_phones_by_sources/results/',
            'emails_out':
                config.GRAPH_FOLDER + self.date + '/precision_of_emails_by_sources/results/',
        }

    def input_folders(self):
        return {
            'dict': config.GRAPH_YT_DICTS_FOLDER
        }

    def requires(self):
        from matching.yuid_matching import graph_dict
        return [graph_dict.YuidAllIdDictsTask(self.date)]

    def output(self):
        return [
            yt_luigi.YtFolderTarget(self.out_f('phones')[:-1]),
            yt_luigi.YtFolderTarget(self.out_f('emails')[:-1])
        ]

    def run(self):
        phone_tables = []
        email_tables = []
        phone_passport_tables = []
        email_passport_tables = []

        get_tables_names('phone', phone_tables, phone_passport_tables, self.in_f('dict'))
        get_tables_names('email', email_tables, email_passport_tables, self.in_f('dict'))

        with yt.Transaction() as _:

            operations = []
            operations += start_mr_operations(self, 'phone', phone_tables, phone_passport_tables)
            operations += start_mr_operations(self, 'email', email_tables, email_passport_tables)
            utils.wait_all(operations)

            all_tables = [
                self.out_f('phones') + 'pairs_passport',
                self.out_f('phones') + 'values_passport',
                self.out_f('emails') + 'pairs_passport',
                self.out_f('emails') + 'values_passport',
            ]

            for suffix in phone_tables:
                all_tables.append(self.out_f('phones') + 'pairs_' + suffix)
                all_tables.append(self.out_f('phones') + 'values_' + suffix)

            for suffix in email_tables:
                all_tables.append(self.out_f('emails') + 'pairs_' + suffix)
                all_tables.append(self.out_f('emails') + 'values_' + suffix)

            mr.sort_all(
                filter(yt.exists, all_tables),
                sort_by='key'
            )

        return


def yuid_intersection(keys, recs, source=None):
    key = keys['key']
    target_ids = set()
    id_types = set()
    for rec in recs:
        if rec['source'] == source:
            target_ids.update(rec['values'])
        else:
            id_types.update(rec['values'])
    if not target_ids or not id_types:
        return
    true_ids = len(target_ids & id_types)
    false_ids = target_ids - id_types
    yield {'key':key, 'true_ids':true_ids, '@table_index':0}
    for id_val in false_ids:
        yield {'key':id_val, 'yuid_key':key, '@table_index':1}

def idtype_intersection(keys, recs):
    error_bit = 0
    yuids = []
    for rec in recs:
        if rec.get('yuid_key', False):
            yuids.append(rec['yuid_key'])
        else:
            error_bit = 1
    for yuid in yuids:
        yield {'key': yuid, 'anomaly_ids': error_bit ^ 1, 'error_ids': error_bit}

def statistics_for_yuid(keys, recs):
    true_ids = 0
    error_ids = 0
    anomaly_ids = 0
    for rec in recs:
        true_ids += rec.get('true_ids', 0)
        error_ids += rec.get('error_ids', 0)
        anomaly_ids += rec.get('anomaly_ids', 0)
    if true_ids + error_ids:
        precision = float(true_ids) / (true_ids + error_ids)
        anomaly = float(true_ids + error_ids + anomaly_ids) / (true_ids + error_ids)
        yield {'p':precision, 'a':anomaly}  # 't':str(keys['key'])[0], 

@yt.aggregator
def count_statistics_main_step(records):
    """
    here we calc avg precision, anomaly and diaps weight
    """
    # interval_size = 1 / NUMBER_OF_INTERVALS
    # interval_center = interval_size / 2.0

    # precision_diaps = defaultdict(int) #{0.0:0, 0.2:0, 0.4:0, 0.6:0, 0.8:0, 1.0:0} # {0,1,2,3,4,5}

    precision_avg_sum = 0
    precision_avg_count = 0

    anomaly_avg_sum = 0
    anomaly_avg_count = 0

    for record in records:  # {'p':precision, 'a':anomaly}
        precision = record['p']
        anomaly = record['a']

        precision_avg_sum += precision
        precision_avg_count += 1

        anomaly_avg_sum += anomaly
        anomaly_avg_count += 1

        # closest_point = int((precision + interval_center) * NUMBER_OF_INTERVALS)  # {0,1,2,3,4,5}
        # precision_diaps[closest_point] += 1

    yield dict(
        [('precision_avg_sum', precision_avg_sum)] + 
        [('precision_avg_count', precision_avg_count)] + 
        [('anomaly_avg_sum', anomaly_avg_sum)] + 
        [('anomaly_avg_count', anomaly_avg_count)] + 
        []  # precision_diaps.items()
    )

def count_statistics_final_step(_, records):
    """
    this is final reducer for statistics
    """
    # precision_diaps = defaultdict(int)

    precision_avg_sum = 0
    precision_avg_count = 0

    anomaly_avg_sum = 0
    anomaly_avg_count = 0

    for record in records:
        precision_avg_sum += record['precision_avg_sum']
        precision_avg_count += record['precision_avg_count']

        anomaly_avg_sum += record['anomaly_avg_sum']
        anomaly_avg_count += record['anomaly_avg_count']

        # for point in range(0, int(NUMBER_OF_INTERVALS) + 1):
        #     precision_diaps[point] += record.get(point, 0)

    precision_avg = (precision_avg_sum / precision_avg_count) if precision_avg_count else 0
    anomaly_avg = (anomaly_avg_sum / anomaly_avg_count) if anomaly_avg_count else 0

    yield dict(
        [('precision_avg', precision_avg)] + 
        [('anomaly_avg', anomaly_avg)] + 
        []  # precision_diaps.items()
    )


class PrecisionSourceStatistics(yt_luigi.BaseYtTask):

    date = luigi.Parameter()
    source = luigi.Parameter()
    id_type = luigi.Parameter()
    test_sample = luigi.Parameter()

    def input_folders(self):
        return {
            'in':
                config.GRAPH_FOLDER + self.date + '/precision_of_' + self.id_type + 's_by_sources/',
        }

    def output_folders(self):
        return {
            'out':
                config.GRAPH_FOLDER + self.date + '/precision_of_' + self.id_type + 's_by_sources/results/',
        }

    def output_table(self):
        return self.out_f('out') + self.source + '_against_' + self.test_sample

    def output(self):
        return [yt_luigi.YtTarget(self.output_table(), allow_empty=True)]

    def run(self):
        target_table = self.in_f('in') + 'pairs_' + self.source
        sample_tables = [
            self.in_f('in') + 'pairs_' + src
            for src in self.test_sample.split('-')
        ]
        sample_type_tables = [
            self.in_f('in') + 'values_' + src
            for src in self.test_sample.split('-')
        ]
        result_table = self.output_table()
        
        with yt.Transaction() as _,\
            yt.TempTable() as tmp_result_table,\
            yt.TempTable() as to_check_table,\
            yt.TempTable() as after_check_table,\
            yt.TempTable() as tmp_preresult_table:

            yt.run_reduce(
                partial(yuid_intersection, source=self.source), 
                [target_table] + sample_tables,
                [tmp_result_table, to_check_table],
                reduce_by=['key'],
            )

            mr.sort_all([tmp_result_table, to_check_table], sort_by="key")

            yt.run_reduce(
                idtype_intersection,
                [to_check_table] + sample_type_tables,
                after_check_table,
                reduce_by=['key'],
            )

            yt.run_sort(after_check_table, sort_by=['key'])

            yt.run_reduce(
                statistics_for_yuid,
                [tmp_result_table, after_check_table],
                tmp_preresult_table,
                reduce_by=['key'],
            )

            yt.run_map_reduce(
                count_statistics_main_step,
                count_statistics_final_step,
                tmp_preresult_table,
                result_table,
                reduce_by='key',  # fake key
            )

        return


def build_additional_requires(self, id_type):
    """
    here we creating luigi tasks, that checks sources precision
    """
    yt_filenames = yt.list(self.in_f(id_type + 's')[:-1])
    pairs_table_names = (f for f in yt_filenames if f.startswith('pairs_'))
    suffixes = set(f.split('_', 1)[-1] for f in pairs_table_names)
    sources = sorted(suffixes)

    requires = []

    for src in sources:
        # one against all another
        test_sample = '-'.join(s for s in sources if s != src)
        if test_sample:
            requires.append(
                PrecisionSourceStatistics(
                    date=self.date,
                    source=src,
                    id_type=id_type,
                    test_sample=test_sample
                )
            )
        if src != 'passport' and yt.exists(self.in_f(id_type + 's') + 'values_passport'):
            # one against passport
            requires.append(
                PrecisionSourceStatistics(
                    date=self.date,
                    source=src,
                    id_type=id_type,
                    test_sample='passport'
                )
            )

    return requires

def collect_statistics(self, id_type):
    """
    here we read result of luigi tasks and prepare to send it to graphite
    """
    result = {}

    folder_path = self.in_f(id_type + 's_reuslts')

    yt_filenames = yt.list(folder_path[:-1])
    filtered_filenames = (f for f in yt_filenames if '_against_' in f)

    for filename in filtered_filenames:
        have_readed = False
        parts = filename.split('_against_')
        src = parts[0]
        tsample = parts[1]
        source_count_table = folder_path + filename

        for record in yt.read_table(source_count_table, raw=False):  # read one record
            have_readed = True
            # precision_diaps = dict()

            precision_avg = record['precision_avg']
            anomaly_avg = record['anomaly_avg']

            # for point in range(0, int(NUMBER_OF_INTERVALS) + 1):
            #     precision_diaps[point] = record.get(point, 0)

            break

        if not have_readed:
            return {}

        if tsample == 'passport':
            result['%s.precision.passport_true_ag_%s' % (id_type, src)] = precision_avg
            result['%s.anomaly.passport_true_ag_%s' % (id_type, src)] = anomaly_avg

            # for point, count in precision_diaps.iteritems():
            #     target = '%s.diaps.passport_true_ag_%s.%s' % (id_type, src, point)
            #     result[target] = count

        else:
            result['%s.precision.full_ag_%s' % (id_type, src)] = precision_avg
            result['%s.anomaly.full_ag_%s' % (id_type, src)] = anomaly_avg

            # for point, count in precision_diaps.iteritems():
            #     target = '%s.diaps.full_ag_%s.%s' % (id_type, src, point)
            #     result[target] = count

    return result


class PrecisionYuidsStatisticsMainTask(yt_luigi.BaseYtTask):

    date = luigi.Parameter()

    def input_folders(self):
        return {
            'dict':
                config.GRAPH_YT_DICTS_FOLDER,
            'phones':
                config.GRAPH_FOLDER + self.date + '/precision_of_phones_by_sources/',
            'emails':
                config.GRAPH_FOLDER + self.date + '/precision_of_emails_by_sources/',
            'phones_reuslts':
                config.GRAPH_FOLDER + self.date + '/precision_of_phones_by_sources/results/',
            'emails_reuslts':
                config.GRAPH_FOLDER + self.date + '/precision_of_emails_by_sources/results/',
        }

    def requires(self):
        """
        standart requires
        """
        return YuidPairsPreprocBySourceStatsTask(self.date)

    @staticmethod
    def output_table():
        """
        name of output file
        """
        return config.LOCAL_OUTPUT_FOLDER + 'graph_precision_yuid_ph-em_source_stat'

    def output(self):
        return [yt_luigi.TodayFileTarget(self.output_table(), self.date)]

    def run(self):
        additional_requires = []
        additional_requires += build_additional_requires(self, 'phone')
        additional_requires += build_additional_requires(self, 'email')
        _ = yield additional_requires  # we need this tasks to be done, to continue

        result = {}
        result.update(collect_statistics(self, 'phone'))
        result.update(collect_statistics(self, 'email'))
        graphite_sender.to_graphite_sender_batch('graph_precision_all', result.items(), self.date)

        yt_luigi.TodayFileTarget.done(self.output_table(), self.date)

        for id_type in ('emails_reuslts', 'phones_reuslts', 'emails', 'phones'):
            mr.drop(self.in_f(id_type)[:-1])



if __name__ == "__main__":
    import os
    yt.config["tabular_data_format"] = yt.YsonFormat(process_table_index=True)
    yt.config.set_proxy(os.getenv('RTCRYPTA_MR_SERVER'))

    @yt_luigi.BaseYtTask.event_handler(luigi.Event.START)
    def on_task_start(_):
        yt_luigi.reset_global_yt_state()

    _ = luigi.build(
        [PrecisionYuidsStatisticsMainTask('2017-07-31')],
        workers=10,
        scheduler_port=int(config.LUIGID_PORT)
    )
