import itertools
import logging
from collections import defaultdict
from functools import partial

import luigi
import yt.wrapper as yt

import graph_stat
from lib.luigi import yt_luigi
from rtcconf import config
from utils import mr_utils as mr
from utils import utils

ACC_INDICES = {150: 0, 500: 1, 1000: 2, 10000: 3}


def has_fields(rec, fields):
    for field in fields:
        if field not in rec or not rec.get(field):
            return False
    return True


def map_geo_yuid(rec):
    try:
        if not has_fields(rec, ['yauid', 'lat', 'lon']):
            return

        yuid = rec['yauid']
        acc = rec.get('acc', None)
        t = rec.get('type', '')

        if t.startswith('map_navig') or t.startswith('analyzer_track'):
            acc_norm = 150
        else:
            try:
                acc_norm = normalize_accuracy(int(float(acc)))
            except ValueError:
                return

        gst = rec.get('geo_source_types', '')
        ip = rec.get('ip', '')

        yield {'key': yuid, 'acc': acc_norm, 'geo_source_types': gst, 'ip': ip, '@table_index': 0}
        if ip:
            yield {'key': yuid, 'acc': acc_norm, 'geo_source_types': gst, 'ip': ip, '@table_index': 1}
    except Exception as e:
        rec['_error'] = str(e)
        rec['@table_index'] = 2
        yield rec


def normalize_accuracy(acc):
    if not acc:
        return 10000
    elif acc <= 150:
        return 150
    elif acc <= 500:
        return 500
    elif acc <= 1000:
        return 1000
    else:
        return 10000

def reduce_split_by_acc_ranges_keep_geo_unique(_, recs):
    lrecs = list(recs)
    desk_recs = [r for r in lrecs if 'device_type' in r and r['device_type'] == 'desk']
    mob_recs = [r for r in lrecs if 'device_type' in r and r['device_type'] == 'mob']
    geo_recs = [r for r in lrecs if 'acc' in r]

    out_geo_sources_per_aggr_acc = defaultdict(set)
    out_geo_hits_per_aggr_acc = defaultdict(int)
    aggregated_geo_accs = set()

    # all less accurate activity includes more accurate
    for geo_rec in geo_recs:
        for acc in ACC_INDICES.keys():
            if geo_rec['acc'] <= acc:
                aggregated_geo_accs.add(acc)
                geo_source_types = geo_rec['geo_source_types'].split(',')
                out_geo_sources_per_aggr_acc[acc].update(geo_source_types)
                out_geo_hits_per_aggr_acc[acc] += geo_rec['hits']

    def get_geo_sources_str(aggr_geo_acc):
        return ','.join(sorted(out_geo_sources_per_aggr_acc[aggr_geo_acc]))

    # if yuid has geo
    if aggregated_geo_accs:
        if desk_recs:
            desk_rec = desk_recs[0]
            for aggr_geo_acc in aggregated_geo_accs:
                yuid_source_types = desk_rec['source_types']
                yield {'key': desk_rec['key'], 'device_type': desk_rec['device_type'],
                       'source_types': yuid_source_types,
                       'geo_source_types': get_geo_sources_str(aggr_geo_acc),
                       'hits': out_geo_hits_per_aggr_acc[aggr_geo_acc],  # count geo hits, not desk hits
                       'aggr_acc': aggr_geo_acc,
                       '@table_index': ACC_INDICES[aggr_geo_acc]}

        if mob_recs:
            mob_rec = mob_recs[0]
            for aggr_geo_acc in aggregated_geo_accs:
                yuid_source_types = mob_rec['source_types']
                yield {'key': mob_rec['key'], 'device_type': mob_rec['device_type'],
                       'source_types': yuid_source_types,
                       'geo_source_types': get_geo_sources_str(aggr_geo_acc),
                       'hits': out_geo_hits_per_aggr_acc[aggr_geo_acc],  # count geo hits, not mob hits
                       'aggr_acc': aggr_geo_acc,
                       '@table_index': ACC_INDICES[aggr_geo_acc] + len(ACC_INDICES)}


