from dataclasses import dataclass
from datetime import timedelta
from typing import Literal, Optional, Union
from urllib.parse import urljoin

from lxml import etree
from smb.common.http_client import collect_errors

from smb.common.aiotvm import HttpClientWithTvm, TvmClient

__all__ = ["MDSClient", "MDSInstallation", "UploadFileResult", "mds_installations"]


@dataclass(frozen=True)
class MDSInstallation:
    outer_read_url: str
    inner_read_url: str
    write_url: str


mds_installations = {
    "testing": MDSInstallation(
        outer_read_url="https://storage.mdst.yandex.net:443",
        inner_read_url="https://storage-int.mdst.yandex.net:443",
        write_url="http://storage-int.mdst.yandex.net:1111",
    ),
    "production": MDSInstallation(
        outer_read_url="https://storage.mds.yandex.net:443",
        inner_read_url="https://storage-int.mds.yandex.net:443",
        write_url="http://storage-int.mds.yandex.net:1111",
    ),
}


@dataclass(frozen=True)
class UploadFileResult:
    __slots__ = ("root_url", "response", "namespace")

    root_url: str
    response: etree.XML
    namespace: str

    @property
    def download_link(self) -> str:
        return urljoin(
            self.root_url,
            f"get-{self.namespace}/{self.response.attrib['key']}?disposition=1",
        )


class MDSClient(HttpClientWithTvm):
    __slots__ = ("installation", "namespace", "inner")

    installation: MDSInstallation
    namespace: str
    inner: bool

    def __init__(
        self,
        installation: Literal["testing", "production"],
        namespace: str,
        tvm_client: TvmClient,
        tvm_destination: Union[str, int],
        tvm_source: Optional[Union[str, int]] = None,
        inner: bool = False,
    ):
        super().__init__(
            url=mds_installations[installation].write_url,
            tvm=tvm_client,
            tvm_source=tvm_source,
            tvm_destination=tvm_destination,
        )
        self.installation = mds_installations[installation]
        self.namespace = namespace
        self.inner = inner

    @collect_errors
    async def upload_file(
        self,
        *,
        file_content: bytes,
        file_name: Optional[str] = None,
        expire: Optional[timedelta] = None,
        timeout: Optional[int] = None,
    ) -> UploadFileResult:
        timeout = dict(timeout=timeout) if timeout else dict()
        expire_params = (
            dict(expire=f"{int(expire.total_seconds())}s") if expire else dict()
        )

        post_uri = f"/upload-{self.namespace}"
        metric_name = "/upload-{namespace}"
        if file_name:
            post_uri += f"/{file_name}"
            metric_name += "/{file_name}"

        response = await self.request(
            method="POST",
            uri=post_uri,
            expected_statuses=[200],
            data=file_content,
            params=expire_params,
            metric_name=metric_name,
            **timeout,
        )

        return UploadFileResult(
            root_url=self.read_url,
            namespace=self.namespace,
            response=etree.XML(response),
        )

    @property
    def write_url(self) -> str:
        return self.installation.write_url

    @property
    def read_url(self) -> str:
        return (
            self.installation.inner_read_url
            if self.inner
            else self.installation.outer_read_url
        )
