# coding=utf-8
from __future__ import unicode_literals, print_function

import travel.avia.admin.init_project  # noqa

import argparse
import logging
from collections import Counter, defaultdict
from datetime import datetime, timedelta, date

import yt.wrapper as yt
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from yql.api.v1.client import YqlClient
from yql.client.parameter_value_builder import YqlParameterValueBuilder

from travel.avia.library.python.avia_data.models.national_version import NationalVersion
from travel.avia.admin.avia_scripts.conversion.booking_yql import build_booking_data_yql
from travel.avia.admin.avia_scripts.conversion.intervals import join_intervals, DateInterval
from travel.avia.library.python.common.models.partner import ConversionByRedirectType, Partner
from travel.avia.library.python.common.utils.date import daterange
from travel.avia.admin.lib.logs import add_stdout_handler, create_current_file_run_log
from travel.avia.admin.lib.pricing import RedirectTypePicker
from travel.avia.admin.lib.yql_helpers import log_errors
from travel.avia.admin.lib.yt_helpers import yt_client_fabric, safe_dates_and_tables_for_daterange

logger = logging.getLogger(__name__)

REDIR_LOG_ROOT = '//home/avia/logs/avia-redir-balance-by-day-log'
BOOKING_LOG_ROOT = '//home/avia/logs/avia-partner-booking-log'

if settings.ENVIRONMENT == 'production':
    CONVERSION_LOG_ROOT = '//home/avia/logs/avia-conversion-log'
else:
    CONVERSION_LOG_ROOT = '//home/avia/testing/logs/avia-conversion-log'

MINIMAL_CONVERSION = 1e-5

CPA_ORDER_YT_PATH = '//home/travel/prod/cpa/avia/orders'
MAX_UPDATED_AT_MSK_DATE_YQL = """
USE `hahn`;
DECLARE $date AS String;

$redirects_count = (
    SELECT
        BILLING_ORDER_ID as billing_order_id,
        count(MARKER) as cnt
    FROM RANGE(`{redirects_log_path}`, $date, $date)
    WHERE NATIONAL = 'ru'
    GROUP BY BILLING_ORDER_ID
);

$last_book_date = (
    SELECT
        billing_order_id,
        DateTime::Format('%Y-%m-%d')(AddTimezone(CAST(MAX(updated_at) as DateTime), 'Europe/Moscow')) as msk_date
    FROM `{cpa_order_path}`
    GROUP BY billing_order_id
);

SELECT
    booking.billing_order_id as billing_order_id,
    booking.msk_date as msk_date,
    redirects.cnt as redirects_cnt
FROM $last_book_date as booking
LEFT JOIN $redirects_count as redirects
ON booking.billing_order_id = redirects.billing_order_id;
""".format(cpa_order_path=CPA_ORDER_YT_PATH, redirects_log_path=REDIR_LOG_ROOT)


class UpdateConversionError(Exception):
    pass


class GetPartnerBookingsError(UpdateConversionError):
    pass


class GetPartnerIntervalsError(UpdateConversionError):
    pass


def get_billing_order_ids():
    partners = list(Partner.objects.filter(
        use_in_update_conversions=True,
        can_fetch_by_daemon=True,
        disabled=False,
    ))

    bad_partners = [p for p in partners if p.billing_order_id is None]
    if bad_partners:
        logger.warning('Партнеры, которые должны участвовать в расчете конверсии, но имеют billing_order_id=NULL: %s',
                       [p.code for p in bad_partners])

    good_partners = [p for p in partners if p.billing_order_id is not None]
    logger.info('Calculate conversion for partners:')
    for partner in good_partners:
        logger.info('\tid=%s code=%s billing_order_id=%s', partner.id, partner.code, partner.billing_order_id)

    return set(p.billing_order_id for p in good_partners)


def _convert_conversions_to_yt_format(date, conversions):
    for cr_model in conversions:
        yield {
            'date': date.strftime('%Y-%m-%d'),
            'updated_at': cr_model.updated_at.strftime('%Y-%m-%d %H:%M:%S'),
            'national_version': cr_model.national_version.code,
            'redirect_type': cr_model.redirect_type.code,
            'conversion': float(cr_model.conversion),
        }


def get_bookings(yql_client, interval_by_billing_order_id):
    markers = set()
    logger.info('Start get bookings')
    yql = build_booking_data_yql(CPA_ORDER_YT_PATH, interval_by_billing_order_id)
    yql_query = yql_client.query(yql, syntax_version=1)
    yql_query.run()
    yql_query.wait_progress()

    if not yql_query.is_ok:
        log_errors(yql_query, logger)
        raise GetPartnerBookingsError('Error while running YQL')

    logger.info('YQL query is done')
    for table in yql_query.get_results():
        table.fetch_full_data()
        markers |= {row[0] for row in table.rows}
    logger.info('Got bookings')
    return markers


