"""
Functions for working with clickhouse data
"""

QUANTILES = ('50', '75', '80', '85', '90', '95', '98', '99', '100')


def _get_responses_per_second(
    client, db_name, job_id, job_date, job_start, job_end, job_scheme_type
):
    """
    Get the list of responses per second for the job.
    :param client: clickhouse_driver.Client
    :param db_name: clickhouse db name
    :type db_name: str
    :param job_id: job n
    :type job_id: int
    :param job_date: job start date
    :type job_date: datetime.datetime.date
    :param job_start: job start time (unix timestamp)
    :type job_start: int
    :param job_end: job end time (unix timestamp), can be empty
    :type: job_end: int
    :param job_scheme_type: Job scheme type, can be 'reqps' or 'instances'
    :param job_scheme_type: str
    :return: list of responses per second, 0 for empty values
    :rtype: list of int
    """
    sql = """
    select intDiv(toUInt32(time), {compress_ratio})*{compress_ratio} as t, avg({scheme})
        from {db_name}.rt_microsecond_details_buffer
        where job_id={job}
            and job_date=toDate('{job_date}')
            and tag=''
            and time >= toDateTime('{start}')
    """
    if job_end not in [None, 0]:
        sql += " and time <= toDateTime('{end}')"
    sql += """
            group by t
            order by t;
    """

    fetched_data = client.execute(
        sql.format(
            scheme='reqps' if job_scheme_type == 'rps' else 'instances',
            db_name=db_name, compress_ratio=1,
            job=job_id, job_date=job_date, start=job_start, end=job_end
        )
    )

    return [value[1] or 0 for value in fetched_data]


def _get_codes_set(client, db_name, job_id, job_date, code_type):
    """
    Get the list of codes for the job.
    :param client: clickhouse_driver.Client
    :param db_name: clickhouse db name
    :type db_name: str
    :param job_id: job n
    :type job_id: int
    :param job_date: job start date
    :type job_date: datetime.datetime.date
    :param code_type: net or proto
    :type code_type: str
    :return: list of codes, can be empty
    :rtype: list of strings
    """
    sql = '''select distinct toUInt32(code) from {db_name}.{code_type}_codes_buffer
        where job_id={job} and job_date=toDate('{job_date}')'''.format(
        db_name=db_name, job=job_id, job_date=job_date, code_type=code_type
    )
    result = client.execute(sql) or []
    return [c[0] for c in result]


def get_cases(client, db_name, job_id):
    """
    Return sorted list of cases
    """
    sql = """select distinct tag from {db_name}.rt_microsecond_details_buffer
    where job_id={job} order by tag""".format(db_name=db_name, job=job_id)
    response = client.execute(sql)
    cases = [case[0] for case in response]
    cases.append('')
    return cases


def get_scheme_type(client, db_name, job_id):
    nonzero_reqps = client.execute(
        """select any(job_id)
            from {db_name}.rt_microsecond_details_buffer
            where job_id={job}
                and reqps!=0
        """.format(db_name=db_name, job=job_id)
    )
    if nonzero_reqps:
        return 'rps'

    nonzero_threads = client.execute(
        """
        select any(job_id)
            from {db_name}.rt_microsecond_details_buffer
            where job_id={job}
                and threads!=0
        """.format(db_name=db_name, job=job_id)
    )
    if nonzero_threads:
        return 'instances'

    # Для случаев ручной заливки пхаута (?)
    return 'rps'


def get_instances_data(client, db_name, job_id, job_date, job_start, job_end, job_scheme_type,
                       _responses_per_second=None):
    """NB: no cases!"""
    data = {'ts': [], 'threads': []}
    sql = '''
    select intDiv(toUInt32(time), {compress_ratio})*{compress_ratio} as t, round(avg(threads), 3) as data
        from {db_name}.rt_microsecond_details_buffer
            where job_id={job}
                and job_date=toDate('{job_date}')
                and tag=''
                and time >= toDateTime('{start}')'''
    if job_end not in [None, 0]:
        sql += " and time <= toDateTime('{end}')"
    sql += " group by t order by t;"
    resp = client.execute(
        sql.format(
            compress_ratio=1, job=job_id, job_date=job_date,
            start=job_start, end=job_end, db_name=db_name
        )
    ) or []
    for v in resp:
        data['ts'].append(v[0])
        data['threads'].append(v[1])
    if not _responses_per_second:
        _responses_per_second = _get_responses_per_second(
            client=client, db_name=db_name,
            job_id=job_id, job_date=job_date,
            job_start=job_start, job_end=job_end, job_scheme_type=job_scheme_type
        )
    data['responses_per_second'] = _responses_per_second
    return {'data': data}


