# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from abc import ABCMeta, abstractmethod
from collections import defaultdict

import sandbox.sandboxsdk.environments as sdk_environments
from sandbox import sdk2
from sandbox.projects.common import solomon

from sandbox.projects.avia.base import AviaBaseTask
from sandbox.projects.avia.lib import logs
from sandbox.projects.avia.lib.datetime_helpers import (
    get_utc_now, _dt_to_unixtime_utc, _round_unixtime, _unixtime_to_msk_dt,
    _dt_to_string, _timestamp_to_utc_string,
)
from sandbox.projects.avia.lib.yt_helpers import (
    yt_read_tables, _get_timestamp_from_table_path, _parse_timestamp, _get_tables, _get_table_time_delta
)

logger = logging.getLogger(__name__)

IMPORTANT_UTM_SOURCES = ('rasp', 'sovetnik', 'yamain', 'ohm_google', 'doubletrade')  # Wizard is an exception
OLD_UTM_SOURCES = ()
IMPORTANT_PARTNERS = {
    'aeroflot',
    'tinkoff1',
    'megotravel',
    'ozon',
    's_seven',
    'utair',
    'onetwotrip',
    'kiwi',
    'citytravel',
    'trip_ru',
    'biletix',
    'ticketsru',
    'kupibilet',
    'superkassa'
}
IMPORTANT_AIRLINES = ('SU', 'S7', 'DP', 'UT', 'U6', '5N', 'WZ', 'A4', 'TK', 'HY')
AIRLINES_FIXLIST = {
    'FV': 'SU',
}


class StatCalculator(object):
    def __init__(self, time_field='unixtime', time_window=30):
        """
        @param time_field: field with unixtime
        @param time_window: window size for aggregation in minutes
        """

        self._time_field = time_field
        self._time_window = time_window * 60
        self._rules = []
        self.stat = defaultdict(dict)

    def add_rule(self, rule):
        self._rules.append(rule)

    def calc(self, records, start_timestamp=None, end_timestamp=None):
        last_index = 0
        last_table = None
        for record in records:
            timestamp = _round_unixtime(int(record['unixtime']), self._time_window)
            if not (start_timestamp <= timestamp < end_timestamp):
                continue
            for rule in self._rules:
                for label_key, label_value in rule.get_keys(record):
                    if timestamp not in self.stat or label_key not in self.stat[timestamp]:
                        self.stat[timestamp][label_key] = rule.default_label_dict()
                    self.stat[timestamp][label_key][label_value] += 1
                    if 'total' in self.stat[timestamp][label_key]:
                        self.stat[timestamp][label_key]['total'] += 1
                    index = record['row_index']
                    table = record['table_name']
                    if index > last_index and table > last_table:
                        last_index = index
                        last_table = table

        return self.stat, last_index, last_table

    def get_fields(self):
        fields = {self._time_field}
        for rule in self._rules:
            fields.update(rule.get_fields())

        return fields


def _get_table_row_number(table):
    return table.ranges[0]['upper_limit']['row_index'] - table.ranges[0]['lower_limit']['row_index']


def get_stat(yt_client, yt_path, start_dt, end_dt, start_timestamp, end_timestamp,
             calculator, last_offset=None):
    table_time_delta = _get_table_time_delta(yt_client, yt_path)
    logger.info('Table delta = %s', table_time_delta)

    tables = sorted(
        _get_tables(
            yt_client, yt_path,
            start_dt, end_dt,
            columns=list(calculator.get_fields()),
            last_offset=last_offset,
            table_time_delta=table_time_delta,
            logger=logger
        ),
        key=str,
    )

    logger.info('Tables: %r', tables)

    if not tables:
        logger.info('Nothing to do, abort')
        return {}, {}

    stat = {}
    last_table = None
    offset = None
    current_tables = []
    record_number = 0
    for table in tables:
        table_dt = _get_timestamp_from_table_path(table)
        if start_dt-table_time_delta <= table_dt < end_dt:
            current_tables.append(table)
            record_number += _get_table_row_number(table)
            if record_number >= 10000:
                logger.info('Processing group table_dt=%s, current_tables=%s, record_number=%s',
                            table_dt, current_tables, record_number)
                stat, offset, last_table = calculator.calc(
                    yt_read_tables(yt_client, current_tables, add_index=True),
                    start_timestamp=start_timestamp, end_timestamp=end_timestamp
                )
                current_tables = []
                record_number = 0

    if record_number > 0:
        logger.info('Processing group table_dt=%s, current_tables=%s, record_number=%s',
                    table_dt, current_tables, record_number)
        stat, offset, last_table = calculator.calc(
            yt_read_tables(yt_client, current_tables, add_index=True),
            start_timestamp=start_timestamp, end_timestamp=end_timestamp
        )
    logger.info('stat.len %s, offset=%s, last_table=%s',
                len(stat), offset, last_table)

    return stat, last_table, offset