def reduce_geo_ip(key, recs):
    recs = list(recs)

    out_geo_sources_per_aggr_acc = defaultdict(set)
    out_geo_hits_per_aggr_acc = defaultdict(int)
    aggregated_geo_accs = set()

    for geo_rec in recs:
        for acc in ACC_INDICES.keys():
            if geo_rec['acc'] <= acc:
                aggregated_geo_accs.add(acc)
                geo_source_types = geo_rec['geo_source_types']
                out_geo_sources_per_aggr_acc[acc].update(geo_source_types)
                out_geo_hits_per_aggr_acc[acc] += geo_rec['hits']

    for aggr_geo_acc in aggregated_geo_accs:
        yield {'key': key['ip'],
               'geo_source_types': ','.join(sorted(out_geo_sources_per_aggr_acc[aggr_geo_acc])),
               'hits': out_geo_hits_per_aggr_acc[aggr_geo_acc],
               'acc': aggr_geo_acc,
               '@table_index': ACC_INDICES[aggr_geo_acc]}


def filter_has_ip(rec):
    if rec['ip']:
        yield rec


def yuid_geo_stat_params():
    return itertools.product(['desk', 'mob'],
                             ['150', '500', '1000', 'all'],
                             [True, False],
                             [['source_types'],
                              ['geo_source_types'],
                              ['source_types', 'geo_source_types']])


def ip_geo_stat_params():
    return itertools.product([''],
                             ['150', '500', '1000', 'all'],
                             [True, False],
                             [['geo_source_types']])


def geo_stat_by_geo_log(dt, in_f, out_f):
    if not mr.all_exists(config.STATBOX_RTGEO_FOLDER + dt, in_f + 'desk_yuids_rus', in_f + 'mob_yuids_rus'):
        return

    yt.run_map(map_geo_yuid, config.STATBOX_RTGEO_FOLDER + dt,
               [out_f + 'geo_log', out_f + 'geo_log_has_ip', out_f + 'geo_log_errors'])

    # per yuid sources
    yuid_out_f = out_f + 'yuid/'
    mr.mkdir(yuid_out_f)

    # use 'key' column as yuid
    yt.run_sort(out_f + 'geo_log', sort_by=['key', 'geo_source_types', 'acc'])
    yt.run_reduce(partial(graph_stat.group_by_and_count_hits, group_columns=['key', 'geo_source_types', 'acc']),
                  out_f + 'geo_log', yuid_out_f + 'geo_hits_grouped',
                  reduce_by=['key', 'geo_source_types', 'acc'])
    mr.sort_all([
        yuid_out_f + 'geo_hits_grouped',
        in_f + 'desk_yuids_rus',
        in_f + 'mob_yuids_rus'
    ], 'key')

    yt.run_reduce(reduce_split_by_acc_ranges_keep_geo_unique,
                  [in_f + 'desk_yuids_rus', in_f + 'mob_yuids_rus',
                   yuid_out_f + 'geo_hits_grouped'],
                  [yuid_out_f + 'desk_yuids_geo_150', yuid_out_f + 'desk_yuids_geo_500',
                   yuid_out_f + 'desk_yuids_geo_1000', yuid_out_f + 'desk_yuids_geo_all',
                   yuid_out_f + 'mob_yuids_geo_150', yuid_out_f + 'mob_yuids_geo_500',
                   yuid_out_f + 'mob_yuids_geo_1000', yuid_out_f + 'mob_yuids_geo_all'],
                  reduce_by='key')

    ops = []
    for device, acc, flatten, source_columns in yuid_geo_stat_params():
        table_name = device + '_yuids_geo_' + acc
        ops.append(graph_stat.sum_by_sources(yuid_out_f, table_name, source_columns, flatten=flatten))

    # per ip sources
    ip_out_f = out_f + 'ip/'
    mr.mkdir(ip_out_f)

    yt.run_sort(out_f + 'geo_log_has_ip', sort_by=['ip', 'geo_source_types', 'acc'])
    yt.run_reduce(partial(graph_stat.group_by_and_count_hits, group_columns=['ip', 'geo_source_types', 'acc']),
                  out_f + 'geo_log_has_ip', ip_out_f + 'geo_hits_grouped_ip',
                  reduce_by=['ip', 'geo_source_types', 'acc'])

    yt.run_sort(ip_out_f + 'geo_hits_grouped_ip', sort_by='ip')
    yt.run_reduce(reduce_geo_ip,
                  ip_out_f + 'geo_hits_grouped_ip',
                  [ip_out_f + 'ips_geo_150', ip_out_f + 'ips_geo_500',
                   ip_out_f + 'ips_geo_1000', ip_out_f + 'ips_geo_all'],
                  reduce_by='ip')

    for device, acc, flatten, source_columns in ip_geo_stat_params():
        table_name = device + 'ips_geo_' + acc
        ops.append(graph_stat.sum_by_sources(ip_out_f, table_name, source_columns, flatten=flatten))

    utils.wait_all(ops)

    graph_stat.sum_sources_to_total_in_dir(yuid_out_f, ['source_types', 'geo_source_types'])
    graph_stat.sum_sources_to_total_in_dir(yuid_out_f, ['source_types'])
    graph_stat.sum_sources_to_total_in_dir(yuid_out_f, ['geo_source_types'])
    graph_stat.sum_sources_to_total_in_dir(ip_out_f, ['geo_source_types'])


