import asyncio
import json
import logging
import traceback
import typing

from travel.rasp.pathfinder_proxy.const import CacheType, Status
from travel.rasp.pathfinder_proxy.logs import log_proxy_result
from travel.rasp.pathfinder_proxy.services.interline_service import InterlineService
from travel.rasp.pathfinder_proxy.services.ticket_daemon_service import TicketDaemonService
from travel.rasp.pathfinder_proxy.services.train_api_service import TrainApiService

logger = logging.getLogger(__name__)


async def merge_async_iterators(iterators: typing.Iterable[typing.AsyncIterator]) -> typing.AsyncIterator:
    # original idea from https://stackoverflow.com/a/49377261/3485638
    queue = asyncio.Queue()

    async def put_to_queue(iterator: typing.AsyncIterator) -> None:
        async for value in iterator:
            await queue.put(value)

    tasks = [asyncio.create_task(put_to_queue(iterator)) for iterator in iterators]

    while not (all(task.done() for task in tasks) and queue.empty()):
        yield await queue.get()


class PriceCollector:
    def __init__(self, cache, train_api_client, ticket_daemon_client, settings):
        self._cache = cache
        self._train_api_service = TrainApiService(train_api_client, cache, settings)
        self._ticket_daemon_service = TicketDaemonService(ticket_daemon_client, cache)
        self._interline_service = InterlineService(ticket_daemon_client, cache)

    async def _request_prices_for_transfer_variants(self, transfer_variants, query):
        async def save_to_cache(status: Status) -> None:
            logger.debug('Saving with %s', status)
            await self._cache.set_cache(
                CacheType.TRANSFERS_WITH_PRICES,
                query['point_from'],
                query['point_to'],
                query['when'],
                query['tld'],
                query['language'],
                json.dumps({
                    'transfer_variants': transfer_variants,
                    'status': status.value
                }, ensure_ascii=False),
                query['transport_types']
            )

        logger.debug('Preparing: {}'.format(query))
        services = [self._train_api_service, self._ticket_daemon_service, self._interline_service]
        iterators = [
            service.iter_variants_with_tariffs(transfer_variants, query['tld'], query['language'])
            for service in services
        ]

        try:
            async for _ in merge_async_iterators(iterators):
                await save_to_cache(Status.QUERYING)
        except Exception:
            traceback.print_exc()
            final_status = Status.ERROR
        else:
            final_status = Status.DONE

        log_proxy_result(
            cache_type=CacheType.TRANSFERS_WITH_PRICES,
            point_from=query['point_from'],
            point_to=query['point_to'],
            when=query['when'],
            tld=query['tld'],
            language=query['language'],
            transport_types=query['transport_types'],
            status=final_status.value,
            transfer_variants=transfer_variants
        )

        await save_to_cache(final_status)
        logger.debug('Prepared: {} with status: {}'.format(query, final_status.value))

    def collect(self, transfer_variants, query):
        asyncio.ensure_future(self._request_prices_for_transfer_variants(transfer_variants, query))