def get_redirects(yt_client, interval, interval_by_billing_order_id):
    dates_and_tables = safe_dates_and_tables_for_daterange(
        yt_client, REDIR_LOG_ROOT, interval.start, interval.end, include_end=False
    )
    logger.info('Redirect tables: %s', [table for _day, table in dates_and_tables])
    for day, table in dates_and_tables:
        for record in yt_client.read_table(yt.TablePath(table), format=yt.JsonFormat()):
            if record.get('FILTER'):  # Skip fraud records
                continue

            if record['NATIONAL_VERSION'] != 'ru':
                continue

            if record['BILLING_ORDER_ID'] not in interval_by_billing_order_id:
                continue

            if day not in interval_by_billing_order_id[record['BILLING_ORDER_ID']]:
                continue

            yield record


def save_conversions_to_db(conversions):
    logger.info('Save conversions to db')
    for cr_model in conversions:
        cr_model.save()

    logger.info('All conversion have been written to database')


def save_conversions_to_yt(date, conversions, ytc, yt_path):
    table = yt.ypath_join(yt_path, date.strftime('%Y-%m-%d'))
    logger.info('Dump to YT: %s', table)
    if not ytc.exists(table):
        ytc.create('table', table, recursive=True, attributes={
            'schema': [
                {'name': 'date', 'type': 'string'},
                {'name': 'national_version', 'type': 'string'},
                {'name': 'updated_at', 'type': 'string'},
                {'name': 'redirect_type', 'type': 'string'},
                {'name': 'conversion', 'type': 'double'},
            ],
            'optimize_for': 'scan',
        })

    ytc.write_table(
        table,
        _convert_conversions_to_yt_format(date, conversions),
    )


def compute_conversions(all_redirects, booked):
    return {
        redirect_type: float(booked.get(redirect_type, 0)) / count if count > 0 else 0
        for redirect_type, count in all_redirects.iteritems()
    }


def convert_to_models(conversions, minimal_conversion, timestamp):
    result = []
    ru_version = NationalVersion.objects.get(code='ru')
    for redirect_type, conversion in conversions.iteritems():
        try:
            cr_model = ConversionByRedirectType.objects.get(
                national_version=ru_version,
                redirect_type=redirect_type,
            )

            if conversion >= minimal_conversion:
                cr_model.conversion = conversion
                cr_model.updated_at = timestamp

            else:
                logger.error(
                    'Conversion %f of %s is too small, will not be updated',
                    conversion,
                    redirect_type,
                )

        except ObjectDoesNotExist:
            if conversion >= minimal_conversion:
                cr_model = ConversionByRedirectType(
                    national_version=ru_version,
                    redirect_type=redirect_type,
                    updated_at=timestamp,
                    conversion=conversion
                )

            else:
                raise KeyError(
                    'Conversion {} of {} is too small and previous had not been found'.format(
                        conversion,
                        redirect_type
                    ))

        result.append(cr_model)

    return result


def exp_decay(element_number):
    return 1. / (2. ** element_number)


def normalization_factor(number_of_elements):
    return sum(exp_decay(n) for n in range(number_of_elements))


def redir_date(redirect):
    return _parse_date(redirect['ISO_EVENTTIME'].split()[0])


def iter_dates(date, max_delta):
    sign = 1 if max_delta > 0 else -1
    for i in range(abs(max_delta)):
        yield date + timedelta(days=i * sign)


def _parse_date(dt):
    return datetime.strptime(dt, '%Y-%m-%d').date() if dt else None


def get_partner_intervals(yql_client, partner_whitelist, today, conversion_window):
    interval_by_billing_order_id = {}
    yql_query = yql_client.query(MAX_UPDATED_AT_MSK_DATE_YQL, syntax_version=1)
    yql_query.run(
        parameters=YqlParameterValueBuilder.build_json_map({
            '$date': YqlParameterValueBuilder.make_string(
                (today - timedelta(days=1)).strftime('%Y-%m-%d')
            ),
        }),
    )
    yql_query.wait_progress()

    if not yql_query.is_ok:
        log_errors(yql_query, logger)
        raise GetPartnerIntervalsError('Error while running YQL')

    logger.info('YQL query is done')
    for table in yql_query.get_results():
        table.fetch_full_data()
        for billing_order_id, msk_max_updated_at_str, yesterday_redirects_count in table.rows:
            if billing_order_id not in partner_whitelist:
                continue
            if not yesterday_redirects_count:
                end_date = today
            else:
                end_date = _parse_date(msk_max_updated_at_str)
            end_date = min(end_date, today - timedelta(days=1))  # за сегодня данные еще не полные
            interval_by_billing_order_id[billing_order_id] = DateInterval(
                start=end_date - timedelta(days=conversion_window - 1),
                end=end_date + timedelta(days=1)  # DateRange excludes end
            )
    return interval_by_billing_order_id


def log_partner_intervals(interval_by_billing_order_id):
    logger.info('Partner Intervals:')
    for billing_order_id in sorted(interval_by_billing_order_id):
        interval = interval_by_billing_order_id[billing_order_id]
        logger.info('\t%s: [%s, %s)', billing_order_id, interval.start.strftime('%Y-%m-%d'),
                    interval.end.strftime('%Y-%m-%d'))


