import collections
import logging
import json

import infra.callisto.libraries.yt as yt
import infra.callisto.protos.deploy.tables_pb2 as tables_pb2
import search.plutonium.deploy.proto.sources_pb2 as sources_pb2
import search.plutonium.deploy.proto.rescan_pb2 as rescan_pb2  # noqa
import search.plutonium.core.state.proto.state_pb2 as state_pb2


# TODO: proto description?
PlutoniumFile = collections.namedtuple('PlutoniumFile', [
    'namespace',
    'state_id',
    'file_id',
    'path',
    'hash',
    'size',
    'created_at',
    'labels'
])


class BannedGroupsStub:
    def list(self):
        return ()

    def write(self, row):
        pass


class RealtimeConfigStub:
    def get_last_version(self):
        return {}


class BannedGroups(yt.SortedYtTable):
    schema = [
        {'name': 'group', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'disabled', 'type': 'boolean'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(BannedGroups, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def list(self):
        result = set()
        for row in self._select_rows('* from [{}]'.format(self.path)):
            if row['disabled']:
                result.add(row['group'])
        return result

    def write(self, row):
        self._insert_rows([row])


class PublicStatesTable(yt.SortedYtTable):
    schema = [
        {'name': 'namespace', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'state_id', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'state', 'type': 'string'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(PublicStatesTable, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def _get_state_proto(self, state):
        proto = state_pb2.TSaaSStateProto()
        if state != 'PRESENT':
            proto.ParseFromString(state)
        return proto

    def list(self, namespaces):
        return [
            tables_pb2.TShardCtrlState(
                Namespace=row['namespace'],
                StateId=row['state_id'],
                State=self._get_state_proto(row['state']),
            )
            for row in self._select_rows(
                '* from [{}]'.format(self.path)
            )
            if row['namespace'] in namespaces
        ]

    def delete(self, states):
        if not states:
            return
        if not self._readonly:
            to_delete = [
                {'namespace': state.Namespace, 'state_id': state.StateId}
                for state in states
            ]
            self._yt_client.delete_rows(self.path, to_delete, format='json')
        else:
            _log.info('Skip %s change, RO mode', self.path)


class RuntimeMetaTable(yt.SortedYtTable):
    schema = [
        {'name': 'namespace', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'state_id', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'file_id', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'path', 'type': 'string'},
        {'name': 'hash', 'type': 'string'},
        {'name': 'size', 'type': 'uint64'},
        {'name': 'created_at', 'type': 'timestamp'},
        {'name': 'labels', 'type': 'any'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(RuntimeMetaTable, self).__init__(yt_client, path, readonly)

    @property
    def cluster(self):
        return self._yt_client.config['proxy']['url']

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def read(self, namespace, state_id):
        query = '* from [{}] where namespace = "{}" and state_id = "{}"'.format(self.path, namespace, state_id)
        return [
            PlutoniumFile(
                namespace=row['namespace'],
                state_id=row['state_id'],
                file_id=row['file_id'],
                path=row['path'],
                hash=row['hash'],
                size=row['size'],
                created_at=row['created_at'],
                labels=row['labels']
            )
            for row in self._select_rows(query)
        ]


class ChunkCtrlStatus(yt.SortedYtTable):
    schema = [
        {'name': 'Namespace', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'StateId', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'State', 'type': 'string'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(ChunkCtrlStatus, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def _row2proto(self, row):
        return tables_pb2.TGenerationStatus(
            Namespace=row['Namespace'],
            StateId=row['StateId'],
            State=tables_pb2.TGenerationState.FromString(row['State'])
        )

    def list(self, namespaces):
        request = '* from [{}] where Namespace in {}'.format(
            self.path, _get_yt_tuple(namespaces)
        )
        for row in self._select_rows(request):
            yield self._row2proto(row)

    def get(self, namespace, state_id):
        for row in self._lookup_rows([{'Namespace': namespace, 'StateId': state_id}]):
            return self._row2proto(row)

        return tables_pb2.TGenerationStatus(
            Namespace=namespace,
            StateId=state_id,
            State=tables_pb2.TGenerationState(State=tables_pb2.EGenerationState.PRESENT_GENERATION_STATE)
        )


class ChunkCtrlTopology(yt.SortedYtTable):
    schema = [
        {'name': 'Namespace', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'StateId', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'Revision', 'type': 'uint64'},
        {'name': 'ResourceSpec', 'type': 'string'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(ChunkCtrlTopology, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def get(self, states):
        lookup_keys = []
        states_timestamp = {}
        for state in states:
            lookup_keys.append({'Namespace': state.Namespace, 'StateId': state.StateId})
            states_timestamp[(state.Namespace, state.StateId)] = state.State
        return [
            tables_pb2.TChunksMapping(
                Namespace=row['Namespace'],
                StateId=row['StateId'],
                Revision=row['Revision'],
                ResourceSpec=sources_pb2.TSource.FromString(row['ResourceSpec']),
                State=states_timestamp[(row['Namespace'], row['StateId'])],
            )
            for row in self._lookup_rows(lookup_keys)
        ]


class ChunkCtrlTarget(yt.SortedYtTable):
    schema = [
        {'name': 'Namespace', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'StateId', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'State', 'type': 'string'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(ChunkCtrlTarget, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def _proto2row(self, target):
        return {
            'Namespace': target.Namespace,
            'StateId': target.StateId,
            'State': target.State.SerializeToString(),
        }

    def write(self, targets):
        if not self._readonly:
            if not targets:
                return
            with self._transaction():
                prev_targets = {
                    (row['Namespace'], row['StateId']): row['State']
                    for row in self._select_rows(
                        '* from [{}]'.format(
                            self.path
                        )
                    )
                }

                to_delete = [
                    {'Namespace': namespace, 'StateId': state_id}
                    for namespace, state_id in set(prev_targets.keys()) - set((t.Namespace, t.StateId) for t in targets)
                ]

                to_update = [
                    self._proto2row(t)
                    for t in targets
                    if prev_targets.get((t.Namespace, t.StateId), '') != t.State.SerializeToString()
                ]

                self._yt_client.delete_rows(self.path, to_delete, format='json')
                self._insert_rows(to_update)
        else:
            _log.info('Skip chunk ctrl target change, RO mode')


class SnapshotReplicatorTarget(yt.SortedYtTable):
    schema = [
        {'name': 'Stream', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'SnapshotId', 'type': 'string'},
        {'name': 'State', 'type': 'string'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(SnapshotReplicatorTarget, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    # def _proto2row(self, target):
    #     return {
    #         'Stream': target.Stream,
    #         'SnapshotId': target.SnapshotId,
    #         'State': target.State,
    #     }

    def write(self, targets):
        streams = set(t['Stream'] for t in targets)

        if not self._readonly:
            if not targets:
                return
            with self._transaction():
                prev_targets = {
                    (row['Stream'], row['SnapshotId']): row['State']
                    for row in self._select_rows(
                        '* from [{}] where Stream in {}'.format(
                            self.path, _get_yt_tuple(streams)
                        )
                    )
                }

                to_update = [
                    t
                    for t in targets
                    if prev_targets.get((t['Stream'], t['SnapshotId']), '') != t['State']
                ]

                self._insert_rows(to_update)
        else:
            _log.info('Skip snapshot replicator target change, RO mode')


class SnapshotReplicatorStatus(yt.SortedYtTable):
    schema = [
        {'name': 'Stream', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'SnapshotId', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'ReplicaSnapshotId', 'type': 'string'},
        {'name': 'State', 'type': 'string'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(SnapshotReplicatorStatus, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    # def _row2proto(self, row):
    #     return tables_pb2.TGenerationStatus(
    #         Namespace=row['Namespace'],
    #         StateId=row['StateId'],
    #         State=tables_pb2.TGenerationState.FromString(row['State'])
    #     )

    def list(self, streams):
        request = '* from [{}] where Stream in {}'.format(
            self.path, _get_yt_tuple(streams)
        )
        for row in self._select_rows(request):
            yield row

    def get(self, stream, snapshot_id):
        for row in self._lookup_rows([{'Stream': stream, 'SnapshotId': snapshot_id}]):
            return row
        return None


class PodStatus(yt.SortedYtTable):
    schema = [
        {'name': 'PodId', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'Namespace', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'LocalPath', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'ResourceState', 'type': 'string'},
        {'name': 'Annotation', 'type': 'string'},
        {'name': 'Timestamp', 'type': 'uint64'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(PodStatus, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def list(self, namespaces, pods):
        rows_limit = 500 * 1000
        last_key = collections.OrderedDict([('PodId', ''), ('Namespace', ''), ('LocalPath', '')])
        last_row = None
        while True:
            query = '* from [{}] where ({}) > {} and Namespace in {} limit {}'.format(
                self.path,
                ', '.join(last_key.keys()), tuple(last_key.values()),
                _get_yt_tuple(namespaces),
                int(rows_limit * 0.8)
            )
            for row in self._select_rows(query, input_row_limit=rows_limit, output_row_limit=rows_limit):
                last_row = row
                if row['PodId'] in pods:
                    yield tables_pb2.TPodStatus(
                        PodId=row['PodId'],
                        Namespace=row['Namespace'],
                        LocalPath=row['LocalPath'],
                        ResourceState=tables_pb2.TResourceState.FromString(row['ResourceState']),
                        Annotation=row['Annotation'],
                        Timestamp=row['Timestamp'],
                    )

            if last_row:
                new_key = {key: last_row[key] for key in last_key.keys()}
                if last_key != new_key:
                    last_key.update(new_key)
                    continue
            break


class PodTarget(yt.SortedYtTable):
    schema = [
        {'name': 'PodId', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'Namespace', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'LocalPath', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'ResourceSpec', 'type': 'string'},
        {'name': 'ResourceLabels', 'type': 'string'},
        {'name': 'DownloadPolicy', 'type': 'string'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(PodTarget, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def _row2proto(self, row):
        return tables_pb2.TPodTarget(
            PodId=row['PodId'],
            Namespace=row['Namespace'],
            LocalPath=row['LocalPath'],
            ResourceSpec=sources_pb2.TSource.FromString(row['ResourceSpec']),
            ResourceLabels=rescan_pb2.TResourceLabels.FromString(row.get('ResourceLabels', '')),
            DownloadPolicy=tables_pb2.TDownloadPolicy.FromString(row.get('DownloadPolicy', '') or ''),
        )

    def _proto2row(self, target):
        return {
            'PodId': target.PodId,
            'Namespace': target.Namespace,
            'LocalPath': target.LocalPath,
            'ResourceSpec': target.ResourceSpec.SerializeToString(),
            'ResourceLabels': target.ResourceLabels.SerializeToString(),
            'DownloadPolicy': target.DownloadPolicy.SerializeToString(),
        }

    def lookup(self, keys):
        for row in self._yt_client.lookup_rows(self.path, keys):
            yield self._row2proto(row)

    def _list(self, namespace):
        rows_limit = 500 * 1000
        last_key = collections.OrderedDict([('PodId', ''), ('Namespace', namespace), ('LocalPath', '')])
        last_row = None
        while True:
            query = 'PodId, Namespace, LocalPath from [{}] where ({}) > {} and Namespace = "{}" limit {}'.format(
                self.path,
                ', '.join(last_key.keys()), tuple(last_key.values()),
                namespace,
                int(rows_limit * 0.8)  # ? fails with full limit
            )
            for row in self._select_rows(query, input_row_limit=rows_limit, output_row_limit=rows_limit):
                last_row = row
                yield row

            if last_row and last_key != last_row:
                last_key.update(last_row)
            else:
                break

    def update(self, targets):
        namespaces = set(t.Namespace for t in targets)
        assert len(namespaces) == 1
        namespace = namespaces.pop()

        new_targets = {
            (
                ('PodId', t.PodId),
                ('Namespace', t.Namespace),
                ('LocalPath', t.LocalPath),
            ): t
            for t in targets
        }

        _log.debug('new target prepared')

        to_delete = []
        for row in self._list(namespace):
            key = (
                ('PodId', row['PodId']),
                ('Namespace', row['Namespace']),
                ('LocalPath', row['LocalPath'])
            )
            if key not in new_targets:
                to_delete.append(row)
            else:
                del new_targets[key]

        _log.debug('targets are filtered')
        _log.debug('add: %s, remove: %s', len(new_targets), len(to_delete))

        if new_targets and not self._readonly:
            batch_size = 100 * 1000
            batch = []
            for target in new_targets.itervalues():
                batch.append(self._proto2row(target))
                if len(batch) >= batch_size:
                    self._insert_rows(batch)
                    batch = []
            if batch:
                self._insert_rows(batch)
        elif new_targets:
            _log.info('Skip update pod targets for namespace %s due to RO mode', namespace)

        if to_delete and not self._readonly:
            batch_size = 100 * 1000
            for begin in xrange(0, len(to_delete), batch_size):
                self._yt_client.delete_rows(
                    self._path,
                    to_delete[begin:begin + batch_size],
                    format='json'
                )
        elif to_delete:
            _log.info('Skip update pod targets for namespace %s due to RO mode', namespace)


class ShardCtrlState(yt.SortedYtTable):
    schema = [
        {'name': 'Namespace', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'StateId', 'type': 'string', 'sort_order': 'ascending'},
        {'name': 'Status', 'type': 'string'},
        {'name': 'Target', 'type': 'string'},
        {'name': 'Demand', 'type': 'string'},
        {'name': 'State', 'type': 'string'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(ShardCtrlState, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def _row2proto(self, row):
        return tables_pb2.TShardCtrlState(
            Namespace=row['Namespace'],
            StateId=row['StateId'],
            Status=tables_pb2.TShardCtrlState.TStateEvents.FromString(row['Status']),
            Target=tables_pb2.TShardCtrlState.TStateEvents.FromString(row['Target']),
            Demand=tables_pb2.TShardCtrlState.TDemand.FromString(row['Demand']),
            State=state_pb2.TSaaSStateProto.FromString(row['State']) if ('State' in row) and (row['State'] not in [None, 'PRESENT']) else state_pb2.TSaaSStateProto(),
        )

    def _proto2row(self, state):
        return {
            'Namespace': state.Namespace,
            'StateId': state.StateId,
            'Status': state.Status.SerializeToString(),
            'Target': state.Target.SerializeToString(),
            'Demand': state.Demand.SerializeToString(),
            'State': state.State.SerializeToString(),
        }

    def get(self, namespace, state_id):
        for row in self._lookup_rows([{'Namespace': namespace, 'StateId': state_id}]):
            return self._row2proto(row)

    def lookup(self, states):
        return map(
            self._row2proto,
            self._lookup_rows(map(
                lambda x: {'Namespace': x.Namespace, 'StateId': x.StateId},
                states
            ))
        )

    def list(self, namespaces):
        return [
            self._row2proto(row)
            for row in self._select_rows(
                '* from [{}]'.format(self.path)
            )
            if row['Namespace'] in namespaces
        ]

    def get_last_states(self, namespaces):
        if not namespaces:
            return []

        query = 'Namespace, max(StateId) as StateId from [{}] where Namespace in {} group by Namespace'.format(
            self.path, _get_yt_tuple(namespaces)
        )

        keys = self._select_rows(query)
        return [
            self._row2proto(row)
            for row in self._lookup_rows(keys)
        ]

    def update(self, states):
        prev_states = {
            (row['Namespace'], row['StateId'], row['Status'], row['Target'], row['Demand'])
            for row in self._lookup_rows([
                {'Namespace': state.Namespace, 'StateId': state.StateId}
                for state in states
            ])
        }
        _log.debug('prev states are read')
        changed_states = []
        for state in states:
            row = self._proto2row(state)
            if (row['Namespace'], row['StateId'], row['Status'], row['Target'], row['Demand']) not in prev_states:
                changed_states.append(row)
        _log.debug('changed states are collected')

        if not self._readonly and changed_states:
            self._insert_rows(changed_states)
        elif self._readonly:
            _log.info('Skip %s update, RO mode', self.path)

    def delete(self, states):
        if not states:
            return
        if not self._readonly:
            to_delete = [
                {'Namespace': state.Namespace, 'StateId': state.StateId}
                for state in states
            ]
            self._yt_client.delete_rows(self.path, to_delete, format='json')
        else:
            _log.info('Skip remove from %s, RO mode', self.path)


class RealtimeConfig(yt.SortedYtTable):
    schema = [
        {'name': 'Timestamp', 'type': 'uint64', 'sort_order': 'ascending'},
        {'name': 'Data', 'type': 'string'},
    ]

    def __init__(self, yt_client, path, readonly=True, tablet_cell_bundle='cajuper'):
        self.tablet_cell_bundle = tablet_cell_bundle
        super(RealtimeConfig, self).__init__(yt_client, path, readonly)

    def _on_init_hook(self):
        self.ensure_table()
        self.ensure_mounted()

    def get_last_version(self):
        query = 'Timestamp, Data from [{}] order by Timestamp desc limit 1'.format(
            self.path
        )

        try:
            keys = self._select_rows(query)
            _log.debug(keys)
            if len(keys) != 1:
                raise Exception("no configs available")
            return json.loads(keys[0]["Data"])
        except Exception as e:
            _log.warning("realtime config problem: {}".format(e))
            return {}


def _get_yt_tuple(lst):
    return tuple(lst if len(lst) > 1 else list(lst) * 2)

_log = logging.getLogger(__name__)