def get_quantiles_data(client, db_name, job_id, job_date, job_start, job_end, job_cases, job_scheme_type,
                       _responses_per_second=None):
    """
    Get quantiles data from clickhouse
    """
    data = {'cases': {}}

    sql = '''
        select t, mq50, mq75, mq80, mq85, mq90, mq95, mq98, mq99, mq100 from
            (
                select intDiv(toUInt32(time), {compress_ratio}) as t
                from {db_name}.rt_microsecond_details_buffer
                where job_id={job}
                             and job_date=toDate('{job_date}')
                             and tag = ''
                             and time >= toDateTime('{start}')'''
    if job_end not in [None, 0]:
        sql += " and time <= toDateTime('{end}')"
    sql += ''' group by t
                order by t
            )
            all left join
            (
                select
                    round(max(q50), 3) as mq50,
                    round(max(q75), 3) as mq75,
                    round(max(q80), 3) as mq80,
                    round(max(q85), 3) as mq85,
                    round(max(q90), 3) as mq90,
                    round(max(q95), 3) as mq95,
                    round(max(q98), 3) as mq98,
                    round(max(q99), 3) as mq99,
                    round(max(q100), 3) as mq100,
                    intDiv(toUInt32(time), 1) as t
                from {db_name}.rt_quantiles_buffer
                where job_id={job}
                    and tag='{tag}'
                    and job_date=toDate('{job_date}')
                    and time >= toDateTime('{start}')'''
    if job_end not in [None, 0]:
        sql += " and time <= toDateTime('{end}')"
    sql += " group by t order by t) using t;"

    for tag in job_cases:
        tag_key = 'overall' if tag == '' else tag
        query = sql.format(db_name=db_name,
                           job=job_id, job_date=job_date, start=job_start, end=job_end,
                           compress_ratio=1, tag=tag)
        fetched_data = client.execute(query)
        transponed_data = list(zip(*fetched_data))
        quantiles = ('ts',) + QUANTILES
        data['cases'][tag_key] = {qu: transponed_data[quantiles.index(qu)] if transponed_data else [] for qu in
                                  quantiles if qu != 'ts'}
    if transponed_data:
        data['ts'] = transponed_data[quantiles.index('ts')]
    if not _responses_per_second:
        _responses_per_second = _get_responses_per_second(
            client, db_name, job_id, job_date, job_start, job_end, job_scheme_type
        )
    data['responses_per_second'] = _responses_per_second

    return {'data': data}


def get_net_codes_data(client, db_name, job_id, job_date, job_start, job_end, job_cases, job_scheme_type,
                       _responses_per_second=None):
    data = {'cases': {}}
    codes = _get_codes_set(client, db_name, job_id, job_date, code_type='net')

    sql = """
    select t, data from (
        select intDiv(toUInt32(time), {compress_ratio})*{compress_ratio} as t
            from {db_name}.rt_microsecond_details_buffer
            where job_id={job}
            and job_date=toDate('{job_date}')
            and tag = ''
            and time >= toDateTime('{start}')
    """
    if job_end not in [None, 0]:
        sql += " and time <= toDateTime('{end}')"
    sql += '''
            group by t
            order by t
        ) all left join (
        select round(sum(cnt)/{compress_ratio}, 3) as data,
                intDiv(toUInt32(time), {compress_ratio})*{compress_ratio} as t
            from {db_name}.net_codes_buffer
            where job_id={job}
            and job_date=toDate('{job_date}')
            and tag='{tag}'
            and code={code}
            and time >= toDateTime('{start}')
    '''
    if job_end not in [None, 0]:
        sql += " and time <= toDateTime('{end}')"
    sql += '''
            group by t
            order by t
            )
            using t
    '''
    # TODO: move all loops to sql part
    result = []
    for tag in job_cases:
        tag_key = 'overall' if tag == '' else tag
        data['cases'][tag_key] = {}
        for code in codes:
            result = client.execute(
                sql.format(
                    db_name=db_name, job=job_id, job_date=job_date,
                    start=job_start, end=job_end,
                    tag=tag, compress_ratio=1, code=code
                )
            )
            data['cases'][tag_key][str(code)] = [v[1] for v in result]
    data['ts'] = [v[0] for v in result]
    if not _responses_per_second:
        _responses_per_second = _get_responses_per_second(
            client, db_name, job_id, job_date, job_start, job_end, job_scheme_type
        )
    data['responses_per_second'] = _responses_per_second

    return {'data': data}


