# coding=utf-8
import travel.avia.admin.init_project  # noqa

import argparse
import logging
from collections import Counter, defaultdict
from datetime import datetime, timedelta
from itertools import chain

import yt.wrapper as yt
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist

from travel.avia.library.python.avia_data.models.national_version import NationalVersion
from travel.avia.library.python.common.models.partner import ConversionByRedirectType, Partner
from travel.avia.admin.lib.logs import add_stdout_handler, create_current_file_run_log
from travel.avia.admin.lib.pricing import RedirectTypePicker
from travel.avia.admin.lib.yt_helpers import safe_tables_for_daterange, yt_client_fabric, yt_read_tables

logger = logging.getLogger(__name__)

REDIR_LOG_ROOT = '//home/avia/logs/avia-redir-balance-by-day-log'
BOOKING_LOG_ROOT = '//home/avia/logs/avia-partner-booking-log'

if settings.ENVIRONMENT == 'production':
    CONVERSION_LOG_ROOT = '//home/avia/logs/avia-conversion-log'
else:
    CONVERSION_LOG_ROOT = '//home/avia/testing/logs/avia-conversion-log'

MINIMAL_CONVERSION = 1e-5


def get_billing_order_ids():
    ids = Partner.objects.filter(
        use_in_update_conversions=True
    ).values_list('billing_order_id')
    return set(chain(*ids))


def _convert_conversions_to_yt_format(date, conversions):
    for cr_model in conversions:
        yield {
            'date': date.strftime('%Y-%m-%d'),
            'updated_at': cr_model.updated_at.strftime('%Y-%m-%d %H:%M:%S'),
            'national_version': cr_model.national_version.code,
            'redirect_type': cr_model.redirect_type.code,
            'conversion': float(cr_model.conversion),
        }


def get_bookings(ytc, date, partner_whitelist, conversion_window=1):
    tables = safe_tables_for_daterange(ytc, BOOKING_LOG_ROOT, date, date + timedelta(days=conversion_window))
    logger.info('Booking tables: %s', tables)
    return {
        record['marker']
        for record in yt_read_tables(ytc, tables)
        if record['billing_order_id'] in partner_whitelist and record['status'] == 'paid'
    }


def get_redirects(ytc, date, partner_whitelist, conversion_window=1):
    tables = safe_tables_for_daterange(ytc, REDIR_LOG_ROOT, date, date + timedelta(days=conversion_window))
    logger.info('Redirect tables: %s', tables)
    for record in yt_read_tables(ytc, tables):
        if record.get('FILTER'):  # Skip fraud records
            continue

        if record['NATIONAL_VERSION'] != 'ru':
            continue

        if record['BILLING_ORDER_ID'] not in partner_whitelist:
            continue
        yield record


def save_conversions_to_db(conversions):
    logger.info('Save conversions to db')
    for cr_model in conversions:
        cr_model.save()

    logger.info('All conversion have been written to database')


def save_conversions_to_yt(date, conversions, ytc):
    table = yt.ypath_join(CONVERSION_LOG_ROOT, date.strftime('%Y-%m-%d'))
    logger.info('Dump to YT: %s', table)
    if not ytc.exists(table):
        ytc.create('table', table, recursive=True, attributes={
            'schema': [
                {'name': 'date', 'type': 'string'},
                {'name': 'national_version', 'type': 'string'},
                {'name': 'updated_at', 'type': 'string'},
                {'name': 'redirect_type', 'type': 'string'},
                {'name': 'conversion', 'type': 'double'},
            ],
            'optimize_for': 'scan',
        })

    ytc.write_table(
        table,
        _convert_conversions_to_yt_format(date, conversions),
    )


def compute_conversions(all_redirects, booked):
    return {
        redirect_type: float(booked.get(redirect_type, 0)) / count if count > 0 else 0
        for redirect_type, count in all_redirects.iteritems()
    }


def convert_to_models(conversions, minimal_conversion, timestamp):
    result = []
    ru_version = NationalVersion.objects.get(code='ru')
    for redirect_type, conversion in conversions.iteritems():
        try:
            cr_model = ConversionByRedirectType.objects.get(
                national_version=ru_version,
                redirect_type=redirect_type,
            )

            if conversion >= minimal_conversion:
                cr_model.conversion = conversion
                cr_model.updated_at = timestamp

            else:
                logger.error(
                    'Conversion %f of %s is too small, will not be updated',
                    conversion,
                    redirect_type,
                )

        except ObjectDoesNotExist:
            if conversion >= minimal_conversion:
                cr_model = ConversionByRedirectType(
                    national_version=ru_version,
                    redirect_type=redirect_type,
                    updated_at=timestamp,
                    conversion=conversion
                )

            else:
                raise KeyError(
                    'Conversion {} of {} is too small and previous had not been found'.format(
                        conversion,
                        redirect_type
                    ))

        result.append(cr_model)

    return result


