# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import time
from collections import defaultdict, Counter
from datetime import datetime

from sandbox import sdk2
from sandbox.projects.common import binary_task
from sandbox.projects.common import solomon

from sandbox.projects.avia.base import AviaBaseTask
from sandbox.projects.avia.lib import logs
from sandbox.projects.avia.lib.yt_helpers import yt_read_tables, last_logs_tables

logger = logging.getLogger(__name__)


class PartnerMetricsCalculator(object):
    def __init__(self, yt_client, discretization, time_window_in_days, orders_table, redirects_folder, partners_table):
        self.yt_client = yt_client
        self.discretization = discretization
        self.time_window_in_days = time_window_in_days
        self.orders_table = orders_table
        self.redirects_folder = redirects_folder
        self.partners_table = partners_table

    def get_stats(self):
        start_timestamp = int(time.time()) - self.time_window_in_days * 24 * 60 * 60
        orders = self._get_orders(self.orders_table, start_timestamp)
        redirects = self._get_redirects(self.redirects_folder)
        partners = self._get_partners(self.partners_table)

        combined_stats = self._enrich_redirects(redirects, orders, partners)
        stats_by_time_and_partner = self._group_by_time_and_partner(combined_stats)

        return self._get_stats(stats_by_time_and_partner)

    def _get_orders(self, table, start_timestamp):
        columns = ['partner_order_id', 'partner_name', 'order_amount_rub', 'profit_amount_ex_tax', 'label', 'status']
        path = self.yt_client.TablePath(table, columns=columns, lower_key=[start_timestamp])
        orders = self.yt_client.read_table(path)
        as_dict = {}
        for order in orders:
            if order['status'] != 'confirmed':
                continue
            as_dict[order['label']] = order

        return as_dict

    def _get_redirects(self, directory):
        tables = last_logs_tables(self.yt_client, directory, self.time_window_in_days)

        redirects = yt_read_tables(
            self.yt_client, tables,
            add_index=True,
            columns=['MARKER', 'FILTER', 'NATIONAL_VERSION', 'ISO_EVENTTIME', 'PRICE', 'BILLING_ORDER_ID']
        )
        for redirect in redirects:
            if not redirect['FILTER'] and redirect['NATIONAL_VERSION'] == 'ru':
                yield redirect

    def _get_partners(self, table):
        columns = ['billing_order_id', 'code']
        path = self.yt_client.TablePath(table, columns=columns)
        partners = self.yt_client.read_table(path)
        return {partner['billing_order_id']: partner['code'] for partner in partners}

    @staticmethod
    def _enrich_redirects(redirects, orders, partners):
        for row in redirects:
            yield dict(row, partner_code=partners.get(row['BILLING_ORDER_ID']), **orders.get(row['MARKER'], {}))

    @staticmethod
    def _get_important_partners(rows):
        first_row_dt = datetime.strptime(min(row['ISO_EVENTTIME'] for row in rows), '%Y-%m-%d %H:%M:%S')
        last_row_dt = datetime.strptime(max(row['ISO_EVENTTIME'] for row in rows), '%Y-%m-%d %H:%M:%S')
        min_partner_rows_count = 1000 * (last_row_dt - first_row_dt).total_seconds() / (24. * 60 * 60)
        partners_counter = Counter(row['partner_code'] for row in rows)
        return {
            partner_code for partner_code, cnt in partners_counter.most_common()
            if cnt >= min_partner_rows_count and partner_code
        }

    def _group_by_time_and_partner(self, rows):
        """
        rounds timestamps up according to discretization value
        """
        rows = list(rows)
        important_partners = self._get_important_partners(rows)
        logger.info('Important partners %r', important_partners)

        by_time = defaultdict(lambda: defaultdict(list))
        for row in rows:
            partner = row['partner_code'] if row['partner_code'] in important_partners else 'other'
            row['timestamp'] = (
                datetime.strptime(row['ISO_EVENTTIME'], '%Y-%m-%d %H:%M:%S') - datetime(1970, 1, 1)).total_seconds()
            event_time = row['timestamp'] // self.discretization * self.discretization
            by_time[event_time][partner].append(row)
            by_time[event_time]['total'].append(row)
        return by_time

    def _get_stats(self, groups):
        by_time = {}
        for eventtime in groups:
            current_time = {}
            for partner in groups[eventtime]:
                stats = self._get_stats_for_partner(groups[eventtime][partner])
                current_time[partner] = stats
            by_time[eventtime] = current_time
        return by_time

    @staticmethod
    def _get_stats_for_partner(partner):
        money_cpa = 0
        money_cpc = 0
        redirects = 0
        sales = set()
        order_amount_rub = 0
        for row in partner:
            redirects += 1
            money_cpa += row.get('profit_amount_ex_tax', 0)
            money_cpc += row['PRICE']
            if 'partner_order_id' in row:
                sales.add(row['partner_order_id'])
            order_amount_rub += row.get('order_amount_rub', 0)

        return {
            'redirects': redirects,
            'sales': len(sales),
            'turnover': order_amount_rub / 1.2,
            'conversion': len(sales) / redirects * 100 if redirects else 0,
            'money': money_cpa + money_cpc
        }


