# coding=utf-8
import travel.avia.admin.init_project  # noqa

import argparse
import logging
from collections import defaultdict, namedtuple
from datetime import date, datetime, timedelta

import yt.wrapper as yt
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from yql.api.v1.client import YqlClient
from yql.client.parameter_value_builder import YqlParameterValueBuilder

from travel.avia.library.python.avia_data.models.national_version import NationalVersion
from travel.avia.library.python.common.models.partner import ClickPriceMultiplierByRedirectType, MINIMAL_AVG_CHECK_COEFF, MAXIMAL_AVG_CHECK_COEFF
from travel.avia.admin.lib.logs import add_stdout_handler, create_current_file_run_log
from travel.avia.admin.lib.pricing import RedirectTypePicker
from travel.avia.admin.lib.yql_helpers import log_errors
from travel.avia.admin.lib.yt_helpers import yt_client_fabric

logger = logging.getLogger(__name__)

CPA_ORDERS_TABLE_PATH = '//home/travel/prod/cpa/avia/orders'

if settings.ENVIRONMENT == 'production':
    YT_MULTIPLIERS_LOG_ROOT = '//home/avia/logs/avia-click-price-multipliers-log'
else:
    YT_MULTIPLIERS_LOG_ROOT = '//home/avia/testing/logs/avia-click-price-multipliers-log'

# "Сферический размер среднего чека в вакууме". Подбирается так, чтобы поправочный коэффициент
# стоимости клика (он зависит от размера средневзвешенного среднего чека) попадал в диапазон
# [MINIMAL_AVG_CHECK_COEFF, MAXIMAL_AVG_CHECK_COEFF]
DEFAULT_NORMALIZER = 9000.
# Веса по дням для расчёта средневзвешенной цены клика: текущий день не учитываем,
# 7 предыдущих дней берём с весами 64/127, 32/127, ..., 1/127 соответственно
DAY_WEIGHTS = [0., 64./127., 32./127., 16./127., 8./127., 4./127., 2./127., 1./127]

YQL_QUERY_TEMPLATE = """
USE `hahn`;
DECLARE $start_date AS Date;
DECLARE $end_date AS Date;
DECLARE $normalizer AS Float;

SELECT
    label_category,
    order_date,
    CAST(AVG(order_amount_rub / (label_adult_seats+label_children_seats)) AS Uint32) as order_per_pax,
    Math::Round(AVG(order_amount_rub / (label_adult_seats+label_children_seats)) / $normalizer, -5) as day_coeff
FROM `{cpa_order_path}`
WHERE
    status = 'confirmed'
    AND label_adult_seats > 0
    AND order_date >= $start_date
    AND order_date <= $end_date
group by
    CAST(DateTime::FromSeconds(CAST(created_at AS Uint32)) as Date) as order_date,
    Case
        WHEN label_utm_source = 'rasp' AND label_utm_medium = 'redirect' THEN 'direct_rasp'
        WHEN label_utm_source = 'sovetnik' AND label_utm_content = 'redirect' THEN 'direct_sovetnik'
        WHEN (label_utm_source = 'wizard_ru' OR label_utm_source = 'unisearch_ru') AND label_wizardredirkey IS NOT NULL THEN 'direct_wizard'
        WHEN (label_utm_source = 'wizard_ru' OR label_utm_source = 'unisearch_ru') AND label_wizardredirkey IS NULL THEN 'indirect_wizard'
        ELSE 'indirect'
    END AS label_category
order by label_category, order_date;
""".format(cpa_order_path=CPA_ORDERS_TABLE_PATH)


class UpdateAvgCheckError(Exception):
    pass


class GetDailyNumbersError(UpdateAvgCheckError):
    pass


DailyCoeff = namedtuple('DailyCoeff', ['label_category', 'order_date', 'day_coeff'])


def _convert_coeff_models_to_yt_format(coeff_date, coeff_models):
    for coeff_model in coeff_models:
        yield {
            'date': coeff_date.strftime('%Y-%m-%d'),
            'updated_at': coeff_model.updated_at.strftime('%Y-%m-%d %H:%M:%S'),
            'national_version': coeff_model.national_version.code,
            'redirect_type': coeff_model.redirect_type.code,
            'multiplier': float(coeff_model.multiplier),
        }


