import json
import logging
from typing import Dict, Iterable, Tuple

from starlette.applications import Starlette
from sqlalchemy.dialects.postgresql import insert

from src.budget_position.models import budget_position_table, budget_position_assignment_table
from src.common import FieldMapper
from src.common.logbroker.consumer import LogbrokerConsumer
from src.config import settings

from src.departments.models import department_table


logger = logging.getLogger(__name__)


def _get_model_data(data: str) -> Iterable[Tuple[str, Dict]]:
    log_broker_object = json.loads(data)

    for received_model in json.loads(log_broker_object['data']):
        received_fields = received_model['fields']
        received_fields['id'] = log_broker_object['id']
        received_fields['source_action'] = log_broker_object['action']
        received_fields['source_datetime'] = log_broker_object['creation_time']
        received_fields['pk'] = received_model['pk']
        yield received_model['model'], received_fields


class PullLogBrokerData:
    model_mapping = {
        'django_intranet_stuff.department': {'table': department_table, 'custom_fields': {'department_id': 'pk'}},
        'budget_position.budgetposition': {'table': budget_position_table},
        'budget_position.budgetpositionassignment': {
            'table': budget_position_assignment_table,
            'aliases': {
                'budget_position_id': 'budget_position',
                'change_registry_id': 'change_registry',
                'previous_assignment_id': 'previous_assignment',
                'person_id': 'person',
                'vacancy_id': 'vacancy',
                'department_id': 'department',
                'value_stream_id': 'value_stream',
                'geography_id': 'geography',
                'bonus_id': 'bonus',
                'reward_id': 'reward',
                'review_id': 'review',
            },
        },
    }

    def __init__(self, mapper: FieldMapper = None) -> None:
        super().__init__()
        self._mapper = mapper or FieldMapper(system_fields={'id', 'source', 'source_id'})

    async def run(self, app: Starlette) -> bool:
        consumer = await LogbrokerConsumer.create(settings.LOGBROKER_TOPIC_NAME, settings.LOGBROKER_CONSUMER_NAME)
        try:
            messages_with_cookie = await consumer.read()
            model_values = {}

            if not messages_with_cookie:
                return False

            for message in messages_with_cookie.messages_data:
                for model_name, model_data in _get_model_data(message):
                    model_mapping_data = self.model_mapping.get(model_name)

                    if model_mapping_data is None:
                        logger.info('%s model skipped', model_name)
                        continue

                    table = model_mapping_data['table']
                    mapped = self._mapper.map(
                        source_name='log_broker',
                        target_type=table,
                        source=model_data,
                        custom=model_mapping_data.get('custom_fields'),
                        aliases=model_mapping_data.get('aliases'),
                    )

                    model_values.setdefault(table, []).append(mapped)

            engine = app.state.engine
            async with engine.acquire() as conn:
                for table, values in model_values.items():
                    logger.info('Inserting %s', table)
                    query = insert(table).values(values).on_conflict_do_nothing()
                    await conn.execute(query)
                    logger.info('%s %s items inserted', len(values), table)

            await consumer.commit(messages_with_cookie.cookie)
            return True
        finally:
            await consumer.stop()