class IStatRule(object):
    __metaclass__ = ABCMeta

    @abstractmethod
    def get_keys(self, record):
        pass

    @abstractmethod
    def get_fields(self):
        pass

    @abstractmethod
    def default_label_dict(self):
        pass


class UtmSourceRule(IStatRule):
    def __init__(self, important_utm_sources):
        self.important_utm_sources = frozenset(important_utm_sources)

    def get_keys(self, record):
        yield ('utm_source', self.get_type_by_utm_source(record.get('utm_source'), record.get('utm_medium')))

    def get_fields(self):
        return ['utm_source', 'utm_medium']

    def default_label_dict(self):
        return {utm_source: 0 for utm_source in self.important_utm_sources | {'other', 'wizard', 'total'}}

    def get_type_by_utm_source(self, utm_source, utm_medium):
        if utm_source is None:
            return 'other'

        if utm_source.startswith('wizard') or utm_source.startswith('unisearch'):
            return 'wizard'

        utm_source_medium = '{}_{}'.format(utm_source, utm_medium)
        if utm_source_medium in self.important_utm_sources:
            return utm_source_medium

        if utm_source in self.important_utm_sources:
            return utm_source

        return 'other'


class PartnerRule(IStatRule):
    def __init__(self, yt_client, partners_table, partner_codes):
        self.partners = {record['billing_client_id']: record['code']
                         for record in yt_client.read_table(partners_table) if record['code'] in partner_codes}

    def get_keys(self, record):
        yield ('partner', self.get_partner_code(record.get('billing_client_id')))

    def get_partner_code(self, billing_client_id):
        return self.partners.get(billing_client_id) or 'other'

    def get_fields(self):
        return ['billing_client_id']

    def default_label_dict(self):
        res = {v: 0 for v in self.partners.itervalues()}
        res['other'] = 0
        return res


class AirlineRule(IStatRule):
    FIELDS = ('forward_numbers', 'backward_numbers')

    def __init__(self, airline_codes, fix_list):
        self.codes = frozenset(airline_codes)
        self.fix_list = fix_list

    def get_keys(self, record):
        for airline in self.get_airlines(record):
            yield ('airline', airline)

    def get_airlines(self, record):
        airlines = set()
        for field in self.FIELDS:
            flights = record.get(field, '').split(';')
            if not flights:
                continue

            for flight in flights:
                airline = flight.split(' ')[0]
                airline = self.fix_list.get(airline, airline)
                if airline in self.codes:
                    airlines.add(airline)
                else:
                    airlines.add('other')
        return airlines

    def get_fields(self):
        return self.FIELDS

    def default_label_dict(self):
        return {code: 0 for code in self.codes | {'other'}}