class SendAviaPartnerMetricsToSolomon(binary_task.LastBinaryTaskRelease, AviaBaseTask):
    """
    Send Avia partner metrics to solomon.
    """
    _yt_client = None

    class Parameters(sdk2.Task.Parameters):

        with sdk2.parameters.Group('Map reduce settings') as mr_block:
            yt_cluster = sdk2.parameters.String('MapReduce cluster', default='hahn', required=True)
            yt_token = sdk2.parameters.YavSecret('yt_token', required=True, default='sec-01dfxmszhq27tk66hm107d0ycd')
            partners_table = sdk2.parameters.String('Partners list table', required=True,
                                                    default='//home/rasp/reference/partner')
            orders_table = sdk2.parameters.String('Orders table', required=True,
                                                  default='//home/travel/prod/cpa/avia/orders')
            redirects_folder = sdk2.parameters.String('Redirects folder', required=True,
                                                      default='//home/avia/logs/avia-redir-balance-by-day-log')

        with sdk2.parameters.Group('Solomon settings') as solomon_settings:
            solomon_project = sdk2.parameters.String('Solomon project', required=True, default='avia')
            solomon_cluster = sdk2.parameters.String('Solomon cluster', required=True, default='yt')
            solomon_service = sdk2.parameters.String('Solomon service', required=True, default='redirects')

        with sdk2.parameters.Group('Settings') as date_block:
            time_window = sdk2.parameters.Integer('Time window in days', required=True, default=3)
            discretization = sdk2.parameters.Integer('Discretization in seconds', required=True, default=2 * 60 * 60)

        ext_params = binary_task.binary_release_parameters(none=True)

    class Requirements(sdk2.Requirements):
        # https://wiki.yandex-team.ru/sandbox/clients/#client-tags-multislot
        cores = 1  # exactly 1 core
        ram = 8192  # 8GiB or less

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    def _get_yt_client(self):
        if self._yt_client is None:
            import yt.wrapper
            self._yt_client = yt.wrapper.YtClient(
                proxy=self.Parameters.yt_cluster,
                token=self.Parameters.yt_token.data()['token']
            )
            self._yt_json_format = yt.wrapper.JsonFormat()
        return self._yt_client

    def send_data_to_solomon(self, data):
        common_labels = {
            'project': self.Parameters.solomon_project,
            'cluster': self.Parameters.solomon_cluster,
            'service': self.Parameters.solomon_service,
        }
        logger.info(common_labels)
        sensors = []
        for eventtime in data:
            for partner in data[eventtime]:
                for sensor in data[eventtime][partner]:
                    sensors.append(
                        {
                            'ts': eventtime,
                            'labels': {'sensor': sensor, 'partner': partner},
                            'value': data[eventtime][partner][sensor],
                        }
                    )
            if len(sensors) >= 1000:
                logger.info(sensors)
                solomon.push_to_solomon_v2(self.solomon_token, params=common_labels, sensors=sensors)
                sensors = []
        logger.info(sensors)
        solomon.push_to_solomon_v2(self.solomon_token, params=common_labels, sensors=sensors)

    def on_execute(self):
        super(SendAviaPartnerMetricsToSolomon, self).on_execute()
        logs.configure_logging(logs.get_sentry_dsn(self))
        logging.info('Start')

        yt_client = self._get_yt_client()
        partner_metrics_calculator = PartnerMetricsCalculator(
            yt_client,
            self.Parameters.discretization,
            self.Parameters.time_window,
            self.Parameters.orders_table,
            self.Parameters.redirects_folder,
            self.Parameters.partners_table
        )
        stat = partner_metrics_calculator.get_stats()
        self.send_data_to_solomon(stat)
        logger.info('Done')
