# -*- encoding: utf-8 -*-
import six
import typing  # noqa

from travel.avia.dump_data.lib import model_classes as mc
from travel.avia.dump_data.lib.mysql_connector import MysqlConnector  # noqa


class Collector(object):
    def __init__(self, model_declaration, mysql_connector, yt_client):
        # type: (dict, MysqlConnector) -> None
        self.mysql_connector = mysql_connector
        self.yt_client = yt_client
        self._models = self._init_models(model_declaration)

    def _init_models(self, model_declaration):
        # type: (dict) -> typing.List[mc.BaseModel]
        models = []

        for name, declaration in six.iteritems(model_declaration):
            if declaration['model'] == mc.MysqlModel:
                models.append(self._new_mysql_model(name, declaration))
            elif declaration['model'] == mc.MinPriceModel:
                models.append(self._new_min_price_model(name, declaration))
            elif declaration['model'] == mc.RouteCountModel:
                models.append(self._new_route_count_model(name, declaration))
            else:
                raise Exception('Unknown importer model')

        return models

    def _new_mysql_model(self, name, declaration):
        # type: (str, dict) -> mc.MysqlModel
        model = mc.MysqlModel(
            connector=self.mysql_connector,
            name=name,
            db_table=declaration['db_table'],
            proto_model=declaration['proto_model']
        )

        for field in declaration['fields']:
            model.add_fields(**field)

        return model

    def _new_min_price_model(self, name, declaration):
        return mc.MinPriceModel(
            name=name,
            connector=self.mysql_connector,
            yt_client=self.yt_client,
            proto_model=declaration['proto_model']
        )

    def _new_route_count_model(self, name, declaration):
        return mc.RouteCountModel(
            name=name,
            connector=self.mysql_connector,
            proto_model=declaration['proto_model']
        )

    def iter_by_models(self):
        # type: () -> typing.Iterable[mc.BaseModel]
        return self._models