def exp_decay(element_number):
    return 1. / (2. ** element_number)


def normalization_factor(number_of_elements):
    return sum(exp_decay(n) for n in range(number_of_elements))


def redir_date(redirect):
    return _parse_date(redirect['ISO_EVENTTIME'].split()[0])


def iter_dates(date, max_delta):
    sign = 1 if max_delta > 0 else -1
    for i in range(abs(max_delta)):
        yield date + timedelta(days=i * sign)


def build_conversions(ytc, date, partner_whitelist, days_shift, skip_writing_to_db, conversion_window=1):
    logger.info('Start conversion estimation')
    logger.info('Partners in white list = {}'.format(partner_whitelist))
    logger.info('Date = {}, days_shift = {}, conversion_window = {}'.format(
        date, days_shift, conversion_window,
    ))

    redirect_date_start = date - timedelta(days=days_shift)
    bookings = get_bookings(ytc, redirect_date_start, partner_whitelist, conversion_window + 1)
    logger.info('Booked: %d', len(bookings))

    redirect_type_picker = RedirectTypePicker()
    booked_redirects = defaultdict(Counter)
    all_redirects = defaultdict(Counter)
    redirects = get_redirects(ytc, redirect_date_start, partner_whitelist, conversion_window)
    for redirect in redirects:
        redirect_type = redirect_type_picker.get_from_balance_log(redirect)
        current_redirect_date = redir_date(redirect)
        all_redirects[current_redirect_date][redirect_type] += 1
        if redirect['MARKER'] in bookings:
            booked_redirects[current_redirect_date][redirect_type] += 1
    logger.info('Redirects: %s', all_redirects)

    conversions = defaultdict(float)
    norm_factor = normalization_factor(conversion_window + 1)
    for i, d in enumerate(reversed(list(iter_dates(redirect_date_start, conversion_window + 1)))):
        if d not in all_redirects:
            logger.info('Date {} not in {}'.format(d, all_redirects.keys()))
            continue
        date_conversions = compute_conversions(
            all_redirects[d],
            booked_redirects.get(d, {}),
        )
        weight = exp_decay(i) / norm_factor
        logger.info('Date {} weight = {}. exp_decay = {}, norm_factor = {}'.format(
            d, weight, exp_decay(i), norm_factor,
        ))
        for redir_type, conversion in date_conversions.iteritems():
            logger.info('Redirect {} {} conversion = {}'.format(redir_type, d, conversion))
            conversions[redir_type] += conversion * weight

    logger.info('Conversions: %r', conversions)

    now = datetime.now()
    cr_models = convert_to_models(conversions, MINIMAL_CONVERSION, now)
    if skip_writing_to_db:
        logger.info('Skip writing to database')
    else:
        save_conversions_to_db(cr_models)

    save_conversions_to_yt(
        date,
        cr_models,
        ytc,
    )

    logger.info('Conversions estimated')


def _parse_date(dt):
    return datetime.strptime(dt, '%Y-%m-%d').date() if dt else None


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('--target-day-shift',
                        default=7,
                        type=int,
                        help='shift of day for conversion estimation (2 for conversion-window==1)')
    parser.add_argument('--start-date', type=_parse_date, help='start date')
    parser.add_argument('--end-date', type=_parse_date, help='end date')
    parser.add_argument('--skip-writing-to-db', action='store_true', help='do not save conversions to db')
    parser.add_argument('--conversion-window',
                        default=6,
                        help='1 - for calculation of conversion during 1 day, 6 - for week')

    args = parser.parse_args()

    if args.verbose:
        add_stdout_handler(logger)

    create_current_file_run_log()

    partner_whitelist = get_billing_order_ids()

    try:
        logger.info('Start')
        logger.info('Conversion log: %s', CONVERSION_LOG_ROOT)
        if args.start_date:
            if not args.end_date:
                raise ValueError('Start date is set and end date is not')

            current_date = args.start_date
            while current_date <= args.end_date:
                logger.info('Processing %s', current_date.strftime('%Y-%m-%d'))
                build_conversions(
                    yt_client_fabric.create(),
                    current_date,
                    partner_whitelist,
                    args.target_day_shift,
                    args.skip_writing_to_db,
                    args.conversion_window
                )
                current_date += timedelta(days=1)

        else:
            now = datetime.now()
            date = now.date() - timedelta(days=args.target_day_shift)
            logger.info('Processing %s', date.strftime('%Y-%m-%d'))
            build_conversions(
                yt_client_fabric.create(),
                now.date(),
                partner_whitelist,
                args.target_day_shift,
                args.skip_writing_to_db,
                args.conversion_window
            )

        logger.info('End')
    except:
        logger.exception('ERROR')
        raise
