# -*- coding: utf-8 -*-
import csv
import functools
import logging
import os
import requests

from collections import defaultdict
from datetime import datetime

from django.utils import timezone

from celery import Celery

import yt.wrapper as yt


celery_app = Celery('avia_tasks')

celery_app.config_from_object('django.conf:settings')

log = logging.getLogger(__name__)


@celery_app.task()
def top_positions_task(data, job):
    log.info(data)

    def logs_tables(path, from_date, to_date):
        tables = []
        for table in yt.search(path, node_type="table"):
            try:
                table_date = datetime.strptime(table.split('/')[-1], '%Y-%m-%d').date()
            except ValueError:
                continue

            if from_date <= table_date <= to_date:
                tables.append(table)

        return tables

    def fix_point(point_key, stations_map):
        if point_key and point_key.startswith('s'):
            return stations_map.get(point_key)

        return point_key

    def map_top_100(national_version, stations_map, record):
        from_id = fix_point(record.get('from_id'), stations_map)
        to_id = fix_point(record.get('to_id'), stations_map)

        service = record.get('service')
        yandexuid = record.get('yandexuid')

        conditions = [
            service == 'ticket',
            record.get('national_version', '').lower() == national_version.lower(),
            yandexuid,
            from_id,
            to_id
        ]

        if all(conditions):
            yield {
                'direction': '%s_%s' % (from_id, to_id)
            }

    def reduce_top_100(key, records):
        count = 0

        for r in records:
            count += 1

        yield {
            'direction': key['direction'],
            'count': count
        }

    def human_direction(direction):
        from_point_key, to_point_key = direction.split('_')
        from_title = settlements_map.get(from_point_key, from_point_key)
        to_title = settlements_map.get(to_point_key, to_point_key)

        return '%s - %s' % (
            from_title.encode('utf-8'),
            to_title.encode('utf-8'),
        )

    def build_top_directions():
        yt.run_map(
            functools.partial(map_top_100, data['national_version'], stations_map),
            source_table=yt_tables,
            destination_table=tmp_table,
            format=yt.DsvFormat(),
            spec={"data_size_per_job": settings.YT_DATA_SIZE_PER_JOB},
        )

        yt.run_sort(
            source_table=tmp_table,
            sort_by='direction',
        )

        yt.run_reduce(
            reduce_top_100,
            tmp_table,
            tmp_table,
            format=yt.DsvFormat(),
            reduce_by='direction',
            spec={"data_size_per_job": settings.YT_DATA_SIZE_PER_JOB},
        )

    def get_sorted_top_directions():
        top_directions = []

        for record in yt.read_table(tmp_table, format=yt.JsonFormat(), raw=False):
            count = int(record['count'])
            direction = record['direction']

            top_directions.append((count, direction))

        return sorted(top_directions, reverse=True)[:data['quantity']]

    def append_result(result_key, direction, partner_code, position, results):
        if result_key not in results:
            results[result_key] = {}

        if direction not in results[result_key]:
            results[result_key][direction] = defaultdict(float)

        results[result_key][direction][partner_code] = position

    def save_top_directions(top_directions, shows_count_results, file_name):
        all_partners = set()

        for k, partners in shows_count_results.items():
            for p in partners.keys():
                all_partners.add(p)

        all_partners = sorted(all_partners)

        csvfile = file(file_name, 'w')
        csv_writer = csv.writer(csvfile, delimiter=';', quotechar='"', quoting=csv.QUOTE_ALL)
        row = [u'Direction', u'Direction key', u'Searches'] + [p for p in all_partners]
        csv_writer.writerow(row)

        for count, direction in top_directions:
            row = [human_direction(direction), direction, count]
            for partner in all_partners:
                row.append(dict(shows_count_results[direction]).get(partner, 0))

            csv_writer.writerow(row)

        return csvfile.name

    def save_results(results, file_name, date_format):
        csvfile = file(file_name, 'w')
        csv_writer = csv.writer(csvfile, delimiter=';', quotechar='"', quoting=csv.QUOTE_ALL)
        row = [u'Date', u'Direction', u'Direction key'] + [partner_codes[k].encode('utf-8') for k in partner_codes]
        csv_writer.writerow(row)

        for event_key in sorted(results.keys()):
            for direction in top_directions:
                pos = []
                for partner_code in partner_codes:
                    try:
                        avg_positions = float(results[event_key][direction][partner_code])
                    except KeyError:
                        avg_positions = 0

                    pos_str = '%.1f' % (avg_positions) if avg_positions else '-'
                    pos.append(pos_str.replace('.', ','))

                row = [event_key.strftime(date_format), human_direction(direction), direction] + pos
                csv_writer.writerow(row)

        return csvfile.name

    def get_partner_codes(partner_code):
        partner_codes = {}
        partners_url = '%s/partner/codes/' % settings.RASP_API_HOST
        vendors_url = '%s/dohop_vendors/codes/' % settings.RASP_API_HOST

        r = requests.get(
            partners_url,
            verify=False
        )

        for partner in r.json():
            billing_datasource_id_production = partner['billing_datasource_id_production']
            code = partner['code']

            if partner_code and partner_code != code:
                continue

            if billing_datasource_id_production:
                partner_codes[code] = partner['title']

        r = requests.get(
            vendors_url,
            verify=False
        )

        for partner in r.json():
            code = partner['code']

            if partner_code and partner_code != code:
                continue

            if partner['enabled']:
                partner_codes[code] = partner['title']

        return partner_codes

    def show_map(top_directions, national_version, record):
        direction = '%s_%s' % (
            record['city_from_key'],
            record['city_to_key'],
        )

        if direction in top_directions and national_version == record.get('national'):
            yield {
                'direction': direction,
                'partner': record['partner'],
                'position': record['position'],
                'day_key': record['iso_eventtime'][:10],
                'month_key': record['iso_eventtime'][:7],
            }

    def show_day_reduce(event_key, key, records):
        positions = []
        for r in records:
            try:
                position = int(r.get('position'))

            except ValueError:
                continue

            positions.append(position)

        yield {
            'direction': key['direction'],
            'partner': key['partner'],
            'avg_position': '%.1f' % (float(sum(positions)) / len(positions)) if positions else 0,
            event_key: key[event_key]
        }

    def parse_yt_show_logs(show_log_tables, top_directions):
        if not show_log_tables:
            return {}, {}

        map_tmp_table = yt.create_temp_table()

        day_tmp_table = yt.create_temp_table()

        month_tmp_table = yt.create_temp_table()

        yt.run_map(
            functools.partial(show_map, top_directions, data['national_version']),
            source_table=show_log_tables,
            destination_table=map_tmp_table,
            format=yt.DsvFormat(),
            spec={"data_size_per_job": settings.YT_DATA_SIZE_PER_JOB},
        )

        yt.run_sort(
            source_table=map_tmp_table,
            sort_by=['direction', 'partner', 'day_key'],
        )

        yt.run_reduce(
            functools.partial(show_day_reduce, 'day_key'),
            source_table=map_tmp_table,
            destination_table=day_tmp_table,
            format=yt.DsvFormat(),
            spec={"data_size_per_job": settings.YT_DATA_SIZE_PER_JOB},
            reduce_by=['direction', 'partner', 'day_key'],
        )

        yt.run_sort(
            source_table=map_tmp_table,
            sort_by=['direction', 'partner', 'month_key'],
        )

        yt.run_reduce(
            functools.partial(show_day_reduce, 'month_key'),
            source_table=map_tmp_table,
            destination_table=month_tmp_table,
            format=yt.DsvFormat(),
            spec={"data_size_per_job": settings.YT_DATA_SIZE_PER_JOB},
            reduce_by=['direction', 'partner', 'month_key'],
        )

        day_results = {}
        month_results = {}

        for record in yt.read_table(day_tmp_table, format=yt.JsonFormat(), raw=False):
            key = datetime.strptime(record['day_key'], '%Y-%m-%d').date()
            direction = record['direction']
            partner = record['partner']
            avg_position = float(record['avg_position'])

            append_result(key, direction, partner, avg_position, day_results)

        for record in yt.read_table(month_tmp_table, format=yt.JsonFormat(), raw=False):
            key = datetime.strptime(record['month_key'], '%Y-%m').date()
            direction = record['direction']
            partner = record['partner']
            avg_position = float(record['avg_position'])

            append_result(key, direction, partner, avg_position, month_results)

        return day_results, month_results

    def get_stations_map():
        stations = requests.get(
            '%s/station/airports/' % settings.RASP_API_HOST,
            verify=False
        ).json()

        return {k: v['settlement_point_key'] for k, v in stations.items()}

    def get_settlements_map():
        setlements = requests.get(
            '%s/settlement/list/' % settings.RASP_API_HOST,
            verify=False
        ).json()

        return {k: v['title'] for k, v in setlements.items()}

    def show_count_redurce(key, record):
        count = 0

        for r in record:
            count += 1

        yield {
            'direction': key['direction'],
            'partner': key['partner'],
            'count': count
        }

    def build_shows_count(top_directions):
        shows_count_results = defaultdict(lambda : defaultdict(int))
        show_tmp_table = yt.create_temp_table()

        if show_log_tables:
            yt.run_map(
                functools.partial(show_map, top_directions, data['national_version']),
                source_table=show_log_tables,
                destination_table=show_tmp_table,
                format=yt.DsvFormat(),
                spec={"data_size_per_job": settings.YT_DATA_SIZE_PER_JOB},
            )

            yt.run_sort(
                source_table=show_tmp_table,
                sort_by=['direction', 'partner'],
            )

            yt.run_reduce(
                show_count_redurce,
                source_table=show_tmp_table,
                destination_table=show_tmp_table,
                format=yt.DsvFormat(),
                spec={"data_size_per_job": settings.YT_DATA_SIZE_PER_JOB},
                reduce_by=['direction', 'partner'],
            )

        for record in yt.read_table(show_tmp_table, format=yt.JsonFormat(), raw=False):
            count = int(record['count'])
            partner = record['partner']
            direction = record['direction']

            shows_count_results[direction][partner] = count

        return shows_count_results

    # START MAIN
    from email.mime.text import MIMEText

    from django.conf import settings
    from django.core.mail.message import EmailMultiAlternatives

    from travel.avia.stat_admin.lib.jobs import get_or_create_job_storage_dir, get_zipped_results

    log.info('Start')

    range_key = '%s_%s' % (data['start_date'], data['end_date'])

    log.info('Dates: %s' % (range_key.replace('_', ' - ')))
    log.info('Email: %s' % (data['mailto']))

    yt.config['proxy']['url'] = settings.YT_PROXY
    yt.config['token'] = settings.YT_TOKEN
    yt.config['read_retries']['enable'] = True
    yt.config['clear_local_temp_files'] = True

    log.info('Check output dir')
    storage_dir = get_or_create_job_storage_dir(data['job'])

    log.info('Prepare data')

    partner_codes = get_partner_codes(data.get('partner'))
    stations_map = get_stations_map()
    settlements_map = get_settlements_map()
    range_key = '%s_%s' % (data['start_date'], data['end_date'])

    yt_tables = logs_tables(
        '//home/rasp/logs/rasp-users-search-log',
        data['start_date'],
        data['end_date'],
    )

    jobs_count = len(yt_tables) * 20

    if jobs_count < 100:
        jobs_count = 100

    show_log_tables = logs_tables(
        '//home/rasp/logs/rasp-tickets-show-log',
        data['start_date'],
        data['end_date'],
    )

    tmp_table = yt.create_temp_table(
        path=settings.TOP_DIRECTIONS_TMP_PATH,
        prefix='rasp_min_price_'
    )

    try:
        log.info('Build top')
        if yt_tables:
            build_top_directions()
            sorted_top_directions = get_sorted_top_directions()
        else:
            sorted_top_directions = []

        top_directions = [d[1] for d in sorted_top_directions]

        shows_count_results = build_shows_count(top_directions)

        log.info('Calculate positions')
        day_results, month_results = parse_yt_show_logs(
            show_log_tables, top_directions
        )

        log.info('Save results positions to %s' % storage_dir)

        day_results_filename = save_results(day_results, os.path.join(storage_dir, 'results_day_%s.csv' % range_key), '%m.%d.%Y')
        month_results_filename = save_results(month_results, os.path.join(storage_dir, 'results_month_%s.csv' % range_key), '%B %Y')
        top_results_filename = save_top_directions(sorted_top_directions, shows_count_results, os.path.join(storage_dir, 'top_directions_%s.csv' % range_key))

        log.info('Send results to email')

        mail = EmailMultiAlternatives(
            subject=u'Топ-%s + места в выдаче %s, %s' % (data['quantity'], range_key.replace('_', ' - '), data['national_version']),
            body=u'Даты: %s\n\nФайлы во вложении:\n\n' % (range_key.replace('_', ' - ')),
            from_email=settings.SERVER_EMAIL,
            to=['%s@yandex-team.ru' % data['mailto']],
        )

        for file_name in [day_results_filename, month_results_filename, top_results_filename]:
            with open(file_name, 'r') as f:
                data_tab = f.read().decode('utf-8')

            attachment = MIMEText(data_tab.encode('cp1251'), 'csv', 'cp1251')
            attachment.add_header(
                'Content-Disposition', 'attachment',
                filename=os.path.basename(file_name)
            )

            mail.attach(attachment)

        try:
            mail.send()

        except Exception:
            log.exception("ERROR")

        log.info('Done')

    except Exception as e:
        log.exception('Error')

        job.success = False
        job.state = 'Failed: %s' % e.message
        job.ended = timezone.now()
        job.save()

        return

    job.ended = timezone.now()
    job.success = True
    job.state = 'Finished'
    job.binary = get_zipped_results(job.id)

    job.save()


# Делаем, чтобы таск был под старым именем 'data.tasks.top_positions_task'.
# Первая строчка для того, чтобы при создании таска в очередь записать старое имя.
# Вторая строчка для того, чтобы при чтении таска из очереди celery нашел его.
top_positions_task.name = 'data.tasks.top_positions_task'
celery_app.tasks['data.tasks.top_positions_task'] = celery_app.tasks['travel.avia.stat_admin.data.tasks.top_positions_task']
