#!/usr/bin/env python
# -*- coding: utf-8 -*-

import datetime
from collections import defaultdict
import os
import time
import traceback

import dateutil.parser
from dateutil.tz import tzlocal
from library.python import resource

from crypta.lib.python.juggler.juggler_helpers import report_event_to_juggler
from crypta.lib.python.yt import yt_helpers
from crypta.profile.lib import date_helpers
from crypta.profile.utils import (
    loggers,
    yql_utils,
    yt_utils,
)
from crypta.profile.utils.config import config

FIX_INTERVAL = 60 * 15  # 15 minutes


class LogParser(object):
    def __init__(self, log_name, log_dir, output_schema, query, title=None, udf_resource_dict=None, udf_url_dict=None, yql_libs=None):
        self.log_name = log_name
        self.log_dir = log_dir

        if config.environment == 'production':
            self.broken_dir = os.path.join(
                config.BROKEN_PARSED_LOGS_YT_DIRECTORY,
                self.log_name,
            )
            self.processed_log_dir = os.path.join(
                config.PARSED_LOGS_YT_DIRECTORY,
                self.log_name,
            )
            self.input = '$full_input'
        else:
            self.broken_dir = os.path.join(
                config.TESTING_BROKEN_PARSED_LOGS_DIRECTORY,
                self.log_name,
            )
            self.processed_log_dir = os.path.join(
                config.TESTING_PARSED_LOGS_DIRECTORY,
                self.log_name,
            )

        self.output_schema = output_schema
        self.query = resource.find("/query/input.yql") + query

        self.title = title or 'log_parsing {}'.format(self.__class__.__name__)
        self.udf_resource_dict = udf_resource_dict
        self.udf_url_dict = udf_url_dict
        self.yql_libs = yql_libs

        self.yt = yt_utils.get_yt_client(pool=config.LOG_PARSER_POOL)
        self.logger, self.logger_file_path = loggers.get_file_logger('{}_parser'.format(log_name))

        for dir in [self.processed_log_dir, self.broken_dir]:
            if not self.yt.exists(dir):
                self.yt.create_directory(dir)

    def get_input_tables_by_output_table(self, directory, last_processed_table=None):
        input_tables_by_output_table = defaultdict(list)

        for table in self.yt.list(directory):
            if last_processed_table is None or table > last_processed_table:
                output_table = os.path.join(self.processed_log_dir, table.split('T')[0])
                input_tables_by_output_table[output_table].append(os.path.join(directory, table))

        return input_tables_by_output_table

    def create_output_tables(self, output_tables):
        for table in output_tables:
            if not self.yt.exists(table):
                self.yt.create_empty_table(
                    table,
                    compression='default',
                    schema=self.output_schema,
                    erasure=True,
                )
                self.logger.info('Created table: {}'.format(table))

    def close_table(self, table_to_close):
        if not self.yt.get_attribute(table_to_close, 'closed', False):
            self.yt.run_sort(table_to_close, sort_by=['yandexuid', 'timestamp'])
            self.yt.set_attribute(
                table_to_close,
                'closed',
                True,
            )

            yt_helpers.set_ttl(table_to_close, datetime.timedelta(days=config.NUMBER_OF_DAYS_TO_KEEP_PARSED_LOGS), yt_client=self.yt)

            midnight_timestamp = date_helpers.from_date_string_to_timestamp(date_helpers.get_tomorrow(os.path.basename(table_to_close)))

            loggers.send_to_graphite(
                'task_end.LogParser_{}'.format(self.log_name),
                int(time.time()) - midnight_timestamp,
                timestamp=midnight_timestamp,
            )

    def process_tables(self, input_tables, output_table, transaction):
        with self.yt.TempTable() as intermediate_table:
            prepared_query = self.query.format(
                input_tables=', '.join(['`{}`'.format(table) for table in input_tables]),
                sampling=config.LOG_PARSING_SAMPLING,
                intermediate_table=intermediate_table,
            )

            yql_utils.query(
                query_string=prepared_query,
                transaction=transaction,
                logger=self.logger,
                yt=self.yt,
                title=self.title,
                udf_resource_dict=self.udf_resource_dict,
                udf_url_dict=self.udf_url_dict,
                yql_libs=self.yql_libs,
            )

            self.yt.run_merge(
                [intermediate_table, output_table],
                output_table,
                mode='unordered',
            )

    def _get_broken_tables(self):
        input_tables_by_output_tables = self.get_input_tables_by_output_table(self.broken_dir)
        all_input_tables = []
        for _, input_tables in input_tables_by_output_tables.iteritems():
            all_input_tables.extend(input_tables)

        return all_input_tables

    def process_broken_tables(self):
        with self.yt.Transaction() as transaction:
            output_table = sorted(self.yt.list(self.processed_log_dir, absolute=True))[-1]
            broken_tables = self._get_broken_tables()

            if len(broken_tables) == 0:
                # there is no broken tables
                return

            self.create_output_tables([output_table])
            try:
                self.process_tables(broken_tables, output_table, transaction)
                for table in broken_tables:
                    self.yt.remove(table)

            except Exception as err:
                self.logger.error(err.message)

    def _filter_broken_tables(self, tables):
        now = time.time()

        tables_to_fix = []

        for table in tables:
            if not self.yt.get_attribute(table, 'tried_to_fix', None):
                creation_time = self.yt.get_attribute(
                    table,
                    'creation_time',
                )

                creation_dt = dateutil.parser.parse(creation_time).astimezone(tzlocal())
                creation_timestamp = time.mktime(creation_dt.timetuple())

                if now - creation_timestamp > FIX_INTERVAL:
                    tables_to_fix.append(table)

        return tables_to_fix

    def auto_process_broken_tables(self, transaction):
        self.logger.info('Try to fix broken tables')

        if not self.yt.list(self.processed_log_dir, absolute=True):
            output_table = os.path.join(self.processed_log_dir, date_helpers.get_today_date_string())
        else:
            output_table = sorted(self.yt.list(self.processed_log_dir, absolute=True))[-1]
            if self.yt.get_attribute(output_table, 'closed', None):
                next_date = date_helpers.get_tomorrow(os.path.basename(output_table))
                output_table = os.path.join(self.processed_log_dir, next_date)

        broken_tables = self._filter_broken_tables(self._get_broken_tables())

        self.logger.info('Broken tables: {}'.format(broken_tables))
        if len(broken_tables) == 0:
            # there is no broken tables
            return

        self.create_output_tables([output_table])
        try:
            self.process_tables(broken_tables, output_table, transaction)
            for table in broken_tables:
                self.yt.remove(table)

        except Exception as err:
            for table in broken_tables:
                self.yt.set_attribute(
                    table,
                    'tried_to_fix',
                    True,
                )

            self.logger.error(err.message)

    def _get_broken_not_processed_tables(self):
        now = time.time()

        tables_to_fix = []

        for table in self.yt.list(self.broken_dir, absolute=True):
            if not self.yt.get_attribute(table, 'tried_to_fix'):
                creation_time = self.yt.get_attribute(
                    table,
                    'creation_time',
                )

                creation_dt = dateutil.parser.parse(creation_time).astimezone(tzlocal())
                creation_timestamp = time.mktime(creation_dt.timetuple())

                if now - creation_timestamp > FIX_INTERVAL:
                    tables_to_fix.append(table)

        return tables_to_fix

    def run(self):
        with self.yt.Transaction() as transaction:
            self.auto_process_broken_tables(transaction)

            last_processed_table = self.yt.get_attribute(
                self.processed_log_dir,
                'last_processed_table',
                None,
            )

            input_tables_by_output_tables = self.get_input_tables_by_output_table(
                self.log_dir,
                last_processed_table=last_processed_table,
            )
            output_tables = input_tables_by_output_tables.keys()
            self.create_output_tables(output_tables)

            if input_tables_by_output_tables:
                last_table = os.path.basename(max(input_tables_by_output_tables[max(output_tables)]))
                self.logger.info('Last table: {}'.format(os.path.join(self.log_dir, last_table)))

                for output_table, input_tables in input_tables_by_output_tables.iteritems():
                    try:
                        self.process_tables(input_tables, output_table, transaction)
                    except Exception as err:
                        # copy not processed tables
                        self.logger.error(err.message)
                        trace_str = traceback.format_exc()
                        self.logger.error(trace_str)

                        for table in input_tables:
                            self.yt.copy(
                                table,
                                os.path.join(self.broken_dir, os.path.basename(table)),
                                recursive=True,
                            )

                self.yt.set_attribute(
                    self.processed_log_dir,
                    'last_processed_table',
                    last_table,
                )

                date_to_close = date_helpers.get_yesterday(os.path.basename(max(output_tables)))
                table_to_close = os.path.join(self.processed_log_dir, date_to_close)
                if self.yt.exists(table_to_close):
                    self.close_table(table_to_close)

            broken_tables = list(self.yt.search(
                self.broken_dir,
                node_type=['table'],
                object_filter=lambda node: node.attributes.get('tried_to_fix'),
                attributes=['tried_to_fix'],
            ))

        report_event_to_juggler(
            status='OK' if len(broken_tables) == 0 else 'WARN',
            service='{}_log_parsing'.format(self.log_name),
            host=config.CRYPTA_PROFILE_JUGGLER_HOST,
            description='Some tables has failed to process: {}'.format(', '.join(broken_tables)),
            tags=['log_parsing'],
            logger=self.logger,
        )
