# coding: utf-8

from base64 import urlsafe_b64encode
import getpass
import os
import struct
import time

import paramiko
import requests

from .base import VaultCLIBaseCommand


class OAuthError(Exception):
    def __init__(self, msg, *args, **kwargs):
        super(OAuthError, self).__init__('OAuth. ' + msg, *args, **kwargs)


class OAuthCommand(VaultCLIBaseCommand):
    """
    To obtain the OAuth token for the Vault API via ssh key
    """
    # Алгоритм получения токена — https://wiki.yandex-team.ru/oauth/token/#granttypesshkey
    # Секрет для токена храним открыто, потому что получаем токен для человека, а не приложения.

    OAUTH_GET_TOKEN_URL = 'https://oauth.yandex-team.ru/token'
    OAUTH_CLIENT_ID = 'ce68fbebc76c4ffda974049083729982'
    OAUTH_CLIENT_PWD = 'b6cce2b79f784fa38033ff69f5061d53'

    def auth_args(self):
        self.add_base_argument('rsa_agent_key_num')
        self.add_base_argument('rsa_agent_key_hash')
        self.add_base_argument('rsa_private_key_file')
        self.add_base_argument('rsa_login')

    def process(self, cli_args, client, debug=False, *args, **kwargs):
        rsa_login, rsa_auth = self._get_rsa_auth_from_args(cli_args)
        rsa_login = rsa_login or os.environ.get('SUDO_USER') or getpass.getuser()
        code = self._fetch_oauth_token(rsa_auth, rsa_login)
        self.echo(code)

    def _extract_sign(self, signature_string):
        parts = []
        if isinstance(signature_string, paramiko.message.Message):
            signature_string = signature_string.asbytes()
        while signature_string:
            len = struct.unpack('>I', signature_string[:4])[0]
            bits = signature_string[4:len + 4]
            parts.append(bits)
            signature_string = signature_string[len + 4:]
        return parts[1]

    def _make_ssh_sign(self, rsa_key, rsa_login, ts):
        data = '{}{}{}'.format(ts, self.OAUTH_CLIENT_ID, rsa_login)
        sign = self._extract_sign(rsa_key.sign_ssh_data(data))
        return urlsafe_b64encode(sign)

    def _fetch_oauth_token(self, rsa_auth, rsa_login):
        ts = int(time.time())

        for rsa_key in rsa_auth():
            r = requests.post(
                self.OAUTH_GET_TOKEN_URL,
                data={
                    'grant_type': 'ssh_key',
                    'client_id': self.OAUTH_CLIENT_ID,
                    'client_secret': self.OAUTH_CLIENT_PWD,
                    'login': rsa_login,
                    'ts': ts,
                    'ssh_sign': self._make_ssh_sign(rsa_key, rsa_login, ts),
                },
                timeout=30,
            )

            if r.status_code < 500:
                rv = r.json()
                if rv.get('error'):
                    if rv['error'] == 'invalid_grant':
                        continue
                    else:
                        raise OAuthError(rv.get('error_description', rv['error']))
                else:
                    return rv['access_token']
            else:
                raise OAuthError('HTTP error {}'.format(r.status_code))

        raise OAuthError('SSH sign is not valid or ssh-keys not found')