class SendAviaRedirectsToSolomon(AviaBaseTask):
    """
    Send Avia redirects to solomon.
    """
    _yt_client = None

    class Parameters(sdk2.Task.Parameters):

        with sdk2.parameters.Group('Map reduce settings') as mr_block:
            yt_cluster = sdk2.parameters.String('MapReduce cluster', default='hahn', required=True)
            yt_user = sdk2.parameters.String('MapReduce user', required=True)
            yt_token_vault = sdk2.parameters.String('Token vault name', required=True, default='YT_TOKEN')

            yt_path = sdk2.parameters.String('Directory', required=True,
                                             default='//home/logfeller/logs/avia-json-redir-log/stream/5min')
            solomon_timestamp_path = sdk2.parameters.String('Solomon table for timestamp storage', required=True,
                                                            default='//home/avia/dev/kateov/solomon_timestamps')
            partners_table = sdk2.parameters.String('Partners list table', required=True,
                                                    default='//home/rasp/reference/partner')

        with sdk2.parameters.Group('Solomon settings') as solomon_settings:
            solomon_project = sdk2.parameters.String('Solomon project', required=True, default='avia')
            solomon_cluster = sdk2.parameters.String('Solomon cluster', required=True, default='yt')
            solomon_service = sdk2.parameters.String('Solomon service', required=True, default='redirects')

        with sdk2.parameters.Group('Settings') as date_block:
            with sdk2.parameters.RadioGroup('Log type') as log_name:
                log_name.values['avia-json-redir-log-stream'] = log_name.Value('avia-json-redir-log-stream',
                                                                               default=True)
                log_name.values['rasp-popular-flights-log'] = log_name.Value('rasp-popular-flights-log')
                log_name.values['avia-json-redir-log'] = log_name.Value('avia-json-redir-log')

            start_time = sdk2.parameters.String('Start date (or time)', required=False,
                                                default='2019-02-25T00:00:00')
            end_time = sdk2.parameters.String('End date (default now)', required=False)

            offset = sdk2.parameters.Integer('Offset in table', required=False, default=None)
            time_window = sdk2.parameters.Integer('Time window', required=True, default=20)

            important_partners = sdk2.parameters.String('Partners', required=False,
                                                        default=' '.join(IMPORTANT_PARTNERS))
            important_utm_sources = sdk2.parameters.String('UTM sources', required=False,
                                                           default=' '.join(IMPORTANT_UTM_SOURCES))
            important_airlines = sdk2.parameters.String('Airlines', required=False,
                                                        default=' '.join(IMPORTANT_AIRLINES))

        with sdk2.parameters.Group('Debug settings') as debug_settings:
            debug_run = sdk2.parameters.Bool('Debug run', default=False, required=True)

    class Requirements(sdk2.Requirements):
        # https://wiki.yandex-team.ru/sandbox/clients/#client-tags-multislot
        cores = 1  # exactly 1 core
        ram = 8192  # 8GiB or less

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

        environments = [sdk_environments.PipEnvironment('yandex-yt', version='0.10.8'),
                        sdk_environments.PipEnvironment('yandex-yt-yson-bindings-skynet', version='0.3.32-0'),
                        sdk_environments.PipEnvironment('raven')]

    def _get_yt_client(self):
        if self._yt_client is None:
            import yt.wrapper
            self._yt_client = yt.wrapper.YtClient(
                proxy=self.Parameters.yt_cluster,
                token=sdk2.Vault.data(self.Parameters.yt_user, self.Parameters.yt_token_vault or 'YT_TOKEN')
            )
            self._yt_json_format = yt.wrapper.JsonFormat()
        return self._yt_client

    def send_data_to_solomon(self, data):
        common_labels = {
            'project': self.Parameters.solomon_project,
            'cluster': self.Parameters.solomon_cluster,
            'service': self.Parameters.solomon_service,
        }
        logger.info(common_labels)
        sensors = []
        for timestamp, labels in data.iteritems():
            for label_name, values in labels.iteritems():
                for label_value, series_value in values.iteritems():
                    sensors.append(
                        {
                            'ts': _timestamp_to_utc_string(timestamp),
                            'labels': {'sensor': 'redirects', label_name: label_value},
                            'value': series_value,
                        }
                    )

            if len(sensors) >= 1000:
                solomon.push_to_solomon_v2(self.solomon_token, common_labels, sensors, common_labels=())
                sensors = []
        logger.info(sensors)
        solomon.push_to_solomon_v2(self.solomon_token, common_labels, sensors, common_labels=())

    def get_max_timestamp(self, data):
        max_timestamp = None
        for timestamp, labels in data.iteritems():
            max_timestamp = timestamp if max_timestamp is None or timestamp > max_timestamp else max_timestamp
        return max_timestamp

    def _get_saved_timestamp(self):
        try:
            return _parse_timestamp(
                self._get_yt_client().get_attribute(self.Parameters.solomon_timestamp_path, self.Parameters.log_name)
            )
        except Exception as e:
            logger.warning(e.message)
            return None

    def _get_saved_offset(self):
        try:
            return self._get_yt_client().get_attribute(self.Parameters.solomon_timestamp_path,
                                                       '{}-offset'.format(self.Parameters.log_name))
        except Exception as e:
            logger.warning(e.message)
            return None

    def on_execute(self):
        logs.configure_logging(logs.get_sentry_dsn(self))
        logging.info('Start')

        yt_client = self._get_yt_client()

        if not yt_client.exists(self.Parameters.solomon_timestamp_path):
            yt_client.create('map_node', self.Parameters.solomon_timestamp_path, recursive=True)
        saved_dt = self._get_saved_timestamp()
        saved_offset = self._get_saved_offset()

        start_dt = _parse_timestamp(self.Parameters.start_time) or saved_dt

        end_dt = _parse_timestamp(self.Parameters.end_time) or get_utc_now()

        start_timestamp = _round_unixtime(_dt_to_unixtime_utc(start_dt), self.Parameters.time_window)
        end_timestamp = _round_unixtime(_dt_to_unixtime_utc(end_dt), self.Parameters.time_window)
        logger.info('saved_dt=%s, start_dt=%s, end_dt: %s, start_timestamp=%s, end_timestamp=%s',
                    saved_dt, start_dt, end_dt, start_timestamp, end_timestamp)
        logger.info('Last saved offset: %s', saved_offset)

        rules_by_log_type = {
            'avia-json-redir-log-stream': [
                PartnerRule(yt_client, self.Parameters.partners_table,
                            self.Parameters.important_partners.split(' ') or IMPORTANT_PARTNERS),
                UtmSourceRule(self.Parameters.important_utm_sources.split(' ') or IMPORTANT_UTM_SOURCES)
            ],
            'rasp-popular-flights-log': [
                AirlineRule(self.Parameters.important_airlines.split(' ') or IMPORTANT_AIRLINES, AIRLINES_FIXLIST)
            ],
            'avia-json-redir-log': [UtmSourceRule(OLD_UTM_SOURCES)],
        }

        rules = rules_by_log_type.get(self.Parameters.log_name)
        if not rules:
            logger.error('No rules for %s. Abort', self.Parameters.log_name)
            return

        stat_calculator = StatCalculator(time_window=self.Parameters.time_window)
        for rule in rules:
            stat_calculator.add_rule(rule)

        stat, max_processed_table, new_offset = get_stat(
            yt_client, self.Parameters.yt_path, start_dt, end_dt, start_timestamp, end_timestamp,
            calculator=stat_calculator,
            last_offset=self.Parameters.offset if self.Parameters.offset is not None else saved_offset or 0
        )

        if stat:
            logger.info('Sending to Solomon')
            logger.info(stat)
            self.send_data_to_solomon(stat)

            new_dt = _unixtime_to_msk_dt(self.get_max_timestamp(stat))
            last_table_timestamp = _get_timestamp_from_table_path(max_processed_table)
            logger.info('saved_dt=%s, new_dt=%s, new_offset=%s, last_table_timestamp=%s',
                        saved_dt, new_dt, new_offset, last_table_timestamp)
            if stat:
                if saved_dt is None or saved_dt < new_dt:
                    yt_client.set_attribute(self.Parameters.solomon_timestamp_path,
                                            self.Parameters.log_name,
                                            _dt_to_string(new_dt))
                    yt_client.set_attribute(self.Parameters.solomon_timestamp_path,
                                            '{}-offset'.format(self.Parameters.log_name),
                                            new_offset)
                    if last_table_timestamp:
                        yt_client.set_attribute(self.Parameters.solomon_timestamp_path,
                                                '{}-table'.format(self.Parameters.log_name),
                                                _dt_to_string(last_table_timestamp))

                    logger.info('New timestamp: %s', _dt_to_string(new_dt))
                    logger.info('New offset: %s', new_offset)
            else:
                logger.info('No data processed, no attributes saved')

        logger.info('Done')
