from argparse import ArgumentParser
import logging
import os

from google.protobuf.pyext._message import ScalarMapContainer, MessageMapContainer

from travel.hotels.lib.python3.cli.cli import auto_progress_reporter
from travel.hotels.lib.python3.yt import ytlib
from travel.hotels.lib.python3.yt.versioned_path import VersionedPath, StandardCleanupStrategy
from travel.library.python.tools import replace_args_from_env
from travel.library.python.dicts import aviaalliance_repository, carrier_repository, country_repository, \
    currency_repository, district_repository, pointsynonym_repository, region_repository, station_repository, \
    station_code_repository, station_express_alias_repository, settlement_repository, station_to_settlement_repository, \
    supplier_repository, thread_repository, thread_station_repository, thread_tariff_repository, timezone_repository, \
    transport_model_repository, train_tariff_info_repository, readable_timezone_repository


FORMAT = '%(asctime)-15s | %(levelname)-4.4s | %(name)-12.12s | %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)


RASP_DICTS = {
    'aviaalliance': aviaalliance_repository.AviaAllianceRepository,
    'carrier': carrier_repository.CarrierRepository,
    'country': country_repository.CountryRepository,
    'currency': currency_repository.CurrencyRepository,
    'district': district_repository.DistrictRepository,
    'pointsynonym': pointsynonym_repository.PointSynonymRepository,
    'region': region_repository.RegionRepository,
    'station': station_repository.StationRepository,
    'station_code': station_code_repository.StationCodeRepository,
    'station_express_alias': station_express_alias_repository.StationExpressAliasRepository,
    'settlement': settlement_repository.SettlementRepository,
    'station_to_settlement': station_to_settlement_repository.StationToSettlementRepository,  # ? TRAVEL_DICT_RASP_SETTLEMENT_TO_STATION_PROD
    'supplier': supplier_repository.SupplierRepository,
    'thread': thread_repository.ThreadRepository,
    'thread_station': thread_station_repository.ThreadStationRepository,
    'thread_tariff': thread_tariff_repository.ThreadTariffRepository,
    'timezone': timezone_repository.TimezoneRepository,
    'transport_model': transport_model_repository.TransportModelRepository,
    'train_tariff_info': train_tariff_info_repository.TrainTariffInfoRepository,
    'readable_timezone': readable_timezone_repository.ReadableTimezoneRepository,
}


RASP_DICTS_INDEX_FIELDS = {
    'station_express_alias': 'Alias',
    'station_to_settlement': 'StationId',
}


