# -*- coding: utf-8 -*-

import logging
from re import compile as re_compile
from urlparse import (
    urlparse,
    urlsplit,
)

from passport.backend.core.logging_utils.helpers import mask_sessionid
from passport.backend.social.broker.exceptions import (
    AccessTokenError,
    ApplicationUnknownError,
    ConsumerUnknownError,
    HostInvalidError,
    PkceCodeInvalidError,
    PkceMethodInvalidError,
    PkceVerifierInvalidError,
    ProviderUnknownError,
    RetpathInvalidError,
    SessionInvalidError,
    TaskIdInvalidError,
    UserIpInvalidError,
)
from passport.backend.social.broker.misc import (
    check_retpath,
    check_url,
    RE_TLD,
    SSO_PASSPORT_PUSH_GRAMMAR,
)
from passport.backend.social.common import validators
from passport.backend.social.common.context import request_ctx
from passport.backend.social.common.misc import (
    split_scope_string,
    urlparse_qs,
)
from passport.backend.social.common.pkce import (
    fix_pkce_method_case,
    is_valid_pkce_method,
    PKCE_METHOD_PLAIN,
)
from passport.backend.social.common.provider_settings import providers
from passport.backend.social.common.providers.Facebook import Facebook
from passport.backend.social.common.social_config import social_config
from passport.backend.social.common.task import build_provider_for_task
from passport.backend.social.common.useragent import Url


logger = logging.getLogger('social.broker.args_processor')

TRUE_VALUES = ['1', 'y', 'true', 'yes', 't']

PKCE_MAX_LENGTH = 128
PKCE_MIN_LENGTH = 43

PARAM_FORBIDDEN_LETTERS_RE = re_compile('[^A-Za-z0-9_,-]+')

MORDA_COM_HOSTNAME = re_compile(r'^(www\.yandex\.com|yandex\.com)$')


