import logging
from functools import partial

import luigi
import yt.wrapper as yt

import radius_metrics_calc as calc
import radius_metrics_mr as radius_mr
from data_imports.import_logs import app_metrica_day
from data_imports.import_logs.watch_log import graph_watch_log
from infra.radius.radius_splunk import FilterByRadius

from lib import graphite_sender
from lib.luigi import base_luigi_task
from lib.luigi import yt_luigi
from matching.yuid_matching import graph_dict
from rtcconf import config
from utils import mr_utils as mr
from utils import utils

logger = logging.getLogger('radius_metrics')


def multidevice_percent(rlogin_graph, crypta_graph):
    multidevice_rlogins = len([rlogin for (rlogin, devices) in rlogin_graph.iteritems() if len(devices) > 1])
    multidevice_crypta_ids = len([cid for (cid, devices) in crypta_graph.iteritems() if len(devices) > 1])
    return multidevice_rlogins / float(len(rlogin_graph)), multidevice_crypta_ids / float(len(crypta_graph))


def dump_metrics(logins_to_metrics, table):
    def to_recs():
        for login, m in logins_to_metrics.iteritems():
            prec = m.prec_a / m.prec_b if m.prec_b else 0
            rec = m.rec_a / m.rec_b if m.rec_b else 0
            yield {'login': login,
                   'prec_a': m.prec_a, 'prec_b': m.prec_b,
                   'rec_a': m.rec_a, 'rec_b': m.rec_b,
                   'prec': prec, 'rec': rec,
                   'login_crypta_ids': list(m.crypta_ids),
                   'login_yuid': list(m.login_yuids),
                   'crypta_yuids': list(m.crypta_yuids)}

    yt.write_table(table, to_recs(), raw=False)

    # show overmatching first
    yt.run_sort(table, sort_by='prec_b')


def to_graphite(prefix, precision, recall, rlogin_count, cid_count, similarity, coverage, date):

    def to_float_str(float_value):
        if float_value:
            return '%.3f' % float_value
        else:
            return '0'

    metrics = []
    for prec_type, prec_value in precision.iteritems():
        metrics.append(('prec.' + prec_type, to_float_str(prec_value)))
    for rec_type, rec_value in recall.iteritems():
        metrics.append(('rec.' + rec_type, to_float_str(rec_value)))
    metrics.append(('rlogin_count', str(rlogin_count)))
    metrics.append(('cid_count', str(cid_count)))
    metrics.append(('similarity', to_float_str(similarity)))
    metrics.append(('coverage', to_float_str(coverage)))
    graphite_sender.to_graphite_sender_batch('radius_metrics.' + prefix, metrics, date)


def calculate_metrics_and_persist(prefix, rlogin_graph, crypta_graph, output_yt_folder, date):

    def log(msg):
        print '%s | %s' % (prefix, msg)

    output_table = output_yt_folder + 'metrics_%s' % prefix

    if not rlogin_graph or not crypta_graph:
        yt.create_table(output_table)
        return

    rlogin_count = len(rlogin_graph)
    cid_count = len(crypta_graph)
    log('rlogins count: %d' % rlogin_count)
    log('crypta_ids count: %d' % cid_count)

    logins_to_metrics = calc.get_metrics_per_login(rlogin_graph, crypta_graph)
    if not logins_to_metrics:
        raise Exception('Radius logins didn\'t match to any crypta id in %s' % prefix)

    precisions = [(lm.prec_a, lm.prec_b) for lm in logins_to_metrics.values()]
    recalls = [(lm.rec_a, lm.rec_b) for lm in logins_to_metrics.values()]
    covers = [lm.cover for lm in logins_to_metrics.values()]

    precision, recall, f1_score = calc.average_metrics(precisions, recalls)
    coverage = sum(covers) / rlogin_count
    similarity = calc.similarity(rlogin_graph, crypta_graph)

    mult_rlogin, mult_crypta = multidevice_percent(rlogin_graph, crypta_graph)

    log('Multi-device rlogins %: {}'.format(mult_rlogin))
    log('Multi-device crypta_ids %: {}'.format(mult_crypta))
    log('Precision:')
    for prec_type, prec_value in precision.iteritems():
        log("\t{}: {}".format(prec_type, prec_value))
    log('Recall:')
    for rec_type, rec_value in recall.iteritems():
        log("\t{}: {}".format(rec_type, rec_value))
    log('F1-score: {}'.format(f1_score))
    log('Coverage: {}'.format(coverage))
    log('Similarity: {}'.format(similarity))

    to_graphite(prefix, precision, recall, rlogin_count, cid_count, similarity, coverage, date)

    dump_metrics(logins_to_metrics,
                 output_table)