def get_daily_numbers(start_date, end_date, orders_window, normalizer):
    yql_client = YqlClient(token=settings.YQL_TOKEN)
    yql_query = yql_client.query(YQL_QUERY_TEMPLATE, syntax_version=1)
    yql_query.run(
        parameters=YqlParameterValueBuilder.build_json_map({
            '$start_date': YqlParameterValueBuilder.make_date(start_date),
            '$end_date': YqlParameterValueBuilder.make_date(end_date),
            '$normalizer': YqlParameterValueBuilder.make_float(normalizer),
        }),
    )
    yql_query.wait_progress()

    if not yql_query.is_ok:
        log_errors(yql_query, logger)
        raise GetDailyNumbersError('Error while running YQL')

    logger.info('YQL query is done')
    results = None
    for table in yql_query.get_results():
        table.fetch_full_data()
        results = _update_daily_numbers_for_yt_table(results, table, orders_window)
    return results


def _update_daily_numbers_for_yt_table(current_results, table, orders_window):
    if not current_results:
        current_results = defaultdict(lambda: defaultdict(list))

    records = []
    max_order_date = None
    for label_category, order_date, order_per_pax, day_coeff in table.rows:
        records.append(DailyCoeff(label_category, order_date, day_coeff))
        if not max_order_date or max_order_date < order_date:
            max_order_date = order_date

    # текущий день не учитывается в расчётах, поэтому максимальный день -
    # это следующий за самым последним днём, когда есть хоть один заказ
    if max_order_date:
        max_order_date = max_order_date + timedelta(days=1)

    for record in records:
        for day_shift in range(orders_window):
            coeff_calculation_date = record.order_date+timedelta(days=day_shift+1)
            if coeff_calculation_date <= max_order_date:
                current_results[record.label_category][coeff_calculation_date].append(record)
    return current_results


def save_models_to_db(coeff_models):
    logger.info('Saving click-price adjustment multipliers to db')
    for coeff_model in coeff_models:
        coeff_model.save()

    logger.info('All click-price adjustment multipliers have been written to database')


def save_results_to_yt(date, coeff_models, ytc, output_yt_path):
    table = yt.ypath_join(output_yt_path, date.strftime('%Y-%m-%d'))
    logger.info('Dumping to YT: %s', table)
    if not ytc.exists(table):
        ytc.create('table', table, recursive=True, attributes={
            'schema': [
                {'name': 'date', 'type': 'string'},
                {'name': 'national_version', 'type': 'string'},
                {'name': 'updated_at', 'type': 'string'},
                {'name': 'redirect_type', 'type': 'string'},
                {'name': 'multiplier', 'type': 'double'},
            ],
            'optimize_for': 'scan',
        })

    ytc.write_table(
        table,
        _convert_coeff_models_to_yt_format(date, coeff_models),
    )


def convert_to_models(redirect_type_to_multiplier_dict, updated_at):
    result = []
    ru_version = NationalVersion.objects.get(code='ru')
    redirect_type_picker = RedirectTypePicker()
    for redirect_type, multiplier in redirect_type_to_multiplier_dict.items():
        try:
            model_redirect_type = redirect_type_picker.get_by_code(redirect_type)
            existing_coeff_model = ClickPriceMultiplierByRedirectType.objects.get(
                national_version=ru_version,
                redirect_type=model_redirect_type,
            )
            existing_coeff_model.updated_at = updated_at
            existing_coeff_model.multiplier = multiplier
            result.append(existing_coeff_model)
        except ObjectDoesNotExist:
            new_coeff_model = ClickPriceMultiplierByRedirectType(
                national_version=ru_version,
                redirect_type=model_redirect_type,
                updated_at=updated_at,
                multiplier=multiplier,
            )
            result.append(new_coeff_model)

    return result


def calculate_weighted_avg_per_category(daily_numbers, output_date_start, orders_window, end_date):
    result = defaultdict(dict)
    for category, source_data in sorted(daily_numbers.items()):
        for coeff_date, data in sorted(source_data.items()):
            if coeff_date >= output_date_start and coeff_date <= end_date:
                result[category][coeff_date] = within_limits(calculate_weighted_avg(coeff_date, orders_window, data))
    return result


