import logging
import os
from _struct import Struct

# noinspection PyUnresolvedReferences
import travel.proto.resourcestorage.resource_pb2 as resource_pb2
# noinspection PyUnresolvedReferences
import travel.trains.search_api.api.tariffs.tariffs_pb2 as tariffs_pb2
from travel.library.python.s3_client import S3Client


class SnapshotReader:

    __size_fmt__ = '<i'

    def __init__(self, s3_client: S3Client):
        self.s3_client = s3_client
        self.size_struct = Struct(self.__size_fmt__)

    def get_tariffs(self, prefix: str) -> [tariffs_pb2.DirectionTariffTrain]:
        meta_path = os.path.join(prefix, 'resource.pb')
        logging.info(f'Getting meta from {meta_path}')
        meta_raw = self.s3_client.read(meta_path)
        meta = resource_pb2.ResourceMeta()
        meta.ParseFromString(meta_raw)

        tariffs_path = os.path.join(meta.key, meta.version, 'data.pb')
        logging.info(f'Getting tariffs from {tariffs_path}')

        tariffs_reader = self.s3_client.get_reader(tariffs_path)

        message_size = self._get_message_size(tariffs_reader)
        while message_size:
            tariff_raw = tariffs_reader.read(message_size)
            tariff = tariffs_pb2.DirectionTariffInfo()
            tariff.ParseFromString(tariff_raw)
            yield tariff

            message_size = self._get_message_size(tariffs_reader)

    def _get_message_size(self, reader) -> int:
        message_size = 0
        raw_size = reader.read(self.size_struct.size)
        if raw_size:
            message_size = self.size_struct.unpack(raw_size)[0]
        return message_size
