import typing

import time
from cachetools import TTLCache
from clickhouse_driver import Client
from load.projects.lunaparkapi.handlers import report_data as rd
from load.projects.lunaparkapi.handlers.report_data.monitoring_data import get_monitoring_data
from load.projects.lunaparkapi.handlers.report_data.report_data import _get_responses_per_second

from load.projects.cloud.loadtesting.config import ENV_CONFIG
from load.projects.cloud.loadtesting.db.connection import get_clickhouse_client
from load.projects.cloud.loadtesting.db.tables import JobTable
from yandex.cloud.priv.loadtesting.v1 import tank_job_pb2 as messages
from yandex.cloud.priv.loadtesting.v2 import test_pb2 as test_messages


CHARTS = {
    'QUANTILES': {
        'ru': {
            'name': 'Квантили времен ответов',
            'description': ''  # FIXME
        },
        'en': {
            'name': 'Response time quantiles',
            'description': ''  # FIXME
        }
    },
    'INSTANCES': {
        'ru': {
            'name': 'Тестирующие потоки для всего теста',
            'description': ''  # FIXME
        },
        'en': {
            'name': 'Instances for the whole test',
            'description': ''  # FIXME
        }
    },
    'NET_CODES': {
        'ru': {
            'name': 'Сетевые коды ответов',
            'description': ''  # FIXME
        },
        'en': {
            'name': 'Net response codes',
            'description': ''  # FIXME
        }
    },
    'PROTO_CODES': {
        'ru': {
            'name': 'HTTP коды ответов',
            'description': ''  # FIXME
        },
        'en': {
            'name': 'HTTP response codes',
            'description': ''  # FIXME
        }
    },
}


def get_job_cases(job: JobTable, client: Client) -> typing.List[str]:
    cases = rd.get_cases(client=client, db_name=ENV_CONFIG.CLICKHOUSE_DBNAME, job_id=job.n)
    cases = sorted(set(cases))
    if cases != job.cases:
        job.cases = cases
    return cases


def get_job_scheme_type(job: JobTable, client: Client, cache=TTLCache(maxsize=1024, ttl=360)) -> str:
    if from_cache := cache.get(job.n):
        return from_cache
    scheme = rd.get_scheme_type(client=client, db_name=ENV_CONFIG.CLICKHOUSE_DBNAME, job_id=job.n)
    cache[job.n] = scheme
    return scheme


def get_job_end_time(job: JobTable) -> typing.Optional[int]:
    if job.finished_at:
        return int(time.mktime(job.finished_at.timetuple()))


def chart_data_to_message(
    job: JobTable, data: dict, chart_type: str, instances: typing.List[float],
    name: str, description: str
) -> messages.TankChart:
    rps = [int(r) for r in data.get('data', {}).get('responses_per_second', [])]
    cases = data.get('data', {}).get('cases', {})
    cases_data = []
    for case, metric in cases.items():
        for key, value in metric.items():
            metric_data_message = messages.MetricData(case_name=case, metric_name=key, metric_value=value)
            cases_data.append(metric_data_message)

    return messages.TankChart(
        chart_type=chart_type,
        job_id=job.id,
        name=name,
        description=description,
        ts=data.get('data', {}).get('ts', []),
        responses_per_second=rps,
        threads=[int(t) for t in instances],
        cases_data=cases_data
    )


class _Charts:
    def __init__(self, client, job: JobTable, lang):
        self.job = job
        self.lang = lang
        self.client = client

        self.rd_kwargs = dict(
            client=client,
            db_name=ENV_CONFIG.CLICKHOUSE_DBNAME,
            job_id=job.n,
            job_date=job.started_at.date(),
            job_start=int(time.mktime(job.started_at.timetuple())),
            job_end=get_job_end_time(self.job),
            job_scheme_type=get_job_scheme_type(self.job, self.client)
        )
        self.instances_data = rd.get_instances_data(**self.rd_kwargs)
        self.resp_per_second = [int(t) for t in _get_responses_per_second(**self.rd_kwargs)]
        self.rd_kwargs.update(dict(
            _responses_per_second=self.resp_per_second,
            job_cases=get_job_cases(self.job, self.client)
        ))

        self.threads = [int(t) for t in self.instances_data.get('data', {}).get('threads', [])]

    def get(self):
        return [self._get_chart_message(chart_type,
                                        CHARTS[chart_type][self.lang]['name'],
                                        CHARTS[chart_type][self.lang]['description'], )
                for chart_type in CHARTS]

    def _get_chart_message(self, chart_type: str, name: str, description: str):
        if chart_type == 'QUANTILES':
            data = rd.get_quantiles_data(**self.rd_kwargs)
        elif chart_type == 'NET_CODES':
            data = rd.get_net_codes_data(**self.rd_kwargs)
        elif chart_type == 'PROTO_CODES':
            data = rd.get_proto_codes_data(**self.rd_kwargs)
        elif chart_type == 'INSTANCES':
            data = self.instances_data
            data['data'].update({'cases': {
                'overall': {'instances': self.threads}
            }})
        else:
            raise ValueError(f'unknown chart type: {chart_type}')

        return self.data_to_proto(data, chart_type, name, description)

    def data_to_proto(self, data, chart_type, name, description):
        return chart_data_to_message(self.job, data, chart_type, self.threads, name, description)


class MonitoringCharts(_Charts):
    def __init__(self, client, job, lang):
        super().__init__(client, job, lang)
        self.data = self._get_all_data()
        self.host_list = self._get_host_list()
        self.monitoring_charts_list = self._get_monitoring_charts_list()

    def _get_all_data(self, from_time=None, to_time=None):
        return get_monitoring_data(self.client, self.job, from_time, to_time)

    def _get_host_list(self):
        return self.data.target_host.unique().tolist() if not self.data.empty else []

    def _get_monitoring_charts_list(self):
        return list(self.data.groupby(['target_host', 'metric_type']).groups) if not self.data.empty else []

    def _data_to_proto(self, host, metric_type, name: str = '', description: str = ''):
        data = self.data.loc[(self.data.target_host == host) & (self.data.metric_type == metric_type)]
        metric_data = []
        for name in data.metric_name.unique():
            filtered_data = data.loc[data.metric_name == name]
            metric_data.append(test_messages.MetricData(
                metric_name=name,
                metric_value=filtered_data.value.to_list()
            ))
        chart = test_messages.MonitoringChart(
            monitored_host=host,
            test_id=self.job.id,
            name=metric_type,
            description=description,
            ts=data.ts.unique().tolist(),
            responses_per_second=self.resp_per_second,
            threads=[int(t) for t in self.threads],
            metric_data=metric_data,
            x_axis_label='',
            y_axis_label=''
        )
        return chart

    def get(self):
        return [self._data_to_proto(host, metric_type) for host, metric_type in self.monitoring_charts_list]


def get_test_charts(job: JobTable, lang):
    with get_clickhouse_client() as client:
        return _Charts(client, job, lang).get()


def get_monitoring_charts(job: JobTable, lang: str):
    with get_clickhouse_client() as client:
        return MonitoringCharts(client, job, lang).get()