class Runner:
    def __init__(self):
        self.dict2arg = dict()
        parser = ArgumentParser()
        parser.add_argument('--yt-proxy', default='hahn')
        parser.add_argument('--yt-token')
        parser.add_argument('--yt-token-path')
        parser.add_argument('--yt-output-path', default=ytlib.get_default_user_path('rasp_dicts'))
        parser.add_argument('--transfer-to-cluster', default=None)
        for dict_name in RASP_DICTS:
            self.dict2arg[dict_name] = parser.add_argument(f'--{dict_name}-file').dest
        args = parser.parse_args(replace_args_from_env())

        self.args = args
        yt_config = {
            'token': args.yt_token,
            'token_path': args.yt_token_path,
        }
        self.yt_client = ytlib.create_client(proxy=args.yt_proxy, config=yt_config)
        self.work_dir = self.args.yt_output_path
        self.total_size_bytes = 0

    def run(self):
        with self.yt_client.Transaction():
            cleanup_strategy = StandardCleanupStrategy(7, False)
            versioned_path = VersionedPath(self.args.yt_output_path, yt_client=self.yt_client, cleanup_strategy=cleanup_strategy)
            with versioned_path as work_path:
                for dict_name, dict_class in RASP_DICTS.items():
                    value = getattr(self.args, self.dict2arg[dict_name])
                    if value:
                        self.read_and_dump(dict_name, dict_class, value, ytlib.join(work_path, dict_name))
                    else:
                        logging.info(f"Skip dict {dict_name}, no file given")

        if (self.args.transfer_to_cluster):
            versioned_path.transfer_results(self.args.transfer_to_cluster, self.args.yt_token, self.args.yt_proxy)

        logging.info(f'Total file size is {self.total_size_bytes}')

    def read_and_dump(self, dict_name, repo_class, file_path, yt_table_path):
        file_size = os.path.getsize(file_path)
        self.total_size_bytes += file_size
        logging.info(f"Dumping file {file_path} to {yt_table_path}, size is {file_size}")
        index_field = RASP_DICTS_INDEX_FIELDS.get(dict_name)
        if index_field:
            repo = repo_class(index_field=index_field)
        else:
            repo = repo_class()  # Not all dicts have this index_field in ctor
        repo.load_from_file(file_path)

        schema = self.get_schema(repo_class.get_proto_class())
        ytlib.recreate_table(yt_table_path, self.yt_client, ytlib.schema_from_dict(schema))
        self.yt_client.write_table(yt_table_path, (self.convert_message_value(v)
                                                   for v in auto_progress_reporter(repo.itervalues(), total=repo.size())))

    def get_schema(self, pb_class):
        schema = {}
        for f in pb_class.DESCRIPTOR.fields:
            if f.label == f.LABEL_REPEATED:
                t = 'any'
            else:
                if f.type == f.TYPE_BOOL:
                    t = 'boolean'
                elif f.type == f.TYPE_BYTES:
                    t = 'string'
                elif f.type == f.TYPE_DOUBLE:
                    t = 'double'
                elif f.type == f.TYPE_ENUM:
                    t = 'string'
                elif f.type == f.TYPE_FIXED32:
                    t = 'uint32'
                elif f.type == f.TYPE_FIXED64:
                    t = 'uint64'
                elif f.type == f.TYPE_FLOAT:
                    t = 'double'
                # elif f.type == f.TYPE_GROUP: What is it?
                elif f.type == f.TYPE_INT32:
                    t = 'int32'
                elif f.type == f.TYPE_INT64:
                    t = 'int64'
                elif f.type == f.TYPE_MESSAGE:
                    t = 'any'
                elif f.type == f.TYPE_SFIXED32:
                    t = 'int32'
                elif f.type == f.TYPE_SFIXED64:
                    t = 'int64'
                elif f.type == f.TYPE_SINT32:
                    t = 'int32'
                elif f.type == f.TYPE_SINT64:
                    t = 'int64'
                elif f.type == f.TYPE_STRING:
                    t = 'string'
                elif f.type == f.TYPE_UINT32:
                    t = 'uint32'
                elif f.type == f.TYPE_UINT64:
                    t = 'uint64'
                else:
                    raise Exception(f"Unknown type {f.type} of field {f.full_name}")
            schema[f.name] = t
        return schema

    def convert_single_value(self, fd, v):
        if fd.type == fd.TYPE_GROUP:
            raise Exception(f"TYPE_GROUP is not supported in {fd.full_name}")
        elif fd.type == fd.TYPE_MESSAGE:
            return self.convert_message_value(v)
        elif fd.type == fd.TYPE_ENUM:
            return fd.enum_type.values_by_number[v].name
        return v

    def convert_message_value(self, pb):
        row = {}
        for fd in pb.DESCRIPTOR.fields:
            v = getattr(pb, fd.name)
            if fd.label == fd.LABEL_REPEATED:
                if isinstance(v, ScalarMapContainer):
                    v = {str(k): el for k, el in v.items()}
                elif isinstance(v, MessageMapContainer):
                    v = {str(k): self.convert_single_value(fd, el) for k, el in v.items()}
                else:
                    v = [self.convert_single_value(fd, el) for el in v]
            else:
                v = self.convert_single_value(fd, v)
            row[fd.name] = v
        return row


if __name__ == '__main__':
    Runner().run()
