# -*- coding: utf-8 -*-

from __future__ import (
    absolute_import,
    unicode_literals,
)

from argparse import FileType as ArgParseFileType
import contextlib
import functools
import logging
import time

from passport.backend.core.builders.blackbox.blackbox import Blackbox
from passport.backend.core.conf import settings
from passport.backend.core.dbmanager.manager import safe_execute_queries
from passport.backend.core.dbmanager.transaction_manager import full_transaction
from passport.backend.core.differ.differ import diff
from passport.backend.core.serializers.eav.serialize import serialize_eav
from passport.backend.core.utils.blackbox import get_many_accounts_by_uids
from passport.backend.dbscripts import context
from passport.backend.dbscripts.hide_my_social_display_name import settings as local_settings
from passport.backend.dbscripts.utils import EntryPoint
from passport.backend.utils.common import chop


log = logging.getLogger('passport.backend.dbscripts.hide_my_social_display_name')


def run(worker_pool, uids, _file, dry_run):
    if not uids and _file:
        uids = _file

    started_at = time.time()
    processed_count = 0

    for chunk in chop(uids, worker_pool.workers_count() * settings.UIDS_IN_BLACKBOX_REQUEST):
        # Pool.map_async применяет list к параметру. Поэтому кормим её
        # кусочками поменьше.
        async_result = worker_pool.map_async(
            functools.partial(crunch, dry_run),
            chop(chunk, settings.UIDS_IN_BLACKBOX_REQUEST),
        )
        while not async_result.ready():
            async_result.wait(1)
        async_result.get()

        processed_count += len(chunk)
        in_work_time = time.time() - started_at
        perf = int(processed_count / in_work_time)
        log.info('Performance is %d accounts per second' % perf)


def crunch(dry_run, chunk):
    try:
        accounts, _ = get_accounts(chunk)
        batch_set_dont_use_displayname_as_public_name(accounts, dry_run)
    except Exception:
        log.error('Unhandled exception', exc_info=True)
        raise


def get_accounts(uids):
    userinfo_args = dict(
        attributes=[
            'person.firstname',
            'person.lastname',
            'person.dont_use_displayname_as_public_name',
        ],
        dbfields=[],
        need_aliases=True,
        need_display_name=True,
    )
    try:
        return get_many_accounts_by_uids(uids, Blackbox(), userinfo_args)
    except Exception:
        fmt_uids = ', '.join(map(str, uids))
        log.error('Error occured while processing %s' % fmt_uids, exc_info=True)
        return list(), set()


def batch_set_dont_use_displayname_as_public_name(accounts, dry_run):
    all_queries = list()
    all_uids = list()

    for account in accounts:
        with context.set({'uid': account.uid}):
            try:
                if account.person.dont_use_displayname_as_public_name:
                    continue
                if not account.is_social:
                    continue
                if not is_display_name_autogenerated(account):
                    continue
                with serialize_eav_from_context(account) as context_queries:
                    account.person.dont_use_displayname_as_public_name = True
                all_queries.extend(context_queries)
                all_uids.append(account.uid)
            except Exception:
                log.error('Error occured while processing %s' % account.uid, exc_info=True)

    if all_queries:
        try:
            execute_in_transaction(all_queries, dry_run)
        except Exception:
            fmt_uids = ', '.join(map(str, all_uids))
            log.error('Error occured while processing %s' % fmt_uids, exc_info=True)
        else:
            fmt_uids = ', '.join(map(str, all_uids))
            log.info('Set dont_use_displayname_as_public_name to %s' % fmt_uids)


def is_display_name_autogenerated(account):
    if not account.is_social:
        # Умею определять только сгенеренные имена социальщиков
        return NotImplementedError()

    display_name = account.person.display_name
    if not display_name:
        # Зыбыли запросить display_name из ЧЯ
        raise NotImplementedError()

    firstname = account.person.firstname or ''
    lastname = account.person.lastname or ''
    if firstname and lastname:
        autogenerated_name = firstname + ' ' + lastname
    elif firstname or lastname:
        autogenerated_name = firstname or lastname
    else:
        # Пропускаем, потому что нельзя доказать, что display был сгенерен.
        return False

    autogenerated_display_name = 's:%s:%s:%s' % (
        display_name.profile_id,
        display_name.provider,
        autogenerated_name,
    )
    if display_name != autogenerated_display_name:
        return False

    return True


@contextlib.contextmanager
def serialize_eav_from_context(model):
    queries = list()
    old = model.snapshot()
    yield queries
    difference = diff(old, model)
    queries.extend(serialize_eav(old, model, difference))


@full_transaction
def execute_in_transaction(eav_queries, dry_run=True):
    if dry_run:
        for query in eav_queries:
            log.debug('Dry run query %s' % repr(query))
    else:
        safe_execute_queries(eav_queries)


class Main(EntryPoint):
    SETTINGS = local_settings
    LOCK_NAME = '/passport/hide_my_social_display_name/global_lock'
    WORKER_POOL_SIZE = 9

    def run(self, args):
        run(self._worker_pool, args.uids, args.file, args.dry_run)

    def get_arg_parser(self):
        parser = super(Main, self).get_arg_parser()
        parser.add_argument('--uids', nargs='*', type=int, default=[], metavar='uid')
        parser.add_argument('--from', dest='file', metavar='path', type=ArgParseFileType('r'))
        parser.add_argument('--dry-run', action='store_true', default=False)
        return parser


main = Main()
