# -*- encoding: utf-8 -*-
import os
import sys
import time
from datetime import datetime
from optparse import OptionParser

import yt.wrapper as yt

ROOT_DIR = '//home/rasp'
BACKUP_PATH = '{root_dir}/backup'.format(root_dir=ROOT_DIR)
IGNORE_FIELDS = ['source_uri', '_stbx', 'subkey']


def uniq_reduce(key, records):
    all_records = []

    for r in records:
        all_records.append(r)

    if len(all_records) > 0:
        yield all_records[0]


def yt_last_logs_tables(path, modification_date):
    tables = []
    for table in yt.search(path, node_type="table", attributes="modification_time"):
        if modification_date:
            table_mod_date = datetime.strptime(
                table.attributes['modification_time'][:10],
                '%Y-%m-%d'
            ).date()

            if table_mod_date == modification_date:
                tables.append(table.title().lower())

    sorded_tables = sorted(tables)

    return sorded_tables


def main():
    import travel.avia.admin.init_project  # noqa

    import logging

    from django.conf import settings

    from travel.avia.admin.lib.logs import add_stdout_handler, create_current_file_run_log
    from travel.avia.admin.lib.yt_helpers import configure_wrapper

    log = logging.getLogger(__name__)
    create_current_file_run_log()

    optparser = OptionParser()
    optparser.add_option("-p", "--path", dest="path")
    optparser.add_option("-m", "--modification_date", dest="modification_date")
    options, _args = optparser.parse_args()

    add_stdout_handler(log)

    configure_wrapper(yt)

    log.info('Start')

    if not yt.exists(options.path):
        log.info('Table not found: %s', options.path)
        sys.exit(1)

    node_type = yt.get_attribute(options.path, "type")
    if node_type == "map_node":
        modification_date = datetime.strptime(options.modification_date, "%Y-%m-%d").date()
        source_tables = yt_last_logs_tables(options.path, modification_date)
    elif node_type == "table":
        source_tables = [options.path]
    else:
        log.info('Unknown node_type "%s" for %s', node_type, options.path)
        sys.exit(1)

    log.info('Found %s tables', len(source_tables))

    for source_table in source_tables:
        old_rows_count = int(yt.get_attribute(source_table, 'row_count'))

        log.info('Process %s with %s records', source_table, old_rows_count)
        backup_table_name = '{backup_path}/{table_path}_{timestamp}'.format(
            backup_path=BACKUP_PATH,
            table_path=source_table[len(ROOT_DIR)+1:],
            timestamp=int(time.time())
        )

        back_up_path = os.path.dirname(backup_table_name)

        if not yt.exists(back_up_path):
            log.info('Create map_node: %s', back_up_path)
            yt.create("map_node", path=back_up_path, recursive=True)

        log.info('Move %s to %s', source_table, backup_table_name)
        yt.move(source_table, backup_table_name)

        tmp_table = yt.create_temp_table()

        # Посчитаем набор полей у таблицы
        # Так довольно медленно, но я не придумал ничего надежней
        log.info('Calc all fields in %s', backup_table_name)
        table_fields_set = set()

        for x, record in enumerate(yt.read_table(backup_table_name, format=yt.JsonFormat(), raw=False)):
            if x % 100000 == 0:
                log.info('Records processed: %s', x)

            table_fields_set.update(record.keys())

        for field in IGNORE_FIELDS:
            table_fields_set.discard(field)

        current_fields_list = list(table_fields_set)
        log.info('Sort table: %s by %s', backup_table_name, current_fields_list)

        yt.run_sort(
            source_table=backup_table_name,
            destination_table=tmp_table,
            sort_by=current_fields_list,
        )

        log.info('Final reduce %s to %s', tmp_table, source_table)
        yt.run_reduce(
            uniq_reduce,
            tmp_table,
            source_table,
            format=yt.DsvFormat(),
            reduce_by=current_fields_list,
            spec={"data_size_per_job": settings.YT_DATA_SIZE_PER_JOB},
        )
        new_rows_count = int(yt.get_attribute(source_table, 'row_count'))
        delta_rows_count = old_rows_count - new_rows_count

        log.info('%s duplicates removed (%s / %s)', delta_rows_count, old_rows_count, new_rows_count)

    log.info('Done')
