# -*- coding: utf-8 -*-
from __future__ import unicode_literals

from argparse import ArgumentParser
from collections import OrderedDict
import json
import logging.config
import pwd
import os
import sys

import six
from yt.wrapper import YtClient
import yt.wrapper as yt

from travel.cpa.lib.common import hide_secrets
from travel.cpa.lib.lib_logging import LOG_CONFIG, get_logger
from travel.cpa.lib.lib_yt import backup_table, create_temp_table, get_table_schema, ProcessedTables
from travel.cpa.lib.metered_task_executor import MeteredTaskExecutor
from travel.library.python.tools import replace_args_from_env
from travel.library.python.schematized import Schematized
from travel.library.python.schematized.fields import Float, String, UInt64, Int64
from travel.hotels.lib.python3.yt.versioned_path import KeepLastNCleanupStrategy


LOG = get_logger('merge_train_refunds')

LOGFELLER_FIELDS_TO_HIDE = [
    '_logfeller_index_bucket',
    '_logfeller_timestamp',
    '_rest',
    '_stbx',
    '_timestamp',
    'iso_eventtime',
    'source_uri',
]

SORT_BY = ['RefundedAt']


class TrainRefund(Schematized):

    __fields__ = OrderedDict([
        ('RefundedAt', UInt64()),
        ('OrderId', String()),
        ('OrderPrettyId', String()),
        ('OrderRefundId', String()),
        ('NumberOfTickets', Int64()),
        ('RefundTicketMoney_Amount', Float()),
        ('RefundTicketMoney_Currency', String()),
        ('RefundFeeMoney_Amount', Float()),
        ('RefundFeeMoney_Currency', String()),
        ('RefundInsuranceMoney_Amount', Float()),
        ('RefundInsuranceMoney_Currency', String()),
        ('PartnerRefundFeeMoney_Amount', Float()),
        ('PartnerRefundFeeMoney_Currency', String())
    ])


