# -*- coding: utf-8 -*-

import datetime
from functools import wraps
import logging

from flask.globals import request
from passport.backend.social.common.builders.billing import invalidate_billing_cache as _invalidate_billing_cache
from passport.backend.social.common.chrono import now
from passport.backend.social.common.context import request_ctx
from passport.backend.social.common.db.execute import execute as execute_with_retries
from passport.backend.social.common.db.schemas import (
    business_application_map_table as bamt,
    person_table,
    profile_table,
    sub_table,
)
from passport.backend.social.common.db.utils import (
    get_master_engine,
    get_slave_engine,
)
from passport.backend.social.common.grants import (
    filter_allowed_grants as _filter_allowed_grants,
    get_grants_config as _get_grants_config,
)
from passport.backend.social.common.limits import get_qlimits
from passport.backend.social.common.misc import (
    dump_to_json_string,
    parse_userid,
    split_scope_string,
    trim_message,
    USERID_TYPE_BUSINESS,
)
from passport.backend.social.common.provider_settings import (
    get_profile_addresses,
    providers,
)
from passport.backend.social.common.services_settings import services
from passport.backend.social.common.token.domain import Token
from passport.backend.social.common.token.utils import (
    filter_tokens_granted_to_read,
    find_all_tokens_for_profiles,
)
from passport.backend.social.common.useragent import get_http_pool_manager
from passport.backend.social.common.web_service import Response
from sqlalchemy.sql.expression import (
    and_ as sql_and,
    or_ as sql_or,
    select,
)
from werkzeug.urls import url_decode_stream
from werkzeug.wsgi import LimitedStream


logger = logging.getLogger('social.api.common')


def execute(query):
    """
    Выполнить sql запрос на slave
    """
    return execute_with_retries(get_slave_engine(), query)


def executew(query):
    """
    Выполнить sql запрос на master
    """
    return execute_with_retries(get_master_engine(), query)


def profile_dict(profile, include=None):
    provider_info = providers.get_provider_info_by_id(profile.provider_id)

    p = {
        'profile_id': profile.profile_id,
        "provider": provider_info['name'],
        "provider_code": provider_info['code'],
        "uid": profile.uid or None,
        "userid": profile.userid or None,
        "username": profile.username.encode('utf-8') or None,
        'allow_auth': bool(profile.allow_auth),
    }

    if include:
        p.update(include)

    return p


def token_dict(token):
    scopes = split_scope_string(token.scope)
    token = Token(
        token_id=token.token_id,
        uid=token.uid,
        profile_id=token.profile_id,
        application_id=token.application_id,
        value=token.value,
        secret=token.secret,
        scopes=scopes,
        expired=token.expired,
        created=token.created,
        verified=token.verified,
        confirmed=token.confirmed,
    )
    return token.to_json_dict()


def person_dict(person):
    birthday = None if person.birthday is None else str(person.birthday)
    return {
        'profile_id': person.profile_id,
        'firstname': person.firstname,
        'lastname': person.lastname,
        'nickname': person.nickname,
        'birthday': birthday,
        'gender': person.gender,
        'email': person.email,
        'phone_number': person.phone,
    }


def userid_map_dict(userid_map_item, app):
    assert not app or app.identifier == userid_map_item.application_id
    app_name = app.name if app else ''
    return {
        'business_id': userid_map_item.business_id,
        'business_token': userid_map_item.business_token,
        'application_id': userid_map_item.application_id,
        'application': app_name,
        'userid': userid_map_item.userid,
    }


def jsonify_code(data, code=200):
    response = Response(status=code, mimetype='application/json')
    response.data = dump_to_json_string(data)

    if isinstance(data, dict):
        error = data.get('error', None)
        if error:
            response.api_status = 'error'
            response.api_error_code = error.get('name', None)
            logger.debug('Response: %s' % trim_message(response.data, cut=False))

    return response


def make_error_dict(**kwargs):
    error = {
        'request_id': getattr(request, 'id', None)
    }
    error.update(kwargs)
    if not kwargs.get('description', None):
        name = kwargs.get('name')
        if name and name.endswith('-empty'):
            error['description'] = 'Parameter `%s` is required' % name.split('-')[0]
    return error


def not_found(name='profile-not-found', description=''):
    return error(name, description, code=404)


def internal_error(description='', code=500):
    return error('internal-exception', description, code=code)


def database_error():
    return internal_error('Database failed')


def application_unknown_error():
    return error(name='application-unknown')


def provider_unknown_error():
    return error(name='provider-unknown')


def error(name='', description='', code=400):
    error = make_error_dict(name=name, description=description)
    return jsonify_code(code=code, data=dict(error=error))


