# -*- coding: utf-8 -*-
import os
import ydb
from concurrent.futures import TimeoutError
import json
from collections import Counter


def select_simple(session_pool, full_path):
    def callee(session):
        # new transaction in serializable read write mode
        # if query successfully completed you will get result sets.
        # otherwise exception will be raised
        result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
            """
            PRAGMA TablePathPrefix("{}");
            SELECT call_object
            FROM calls
            WHERE call_state in ("CREATED", "ACCEPTED");
            """.format(full_path),
            commit_tx=True,
        )
        return result_sets[0]
    return session_pool.retry_operation_sync(callee)


def sanitize_metric_name(name):
    return name.replace(' ', '_')

def get_stat(rows):
    counter = Counter()
    for row in rows:
        counter['conferences'] += 1
        call_object = json.loads(row['call_object'])
        participants = call_object.get('participants')
        if not participants:
            continue

        participants_count = 0
        for participant in participants:
            if not participant.get('is_joined'):
                continue
            participants_count += 1
            device_info = participant.get('device_info', {})
            platform = device_info.get('platform', 'UNKNOWN')
            device_type = device_info.get('device_type', 'UNKNOWN')
            network_type = device_info.get('network_type', 'UNKNOWN')
            os_version = "%s_%s" % (platform, device_info.get('os_version', 'UNKNOWN'))

            counter['participant_platform_%s' % platform] += 1
            counter['participant_device_type_%s' % device_type] += 1
            counter['participant_network_type_%s' % network_type] += 1
            counter['participant_os_version_%s' % os_version] += 1

        counter['participants'] += participants_count
        counter['conference_sizes_%d' % participants_count] += 1

    stats = []
    for k, v in counter.items():
        metric_name = sanitize_metric_name("mediator_dbstat_%s_txxx" % k)
        stats.append([metric_name, v])
    return stats


def run(endpoint, database):
    driver_config = ydb.DriverConfig(endpoint, database, credentials=ydb.construct_credentials_from_environ())
    with ydb.Driver(driver_config) as driver:
        try:
            driver.wait(timeout=5)
        except TimeoutError:
            print("Connect failed to YDB")
            print("Last reported errors by discovery:")
            print(driver.discovery_debug_details())
            exit(1)

        with ydb.SessionPool(driver, size=10) as session_pool:
            r = select_simple(session_pool, database)
    return get_stat(r.rows)