class GeoCoverageStats(yt_luigi.BaseYtTask):
    """
    Calculates how many yuids and ips are covered by RT Geo service
    """
    date = luigi.Parameter()

    def input_folders(self):
        return {
            'all_stat': config.YT_OUTPUT_FOLDER + self.date + '/stat_new/all/'
        }

    def output_folders(self):
        return {
            'geo_stat': config.YT_OUTPUT_FOLDER + self.date + '/stat_new/geo/'
        }

    def requires(self):
        return graph_stat.PrepareTodayTotalUsageStats(self.date)

    def run(self):
        all_stat_folder = self.in_f('all_stat')
        geo_stat_folder = self.out_f('geo_stat')
        mr.mkdir(geo_stat_folder)
        geo_stat_by_geo_log(self.date, all_stat_folder, geo_stat_folder)

    def output(self):
        out_folder = self.out_f('geo_stat')
        # TODO: add all output tables
        return [yt_luigi.YtTarget(out_folder + 'yuid/sum_by_source_types/desk_yuids_geo_all_count_total'),
                yt_luigi.YtTarget(out_folder + 'yuid/sum_by_source_types/mob_yuids_geo_all_count_total')]


if __name__ == '__main__':
    yt.config.set_proxy(config.MR_SERVER)
    yt.config["tabular_data_format"] = yt.YsonFormat(process_table_index=True)

    logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO)

    in_f = '//home/crypta/production/state/graph/2017-04-02/stat_new/all/'
    yuid_out_f = '//home/crypta/production/state/graph/2017-04-02/stat_new/geo/yuid/'
    yuid_out_f1 = '//home/crypta/production/state/graph/2017-04-02/stat_new/geo/yuid1/'
    yuid_geo1 = '//home/crypta/production/state/graph/2017-04-02/stat_new/geo1/'

    mr.mkdir(yuid_out_f1)
    mr.mkdir(yuid_geo1)

    # ops = []
    # for device, acc, flatten, source_columns in yuid_geo_stat_params():
    #     table_name = device + '_yuids_geo_' + acc
    #     ops.append(graph_stat.sum_by_sources(yuid_out_f, table_name, source_columns, flatten=flatten))
    #
    # utils.wait_all(ops)

    # yt.run_reduce(reduce_split_by_acc_ranges_keep_geo_unique,
    #               [in_f + 'desk_yuids_rus', in_f + 'mob_yuids_rus',
    #                yuid_out_f + 'geo_hits_grouped'],
    #               [yuid_out_f1 + 'desk_yuids_geo_150', yuid_out_f1 + 'desk_yuids_geo_500',
    #                yuid_out_f1 + 'desk_yuids_geo_1000', yuid_out_f1 + 'desk_yuids_geo_all',
    #                yuid_out_f1 + 'mob_yuids_geo_150', yuid_out_f1 + 'mob_yuids_geo_500',
    #                yuid_out_f1 + 'mob_yuids_geo_1000', yuid_out_f1 + 'mob_yuids_geo_all'],
    #               reduce_by='key')

    # ops = []
    # for device, acc, flatten, source_columns in yuid_geo_stat_params():
    #     table_name = device + '_yuids_geo_' + acc
    #     ops.append(graph_stat.sum_by_sources(yuid_out_f1, table_name, source_columns, flatten=flatten))
    #
    # utils.wait_all(ops)


    geo_stat_by_geo_log('2017-03-02', in_f, yuid_geo1)

