# -*- coding: utf-8 -*-
import httpx

from collections import namedtuple
from tvm2.aio.thread_tvm2 import TVM2 as AsyncTVM2
from tvm2.sync.thread_tvm2 import TVM2
from tvmauth import BlackboxTvmId

from .conf import AppTypeSettings
from .utils import import_from_string


_ServiceTicket = namedtuple('_ServiceTicket', ['src'])


class MockAsyncTVM2(AsyncTVM2):

    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

    async def get_service_ticket(self, tvm_client_id):
        return f'serv:{tvm_client_id}'

    async def parse_service_ticket(self, service_ticket):
        return _ServiceTicket(src=service_ticket[5:])


class MockTVM2(TVM2):
    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

    def get_service_ticket(self, tvm_client_id):
        return f'serv:{tvm_client_id}'

    def parse_service_ticket(self, service_ticket):
        return _ServiceTicket(src=service_ticket[5:])


_async_tvm2_class = None
_tvm2_class = None


def get_async_tvm2_class(settings: AppTypeSettings) -> TVM2:
    global _async_tvm2_class
    if _async_tvm2_class is None:
        _async_tvm2_class = import_from_string(settings.async_tvm2_class_name)
    return _async_tvm2_class


def get_tvm2_class(settings: AppTypeSettings) -> TVM2:
    global _tvm2_class
    if _tvm2_class is None:
        _tvm2_class = import_from_string(settings.tvm2_class_name)
    return _tvm2_class


def get_async_tvm2_client(settings: AppTypeSettings, blackbox_name: str = None) -> TVM2:
    async_tvm2_class = get_async_tvm2_class(settings)
    assert async_tvm2_class is not None
    blackbox_name = blackbox_name or settings.blackbox_name
    return async_tvm2_class(
        client_id=settings.tvm2_client_id,
        secret=settings.tvm2_client_secret,
        blackbox_client=BlackboxTvmId[blackbox_name],
        destinations=settings.get_tvm2_destinations(),
        allowed_clients=settings.tvm2_allowed_client_ids,
    )


def get_tvm2_client(settings: AppTypeSettings, blackbox_name: str = None) -> TVM2:
    tvm2_class = get_tvm2_class(settings)
    assert tvm2_class is not None
    blackbox_name = blackbox_name or settings.blackbox_name
    return tvm2_class(
        client_id=settings.tvm2_client_id,
        secret=settings.tvm2_client_secret,
        blackbox_client=BlackboxTvmId[blackbox_name],
        destinations=settings.get_tvm2_destinations(),
        allowed_clients=settings.tvm2_allowed_client_ids,
    )


async def get_async_service_ticket(settings: AppTypeSettings, tvm2_client_id: int, blackbox_name: str = None):
    client = get_async_tvm2_client(settings, blackbox_name)
    ticket = await client.get_service_ticket(tvm2_client_id)
    if not ticket:
        raise ValueError(f'Cannot get service ticket for {tvm2_client_id}')
    return ticket


def get_service_ticket(settings: AppTypeSettings, tvm2_client_id: int, blackbox_name: str = None):
    client = get_tvm2_client(settings, blackbox_name)
    ticket = client.get_service_ticket(tvm2_client_id)
    if not ticket:
        raise ValueError(f'Cannot get service ticket for {tvm2_client_id}')
    return ticket


class HttpxTvmAuth(httpx.Auth):
    header = 'X-Ya-Service-Ticket'

    def __init__(self, settings: AppTypeSettings, tvm2_client_id: int, blackbox_name: str = None):
        self.settings = settings
        self.tvm2_client_id = tvm2_client_id
        self.blackbox_name = blackbox_name

    def auth_flow(self, request):
        header = get_service_ticket(self.settings, self.tvm2_client_id, self.blackbox_name)
        request.headers[self.header] = header
        yield request

    async def async_auth_flow(self, request):
        header = await get_async_service_ticket(self.settings, self.tvm2_client_id, self.blackbox_name)
        request.headers[self.header] = header
        yield request
