import logging
from datetime import datetime
from typing import List, Optional

import aioch

from maps_adv.common.helpers import dsn_parser


class ClickHouseQueryLog:
    _ch_config: dict
    _hosts: List[dict]

    def __init__(self, database_url: str, ssl_cert_file: Optional[str]):
        self._ch_config = dsn_parser.parse(database_url=database_url)
        self._hosts = self._ch_config.pop("hosts")
        if self._ch_config["secure"] and not self._ch_config["ca_certs"]:
            self._ch_config["ca_certs"] = ssl_cert_file

    def _client(self, host: str, port: int) -> aioch.Client:
        return aioch.Client(host=host, port=port, **self._ch_config)

    async def retrieve_metrics_for_queries(
        self, from_datetime: datetime, to_datetime: datetime
    ) -> List[dict]:

        sql = f"""
        SELECT arraySort(
                   arrayDistinct(
                       extractAll(
                           query,
                           '(?m:^\\\\s*--\\\\s*tag\\\\s*:\\\\s*([a-z_\\\\-0-9]+)\\\\s*;?\\\\s*$)'
                       )
                   )
               ) AS tags,
               toUInt32OrZero(extract(exception, '^Code:\\\\s*([0-9]+),.*')) AS exception_code,
               max(memory_usage) AS max_memory,
               max(query_duration_ms) AS max_duration_ms,
               min(memory_usage) AS min_memory,
               min(query_duration_ms) AS min_duration_ms,
               toInt64(avg(memory_usage)) AS avg_memory,
               toInt64(avg(query_duration_ms)) AS avg_duration_ms,
               toInt64(quantile(0.5)(memory_usage)) AS median_memory,
               toInt64(quantile(0.5)(query_duration_ms)) AS median_duration_ms
        FROM query_log
        WHERE type = 'QueryFinish'
          AND event_time BETWEEN {int(from_datetime.timestamp())} AND {int(to_datetime.timestamp())}
          AND length(tags) > 0
        GROUP BY tags, exception_code
        ORDER BY tags, exception_code
        """  # noqa: E501
        metrics = []

        for node in self._hosts:
            try:
                for (
                    tags,
                    exception_code,
                    max_memory,
                    max_duration_ms,
                    min_memory,
                    min_duration_ms,
                    avg_memory,
                    avg_duration_ms,
                    median_memory,
                    median_duration_ms,
                ) in await self._client(node["host"], node["port"]).execute(sql):
                    tags_str = ",".join(map(str, tags))
                    for (aggregate, kind, value) in (
                        ("min", "memory", min_memory),
                        ("min", "duration_ms", min_duration_ms),
                        ("max", "memory", max_memory),
                        ("max", "duration_ms", max_duration_ms),
                        ("avg", "memory", avg_memory),
                        ("avg", "duration_ms", avg_duration_ms),
                        ("median", "memory", median_memory),
                        ("median", "duration_ms", median_duration_ms),
                    ):
                        metrics.append(
                            {
                                "labels": {
                                    "host": node["host"],
                                    "port": node["port"],
                                    "exception_code": exception_code,
                                    "type": kind,
                                    "aggregate": aggregate,
                                    "tags": tags_str,
                                },
                                "type": "IGAUGE",
                                "value": value,
                            }
                        )
            except Exception as exception:
                logging.getLogger().exception(exception)

        return metrics