def filter_graph_threshold(graph, max_yuids):
    return {crypta_id: yuids for crypta_id, yuids in graph.items() if len(yuids) < max_yuids}


def enrich_with_ua_and_filter(yuid_table, out_table, workdir, yuid_all_dict):
    # filter yuid ua info only for relevant yuids. We don't reduce full table for performance reasons
    yuids = set(r['yuid'] for r in yt.read_table(yuid_table, raw=False))
    tmp_ua_table = workdir + 'yuid_ua_tmp'
    yt.run_map(partial(radius_mr.filter_by_yuids, yuids=yuids),
               yuid_all_dict, tmp_ua_table)
    yt.run_merge(tmp_ua_table, tmp_ua_table, spec={'combine_chunks': True})

    mr.sort_all([yuid_table, tmp_ua_table], sort_by='yuid')
    yt.run_reduce(radius_mr.reduce_add_ua_profile,
                  [yuid_table, tmp_ua_table],
                  out_table, reduce_by='yuid')


def prepare_yuid_radius_data(radius_logs, watch_logs, workdir, yuid_dict):
    # Join radius and watchlog
    utils.wait_all([
        yt.run_sort(watch_logs, workdir + 'watch_log', sort_by=['ip', 'timestamp'], sync=False),
        yt.run_sort(radius_logs, workdir + 'radius_log', sort_by=['ip', 'timestamp'], sync=False)
    ])
    yt.run_reduce(radius_mr.join_logs_by_ip_and_ts,
                  [workdir + 'watch_log', workdir + 'radius_log'],
                  [workdir + 'yuid_rlogin_raw'],
                  sort_by=['ip', 'timestamp'], reduce_by='ip')

    # Sometimes radius login events may be missing, thus we count only first login as more robust
    yt.run_sort(workdir + 'yuid_rlogin_raw', sort_by=['yuid', 'timestamp'])
    yt.run_reduce(radius_mr.remove_login_conflicts_by_hits_freq,
                  workdir + 'yuid_rlogin_raw',
                  [workdir + 'yuid_rlogin', workdir + 'yuid_rlogin_raw_conflict'],
                  sort_by=['yuid', 'timestamp'], reduce_by='yuid')

    # Enrich external data with internal dicts
    # Add user-agent info and keep only records with user agent
    enrich_with_ua_and_filter(workdir + 'yuid_rlogin',
                              workdir + 'yuid_rlogin_final',
                              workdir, yuid_dict)

    # Avoid inner-Yandex botnets. May be skipped and filtered later at final step
    yt.run_sort(workdir + 'yuid_rlogin_final', sort_by='login')
    yt.run_reduce(partial(radius_mr.remove_large_rlogins, max_count=100),
                  workdir + 'yuid_rlogin_final',
                  [workdir + 'yuid_rlogin_final_stable', workdir + 'yuid_rlogin_final_large'],
                  reduce_by='login')


def prepare_device_radius_data(radius_logs, device_watch_logs, workdir):
    '''
    Args:
        radius_logs: YT table of format
                     |ip|login|rec_type|timestamp|;
        device_watch_logs: YT table of format
                     |ip|key(deviceid)|subkey|timestamp|ua_profile|
    '''
    # Prepare watchlog
    yt.run_map(radius_mr.prepare_device_watch_log,
               device_watch_logs,
               workdir + 'device_watch_log')
    # Join radius and device watchlog
    yt.run_sort(workdir + 'device_watch_log', sort_by=['ip', 'timestamp'])
    yt.run_sort(radius_logs, workdir + 'radius_log', sort_by=['ip', 'timestamp'])
    yt.run_reduce(radius_mr.join_device_and_rlogin_for_logged_in_yuids,
                  [workdir + 'device_watch_log', workdir + 'radius_log'],
                  [workdir + 'device_rlogin_raw'],
                  sort_by=['ip', 'timestamp'], reduce_by='ip')
    yt.run_sort(workdir + 'device_rlogin_raw', sort_by=['devid', 'login', 'ua_profile'])
    yt.run_reduce(partial(radius_mr.distinct,
                          distinct_by=['devid', 'login', 'ua_profile']),
                  workdir + 'device_rlogin_raw',
                  workdir + 'device_rlogin_final',
                  reduce_by=['devid', 'login', 'ua_profile'])