def get_full_sids_list(sids):
    resulting_sids = services.get_default_sids()

    for sid in sids:
        subscription_val = sid.value
        subscription_id = sid.sid
        if subscription_val == 1 and subscription_id not in resulting_sids:
            resulting_sids.append(subscription_id)
        if subscription_val == 0 and subscription_id in resulting_sids:
            resulting_sids.remove(subscription_id)

    return [{'sid': s} for s in resulting_sids]


def get_current_subscription_state(sid_query, sid):
    subscription = execute(select([sub_table]).where(sid_query)).fetchone()

    is_subscribed_by_default = sid in services.get_default_sids()
    is_currently_subscribed = (not subscription and is_subscribed_by_default) or \
                              (subscription and subscription['value'] == 1)
    return subscription, is_subscribed_by_default, is_currently_subscribed


def get_arg(arg):
    value = request.values.get(arg, '')
    if value:
        return value

    if isinstance(request.stream, LimitedStream):
        return url_decode_stream(request.stream).get(arg, '')

    return ''


def expand_provider(profile):
    provider_info = providers.get_provider_info_by_id(
        profile.provider_id, include_provider_class=False
    )
    if provider_info:
        return {
            'code': provider_info['code'],
            'id': provider_info['id'],
            'name': provider_info['name'], }
    else:
        return {}


def batch_include_tokens(profiles):
    profile_ids = [p.profile_id for p in profiles]
    tokens = find_all_tokens_for_profiles(profile_ids, get_slave_engine())
    tokens = filter_tokens_granted_to_read(
        tokens,
        filter_allowed_grants_func=filter_allowed_grants,
    )
    output = dict([(profile_id, {'tokens': []}) for profile_id in profile_ids])
    for token in tokens:
        output[token.profile_id]['tokens'].append(token.to_json_dict())

    return output


def batch_include_subs(profiles):
    profile_ids = [p.profile_id for p in profiles]
    query = (
        select(
            [sub_table],
            from_obj=[profile_table.join(sub_table)]
        ).where(
            profile_table.c.profile_id.in_(profile_ids)
        )
    )

    subs = execute(query).fetchall()
    output = dict([(profile_id, {'subscriptions': []}) for profile_id in profile_ids])
    for sid in subs:
        output[sid.profile_id]['subscriptions'].append(sid)

    for profile_id in output.keys():
        output[profile_id]['subscriptions'] = get_full_sids_list(output[profile_id]['subscriptions'])

    return output


def batch_include_person(profiles):
    profile_ids = [p.profile_id for p in profiles]
    query = select(
        [person_table],
        from_obj=[profile_table.join(person_table)]).where(profile_table.c.profile_id.in_(profile_ids))
    persons = execute(query).fetchall()
    output = {}
    for person in persons:
        output[person.profile_id] = {'person': person_dict(person)}
    return output


def batch_include_userid_map_and_addresses(profiles):
    output = dict((p.profile_id, {'userid_map': [], 'addresses': []}) for p in profiles)
    all_business_tokens = set()

    for profile in profiles:
        userid_type, business_info = parse_userid(profile.userid)
        if userid_type == USERID_TYPE_BUSINESS:
            business_id, business_token = business_info
            all_business_tokens.add((business_id, business_token))
        else:
            output[profile.profile_id]['addresses'] = get_profile_addresses(
                profile.provider_id,
                userid=profile.userid,
                username=profile.username,
                profile_id=profile.profile_id,
                uid=profile.uid,
            )

    if not all_business_tokens:
        return output

    clauses = (
        sql_and(bamt.c.business_id == business_id, bamt.c.business_token == id_)
        for business_id, id_ in all_business_tokens
    )
    query = select([bamt]).where(sql_or(*clauses))

    business_mapping_items = execute(query).fetchall()

    apps, _ = providers.get_many_applications_by_ids([i.application_id for i in business_mapping_items])
    app_id_to_app = {a.identifier: a for a in apps}

    for profile in profiles:
        userid_type, data = parse_userid(profile.userid)
        if userid_type != USERID_TYPE_BUSINESS:
            continue

        business_id, business_token = data
        for item in business_mapping_items:
            if item.business_id != business_id or item.business_token != business_token:
                continue
            app = app_id_to_app.get(item.application_id)
            output[profile.profile_id]['userid_map'].append(userid_map_dict(item, app))

    # Для всех профилей с непустым userid_map добавим addresses
    provider_id_by_profile_id = dict((p.profile_id, p.provider_id) for p in profiles)
    profile_by_profile_id = dict((p.profile_id, p) for p in profiles)

    for profile_id, data in output.iteritems():
        if not data['userid_map']:
            continue
        # Лучшее решение - взять userid для мета-приложения 0.
        # Если его нет - для основного приложения (с минимальным application_id).
        userid = sorted(data['userid_map'], key=lambda x: x['application_id'])[0]['userid']

        profile = profile_by_profile_id[profile_id]
        output[profile_id]['addresses'] = get_profile_addresses(
            provider_id_by_profile_id[profile_id],
            userid=userid,
            profile_id=profile_id,
            uid=profile.uid,
        )

    return output


