
import os
import json
import glob
import argparse
import datetime
import yt.wrapper as yt
from collections import OrderedDict, defaultdict, namedtuple

SkuKey = namedtuple('SkuKey', ['segment', 'sku', 'val_type'])


def parse_args():

    parser = argparse.ArgumentParser(description='Put raw data to yt')
    parser.add_argument('--yt-cluster', help='yt cluster', required=True, dest='yt_cluster')
    parser.add_argument('--yt-root-path', help='yt root path', required=True, dest='yt_root_path')
    parser.add_argument('--yt-table-name', help='yt table name', required=False, dest='yt_table_name')
    parser.add_argument('--data-path', help='data path', required=True, dest='data_path')
    parser.add_argument('--date-string', help='date string', required=False, dest='date_string')

    return parser.parse_args()


def fix_path(path):
    return os.path.realpath(os.path.expanduser(path))


def get_current_date_string():
    current_date = datetime.datetime.now()
    return current_date.strftime("%Y/%m/%d")


def put_raw_daily_data(yt_client, data, yt_root_path, string_date, table_name):

    day_path = yt_root_path + "/daily/" + string_date
    yt_client.mkdir(day_path, recursive=True)

    raw_table_path = day_path + "/raw_" + table_name
    yt_client.create("table", raw_table_path,
                     ignore_existing=True,
                     attributes={
                         "schema":
                             [
                                 {"name": "invnum", "type": "int64"},
                                 {"name": "segment", "type": "string"},
                                 {"name": "fqdn", "type": "string"},
                                 {"name": "elementary_resources", "type": "any"},
                                 {"name": "elementary_resources_costs", "type": "any"},
                                 {"name": "quota", "type": "any"},
                                 {"name": "quota_costs", "type": "any"},
                                 {"name": "quota_cost_operations", "type": "any"},
                                 {"name": "segmentation", "type": "any"},
                             ]
                     }
                     )
    yt_client.write_table(raw_table_path, data)


def aggregate_row(row, total_dict, source_type):

    segment_list = row['segment'].split('.')
    segment = '.'.join(segment_list[:2] if len(segment_list) > 1 else segment_list)

    row_sources = row.get(source_type, {}) or {}
    for sku, sku_vals in row_sources.iteritems():
        for val_type, val in sku_vals.iteritems():
            total_dict[SkuKey(segment, sku, val_type)] += val


def aggregated_value_dict_to_list(data_dict):
    data = []
    for k, v in data_dict.iteritems():
        data.append({
            "segment": k.segment,
            "sku": k.sku,
            "val_type": k.val_type,
            "value": v,
        })
    return data


def aggregated_segment_dict_to_list(data_dict):
    data = []
    for k, v in data_dict.iteritems():
        data.append({
            "segment": k,
            "processed": v.get("processed", 0) or 0,
            "unprocessed": v.get("unprocessed", 0) or 0,
        })
    return data


def get_daily_agregated_data(yt_client, yt_root_path, string_date):

    yt_day_path = yt_root_path + "/daily/" + string_date

    total_quota = defaultdict(float)
    total_quota_costs = defaultdict(float)
    total_quota_cost_operations = defaultdict(float)
    total_segment_count = defaultdict(lambda: defaultdict(int))
    for path in yt_client.list(yt_day_path):
        if not path.startswith('raw_'):
            continue

        table = yt_client.read_table(yt_day_path + "/" + path)
        for row in table:
            segment_list = row['segment'].split('.')
            segment = '.'.join(segment_list[:2] if len(segment_list) > 1 else segment_list)

            if len(row.get('quota_costs', {}) or {}):
                total_segment_count[segment]['processed'] += 1
            else:
                total_segment_count[segment]['unprocessed'] += 1
            aggregate_row(row, total_quota, 'quota')
            aggregate_row(row, total_quota_costs, 'quota_costs')
            aggregate_row(row, total_quota_cost_operations, 'quota_cost_operations')

    return (aggregated_value_dict_to_list(total_quota), aggregated_value_dict_to_list(total_quota_costs),
            aggregated_value_dict_to_list(total_quota_cost_operations),
            aggregated_segment_dict_to_list(total_segment_count))


def write_agregated_value_data(yt_client, table_path, data):
    yt_client.create("table", table_path,
                     ignore_existing=True,
                     attributes={
                         "schema":
                             [
                                 {"name": "segment", "type": "string"},
                                 {"name": "sku", "type": "string"},
                                 {"name": "val_type", "type": "string"},
                                 {"name": "value", "type": "double"},
                             ]
                     }
                     )

    yt_client.write_table(table_path, data)


def write_yt_segment_data(yt_client, table_path, data):
    yt_client.create("table", table_path,
                     ignore_existing=True,
                     attributes={
                         "schema":
                             [
                                 {"name": "segment", "type": "string"},
                                 {"name": "processed", "type": "int64"},
                                 {"name": "unprocessed", "type": "int64"},
                             ]
                     }
                     )

    yt_client.write_table(table_path, data)


def update_daily_aggregates(yt_client, yt_root_path, string_date):

    yt_day_path = yt_root_path + "/daily/" + string_date
    quota, quota_costs, quota_cost_operations, segment_count = get_daily_agregated_data(yt_client, yt_root_path,
                                                                                        string_date)
    write_agregated_value_data(yt_client, yt_day_path + '/total_quota', quota)
    write_agregated_value_data(yt_client, yt_day_path + '/total_quota_costs', quota_costs)
    write_agregated_value_data(yt_client, yt_day_path + '/total_quota_cost_operations', quota_cost_operations)

    write_yt_segment_data(yt_client, yt_day_path + '/total_segment_count', segment_count)


def get_data(path):
    with open(path, 'r') as f:
        return json.load(f, object_pairs_hook=OrderedDict)


if __name__ == '__main__':
    args = parse_args()

    yt_cluster = args.yt_cluster
    root_path = args.yt_root_path
    client = yt.YtClient(yt_cluster)

    data_path = fix_path(args.data_path)

    for file_path in glob.glob(data_path):

        yt_table_name = args.yt_table_name if args.yt_table_name else os.path.splitext(os.path.basename(file_path))[0]
        date_string = args.date_string if args.date_string else get_current_date_string()

        test_data = get_data(file_path)
        put_raw_daily_data(client, test_data, root_path, date_string, yt_table_name)

    update_daily_aggregates(client, root_path, date_string)


