# -*- coding: utf-8 -*-
import logging
from datetime import date, timedelta, datetime
from itertools import chain


from django.core.management import BaseCommand
from django.conf import settings

import yt.wrapper as ytw
import yt.logger as ytlogger


from ._loggers import setup_logger
from travel.avia.admin.lib import yt_helpers as yth


LOGGER_FMT_STR = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
formatter = logging.Formatter(LOGGER_FMT_STR)
logging.basicConfig(format=LOGGER_FMT_STR)
log = logging.getLogger(__name__)


ALLOWED_ENVS = ["production", "dev"]


LOG_PATH = '//home/rasp/logs/avia-suggests-log'
SCHEMA = [
    {'name': 'count', 'type': 'int64'},
    {'name': 'need_country', 'type': 'boolean'},
    {'name': 'query', 'type': 'string'},
]


class Trie:
    def __init__(self):
        self.graph = {}

    def add(self, token, count):
        current_node = self.graph

        for letter in token:
            if letter not in current_node:
                current_node[letter] = {
                    '_empty': 0,
                }
            current_node = current_node[letter]

        current_node['_empty'] += count

    def generate_minimal_empties(self):
        return self._generate_minimal_empties_from_node(self.graph, '')

    def _generate_minimal_empties_from_node(self, node, path):
        for letter, next_node in node.iteritems():
            if letter == '_empty':
                continue
            next_path = path + letter
            if next_node['_empty'] > 0:
                yield next_path, next_node['_empty']
            else:
                for x in self._generate_minimal_empties_from_node(next_node, next_path):
                    yield x


@ytw.aggregator
def word_count(records):
    saved = {}
    for r in records:
        if r['rawQuery'] == '' or int(r['count']) != 0:
            continue

        key = r['rawQuery'], r['needCountry']
        saved[key] = saved.get(key, 0) + 1

    for key, value in saved.iteritems():
        raw_query, need_country = key

        yield {
            'query': raw_query,
            'need_country': need_country,
            'num': value,
        }


def summator(key, records):
    ans = dict(key)
    ans['num'] = sum(r['num'] for r in records)

    yield ans


class Command(BaseCommand):
    help = 'Building report for empty suggests'

    def add_arguments(self, parser):
        parser.add_argument(
            '--stdout', action='store_true', default=False,
            dest='add_stdout_handler', help='Add stdout handler',
        )

        parser.add_argument(
            '-d', '--days', type=int, default=30,
            dest='days', help='number of days to process',
        )

    def handle(self, *db_names, **options):
        setup_logger(
            log,
            options.get('verbosity'),
            options.get('add_stdout_handler'),
            formatter=logging.Formatter(LOGGER_FMT_STR),
        )

        setup_logger(
            ytlogger.LOGGER,
            options.get('verbosity'),
            options.get('add_stdout_handler'),
            formatter=logging.Formatter(LOGGER_FMT_STR),
        )

        try:
            log.info('Start')
            current_env = settings.ENVIRONMENT
            if current_env not in ALLOWED_ENVS:
                log.info('Can\'t work in %s', current_env)
                return

            start_date = date.today() - timedelta(options.get('days'))
            end_date = date.today() - timedelta(days=1)

            yth.configure_wrapper(ytw)

            source_tables = yth.tables_for_daterange(
                ytw, LOG_PATH, start_date, end_date
            )
            source_tables = [
                ytw.TablePath(input_table, columns=['needCountry', 'count', 'rawQuery'])
                for input_table in source_tables
            ]

            log.info('Count empty queries')

            word_count_table = ytw.create_temp_table()
            ytw.run_map_reduce(
                source_table=source_tables,
                destination_table=word_count_table,
                mapper=word_count,
                reducer=summator,
                reduce_combiner=summator,
                reduce_by=['query', 'need_country'],
            )

            log.info('Building tries')

            with_countries_trie = Trie()
            without_countries_trie = Trie()

            for record in ytw.read_table(word_count_table):
                if record['need_country']:
                    with_countries_trie.add(record['query'], record['num'])
                else:
                    without_countries_trie.add(record['query'], record['num'])

            with_countries_records = (
                {'query': token, 'count': count, 'need_country': True}
                for token, count in with_countries_trie.generate_minimal_empties()
            )

            without_countries_records = (
                {'query': token, 'count': count, 'need_country': False}
                for token, count in without_countries_trie.generate_minimal_empties()
            )

            records = sorted(
                chain(with_countries_records, without_countries_records),
                key=lambda x: x['count'],
                reverse=True,
            )

            output_table_path = '//home/rasp/suggest_reports/{}'.format(
                datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
            )

            log.info('Writing results to %s', output_table_path)

            ytw.create('table', output_table_path, recursive=True, attributes={
                'schema': SCHEMA,
            })
            ytw.set_attribute(output_table_path, 'optimize_for', 'scan')
            ytw.set_attribute(output_table_path, 'days', options.get('days'))
            ytw.write_table(output_table_path, records)

            log.info('End')

        except Exception as e:
            log.exception('Exception: %r', e)