def batch_expand_provider(profiles):
    logger.debug('Expanding with \'provider\'.')
    output = {}
    for profile in profiles:
        output[profile.profile_id] = {'provider': expand_provider(profile)}
    return output


BATCH_INCLUDE_FUNCTIONS = {
    'tokens': batch_include_tokens,
    'subscriptions': batch_include_subs,
    'person': batch_include_person,
    'userid_map_and_addresses': batch_include_userid_map_and_addresses,
}

BATCH_EXPAND_FUNCTIONS = {
    'provider': batch_expand_provider,
}


def batch_include_for_profile(profiles):
    res = dict([(p.profile_id, {}) for p in profiles])

    def update_result(function_dict, updaters, updater_params):
        update_params = set((updaters or '').split(','))
        update_params.add('userid_map_and_addresses')

        for update in update_params:
            update_func = function_dict.get(update.strip())
            if not update_func:
                continue
            update_results = update_func(*updater_params)
            for profile_id, data in update_results.iteritems():
                res[profile_id].update(data)

    update_result(BATCH_INCLUDE_FUNCTIONS, get_arg('include'), (profiles,))
    update_result(BATCH_EXPAND_FUNCTIONS, get_arg('expand'), (profiles,))
    return res


def get_profiles_from_db_batch(*args, **kwargs):
    rows = _select_profile_rows_from_database(*args, **kwargs)
    rows = _profile_rows_to_valid_profile_rows(rows)
    return _profile_rows_to_profile_dict_with_includes(rows)


def _select_profile_rows_from_database(query=None, clause=tuple()):
    if query is None:
        query = (
            select([profile_table])
            .where(clause)
            .order_by(profile_table.c.profile_id)
            .limit(get_qlimits()['profiles'])
        )

    profiles = execute(query).fetchall()
    return profiles or list()


def _profile_rows_to_valid_profile_rows(profiles):
    known_provider_profiles = list()
    for profile in profiles:
        if not providers.get_provider_info_by_id(profile.provider_id):
            logger.debug('Unknown provider profile ignored: provider_id = %s' % profile.provider_id)
        else:
            known_provider_profiles.append(profile)
    return known_provider_profiles


def _profile_rows_to_profile_dict_with_includes(profiles):
    profile_list = list()
    if not profiles:
        return profile_list
    includes = batch_include_for_profile(profiles)
    for profile in profiles:
        include = includes.get(profile.profile_id) or dict()
        profile_list.append(profile_dict(profile, include))
    return profile_list


def required_args(**params):
    def decorator(f):
        @wraps(f)
        def _wrapper(*args, **kwargs):
            values = {}
            request_values = request.values
            request_values.update(kwargs)
            for pname, ptype in params.items():
                param = request_values.get(pname, None)
                if param is None:
                    return error(
                        name="%s-empty" % pname,
                        description='Parameter `%s` is required' % pname)
                try:
                    values[pname] = ptype(param)
                except ValueError:
                    return error(
                        name="%s-invalid" % pname,
                        description='Parameter `%s` has bad format' % pname)

            kwargs.update(values)
            return f(*args, **kwargs)

        return _wrapper

    return decorator


def get_timestamp(name='verified', default_current_timestamp=False):
    """
    При default_current_timestamp = True считаем не переданное или нулевое
    значение текущим временем.
    """
    param = request.form.get(name)
    if param is not None:
        try:
            param = int(param)
        except ValueError:
            raise ValueError("%s-invalid" % name)

    if param in {0, 1}:
        # Не могу понять кому нужно сюда передавать 0 и 1, поэтому ничинаю
        # логировать запросы.
        logger.debug('get_timestamp(%s) called and value == %d', name, param)

    if param == 1 or (not param and default_current_timestamp):
        param = now.i()

    if param is not None:
        param = datetime.datetime.fromtimestamp(param)

    return param


def invalidate_billing_cache(uid, fail_safe=True):
    return _invalidate_billing_cache(get_http_pool_manager(), uid, fail_safe)


def get_grants_config():
    return _get_grants_config()


def filter_allowed_grants(grants):
    grants_config = get_grants_config()
    grants_config.load()
    return _filter_allowed_grants(grants_config, request_ctx.grants_context, grants)
