# -*- encoding: utf-8 -*-

import os
import sys
import time
from datetime import datetime, timedelta
from collections import defaultdict
from optparse import OptionParser
from pytz import timezone

import yt.wrapper as yt
import yt.logger_config as yt_logger_config
import yt.logger as yt_logger


EXTRA_ATTRS = {
}


ALLOWED_ENVS = ['production', 'dev']
STAGING_PATH = '//home/rasp/staging'
IMPORT_PATH = '//home/rasp/import'
LOGS_PATH = '//home/rasp/logs'
MAP_NODE_TTL_DAYS = 2
LOCK_TIME_MS = 2000
HOLD_IMPORT_TABLES_MINS = 15
MAX_MERGE_BATCH_SIZE = 149
EXCLUDE_FROM_IMPORT = ('//home/rasp/import/statbox-logbroker-state')


def get_5_min_tables(log):
    tables = []

    msk_timezone = timezone('Europe/Moscow')

    for table in yt.search(IMPORT_PATH, node_type=['table'], attributes=["modification_time"]):
        table_name, table_date, table_time = table.split('/')[-3:]

        if table.startswith(EXCLUDE_FROM_IMPORT):
            continue

        modification_date = datetime.strptime(
            table.attributes['modification_time'][:19], '%Y-%m-%dT%H:%M:%S'
        ).replace(tzinfo=timezone('UTC')).astimezone(msk_timezone)

        if modification_date + timedelta(minutes=HOLD_IMPORT_TABLES_MINS) > datetime.now(msk_timezone):
            log.info('Skip %s %s' % (table, modification_date.strftime('%H:%M')))
            continue

        tables.append({
            'table_name': table_name,
            'table_date': table_date,
            'table_time': table_time,
            'table': table
        })

    return sorted(tables)


def main():
    import travel.avia.admin.init_project  # noqa

    import logging

    from django.conf import settings

    from travel.avia.admin.lib.logs import add_stdout_handler, create_current_file_run_log
    from travel.avia.admin.lib.yt_helpers import configure_wrapper

    log = logging.getLogger(__name__)
    create_current_file_run_log()

    optparser = OptionParser()

    optparser.add_option('-v', '--verbose', action='store_true')
    optparser.add_option('-p', '--proxy', dest='proxy', default=settings.YT_PROXY)

    options, args = optparser.parse_args()

    if options.verbose:
        add_stdout_handler(log)

    else:
        yt_logger_config.LOG_LEVEL = 'WARNING'
        reload(yt_logger)

    configure_wrapper(yt)
    if options.proxy != settings.YT_PROXY:
        yt.config['proxy']['url'] = options.proxy

    log.info('Start')

    current_env = settings.ENVIRONMENT
    if current_env not in ALLOWED_ENVS:
        allowed_envs_str = ', '.join(ALLOWED_ENVS)
        log.info('Current ENVIRONMENT %s. Run only %s allowed.' % (current_env, allowed_envs_str))
        sys.exit()

    # Проверим пути
    if not yt.exists(STAGING_PATH):
        log.info('Create path %s', STAGING_PATH)
        yt.create("map_node", STAGING_PATH)

    # Move to stage:
    tables = get_5_min_tables(log)
    tables_len = len(tables)

    for x, table_info in enumerate(tables):
        source_table = table_info['table']
        staging_table = '/'.join([
            STAGING_PATH,
            '%s__%s__%s' % (
                table_info['table_name'],
                table_info['table_date'],
                table_info['table_time']
            )
        ])

        with yt.Transaction():
            log.info('Lock %s', source_table)
            try:
                yt.lock(source_table, mode="exclusive")

            except:
                continue

            log.info('Move %s/%s: %s to %s', x, tables_len, source_table, staging_table)

            yt.move(
                source_table,
                staging_table
            )

    # Aggregate staging for bulk update
    sorted_tables = sorted(yt.search(STAGING_PATH, node_type=['table']))
    aggregate_tables = defaultdict(list)

    for staging_table in sorted_tables:
        try:
            table_name, table_date, _ = staging_table.split('/')[-1].split('__')
        except ValueError:
            continue

        destination_table = os.path.join(
            LOGS_PATH,
            table_name,
            table_date
        )

        aggregate_tables[destination_table].append(staging_table)

    for destination_table, staging_tables in aggregate_tables.items():
        # Завернем в транзакцию
        with yt.Transaction(timeout=60000 * 5):
            # Создадим таблицу если нет, для того, что бы эксклюзивно локнуть
            if not yt.exists(destination_table):
                log.info('Create table %s', destination_table)
                yt.create("table", destination_table, recursive=True)

                # Почему-то в рамках транзакции лок не видит новой таблицы,
                # сразу за созданием
                log.info('Sleep 10 seconds')
                time.sleep(10)

            try:
                log.info('Lock destination: %s', destination_table)
                yt.lock(destination_table, mode='exclusive')
            except:
                log.info('Can\'t lock %s', destination_table)
                continue

            # Залочим таблицы и создадим список успешно залоченных
            locked_tables = []
            for t in staging_tables:
                if len(locked_tables) >= MAX_MERGE_BATCH_SIZE:
                    break

                try:
                    log.info('Lock source: %s', t)
                    yt.lock(t, mode='exclusive')
                except:
                    log.info('Can\'t lock %s', t)
                    continue

                locked_tables.append(t)

            if len(locked_tables) == 0:
                log.info('Skip merge for %s (%s locked)', destination_table, len(locked_tables))
                continue

            log.info('Merge %s staging to %s', len(locked_tables), destination_table)

            # Смержим
            yt.run_merge(
                source_table=[destination_table] + locked_tables,
                destination_table=destination_table,
                # Уберем пока
                # mode='unordered',
                # spec={'combine_chunks': True}
            )

            # Добавим атрибуты
            try:
                destination_log_name = destination_table.split('/')[-2]
                for attr_name, attr_val in EXTRA_ATTRS.get(destination_log_name, []):
                    log.info('Set %s: %s to %s', attr_name, attr_val, destination_table)
                    yt.set_attribute(destination_table, attr_name, attr_val)
            except Exception:
                log.exception('ERROR')
            except:
                log.error('ERROR WHEN SETTING ATTR')

            # Удалим смерженное
            for t in locked_tables:
                log.info("Remove %s", t)
                yt.remove(t)

    # Удаляем пустые ноды
    for map_node in yt.search(IMPORT_PATH, node_type=['map_node'], attributes=["count", "modification_time"]):
        count = map_node.attributes['count']
        modification_date = datetime.strptime(map_node.attributes['modification_time'][:10], '%Y-%m-%d').date()

        if count == 0 and (modification_date + timedelta(days=MAP_NODE_TTL_DAYS) <= datetime.now().date()):
            log.info('Remove empty map_node %s' % map_node)
            try:
                yt.remove(map_node)

            except Exception:
                log.exception('ERROR REMOVE %s' % map_node)
                continue

    log.info('Done')
