#!/usr/bin/python
# -*- coding: utf-8 -*-
from argparse import ArgumentParser
from datetime import (
    date,
    timedelta,
)
import json
import logging
import logging.config
import os
import sys
import traceback

from passport.backend.ch_stat_loader.ch_stat_loader import settings
from passport.backend.ch_stat_loader.ch_stat_loader.am_notifications import get_AMnotificator
from passport.backend.ch_stat_loader.ch_stat_loader.am_reports_config import AM_REPORTS
from passport.backend.ch_stat_loader.ch_stat_loader.query import (
    APIKeysQuery,
    get_login_sdk_frequent_reporters,
    QuotaExceededClickhouseException,
    set_api_keys,
    set_client_id_mapping_getter,
)
from passport.backend.ch_stat_loader.ch_stat_loader.report import (
    is_report_published,
    Report,
)
from passport.backend.ch_stat_loader.ch_stat_loader.utils import (
    from_str_to_date,
    retriable_n,
)
from passport.backend.ch_stat_loader.ch_stat_loader.yt_utils import run_exclusively
import requests
from statface_client import StatfaceClient


log = logging.getLogger('ch_stat_loader.runner')


class Runner(object):
    def __init__(self, reports):
        self.reports = reports

    def build_queries(self):
        joinable_queries = {}
        target_queries = []
        for report in self.reports:
            for query in report.get_queries():
                query_joined = False
                if query.can_be_joined:
                    join_key = query.join_key
                    if join_key in joinable_queries:
                        joinable_queries[join_key].join(query)
                        query_joined = True
                    else:
                        joinable_queries[join_key] = query
                if not query_joined:
                    target_queries.append(query)

        return target_queries

    def run(self):
        exceptions = []
        for query in sorted(self.build_queries(), key=lambda q: q.priority, reverse=True):
            try:
                results = query.execute()
                query.notify_reports(results)
            except Exception as e:
                exceptions.append(sys.exc_info())
                if isinstance(e, QuotaExceededClickhouseException):
                    break

        for report in self.reports:
            try:
                report.publish()
            except Exception:
                exc_info = sys.exc_info()
                exceptions.append((
                    exc_info[0],
                    '{}\nin report \'{}\''.format(exc_info[1], report.path),
                    exc_info[2],
                ))

        if len(exceptions) == 0:
            return
        elif len(exceptions) == 1:
            raise exceptions[0][0], exceptions[0][1], exceptions[0][2]
        else:
            exception_strings = [
                '{}{}: {}'.format(''.join(traceback.format_tb(ex[2])), ex[0].__name__, ex[1])
                for ex in exceptions
            ]
            raise Exception('Multiple exceptions while collecting reports:\n\n{}'.format('\n\n'.join(exception_strings)))


def parse_args(argv):
    parser = ArgumentParser()
    parser.add_argument(
        '--date_start',
        type=lambda string: from_str_to_date(string).date(),
        metavar='YYYY-MM-DD',
        nargs='?',
        default=date.today() - timedelta(days=1),
    )
    parser.add_argument(
        '--date_end',
        type=lambda string: from_str_to_date(string).date(),
        metavar='YYYY-MM-DD',
        nargs='?',
        default=date.today(),
    )
    parser.add_argument('--skip_api_keys_update', action='store_true')
    parser.add_argument('--force_recalculate', action='store_true')
    parser.add_argument('--send_notifications', action='store_true')
    parser.add_argument(
        '--notifications',
        nargs='+',
        default=[],
        metavar='notification-name',
        help='Notifications to process. All notifications will be processed if argument is absent.',
    )
    parser.add_argument(
        'reports',
        metavar='REPORT_PATH',
        type=lambda string: string.strip('/').lower(),
        nargs='*',
        help='path to target reports (without host)',
    )
    parser.add_argument(
        '--only_show_paths',
        action='store_true',
        help='Prints paths to target reports (without host). Does not do anything else.',
    )
    parser.add_argument(
        '-pp',
        '--publish_production',
        action='store_true',
        help='publish to Statface production instead of beta',
    )
    return parser.parse_args(argv[1:])


def ensure_yt_settings_path_exist(yt):
    if not yt.exists(settings.YT_SETTINGS_PATH):
        yt.create('document', settings.YT_SETTINGS_PATH, recursive=True)
        yt.set_attribute(settings.YT_SETTINGS_PATH, 'api-keys', [])
        yt.set_attribute(settings.YT_SETTINGS_PATH, 'client-id-mapping', {})


