# encoding: utf-8
from __future__ import unicode_literals

import json
import logging
import os
from base64 import b64encode
from urllib import quote, unquote
from urlparse import urlparse

from ipaddress import ip_address   # FIXME https://paste.yandex-team.ru/10059442
from tornado import gen, httpclient, web
from ylog.context import log_context

from intranet.webauth.lib import settings
from intranet.webauth.lib.auth_request_handler import (
    LOGIN_HEADER,
    UID_HEADER,
)
from intranet.webauth.lib.authorizer import Authorizer
from intranet.webauth.lib.cache_storage import CacheClient
from intranet.webauth.lib.crypto_utils import (
    decrypt_signed_value,
    get_hmac,
)
from intranet.webauth.lib.utils import request_log_context

networks_cache_storage = CacheClient()
logger = logging.getLogger(__name__)

NETWORKS_INFO_CACHE_TIME = 3600
NETWORKS_INFO_HOST = 'hbf.yandex.net'
NETWORKS_INFO_API = '/macros/'
NETWORKS_INFO_TIMEOUT = 2


class CheckOAuthTokenHandler(web.RequestHandler):
    # disables response caching
    def compute_etag(self):
        return None

    def get_state_string(self):
        raw_string = os.urandom(16)
        return quote(b64encode(raw_string))

    def get_domain(self):
        retpath = self.request.headers.get(b'Webauth-Retpath')
        if not retpath:
            return None

        return urlparse(retpath).netloc

    @staticmethod
    def parse_network(raw_network):
        raw_ip, raw_mask = raw_network.split('/')
        ip = ip_address(raw_ip).packed
        mask = ip_address(raw_mask).packed
        return ip, mask

    @staticmethod
    def masked_ip(ip, mask):
        return b''.join(chr(ord(x) & ord(m)) for (x, m) in zip(ip, mask))

    @classmethod
    def ip_in_network(cls, ip, network):
        network_ip, mask = network
        if len(ip) != len(mask):
            return False  # IPv4 != IPv6
        return cls.masked_ip(ip, mask) == cls.masked_ip(network_ip, mask)

    @gen.coroutine
    def check_yamoney_bypass(self):
        raw_ip = Authorizer(headers=self.request.headers).get_user_ip()
        if not raw_ip:
            raise gen.Return(False)
        ip = ip_address(unicode(raw_ip)).packed

        domain = self.get_domain()
        if not domain:
            raise gen.Return(False)

        rt_macroses = settings.WEBAUTH_NETWORK_WHITELISTS.get(domain)
        if not rt_macroses:
            raise gen.Return(False)

        networks = []
        for macros in rt_macroses:
            url = (
                'https://{}{}{}?trypo_format=dottedquad'.format(
                    NETWORKS_INFO_HOST,
                    NETWORKS_INFO_API,
                    macros,
                )
            )
            data = networks_cache_storage.get(url)
            if data is None:
                request = httpclient.HTTPRequest(
                    url=url,
                    connect_timeout=NETWORKS_INFO_TIMEOUT,
                    request_timeout=NETWORKS_INFO_TIMEOUT,
                    ca_certs=settings.WEBAUTH_ROOT_CERT_LOCATION,
                    validate_cert=settings.WEBAUTH_HTTP_CLIENT_VALIDATE_SSL,
                )
                response = None
                try:
                    http_client = httpclient.AsyncHTTPClient()
                    response = yield http_client.fetch(request)
                    data = json.loads(response.body)
                    networks_cache_storage.set(url, data, NETWORKS_INFO_CACHE_TIME)
                except Exception as err:
                    logger.error('Bad HBF response (%s):  %s', url, response if response else err)
                    continue
            try:
                networks.extend([self.parse_network(address) for address in data])
            except Exception:
                logger.exception('Exception while parsing HBF response (%s)', data)

        for network in networks:
            if self.ip_in_network(ip, network):
                raise gen.Return(True)

        raise gen.Return(False)

    @gen.coroutine
    def get(self):
        with log_context(**request_log_context(self.request)):
            logger.info('Incoming request')

            yamoney_status = yield self.check_yamoney_bypass()
            if yamoney_status:
                logger.info('Accepted request from whitelisted network to whitelisted host')
                self.set_status(200)
                self.accept()
                raise gen.Return()

            token_header = self.request.headers.get('Webauth-Authorization', default=None)
            if token_header is None:
                encrypted_token = self.get_cookie(settings.WEBAUTH_CLIENT_TOKEN_COOKIE, default=None)
                if encrypted_token is not None:
                    token = decrypt_signed_value(unquote(encrypted_token))
                    if token is not None:
                        token_header = 'OAuth %s' % token

            if token_header is None:
                self.set_status(401)
                logger.debug('Denied request. (No login) Reasons: No OAuth token in headers or cookies')
                self.decline(['No OAuth token in headers or cookies'])
                raise gen.Return()

            auth = Authorizer.from_request(self.request,
                                           forced_domain='internal',
                                           scopes_to_check=[settings.WEBAUTH_OAUTH_SCOPE],
                                           simulated=True, )
            auth.query_arguments['required'] = 'token'
            auth.query_arguments.pop('optional', None)
            auth.query_arguments['idm_role'] = self.request.headers.get('Webauth-Idm-Role', default='')
            auth.headers['Authorization'] = token_header
            auth.headers.pop('Webauth-Authorization', None)
            result, info = yield auth.get()
            if result:
                self.set_status(200)
                self.accept(*info)
            else:
                self.set_status(403 if len(info) == 3 else 401)
                self.decline(*info)

    def accept(self, login='', uid=''):
        self.set_header(b'Content-type', 'text/plain')
        self.set_header(LOGIN_HEADER, login)
        self.set_header(UID_HEADER, uid)
        self.set_header(b'Connection', 'Close')

        with log_context(login=login):
            logger.info('Accepted request')

        self.finish(b'Auth completed\n')

    def decline(self, reasons, login='', uid=''):
        self.set_header(b'Webauth-Oauth-App-Id', settings.WEBAUTH_OAUTH_APPLICATION_ID)
        csrf_token = self.get_state_string()
        self.set_header(b'Webauth-Csrf-Token', csrf_token)

        retpath = self.request.headers.get(b'Webauth-Retpath', default='')
        csrf_state = quote(quote(retpath) + ':' + quote(get_hmac(csrf_token)))
        self.set_header(b'Webauth-Csrf-State', csrf_state)

        reasons_string = '; '.join(reasons)
        self.set_header(b'Webauth-Denial-Reasons', reasons_string)

        with log_context(login=login, reasons=reasons):
            logger.info('Denied request')

        self.finish()