def prepare_yuid_crypta_id_data(vertices_table, workdir):
    office_yuids_dict = dict()
    for r in yt.read_table(workdir + 'yuid_rlogin_final_stable'):
        office_yuids_dict[r['yuid']] = r

    # crypta_id table processing
    yt.run_map(partial(radius_mr.prepare_vertices, office_yuids_dict=office_yuids_dict),
               vertices_table, workdir + 'yuid_cid_office_relevant')


def prepare_device_crypta_id_data(vertices_table, workdir):
    office_devids_dict = dict()
    for r in yt.read_table(workdir + 'device_rlogin_final'):
        office_devids_dict[r['devid']] = r

    # crypta_id table processing
    yt.run_map(partial(radius_mr.prepare_device_vertices, office_devids_dict=office_devids_dict),
               vertices_table, workdir + 'devid_cid_office_relevant')


def get_available_log_tables(date, ndays, watch_log_table_name):
    # Input data
    if not yt.exists(config.RADIUS_LOG_YT_FOLDER[:-1]):
        logger.error('No radius log data is available. Stopping execution')
        return [], []

    # Logs from Radius auth server. Contains proven IP-login linkage by time
    radius_log_dates = mr.list_dates_before(config.RADIUS_LOG_YT_FOLDER, date, ndays)
    # Watch logs. Contains IP-yuid linkage by time. Mined from fingerprints and already filtered by radius log ips
    watch_log_dates = mr.list_dates_before(config.YT_OUTPUT_FOLDER, date, ndays)

    available_dates = sorted(list(set(radius_log_dates).intersection(watch_log_dates)))
    if not available_dates:
        logger.error('No data is available to calculate metrics. Stopping execution')
        return [], []

    radius_logs = [config.RADIUS_LOG_YT_FOLDER + dt + '/radius_log' for dt in available_dates]
    radius_logs = filter(lambda table: yt.exists(table), radius_logs)
    logger.info("Radius log tables for %d days: %s", ndays, '\n'.join(radius_logs))

    watch_logs = [config.YT_OUTPUT_FOLDER + dt + '/' + watch_log_table_name for dt in available_dates]
    watch_logs = filter(lambda table: yt.exists(table), watch_logs)
    logger.info("Watch log %s tables for %d days: %s", watch_log_table_name, ndays, '\n'.join(watch_logs))

    return radius_logs, watch_logs