def within_limits(coeff):
    if coeff < MINIMAL_AVG_CHECK_COEFF:
        return MINIMAL_AVG_CHECK_COEFF
    if coeff > MAXIMAL_AVG_CHECK_COEFF:
        return MAXIMAL_AVG_CHECK_COEFF
    return coeff


def calculate_weighted_avg(coeff_date, orders_window, data):
    covered_portion = 0.
    result = 0.
    for elem in data:
        dates_diff = (coeff_date - elem.order_date).days
        if dates_diff > orders_window or dates_diff <= 0:
            continue
        day_weight = DAY_WEIGHTS[dates_diff if dates_diff < len(DAY_WEIGHTS) else -1]
        covered_portion += day_weight
        result += day_weight * elem.day_coeff
    if covered_portion != 1. and covered_portion > 0:
        result = result / covered_portion
    return result


def last_day_results(coeffs_per_category):
    result = {}
    for category, coeffs in coeffs_per_category.items():
        result[category] = coeffs[sorted(coeffs.keys())[-1]]
    return result


def build_avg_check_coeff(ytc, start_date, end_date, skip_writing_to_db, orders_window, normalizer, output_yt_path):
    updated_at = datetime.now()
    logger.info('Start calculating daily averages')
    logger.info('start_date = {}, end_date = {}, orders_window = {}, normalizer={}'.format(
        start_date, end_date, orders_window, normalizer))

    daily_numbers = get_daily_numbers(start_date, end_date, orders_window, normalizer)
    logger.info('Got daily numbers for categories: %s', daily_numbers.keys())
    for category, value in daily_numbers.items():
        logger.info('Got daily numbers for category: %s', category)
        for day, records in sorted(value.items()):
            logger.info('Day %s:', day)
            for record in records:
                logger.info('Record %s:', record)

    output_date_start = start_date + timedelta(days=orders_window)
    results = calculate_weighted_avg_per_category(daily_numbers, output_date_start, orders_window, end_date)
    for category, value in results.items():
        logger.info('Got results for category: %s', category)
        for day, coeff in sorted(value.items()):
            logger.info('Day %s: %f', day, coeff)

    last_day_multipliers = last_day_results(results)
    coeff_models = convert_to_models(last_day_multipliers, updated_at)

    if skip_writing_to_db:
        logger.info('Skip writing to database')
    else:
        save_models_to_db(coeff_models)

    today = date.today()
    save_results_to_yt(
        today,
        coeff_models,
        ytc,
        output_yt_path,
    )

    logger.info('Click-price adjustment multipliers have been successfuly calculated')


def _parse_date(dt):
    return datetime.strptime(dt, '%Y-%m-%d').date() if dt else None


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('--start-date', type=_parse_date, help='start date, default = today - 30 days')
    parser.add_argument('--end-date', type=_parse_date, help='end date, default = today')
    parser.add_argument('--skip-writing-to-db', action='store_true', help='do not save conversions to db')
    parser.add_argument('--orders-window',
                        default=7,
                        help='1 - calculate coeff, using orders created in a single day, 7 - in a week')
    parser.add_argument('--normalizer',
                        default=DEFAULT_NORMALIZER,
                        help='coeff normalizer, to keep the coeff value around 1')
    parser.add_argument('--output-yt-path', default=YT_MULTIPLIERS_LOG_ROOT, help='Yt path to store the result')

    args = parser.parse_args()

    if args.verbose:
        add_stdout_handler(logger)

    create_current_file_run_log()

    try:
        logger.info('Start')

        end_date = args.end_date if args.end_date else datetime.now().date()
        start_date = args.start_date if args.start_date else end_date - timedelta(days=30)

        build_avg_check_coeff(
            yt_client_fabric.create(),
            start_date,
            end_date,
            args.skip_writing_to_db,
            args.orders_window,
            args.normalizer,
            args.output_yt_path,
        )

        logger.info('End')
    except:
        logger.exception('ERROR')
        raise