def get_proto_codes_data(client, db_name, job_id, job_date, job_start, job_end, job_cases, job_scheme_type,
                         _responses_per_second=None):
    data = {'cases': {}}
    codes = _get_codes_set(client=client, db_name=db_name, job_id=job_id, job_date=job_date, code_type='proto')

    sql = '''
    select t, data from (
        select intDiv(toUInt32(time), {compress_ratio})*{compress_ratio} as t
            from {db_name}.rt_microsecond_details_buffer
            where job_id={job}
            and job_date=toDate('{job_date}')
            and tag = ''
            and time >= toDateTime('{start}')
            '''
    if job_end not in [None, 0]:
        sql += " and time <= toDateTime('{end}')"
    sql += '''
        group by t
            order by t
        ) all left join (
        select round(sum(cnt)/{compress_ratio}, 3) as data,
               intDiv(toUInt32(time), {compress_ratio})*{compress_ratio} as t
            from {db_name}.proto_codes_buffer
            where job_id={job}
            and job_date=toDate('{job_date}')
            and tag='{tag}'
            and code={code}
            and time >= toDateTime('{start}')
            '''
    if job_end not in [None, 0]:
        sql += " and time <= toDateTime('{end}')"
    sql += '''
            group by t
            order by t
            )
            using t
    '''
    # TODO: move all loops to sql part
    response = []
    for tag in job_cases:
        tag_key = 'overall' if tag == '' else tag
        data['cases'][tag_key] = {}
        for code in codes:
            response = client.execute(
                sql.format(
                    db_name=db_name,
                    job=job_id, job_date=job_date,
                    start=job_start, end=job_end,
                    tag=tag, compress_ratio=1, code=code
                )
            )
            data['cases'][tag_key][str(code)] = [v[1] for v in response]
    data['ts'] = [v[0] for v in response]
    if not _responses_per_second:
        _responses_per_second = _get_responses_per_second(
            client, db_name, job_id, job_date, job_start, job_end, job_scheme_type
        )
    data['responses_per_second'] = _responses_per_second
    return {'data': data}


def get_comparison_quantiles_data(client, tests_ids, metrics_names):
    """
    Get quantiles data for comparison from clickhouse
    """
    metrics_names = set(metrics_names)
    unknown_metrics = metrics_names - set(QUANTILES)
    if unknown_metrics:
        raise ValueError('Unknown metrics: {unknown_metrics}'.format(unknown_metrics=unknown_metrics))

    quantiles_outer_select = ', '.join('metric{quantile}'.format(quantile=quantile) for quantile in metrics_names)
    quantiles_inner_select = ', '.join(
        'round(max(q{quantile}), 3) as metric{quantile}'.format(quantile=quantile)
        for quantile in metrics_names
    )

    sql = '''
        select job_id as test_id, tag, t, {quantiles_outer_select} from
            (
                select job_id, intDiv(toUInt32(time), 1) as t
                from rt_microsecond_details_buffer
                where job_id in %(tests_ids)s
                             and tag = ''
                group by job_id, t
                order by job_id, t
            ) as all
            left join
            (
                select
                    job_id, tag, intDiv(toUInt32(time), 1) as t,
                    {quantiles_inner_select}
                from rt_quantiles_buffer
                where job_id in %(tests_ids)s
                group by job_id, tag, t
                order by job_id, tag, t
            ) as quantiles
            using job_id, t;'''

    query = sql.format(
        quantiles_outer_select=quantiles_outer_select,
        quantiles_inner_select=quantiles_inner_select,
    )
    fetched_data = client.query_dataframe(query, {'tests_ids': tests_ids})
    return fetched_data.melt(
        id_vars=['test_id', 'tag', 't'],
        var_name="metric_name",
        value_name="metric_value",
    )


def get_comparison_codes_data(client, tests_ids, metrics_names, chart_type):
    if chart_type == 'NET_CODES':
        db_table = 'net_codes_buffer'
    elif chart_type == 'PROTO_CODES':
        db_table = 'proto_codes_buffer'
    else:
        raise ValueError('Unknown chart_type')
    sql = """
    select job_id as test_id, tag, t, code as metric_name, data as metric_value from (
        select job_id, intDiv(toUInt32(time), 1) as t
            from rt_microsecond_details_buffer
            where job_id in %(tests_ids)s
                 and tag = ''
            group by job_id, t
            order by job_id, t
        ) as all
        left join
        (
            select
                job_id,
                tag,
                code,
                intDiv(toUInt32(time), 1) as t,
                round(sum(cnt), 3) as data
            from {db_table}
            where job_id in %(tests_ids)s
                and code in %(metrics_names)s
            group by job_id, tag, t, code
            order by job_id, tag, t, code
        ) as codes_data
        using t
    """

    query = sql.format(db_table=db_table)
    return client.query_dataframe(query, {'tests_ids': tests_ids, 'metrics_names': metrics_names})


def get_comparison_instances_data(client, tests_ids, metrics_names):
    if metrics_names != ['instances']:
        raise ValueError('Unknown metrics')

    sql = """
    select
        job_id as test_id,
        intDiv(toUInt32(time), 1) as t,
        round(avg(threads), 3) as metric_value
    from rt_microsecond_details_buffer
        where job_id in %(tests_ids)s
            and tag=''
    group by job_id, t
    order by job_id, t;"""
    data = client.query_dataframe(sql, {'tests_ids': tests_ids})
    data['metric_name'] = 'instances'
    data['tag'] = ''
    return data