def calculate_all_metrics(base_workdir, cid_workdir, vertices_type, date, calculate_device_metrics):
    # Filtered tables size are defined by the number of office IPs and shouldn't be more than 20k
    # yuid graph
    logger.info('Fetching yuid dicts to memory...')
    yuid_rlogin_recs = list(yt.read_table(cid_workdir + 'yuid_cid/yuid_rlogin_final_stable', raw=False))
    logger.info("Fetched %d yuid_rlogins to memory", len(yuid_rlogin_recs))
    yuid_cid_recs = list(yt.read_table(cid_workdir + 'yuid_cid/yuid_cid_office_relevant', raw=False))
    logger.info("Fetched %d yuid_cids to memory", len(yuid_cid_recs))


    # Getting real assessors rlogin groups of yuids and predicted crypta_id slices of yuids (multimaps)
    rlogin_graph, crypta_graph = radius_mr.get_rlogin_and_cid_splices(
        yuid_rlogin_recs, yuid_cid_recs,
        add_info=['device_type', 'ua_profile'])
    rlogin_graph_d, crypta_graph_d = radius_mr.get_rlogin_and_cid_splices(
        yuid_rlogin_recs, yuid_cid_recs,
        lambda record: record['device_type'] == 'desk',
        add_info=['device_type', 'ua_profile'])
    rlogin_graph_m, crypta_graph_m = radius_mr.get_rlogin_and_cid_splices(
        yuid_rlogin_recs, yuid_cid_recs,
        lambda record: record['device_type'] == 'phone',
        add_info=['device_type', 'ua_profile'])

    postfix = '_' + vertices_type

    print "===All===> " + vertices_type
    calculate_metrics_and_persist('all' + postfix, rlogin_graph, crypta_graph, base_workdir, date)
    print '===Desktop===> ' + vertices_type
    calculate_metrics_and_persist('desktop' + postfix, rlogin_graph_d, crypta_graph_d, base_workdir, date)
    print '===Mobile===> ' + vertices_type
    calculate_metrics_and_persist('mobile' + postfix, rlogin_graph_m, crypta_graph_m, base_workdir, date)

    if calculate_device_metrics:
        # device id graph
        device_rlogin_recs = list(yt.read_table(cid_workdir + 'devid_cid/device_rlogin_final', raw=False))
        logger.info('Fetched %d device_rlogins to memory', len(device_rlogin_recs))
        devid_cid_recs = list(yt.read_table(cid_workdir + 'devid_cid/devid_cid_office_relevant', raw=False))
        logger.info("Fetched %d devid_cid to memory", len(devid_cid_recs))

        if device_rlogin_recs and devid_cid_recs:
            # device id metrics
            device_rlogin_graph, device_crypta_graph = \
                radius_mr.get_rlogin_and_cid_splices(device_rlogin_recs,
                                                     devid_cid_recs,
                                                     id_key='devid',
                                                     add_info=['ua_profile'])
            rlogin_full_graph = radius_mr.merge_yuid_and_devid_graphs(
                rlogin_graph, device_rlogin_graph)
            crypta_full_graph = radius_mr.merge_yuid_and_devid_graphs(
                crypta_graph, device_crypta_graph)

            print '===Only Devices===> ' + vertices_type
            calculate_metrics_and_persist('only_devices' + postfix,
                                          radius_mr.devid_to_yuid(device_rlogin_graph),
                                          radius_mr.devid_to_yuid(device_crypta_graph),
                                          base_workdir, date)
            print '===With Devices===> ' + vertices_type
            calculate_metrics_and_persist('with_devices' + postfix,
                                          radius_mr.devid_to_yuid(rlogin_full_graph),
                                          radius_mr.devid_to_yuid(crypta_full_graph),
                                          base_workdir, date)



def metrics_diff(m1, m2, workdir):
    # TODO: calculate diffs in regular process for easy investigation
    mr.sort_all([m1, m2], sort_by='login')

    join_columns = ["crypta_yuids", "login_crypta_ids", "login_yuids",
                    "prec", "prec_a", "prec_b", "rec", "rec_a", "rec_b"]

    yt.run_reduce(partial(mr.join_left_right, l_cols=join_columns, r_cols=join_columns),
                  [m1, m2], workdir + 'metrics_diff',
                  reduce_by='login')
    yt.run_sort(workdir + 'metrics_diff', sort_by='prec_b_1')


class PrepareYuidRadiusData(base_luigi_task.BaseTask):
    date = luigi.Parameter()

    def requires(self):

        return [
            graph_watch_log.ImportWatchLogDayTask(date=self.date, run_date=self.date),
            graph_dict.YuidAllIdDictsTask(date=self.date),
            FilterByRadius(date=self.date)]

    def run(self):
        dict_dir = config.GRAPH_YT_DICTS_FOLDER
        yuid_with_all_dict = yt.TablePath(dict_dir + 'yuid_with_all', columns=radius_mr.YUID_ALL_COLUMNS)

        workdir = config.RADIUS_METRICS_YT_FOLDER + self.date + '/yuid_rlogin/'
        mr.mkdir(workdir)

        radius_logs, yuid_watch_logs = get_available_log_tables(self.date, int(config.STORE_DAYS),
                                                                'raw_links/watch_log_filtered_by_radius')
        if radius_logs and yuid_watch_logs:
            prepare_yuid_radius_data(radius_logs, yuid_watch_logs, workdir, yuid_with_all_dict)
        else:
            logger.error('Can\'t prepare yuid radius data')

    def output(self):
        return [
            yt_luigi.YtTarget(config.RADIUS_METRICS_YT_FOLDER + self.date + '/yuid_rlogin/yuid_rlogin_final_stable')]


