# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import tempfile

from os import path

from travel.rasp.bus.db import session_scope
from travel.rasp.bus.db.models.matching import PointMatching, PointType, parse_point_key
from travel.rasp.bus.db.models.supplier import Supplier
from travel.rasp.bus.db.models.carrier import Carrier
from travel.rasp.bus.db.models.carrier_matching import CarrierMatching
from travel.rasp.bus.db.models.register_type import RegisterType
from travel.proto.dicts.buses.supplier_pb2 import TSupplier
from travel.proto.dicts.buses.carrier_pb2 import TCarrier
from travel.proto.dicts.buses.register_type_pb2 import TRegisterType
from travel.proto.dicts.buses.carrier_matching_pb2 import TCarrierMatching
from travel.proto.dicts.buses.common_pb2 import POINT_KEY_TYPE_SETTLEMENT, POINT_KEY_TYPE_STATION, POINT_KEY_TYPE_UNKNOWN
from travel.proto.dicts.buses.point_matching_pb2 import TPointMatching
from travel.library.python.dicts.buses.supplier_repository import SupplierRepository
from travel.library.python.dicts.buses.carrier_repository import CarrierRepository
from travel.library.python.dicts.buses.carrier_matching_repository import CarrierMatchingRepository
from travel.library.python.dicts.buses.point_matching_repository import PointMatchingRepository
from travel.library.python.dicts.buses.register_type_repository import RegisterTypeRepository


class AdminDictsDumper(object):

    def __init__(self, directory=None):
        self.directory = directory or tempfile.gettempdir()

    def _get_bin_path(self, name):
        return path.join(self.directory, "{}.bin".format(name))

    def dump_supplier_dict(self):
        logging.info("Dumping supplier dict")
        repository = SupplierRepository(index_field='id')
        with session_scope() as session:
            for obj in session.query(Supplier):
                proto_obj = TSupplier()
                for field_name in ['id', 'code', 'name', 'register_type_id', 'register_number', 'legal_name',
                                   'actual_address', 'legal_address', 'timetable', 'taxation_number', 'first_name',
                                   'middle_name', 'last_name', 'hidden']:
                    if getattr(obj, field_name) is None:
                        continue
                    setattr(proto_obj, field_name, getattr(obj, field_name))
                repository.add_object(proto_obj)
        fn = self._get_bin_path('supplier')
        repository.dump_to_file(fn)
        logging.info("Done as {}".format(fn))
        return fn

    def dump_carrier_dict(self):
        logging.info("Dumping carrier dict")
        repository = CarrierRepository(index_field='id')
        with session_scope() as session:
            for obj in session.query(Carrier):
                proto_obj = TCarrier()
                for field_name in ['id', 'name', 'register_type_id', 'register_number', 'inn', 'legal_name',
                                   'actual_address', 'legal_address', 'timetable', 'first_name', 'middle_name',
                                   'last_name', 'hidden']:
                    if getattr(obj, field_name) is None:
                        continue
                    setattr(proto_obj, field_name, getattr(obj, field_name))
                repository.add_object(proto_obj)
        fn = self._get_bin_path('carrier')
        repository.dump_to_file(fn)
        logging.info("Done as {}".format(fn))
        return fn

    def dump_carrier_matching_dict(self):
        logging.info("Dumping carrier_matching dict")
        repository = CarrierMatchingRepository(index_field='id')
        with session_scope() as session:
            for obj in session.query(CarrierMatching):
                proto_obj = TCarrierMatching()
                for field_name in ['id', 'supplier_id', 'code', 'carrier_id']:
                    if getattr(obj, field_name) is None:
                        continue
                    setattr(proto_obj, field_name, getattr(obj, field_name))
                repository.add_object(proto_obj)
        fn = self._get_bin_path('carrier_matching')
        repository.dump_to_file(fn)
        logging.info("Done as {}".format(fn))
        return fn

    @staticmethod
    def _point_type_to_proto(point_type):
        if point_type == PointType.CITY:
            return POINT_KEY_TYPE_SETTLEMENT
        elif point_type == PointType.STATION:
            return POINT_KEY_TYPE_STATION
        return POINT_KEY_TYPE_UNKNOWN

    def dump_point_matching_dict(self):
        logging.info("Dumping point_matching dict")
        repository = PointMatchingRepository(index_field='id')
        with session_scope() as session:
            for obj in session.query(PointMatching):
                proto_obj = TPointMatching()
                for field_name in ['id', 'supplier_id', 'supplier_point_id', 'parent_id', 'title', 'description',
                                   'country', 'city_id', 'country_code', 'city_title', 'region',
                                   'region_code', 'district', 'extra_info', 'timezone_info',
                                   'disabled', 'outdated', 'in_segments']:
                    if getattr(obj, field_name) is None:
                        continue
                    setattr(proto_obj, field_name, getattr(obj, field_name))
                proto_obj.type = self._point_type_to_proto(obj.type)
                if obj.point_key:
                    point_type, point_id = parse_point_key(obj.point_key)
                    if point_type != PointType.INVALID:
                        proto_obj.point_key.type = self._point_type_to_proto(point_type)
                        proto_obj.point_key.id = point_id
                if obj.latitude and obj.longitude:
                    proto_obj.coordinates.Latitude = obj.latitude
                    proto_obj.coordinates.Longitude = obj.longitude

                repository.add_object(proto_obj)
        fn = self._get_bin_path('point_matching')
        repository.dump_to_file(fn)
        logging.info("Done as {}".format(fn))
        return fn

    def dump_register_type_dict(self):
        logging.info("Dumping register_type dict")
        repository = RegisterTypeRepository(index_field='id')
        with session_scope() as session:
            for obj in session.query(RegisterType):
                proto_obj = TRegisterType()
                for field_name in ['id', 'code', 'title', 'description']:
                    if getattr(obj, field_name) is None:
                        continue
                    setattr(proto_obj, field_name, getattr(obj, field_name))
                repository.add_object(proto_obj)
        fn = self._get_bin_path('register_type')
        repository.dump_to_file(fn)
        logging.info("Done as {}".format(fn))
        return fn

    def dump_all(self):
        return [
            self.dump_supplier_dict(),
            self.dump_carrier_dict(),
            self.dump_carrier_matching_dict(),
            self.dump_point_matching_dict(),
            self.dump_register_type_dict(),
        ]