def log_joined_intervals(intervals):
    logger.info('Joined Intervals:')
    for interval in intervals:
        logger.info('\t[%s, %s)', interval.start.strftime('%Y-%m-%d'), interval.end.strftime('%Y-%m-%d'))


def log_redirects(booked_redirects, all_redirects):
    logger.info('Redirects:')
    for day, billing_order_id in sorted(all_redirects, key=lambda x: (x[1], x[0])):
        all_redirects_counter = all_redirects[(day, billing_order_id)]
        for redirect_type in sorted(all_redirects_counter):
            booked = booked_redirects.get((day, billing_order_id), Counter()).get(redirect_type, '-')
            logger.info('\t%s, %s, %s: %s / %s', billing_order_id, day.strftime('%Y-%m-%d'), redirect_type,
                        booked, all_redirects_counter[redirect_type])


def calculte_conversions(yt_client, today, partner_whitelist, conversion_window):
    logger.info('Conversion log: %s', CONVERSION_LOG_ROOT)
    logger.info('Processing %s', today.strftime('%Y-%m-%d'))

    yql_client = YqlClient(token=settings.YQL_TOKEN)
    interval_by_billing_order_id = get_partner_intervals(yql_client, partner_whitelist, today, conversion_window)
    log_partner_intervals(interval_by_billing_order_id)

    intervals = join_intervals(interval_by_billing_order_id.values())
    log_joined_intervals(intervals)

    bookings = get_bookings(yql_client, interval_by_billing_order_id)

    redirect_type_picker = RedirectTypePicker()
    booked_redirects = defaultdict(Counter)
    all_redirects = defaultdict(Counter)
    for interval in intervals:
        redirects = get_redirects(yt_client, interval, interval_by_billing_order_id)
        for redirect in redirects:
            redirect_type = redirect_type_picker.get_from_balance_log(redirect)
            current_redirect_date = redir_date(redirect)
            all_redirects[(current_redirect_date, redirect['BILLING_ORDER_ID'])][redirect_type] += 1
            if redirect['MARKER'] in bookings:
                booked_redirects[(current_redirect_date, redirect['BILLING_ORDER_ID'])][redirect_type] += 1
    log_redirects(booked_redirects, all_redirects)

    sum_redirects = defaultdict(Counter)
    sum_bookings = defaultdict(Counter)
    for billing_order_id, interval in interval_by_billing_order_id.items():
        for i, day in enumerate(reversed(list(daterange(interval.start, interval.end)))):
            if (day, billing_order_id) not in all_redirects:
                logger.info('billing_order_id %s and date %s not in redirects',
                            billing_order_id, day.strftime('%Y-%m-%d'))
                continue
            sum_redirects[i] += all_redirects[(day, billing_order_id)]
            sum_bookings[i] += booked_redirects.get((day, billing_order_id), Counter())

    date_conversions = {}
    for i in range(conversion_window):
        date_conversions[i] = compute_conversions(
            sum_redirects[i],
            sum_bookings.get(i, {})
        )

    conversions = defaultdict(float)
    norm_factor = normalization_factor(conversion_window)
    for i in date_conversions:
        weight = exp_decay(i) / norm_factor
        logger.info('Date number %s weight = %s. exp_decay = %s, norm_factor = %s',
                    i, weight, exp_decay(i), norm_factor)
        for redirect_type, conversion in date_conversions[i].iteritems():
            logger.info('Redirect %s date number %s conversion = %s', redirect_type, i, conversion)
            conversions[redirect_type] += conversion * weight
    return conversions


def log_conversions(conversions):
    logger.info('Conversions:')
    for redirect_type, conversion in conversions.items():
        logger.info('\t%s: %s', redirect_type, conversion)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('--skip-writing-to-db', action='store_true', help='do not save conversions to db')
    parser.add_argument('--conversion-window', default=7, type=int,
                        help='1 - for calculation of conversion during 1 day, 7 - for week')
    parser.add_argument('--output-yt-path', default=CONVERSION_LOG_ROOT, help='Yt path to store result')

    args = parser.parse_args()

    if args.verbose:
        add_stdout_handler(logger)
    create_current_file_run_log()

    logger.info('Start')
    partner_whitelist = get_billing_order_ids()
    try:
        today = date.today()
        yt_client = yt_client_fabric.create()

        conversions = calculte_conversions(yt_client, today, partner_whitelist, args.conversion_window)
        log_conversions(conversions)

        now = datetime.now()
        cr_models = convert_to_models(conversions, MINIMAL_CONVERSION, now)
        if args.skip_writing_to_db:
            logger.info('Skip writing to database')
        else:
            save_conversions_to_db(cr_models)

        save_conversions_to_yt(
            today,
            cr_models,
            yt_client,
            args.output_yt_path,
        )

        logger.info('Conversions estimated')
    except Exception:
        logger.exception('ERROR')
        raise

    logger.info('End')