class PrepareDevidRadiusData(base_luigi_task.BaseTask):
    date = luigi.Parameter()

    def requires(self):
        return [
            app_metrica_day.ImportAppMetrikaDayTask(date=self.date, run_date=self.date),
            FilterByRadius(date=self.date), ]

    def run(self):
        ndays = int(config.STORE_DAYS)
        workdir = config.RADIUS_METRICS_YT_FOLDER + self.date + '/devid_rlogin/'
        mr.mkdir(workdir)

        radius_logs, device_watch_logs = \
            get_available_log_tables(self.date, ndays, 'mobile/mmetrika_log_filtered_by_radius')
        if radius_logs and device_watch_logs:
            prepare_device_radius_data(radius_logs, device_watch_logs, workdir)
        else:
            logger.error('Can\'t prepare device radius data')

    def output(self):
        return [yt_luigi.YtTarget(config.RADIUS_METRICS_YT_FOLDER + self.date + '/devid_rlogin/device_rlogin_final')]


class RadiusMetricsForVertices(base_luigi_task.BaseTask):
    vertices_config = luigi.Parameter()
    calculate_device_metrics = luigi.Parameter(default=True)

    def requires(self):
        return [PrepareYuidRadiusData(self.vertices_config.date),
                PrepareDevidRadiusData(self.vertices_config.date),
                self.vertices_config.producing_task]

    def run(self):
        vertices_t = self.vertices_config.get_vertices_table()

        rlogin_in_f = config.RADIUS_METRICS_YT_FOLDER + self.vertices_config.date + '/'
        cid_workdir = rlogin_in_f + self.vertices_config.vertices_type + '/'

        logger.info("Preparing crypta splices...")
        # TODO: make two separate luigi tasks to speed up
        mr.mkdir(cid_workdir + 'yuid_cid')
        # copy radius data to avoid collision in parallel vertices calculation
        mr.copy(rlogin_in_f + 'yuid_rlogin/yuid_rlogin_final_stable', cid_workdir + 'yuid_cid/yuid_rlogin_final_stable')
        prepare_yuid_crypta_id_data(vertices_t, cid_workdir + 'yuid_cid/')

        if self.calculate_device_metrics:
            # TODO: make two separate luigi tasks to speed up
            mr.mkdir(cid_workdir + 'devid_cid')
            # copy radius data to avoid collision in parallel vertices calculation
            mr.copy(rlogin_in_f + 'devid_rlogin/device_rlogin_final', cid_workdir + 'devid_cid/device_rlogin_final')
            prepare_device_crypta_id_data(vertices_t, cid_workdir + 'devid_cid/')

        # TODO: make separate luigi task after two above
        rlogin_out_f = rlogin_in_f

        calculate_all_metrics(rlogin_out_f, cid_workdir, self.vertices_config.vertices_type,
                              self.vertices_config.date, self.calculate_device_metrics)

    def output(self):
        metrics_types = ['all', 'desktop', 'mobile']  # 'only_devices', 'with_devices']
        #  currently not calculated for all
        prefix = config.RADIUS_METRICS_YT_FOLDER + self.vertices_config.date + '/metrics_'
        postfix = '_' + self.vertices_config.vertices_type
        return [yt_luigi.YtTarget(prefix + t + postfix, allow_empty=True) for t in metrics_types]


def list_date_tables(folder, table_postfix):
    import re
    dt_pattern = re.compile('(\d{4}-\d{2}-\d{2})')
    dates = [dt for dt in yt.list(folder[:-1]) if dt_pattern.match(dt)]
    dt_tables = [folder + dt + table_postfix for dt in dates]
    return [(dt, t) for dt, t in zip(dates, dt_tables) if yt.exists(t)]


def login_diff(login_key, recs):
    logins20, logins21 = mr.split_left_right(recs)

    out_rec = {'login': login_key['login']}

    if logins20 and logins21:
        for k, v in logins20[0].iteritems():
            out_rec[k + '_1'] = v
        for k, v in logins21[0].iteritems():
            out_rec[k + '_2'] = v
        out_rec['rec_diff'] = float(out_rec['rec_2'] - out_rec['rec_1'])

    elif logins20:
        for k, v in logins20[0].iteritems():
            out_rec[k + '_1'] = v
        out_rec['rec_diff'] = -2.0

    elif logins21:
        for k, v in logins21[0].iteritems():
            out_rec[k + '_2'] = v
        out_rec['rec_diff'] = 100.0

    yield out_rec


