import travel.avia.admin.init_project  # noqa

import argparse
import logging
from collections import defaultdict
from datetime import datetime, timedelta

import yt.wrapper as yt
from django.conf import settings

from travel.avia.admin.lib.logs import create_current_file_run_log, add_stdout_handler
from travel.avia.admin.lib.statface_helpers import post_data_to_stat
from travel.avia.admin.lib.yt_helpers import yt_client_fabric, safe_tables_for_daterange


logger = logging.getLogger(__name__)
create_current_file_run_log()


ALLOWED_ENVS = ['production', 'dev']
REDIR_LOG = '//home/avia/logs/avia-json-redir-log'
PARTNER_REPORT = 'ticket.yandex/Redirects/BySourceAndPartner'
AGGREGATED_REPORT = 'ticket.yandex/Redirects/BySource'

_DATE_FORMAT = '%Y-%m-%d'
_ONE_DAY = timedelta(days=1)


def _parse_date(datestr):
    return datetime.strptime(datestr, '%Y-%m-%d').date() if datestr else datetime.today()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--verbose', dest='verbose',  action='store_true')
    parser.add_argument('--yesterday', action='store_true')
    parser.add_argument('--left-date', type=_parse_date, default=datetime.today().date())
    parser.add_argument('--right-date', type=_parse_date, default=datetime.today().date())
    parser.add_argument('--skip-env-check', action='store_true')

    return parser.parse_args()


def get_redirect_type(record):
    if record.get('wizardRedirKey'):
        return 'wizard_direct'

    utm_source = record.get('utm_source') or ''
    if utm_source.startswith('wizard') or utm_source.startswith('unisearch'):
        return 'wizard_indirect'

    if record.get('stid'):
        return 'sovetnik_direct'

    if utm_source == 'rasp' and record.get('utm_medium') == 'redirect':
        return 'rasp'

    return 'other'


def fix_unixtime(unixtime):
    return unixtime - unixtime % (60 * 60)


def get_redir_stat(yt_client, tables):
    # time -> pp -> partner -> type -> count
    stat_by_pp = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(int))))
    for table in tables:
        for record in yt_client.read_table(table, format=yt.JsonFormat()):
            time = fix_unixtime(record['unixtime'])
            pp = record.get('pp')
            if not pp:
                continue

            partner = record.get('billing_client_id')
            if not partner:
                continue

            redirect_type = get_redirect_type(record)
            stat_by_pp[time][pp][partner][redirect_type] += 1

    return stat_by_pp


def aggregate_by_partner(redir_stat):
    # redir_stat ==  time -> pp -> partner -> type -> count

    # time -> pp -> type -> count
    aggregated = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
    for time, pp_stat in redir_stat.iteritems():
        for pp, partners in pp_stat.iteritems():
            for partner, type_stat in partners.iteritems():
                for redirect_type, count in type_stat.iteritems():
                    aggregated[time][pp][redirect_type] += count

    # add sum for each plaform
    for time, pp_stat in aggregated.iteritems():
        for pp, redirect_type in pp_stat.iteritems():
            aggregated[time][pp]['total'] = sum(redirect_type.itervalues())

    return aggregated


def send_partner_stat(redir_stat):
    records = [
        {
            'fielddate': time,
            'pp': pp,
            'billing_client_id': billing_client_id,
            'redirect_type': redirect_type,
            'count': count,
        }
        for time, pp_stat in redir_stat.iteritems()
        for pp, partners in pp_stat.iteritems()
        for billing_client_id, partner_stat in partners.iteritems()
        for redirect_type, count in partner_stat.iteritems()
    ]

    post_data_to_stat(
        PARTNER_REPORT,
        records,
        scale='h',
        beta=settings.ENVIRONMENT != 'production',
    )


def send_aggregated_stat(redir_stat):
    records = [
        {
            'fielddate': time,
            'pp': pp,
            'redirect_type': redirect_type,
            'count': count,
        }
        for time, pp_stat in redir_stat.iteritems()
        for pp, stat_by_type in pp_stat.iteritems()
        for redirect_type, count in stat_by_type.iteritems()
    ]

    post_data_to_stat(
        AGGREGATED_REPORT,
        records,
        scale='h',
        beta=settings.ENVIRONMENT != 'production',
    )


def main():
    args = parse_args()
    if args.verbose:
        add_stdout_handler(logger)

    logger.info('Start')

    if args.yesterday:
        left_date = right_date = datetime.today().date() - _ONE_DAY
    else:
        left_date = args.left_date
        right_date = args.right_date

    if not (args.skip_env_check or settings.ENVIRONMENT in ALLOWED_ENVS):
        logger.info('Can work only in %s', ', '.join(ALLOWED_ENVS))
        return

    yt_client = yt_client_fabric.create()
    tables = safe_tables_for_daterange(yt_client, REDIR_LOG, left_date, right_date)

    # pp -> partner -> type -> count
    redir_stat = get_redir_stat(yt_client, tables)

    # pp -> type -> count
    redir_aggregated = aggregate_by_partner(redir_stat)

    # add sum for each partner
    for time, pp_stat in redir_stat.iteritems():
        for partners in pp_stat.itervalues():
            for partner, partner_stat in partners.iteritems():
                partner_stat['total'] = sum(partner_stat.itervalues())

    send_partner_stat(redir_stat)
    send_aggregated_stat(redir_aggregated)

    logger.info('End')
