import os
from multiprocessing.pool import ThreadPool as Pool

import yt.wrapper as yt
import yt.yson as yson

BASE_IN_PATH = '//home/taxi/testing/export/taxi-logistic-platform-production'
BASE_MARKET_IN_PATH = '//home/taxi/testing/export/taxi-logistic-platform-market-production'
BASE_OUT_PATH = '//home/taxi-delivery/analytics/production/logistic-platform'
BASE_MARKET_OUT_PATH = '//home/taxi-delivery/analytics/production/logistic-platform-market'


def get_layer_paths(yt_client):
    operation_id = os.environ.get('NV_YT_OPERATION_ID')
    if operation_id:
        attrs = yt_client.get_operation_attributes(operation_id)
        layers = attrs['full_spec']['tasks']['task']['layer_paths']
        print(f'Layers of parent task {layers}')
        return layers


tables = [
    ('employers', 'employer_id'),
    ('node_reservations', 'reservation_id'),
    ('operator_events', 'internal_event_id'),
    ('payment', 'id'),
    ('planned_nodes', 'node_id'),
    ('planned_transfers', 'transfer_id'),
    #('requested_actions', 'action_id'),
    ('requests', 'request_id'),
    ('resource_items', 'internal_item_id'),
    ('resource_places', 'internal_place_id'),
    ('resources', 'resource_id'),
    ('station_tags', 'tag_id'),
    ('stations', 'station_id'),
]


def main():
    pool = Pool(len(tables))
    pool.map(run, tables)


def mapper(row):
    if row.get('history_action') != 'remove':
        yield {k: str(v) if v is not None and not isinstance(v, dict) else v for k, v in row.items()}


def reducer(key, rows):
    rows = list(rows)
    result = max(rows, key=lambda row: int(row['history_event_id']))
    result['history_start_timestamp'] = str(min(int(row['history_timestamp']) for row in rows))
    yield result


def get_first_row(yt_client, table):
    return next(yt_client.read_table(yt_client.TablePath(table, exact_index=0)))

def run(pair):
    table_name, unique_key = pair
    process(table_name, unique_key, BASE_IN_PATH, BASE_OUT_PATH)
    # process(table_name, unique_key, BASE_MARKET_IN_PATH, BASE_MARKET_OUT_PATH)

def process(table_name, unique_key, in_path, out_path):
    yt_client = yt.YtClient(token=os.environ['YT_TOKEN'], proxy='hahn.yt.yandex.net')

    spec = None
    layer_paths = get_layer_paths(yt_client)
    if layer_paths:
        spec = {
            'reducer': {'layer_paths': layer_paths},
            'mapper': {'layer_paths': layer_paths},
        }

    def infer_schema(tables, unique_key=None):
        all_keys = set()
        for table in tables:
            first_row = get_first_row(yt_client, table)
            all_keys.update(first_row.keys())
        schema = yson.YsonList([])
        if unique_key:
            schema.append({'name': unique_key, 'type': 'string', 'sort_order': 'ascending'})
        has_history_start = False
        for key in all_keys:
            if key == unique_key:
                continue
            if key == 'history_start_timestamp':
                has_history_start = True
            column_shema = {'name': key, 'type': 'string'}
            if key == 'unpacked_data' or key.startswith('_unpacked'):
                column_shema['type'] = 'any'
            schema.append(column_shema)
        if unique_key:
            schema.attributes['unique_keys'] = True
        if not has_history_start:
            schema.append({'name': 'history_start_timestamp', 'type': 'string'})
        return schema

    in_folder = '%s/%s_history' % (in_path, table_name)
    out_table = '%s/%s' % (out_path, table_name)

    if not yt_client.exists(in_folder):
        return

    all_tables = ['%s/%s' % (in_folder, table) for table in sorted(yt_client.list(in_folder))]

    last_processed_table = None
    if yt_client.exists(out_table):
        last_processed_table = yt_client.get_attribute(out_table, '_last_processed_table_', default=None)

    if last_processed_table is None:
        print(f'Last processed table not found - calculating from scratch')
        print('Calculating result table schema...')
        out_schema = infer_schema(all_tables, unique_key)
        yt_client.create('table', out_table, force=True, attributes={'schema': out_schema, 'optimize_for': 'scan'})
        in_tables = all_tables
    else:
        print(f'Last processed table: {last_processed_table}')
        in_tables = [out_table] + [table for table in all_tables if table >= last_processed_table]
        print('Calculating result table schema...')
        out_schema = infer_schema(in_tables, unique_key)
        yt_client.alter_table(out_table, schema=out_schema)

    for table in in_tables:
        print(f'Processing: {table}')
    print('Calculating temp file schema...')
    tmp_schema = infer_schema(in_tables)

    tmp_schema = infer_schema(in_tables)
    print('Merging to resulting table...')
    with yt_client.TempTable(attributes={'schema': tmp_schema, 'optimize_for': 'scan'}) as tmp:
        yt_client.run_map_reduce(mapper, reducer, in_tables, tmp, reduce_by=unique_key, spec=spec)
        yt_client.run_sort(tmp, out_table, sort_by=unique_key)

    print(f'Last processed table: {in_tables[-1]}')
    yt_client.set_attribute(out_table, '_last_processed_table_', in_tables[-1])


if __name__ == '__main__':
    main()