def compare_metrics_two_days(day1, day2):
    t1 = '//home/crypta/production/state/radius/metrics/%s/metrics_all_exact_cluster' % day1
    t2 = '//home/crypta/production/state/radius/metrics/%s/metrics_all_exact_cluster' % day2
    out_t = '//home/crypta/production/state/radius/metrics/%s/metrics_all_exact_cluster_diff' % day2

    mr.sort_all([t1, t2], sort_by='login')

    yt.run_reduce(login_diff, [
        t1,
        t2
    ], out_t, reduce_by='login')

    yt.run_sort(out_t, sort_by='rec_diff')


def reduce_rlogin(login_key, recs):
    recs = list(recs)
    ua_profiles = dict()
    for r in recs:
        ua_profiles[r['yuid']] = r['ua_profile']

    left_login, right_login = mr.split_left_right(recs)
    out_rec = dict(login_key)

    if left_login and right_login:
        left_yuids = set(r['yuid'] for r in left_login)
        right_yuids = set(r['yuid'] for r in right_login)

        left_diff = left_yuids.difference(right_yuids)
        right_diff = right_yuids.difference(left_yuids)

        if left_diff and right_diff:
            out_rec['@table_index'] = 0
        elif left_diff:
            out_rec['@table_index'] = 1
        elif right_diff:
            for r in right_diff:
                yield {'login': login_key['login'], 'yuid': r, 'ua_profile': ua_profiles[r], '@table_index': 6}
            out_rec['@table_index'] = 2
        else:
            out_rec['@table_index'] = 3

        yield out_rec

    elif left_login:
        out_rec['@table_index'] = 4
        yield out_rec
    elif right_login:
        out_rec['@table_index'] = 5
        yield out_rec


def compare_rlogins(day1, day2):
    out_dir = '//home/crypta/production/state/radius/metrics/%s/yuid_rlogin/diff/' % day2
    mr.mkdir(out_dir)

    t1 = '//home/crypta/production/state/radius/metrics/%s/yuid_rlogin/yuid_rlogin_final_stable' % day1
    t2 = '//home/crypta/production/state/radius/metrics/%s/yuid_rlogin/yuid_rlogin_final_stable' % day2
    yt.run_reduce(reduce_rlogin,
                  [t1,
                   t2],
                  [out_dir + 'yuids_different',
                   out_dir + 'left_yuids_more',
                   out_dir + 'right_yuids_more',
                   out_dir + 'yuids_same',
                   out_dir + 'login_outdated',
                   out_dir + 'login_new',
                   out_dir + 'right_yuids_more_ua_profile',
                   ],
                  reduce_by='login')


if __name__ == '__main__':
    yt.config.set_proxy(config.MR_SERVER)
    yt.config["tabular_data_format"] = yt.YsonFormat(process_table_index=True)


    new_yuids_count = list()

    for day in range(17, 31):
        day1 = '2017-05-%02d' % (day - 1)
        day2 = '2017-05-%02d' % day
        # compare_metrics_two_days('2017-05-18', '2017-05-19')
        # compare_rlogins(day1, day2)

        x = yt.row_count('//home/crypta/production/state/radius/metrics/%s/yuid_rlogin/diff/right_yuids_more_ua_profile' % day2)
        y = yt.row_count('//home/crypta/production/state/radius/metrics/%s/yuid_rlogin/diff/right_yuids_more' % day2)
        new_yuids_count.append((x, y))

    for day in range(2, 7):
        day1 = '2017-06-%02d' % (day - 1)
        day2 = '2017-06-%02d' % day
        # compare_metrics_two_days('2017-05-18', '2017-05-19')
        # compare_rlogins(day1, day2)

        x = yt.row_count('//home/crypta/production/state/radius/metrics/%s/yuid_rlogin/diff/right_yuids_more_ua_profile' % day2)
        y = yt.row_count('//home/crypta/production/state/radius/metrics/%s/yuid_rlogin/diff/right_yuids_more' % day2)
        new_yuids_count.append((x, y))

    print new_yuids_count