def prepare_client_id_mapping(yt, target_date):
    ensure_yt_settings_path_exist(yt)
    client_id_mapping = dict(yt.get_attribute(settings.YT_SETTINGS_PATH, 'client-id-mapping'))
    client_id_mapping = {
        key: client_id_mapping[key].decode('utf-8') for key in client_id_mapping
    }
    log.debug('Client id mapping from YT: %s', client_id_mapping)

    @retriable_n(10, 10)
    def get_name_by_client_id(client_id):
        r = requests.get(
            'https://oauth.yandex.ru/client/%s/info' % client_id,
        )
        name = json.loads(r.text)['name']
        log.debug('Loaded name for unknown client_id: %s -> %s', client_id, name)
        return name

    frequent_reporters = get_login_sdk_frequent_reporters(target_date)
    for client_id in frequent_reporters - set(client_id_mapping.keys()):
        client_id_mapping[client_id] = get_name_by_client_id(client_id)
    client_id_mapping = {
        key: client_id_mapping[key].encode('utf-8') for key in client_id_mapping
    }
    yt.set_attribute(settings.YT_SETTINGS_PATH, 'client-id-mapping', client_id_mapping)

    return client_id_mapping


def prepare_api_keys(yt, target_date, skip_api_keys_update):
    ensure_yt_settings_path_exist(yt)
    known_api_keys = set(yt.get_attribute(settings.YT_SETTINGS_PATH, 'api-keys'))
    log.debug('API keys from YT: %s', known_api_keys)

    known_api_keys.update(settings.KNOWN_GOOD_API_KEYS)
    if not skip_api_keys_update:
        query = APIKeysQuery(date_start=target_date, date_end=target_date + timedelta(days=1), report=None)
        result = query.execute()
        row = result.get_slice(['api_keys'])[0]
        updated_keys = set(map(int, row['api_keys']))
        updated_keys -= settings.KNOWN_BANNED_API_KEYS

        changes = updated_keys - known_api_keys
        if changes:
            log.debug('Detected new API keys: %s', changes)
        known_api_keys.update(updated_keys)
        yt.set_attribute(settings.YT_SETTINGS_PATH, 'api-keys', list(sorted(known_api_keys)))

    return known_api_keys


@run_exclusively(lock_path=settings.YT_CH_STAT_LOADER_LOCK_NODE)
def main():
    logging.config.dictConfig(settings.LOGGING_CONFIG)
    yt = settings.get_yt()
    args = parse_args(sys.argv)

    statface_client = StatfaceClient(
        host=settings.STATFACE_HOST if args.publish_production else settings.STATFACE_BETA_HOST,
        username=settings.STATFACE_ROBOT_LOGIN,
        password=settings.STATFACE_ROBOT_PASSWORD,
    )

    am_notificator = get_AMnotificator(statface_client)
    if len(args.notifications) > 0:
        log.debug('Active notifications: {}'.format(args.notifications))
        am_notificator.set_active_notifications(args.notifications)
    am_notificator.reset_notifications()

    set_client_id_mapping_getter(lambda target_date: prepare_client_id_mapping(settings.get_yt(), target_date))

    prefix = 'Passport.All/own/am'
    reports = []
    for offset in range((args.date_end - args.date_start).days):
        target_date = args.date_start + timedelta(days=offset)

        api_keys = prepare_api_keys(yt, target_date, args.skip_api_keys_update)
        set_api_keys(api_keys)

        for config in AM_REPORTS:
            subfolder = config['name']
            for report_config in config['reports']:
                for type_ in report_config['types']:
                    name = type_
                    if type_ == 'custom':
                        name = report_config['name']
                    path = os.path.join(prefix, subfolder, name)
                    if args.reports:
                        if path.strip('/').lower() not in args.reports:
                            log.debug('Skipping %s (%s)', path, target_date)
                            continue
                    if not args.force_recalculate:
                        if is_report_published(statface_client, path, target_date):
                            log.debug('Skipping (already published) %s (%s)', path, target_date)
                            continue

                    report = Report(
                        group_title=config['description'],
                        path=path,
                        type_=type_,
                        config=report_config,
                        date_start=target_date,
                        date_end=target_date + timedelta(days=1),
                        custom_types=report_config.get('custom_types'),
                        statface_client=statface_client,
                    )
                    reports.append(report)

    if args.only_show_paths:
        print('\n'.join([r.path for r in reports]))
        return

    runner = Runner(reports)
    runner.run()

    del runner
    del reports

    if args.send_notifications:
        am_notificator.send_notifications(args.date_start, args.date_end)


if __name__ == '__main__':
    main()