class ArgsProcessor(object):
    def __init__(self, args, form, log_actions=True):
        self.args = args
        self.form = form
        self.log_actions = log_actions

        self.processed_args = dict()

    def _process_provider_info(self, provider_info):
        self.processed_args['provider'] = build_provider_for_task(
            code=provider_info['code'],
            name=provider_info['name'],
            id=provider_info['id'],
        )

        # Отключаем требование имени и/или фамилии для некоторых провайдеров
        is_firstname_optional = provider_info['code'] in ('tg')
        is_lastname_optional = provider_info['code'] in ('tw', 'tg')
        if is_firstname_optional:
            self.processed_args['provider']['is_firstname_optional'] = is_firstname_optional
        if is_lastname_optional:
            self.processed_args['provider']['is_lastname_optional'] = is_lastname_optional

    def _log_selected_application(self):
        if self.log_actions:
            if self.processed_args['provider']:
                logger.debug(
                    'Selected provider_name = "%s", app_name = "%s"' % (
                        self.processed_args['provider'],
                        self.processed_args['application'].name,
                    ),
                )
            else:
                logger.debug('Selected app_name = "%s"' % self.processed_args['application'].name)

    def get_bool(self, field, from_form=False):
        if not from_form:
            args = self.args
        else:
            args = self.form
        value = args.get(field, '').strip()
        value = value.lower() in TRUE_VALUES
        self.processed_args[field] = value

    def process_yandex_token(self):
        token = self.form.get('token', '').strip()
        self.processed_args['oauth_token'] = token

    def process_frontend_url(self):
        frontend_url = self.form.get('frontend_url', '').strip()
        if not frontend_url:
            raise SessionInvalidError('"frontend_url" POST parameter has to be passed')
        if not frontend_url.endswith('/'):
            frontend_url += '/'
        if self.log_actions:
            logger.debug('frontend_url = "%s"', frontend_url)
        self.processed_args['frontend_url'] = frontend_url

    def process_scope(self):
        scope = self.args.get('scope')
        self.processed_args['scope'] = split_scope_string(scope)

    def process_provider_token(self):
        """
        provider_token и provider_token_secret (опционально) - токен соц сети
        """
        self.processed_args['provider_token'] = self.form.get('provider_token', '').strip()
        self.processed_args['provider_token_secret'] = self.form.get('provider_token_secret', '').strip()
        self.processed_args['scope'] = self.form.get('scope', '').strip()

        if not self.processed_args['provider_token']:
            raise AccessTokenError('provider_token is empty')

    def process_sid(self):
        sid = self.args.get('sid', '').strip()
        try:
            sid = int(sid)
        except ValueError:
            sid = None
        self.processed_args['sid'] = sid

    def process_place(self):
        place = self.args.get('place', 'query').strip()
        available = ['fragment', 'query']
        if place and place not in available:
            place = ''
        if self.log_actions:
            logger.debug('place = "%s"' % place)
        self.processed_args['place'] = place

    def create_retry_url(self, frontend_url, task_id):
        if self.log_actions:
            logger.debug('Creating a retry_url...')

        self.processed_args['retry_url'] = '%s%s/retry' % (frontend_url, task_id)
        if self.log_actions:
            logger.debug('retry_url = "%s"', self.processed_args['retry_url'])

    def process_hostname_and_tld(self, frontend_url):
        parsed = urlparse(frontend_url)
        hostname = parsed.hostname
        if not hostname:
            raise HostInvalidError('Can not get hostname from "%s"' % frontend_url)
        if self.log_actions:
            logger.debug('hostname = "%s"', hostname)

        res = RE_TLD.match(hostname)
        if res:
            self.processed_args['tld'] = res.groups()[0]
        else:
            raise HostInvalidError('Can not get a valid TLD from host "%s"', hostname)

        if any('\\' in i for i in (parsed.netloc, parsed.username, parsed.password) if i):
            raise HostInvalidError('Prohibited symbol "\\" has been found in "%s"' % frontend_url)

        self.processed_args['hostname'] = hostname
        if self.log_actions:
            logger.debug('hostname=%s', hostname)

    def process_consumer(self):
        consumer = self.args.get('consumer', '').strip()
        if not consumer:
            raise ConsumerUnknownError('"consumer" GET parameter has to be passed')

        self.processed_args['consumer'] = consumer

        if self.log_actions:
            logger.debug('consumer = "%s"', consumer)

    def process_application_name(self, provider_is_required=True):
        app_name = self.args.get('application_name', '').strip()
        if not app_name:
            raise ApplicationUnknownError('Application missing')
        app = providers.get_application_by_name(app_name)
        if not app:
            raise ApplicationUnknownError("Application doesn't exist")
        self.processed_args['application'] = app
        request_ctx.application = app

        if not app.provider and provider_is_required:
            raise ProviderUnknownError('Provider of the application is unknown')

        if app.provider:
            self._process_provider_info(app.provider)
        else:
            self.processed_args['provider'] = None

        self._log_selected_application()

    def process_app_and_provider(self, tld):
        """
        Создаем объекты приложения и провайдера.

        Если они оба не указаны — ошибка.

        Указан только провайдер — приложение берем дефолтное.

        Указано только приложение — провайдера берем из приложения.

        Указаны оба — можно указать идентификатор приложения и из пространства
        имён провайдера, и из простарнства имём Социализма. Проверяем, что указано
        приложение от этого провайдера, иначе ошибка.
        """
        provider_code = self.args.get('provider') or ''
        provider_code = provider_code.strip()

        app_name = self.args.get('application') or ''
        app_name = app_name.strip()

        if not provider_code and not app_name:
            raise SessionInvalidError('"provider" or "application" GET parameter required')

        # Получаем provider_info на основе параметров запроса
        provider_info = None
        if provider_code:
            if provider_code not in providers.providers:
                raise ProviderUnknownError("Provider doesn't exist")
            provider_info = providers.providers[provider_code]

        # Получаем app_info на основе параметров запроса
        if app_name:
            app = providers.get_application_by_name(app_name)
            if not app and provider_info:
                key = (provider_info['id'], app_name)
                app = providers.get_application_by_provider_app_id(*key)

            if not app:
                raise ApplicationUnknownError("App doesn't exist")
            if provider_code:
                if app.provider['code'] != provider_code:
                    raise ApplicationUnknownError('Inconsistent data: provider_code=%s and application=%s' %
                                                  (provider_code, app_name))
            else:
                # Получаем provider_info по app_info
                provider_info = app.provider
        else:
            # Если еще нет app_info, получаем его на основе provider_code (который тут гарантированно есть)
            app = providers.get_application_for_provider(provider_code, tld)
            if not app:
                raise ApplicationUnknownError("Failed to get application for provider=%s, tld=%s" % (provider_code, tld))

        self.processed_args['application'] = app
        request_ctx.application = app

        if not provider_info:
            raise ProviderUnknownError('Provider of the application is unknown')

        self._process_provider_info(provider_info)
        self._log_selected_application()

    def process_retpath(self):
        retpath = self.args.get('retpath', '').strip()
        if not retpath:
            raise RetpathInvalidError('"retpath" GET parameter required')

        if self._is_sso_passport_push_url(retpath):
            sso_retpath = urlsplit(retpath)
            sso_retpath = urlparse_qs(sso_retpath.query).get('retpath')
            if not sso_retpath:
                raise RetpathInvalidError('Invalid netloc: "%s"' % retpath)
            retpath = sso_retpath[0]

        check_retpath(retpath, social_config.allowed_retpath_schemes)

        self.processed_args['retpath'] = retpath
        if self.log_actions:
            logger.debug('retpath = "%s"' % retpath)

    def _is_sso_passport_push_url(self, url):
        try:
            check_url(
                [SSO_PASSPORT_PUSH_GRAMMAR],
                invalid_hosts=None,
                url=url,
            )
            return True
        except RetpathInvalidError:
            return False

    def fix_morda_retpath(self):
        retpath = self.processed_args.get('retpath')
        if not retpath:
            return
        parsed_url = Url(retpath)
        if not MORDA_COM_HOSTNAME.match(parsed_url.hostname):
            return

        if self.log_actions:
            logger.debug('Fix yandex.com retpath')
        parsed_url.add_params([('redirect', '0')])
        self.processed_args['retpath'] = str(parsed_url)

    def process_display(self):
        self.processed_args['display'] = self.args.get('display', '').strip()

    def process_session_id(self, enable_logging=True):
        self.processed_args['Session_id'] = self.form.get('Session_id', '').strip()
        if enable_logging:
            logger.debug('%s = "%s"' % ('Session_id', mask_sessionid(self.processed_args['Session_id'])))

    def process_user_ip(self, enable_logging=True, required=True):
        self.processed_args['user_ip'] = self.form.get('user_ip', '').strip()
        if required and not self.processed_args['user_ip']:
            raise UserIpInvalidError('"user_ip" POST parameter has to be passed')
        if enable_logging:
            logger.debug('%s = "%s"' % ('user_ip', self.processed_args['user_ip']))

    def process_yandexuid(self, enable_logging=True):
        self.processed_args['yandexuid'] = self.form.get('yandexuid', '').strip()
        if enable_logging:
            logger.debug('%s = "%s"' % ('yandexuid', self.processed_args['yandexuid']))

    def process_yandex_auth_code(self):
        code = self.args.get('yandex_auth_code')
        if code is not None:
            code = code.strip()
        code = None if code == 'null' else code
        self.processed_args['yandex_auth_code'] = code

    def process_pkce(self, required=True):
        code = self.args.get('code_challenge', '').strip()
        method = self.args.get('code_challenge_method', PKCE_METHOD_PLAIN).strip()
        if not code:
            if required:
                raise PkceCodeInvalidError()
            code = None
            method = None
        else:
            if not (PKCE_MIN_LENGTH <= len(code) <= PKCE_MAX_LENGTH):
                raise PkceCodeInvalidError()
            if not is_valid_pkce_method(method):
                raise PkceMethodInvalidError()
            method = fix_pkce_method_case(method)
        self.processed_args['code_challenge'] = code
        self.processed_args['code_challenge_method'] = method

    def process_task_id(self):
        task_id = self.form.get('task_id', '').strip()
        try:
            self.processed_args['task_id'] = validators.TaskId().to_python(task_id)
        except validators.Invalid:
            raise TaskIdInvalidError()

    def process_pkce_verifier(self, required=True):
        verifier = self.form.get('code_verifier', '').strip()
        if (required and not verifier) or len(verifier) > PKCE_MAX_LENGTH:
            raise PkceVerifierInvalidError()
        self.processed_args['code_verifier'] = verifier

    def process_login_hint(self):
        login_hint = self.args.get('login_hint', '').strip()
        self.processed_args['login_hint'] = login_hint or None

    def process_ui_language(self):
        ui_language = self.form.get('ui_language', '').strip()
        self.processed_args['ui_language'] = ui_language or None

    def process_passthrough_errors(self):
        self._process_list('passthrough_errors')

    def process_experiments(self):
        self._process_list('experiments')

    def _process_list(self, arg_name):
        values = self.args.get(arg_name, '').strip()
        if values:
            values = PARAM_FORBIDDEN_LETTERS_RE.sub('', values)
            values = values.split(',')
            values = {e.strip().lower() for e in values}
            values = sorted([e for e in values if e])
        self.processed_args[arg_name] = values or None

    def process_query(self):
        query = self.form.get('query', '').strip()
        try:
            query = urlparse_qs(query)
        except ValueError:
            query = dict()
        for key in query:
            query[key] = query[key][0]
        self.processed_args['query'] = query

    def process_flags(self):
        self._process_set('flags')

    def _process_set(self, arg_name):
        self._process_list(arg_name)
        if self.processed_args[arg_name]:
            self.processed_args[arg_name] = set(self.processed_args[arg_name])

    def fix_not_yandex_facebook_app(self):
        provider_code = self.args.get('provider') or ''
        provider_code = provider_code.strip()

        app_name = self.args.get('application') or ''
        app_name = app_name.strip()

        if not (provider_code == Facebook.code and app_name in social_config.not_yandex_facebook_client_ids):
            return

        if self.log_actions:
            logger.debug('Fix %s application' % social_config.not_yandex_facebook_client_ids[app_name])
        self.args['application'] = social_config.general_facebook_client_id

    def process_user_param(self):
        self.processed_args['user_param'] = self.args.get('user_param', '').strip() or None


def _does_look_like_native_app_id(provider_id, app_id):
    return (
        provider_id in {1, 2} and app_id.isdigit() or
        provider_id == 5 and app_id.endswith('.apps.googleusercontent.com')
    )