class Executor(MeteredTaskExecutor):
    def __init__(self, yt_client, options, common_labels):
        super(Executor, self).__init__('merge_train_refunds', options, common_labels)
        self.yt_client = yt_client
        self.options = options
        self.train_refunds_table = yt.ypath_join(self.options.dst_dir, 'train', 'refunds')
        self.sorted_table = yt.ypath_join(self.options.dst_dir, '_temp_train_refunds_sorted')
        self.unsorted_table = yt.ypath_join(self.options.dst_dir, '_temp_train_refunds_unsorted')

        common_schema = OrderedDict()
        common_schema.update(TrainRefund().get_yt_schema())

        self.schema_fields = set(common_schema.keys())

        self.sorted_schema = get_table_schema(common_schema, hide_fields=LOGFELLER_FIELDS_TO_HIDE, sort_by=SORT_BY)
        self.unsorted_schema = get_table_schema(common_schema, hide_fields=LOGFELLER_FIELDS_TO_HIDE)

        self.processed_table = yt.ypath_join(self.options.dst_dir, 'processed_train_refunds')

    def ensure_tables(self):
        dst_dir = self.options.dst_dir
        if not self.yt_client.exists(dst_dir):
            LOG.info('Creating destination directory %s', dst_dir)
            self.yt_client.create('map_node', dst_dir)
        if not self.yt_client.exists(self.train_refunds_table):
            LOG.info('Creating train_refunds table: %s', self.train_refunds_table)
            self.yt_client.create('table', self.train_refunds_table)

    @staticmethod
    def get_metric_key(**kwargs):
        return json.dumps(kwargs, sort_keys=True)

    def send_metrics(self, metrics):
        for key, value in six.iteritems(metrics):
            labels = json.loads(key)
            sensor = labels.pop('sensor')
            value = value['$']['completed']['sorted_reduce']['sum']
            self.solomon_client.send(sensor, value, **labels)

    def mapper(self, row):
        if '_rest' in row:
            row.update(row['_rest'])
        for field in LOGFELLER_FIELDS_TO_HIDE:
            row.pop(field, None)
        unknown_fields = set(row.keys()) - self.schema_fields
        for field in unknown_fields:
            row.pop(field, None)

        yield row

    def reducer(self, key, rows):
        # materializing iterator
        rows = list(rows)

        # TODO(mbobrov): decide on proper metrics
        processed_train_refunds_count = 'train_refunds.processed'
        processed_train_refunds_count_key = Executor.get_metric_key(sensor=processed_train_refunds_count)
        metrics = dict()
        metrics[processed_train_refunds_count_key] = 1
        yt.write_statistics(metrics)

        # we simply take the latest version of train refund
        yield rows[-1]

    def try_execute(self):
        self.ensure_tables()
        with self.yt_client.Transaction():
            processed_tables = ProcessedTables(self.yt_client, self.processed_table, self.options.src_dir)
            new_tables = processed_tables.new_tables
            if not new_tables:
                LOG.info('No tables to process')
                return
            LOG.info('Found %d new table(s)', len(new_tables))

            source_tables = [self.train_refunds_table]
            source_tables.extend(new_tables)

            create_temp_table(self.yt_client, self.unsorted_table, self.unsorted_schema)
            create_temp_table(self.yt_client, self.sorted_table, self.sorted_schema)

            LOG.info('Collecting all train refunds into temporary table')
            self.yt_client.run_map(
                self.mapper,
                source_tables,
                self.unsorted_table,
                format=yt.JsonFormat(control_attributes_mode="iterator", encoding="utf-8"),
                spec={'data_size_per_job': 2000000, 'max_failed_job_count': 1},
            )

            LOG.info('Sorting temporary table')
            self.yt_client.run_sort(
                self.unsorted_table, destination_table=self.unsorted_table, sort_by=['OrderRefundId', 'RefundedAt']
            )

            LOG.info('Executing train refunds deduplication')
            operation = self.yt_client.run_reduce(
                self.reducer,
                self.unsorted_table,
                self.unsorted_table,
                sort_by=['OrderRefundId', 'RefundedAt'],
                reduce_by=['OrderRefundId'],
                format=yt.JsonFormat(control_attributes_mode="iterator", encoding="utf-8"),
                spec={'data_size_per_job': 10000, 'max_failed_job_count': 1},
            )

            operation.wait()
            metrics = operation.get_job_statistics().get('custom', {})
            LOG.info('Collected metrics %r', metrics)

            LOG.info('Sorting temporary table')
            self.yt_client.run_sort(self.unsorted_table, destination_table=self.sorted_table, sort_by=SORT_BY)

            LOG.info('Replacing old train refunds table with temporary table')
            self.yt_client.move(self.sorted_table, self.train_refunds_table, force=True)

            LOG.info('Dropping unsorted temporary table')
            self.yt_client.remove(self.unsorted_table)

            LOG.info('Writing processed tables')
            processed_tables.write_processed_tables()

            LOG.info('Doing train refunds table backup')
            backup_table(self.yt_client, self.train_refunds_table, self.options.backup_dir)

        KeepLastNCleanupStrategy(self.options.keep_last_tables).clean(self.options.backup_dir, self.yt_client)

        self.send_metrics(metrics)


def main():
    logging.config.dictConfig(LOG_CONFIG)
    user = pwd.getpwuid(os.getuid()).pw_name

    default_src_dir = '//logs/travel-test-train-refunds-cpa-export-log/30min'

    default_dst_dir = yt.ypath_join('//home/travel', user, 'cpa')
    default_backup_dir = yt.ypath_join('//home', 'travel', user, 'cpa', 'backup')

    parser = ArgumentParser()
    parser.add_argument('--src-dir', default=default_src_dir)
    parser.add_argument('--dst-dir', default=default_dst_dir)

    parser.add_argument('--backup-dir', default=default_backup_dir)
    parser.add_argument('--keep-last-tables', type=int, default=100)

    parser.add_argument('--yt-proxy', default='hahn')
    parser.add_argument('--yt-token', default=None)

    Executor.configure(parser)
    options = parser.parse_args(replace_args_from_env(sys.argv[1:]))

    displayed_options = hide_secrets(vars(options))
    LOG.info('Working with %r', displayed_options)

    common_labels = {'yt_proxy': options.yt_proxy}
    Executor.resolve_secrets(options, common_labels=common_labels)

    # FIXME: tests with local yt fails on attempt to remove nonexistent reducer pickle file
    yt_client = YtClient(options.yt_proxy, options.yt_token, config={'clear_local_temp_files': False})

    Executor(yt_client, options, common_labels).execute()


if __name__ == '__main__':
    main()
