import datetime
import logging

import aniso8601
import google.protobuf.text_format as text_format

import infra.callisto.libraries.yt as yt_utils
import infra.callisto.controllers.utils.nested_obj as nested_obj


class TargetTable(yt_utils.SortedYtTable):
    def __init__(self, yt_client, path, configuration_class, readonly=True):
        self._conf_class = configuration_class
        self.schema = [
            {'name': 'time', 'type': 'string', 'sort_order': 'ascending'},
            {'name': 'slot_id', 'type': 'string'},
        ] + nested_obj.yt_schema(configuration_class)

        super(TargetTable, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def slots_ids(self):
        request = 'slot_id from [{}] group by slot_id'.format(self.path)
        return {row['slot_id'] for row in self._select_rows(request)}

    def _top_n_rows(self, slot_id, n):
        keys_request = 'revision, max(time) from [{}] where slot_id = "{}" group by revision'.format(
            self.path, slot_id
        )
        keys = sorted([row['max(time)'] for row in self._select_rows(keys_request)], reverse=True)[:n]
        if not keys:
            return []
        request = '* from [{}] where slot_id = "{}" and time in ({})'.format(
            self.path, slot_id, yt_utils.wrap_strings(keys),
        )
        return self._select_rows(request)

    def top_n(self, slot_id, n):
        return [
            self._conf_class.load_json_flat(row)
            for row in self._top_n_rows(slot_id, n)
        ]

    def head_target_and_time(self, slot_id):
        for row in self._top_n_rows(slot_id, 1):
            return self._conf_class.load_json_flat(row), aniso8601.parse_datetime(row['time'])
        return None, None

    def push(self, slot_id, conf):
        head = self._head(slot_id)
        if head != conf:
            assert not head or conf.revision >= head.revision
            if self._readonly:
                _log.debug('ro mode, push %s %s', slot_id, conf)
            else:
                self._write(slot_id, conf)

    def _head(self, slot_id):
        rows = self.top_n(slot_id, 1)
        return rows[0] if rows else None

    def _write(self, slot_id, conf):
        conf = dict(conf.dump_json_flat(), slot_id=slot_id, time=datetime.datetime.now().isoformat())
        self._insert_rows([conf])


class ConfigsTable2(yt_utils.SortedYtTable):
    schema = [
        {'name': 'slot_id', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'time',    'type': 'string', 'sort_order': 'ascending'},
        {'name': 'configs', 'type': 'string'},
    ]
    default_config_path = 'all/{short_host}:{port}.cfg'

    def __init__(self, yt_client, path, proto, readonly=True):
        super(ConfigsTable2, self).__init__(yt_client, path, readonly)
        self._proto = proto

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def slot_ids(self):
        request = 'slot_id from [{}] group by slot_id'.format(self.path)
        return [row['slot_id'] for row in self._select_rows(request)]

    def slot_configs(self, slot_id):
        request = '* from [{}] where slot_id = "{}" order by time desc limit 1'.format(self.path, slot_id)
        for row in self._select_rows(request):
            return text_format.Parse(row['configs'], self._proto())

    def update(self, slot_id, configs):
        self._insert_rows({
            'slot_id': slot_id,
            'time': datetime.datetime.now().isoformat(),
            'configs': text_format.MessageToString(configs)
        })


_log = logging.getLogger(__name__)
