#!/usr/bin/python

import time
from hashlib import sha256

import requests
import ticket_parser2 as tp2
from ticket_parser2.low_level import ServiceContext as TVMServiceContext
from ticket_parser2.exceptions import TicketParsingException as TVMTicketParsingException

VMAGENT_ID = 2000468
VMPROXY_ID = 2000469
LOCAL_TOKEN_PATH = '/tmp/vmagent.token'


class AuthError(Exception):
    pass


class TVMAuthContext(object):
    def __init__(self, client_id, secret=None, secret_file=None):

        self._secret = secret
        self._client_id = client_id

        if secret_file:
            with open(secret_file, 'r') as f:
                self._secret = f.read()

        with open("/dev/urandom", 'r') as f:
            self._local_token = sha256(f.read(64)).hexdigest()

        with open(LOCAL_TOKEN_PATH, 'w') as f:
            f.write(self._local_token)

        self._keys = None
        self._keys_ts = None

        self._refresh_keys()
        self._ticket_cache = {}

    def _refresh_keys(self):

        now = time.time()

        if self._keys and self._keys_ts and now < 3600 + self._keys_ts:
            return

        url = 'https://tvm-api.yandex.net/2/keys?lib_version={version}'
        self._keys = requests.get(url.format(version=tp2.__version__)).content

        self._keys_ts = time.time()
        self._svc_ctx = TVMServiceContext(self._client_id, self._secret, self._keys)

    def get_request_ticket(self, dst_id):
        # vmproxy
        self._refresh_keys()
        ts = int(time.time())
        ticket, tts = self._ticket_cache.get(dst_id, (None, None))

        if not ticket or not tts or ts > tts + 3540:
            ticket = requests.post(
                'https://tvm-api.yandex.net/2/ticket',
                data={
                    'grant_type': 'client_credentials',
                    'src': self._client_id,
                    'dst': dst_id,
                    'ts': ts,
                    'sign': self._svc_ctx.sign(ts, dst_id)
                }
            ).json()[str(dst_id)]["ticket"]
            tts = ts

            self._ticket_cache[dst_id] = (ticket, tts)

        return ticket

    def verify_request_ticket(self, ticket_body, src_id):
        # vmagent
        self._refresh_keys()
        try:
            ticket = self._svc_ctx.check(ticket_body)
        except TVMTicketParsingException as e:
            raise AuthError('Cannot validate TVM ticket: {} , {}'
                            .format(e.message, e.debug_info))

        if ticket.src == src_id:
            return

        raise AuthError("Unauthorized try")

    def verify_local_token(self, token):
        if self._local_token != token:
            raise AuthError('Invalid local token supplied')

    def extract_tvm_auth(self, headers):
        tvm_hdr = headers.getheader('x-ya-service-ticket', None)
        local_token = headers.getheader('local-token', None)

        if tvm_hdr or local_token:
            return tvm_hdr, local_token

        raise AuthError('No TVM info supplied')
