# -*- coding: utf-8 -*-
import os
import sys
import logging
from argparse import ArgumentParser, Namespace
from logging.config import dictConfig
from multiprocessing.dummy import Pool
from multiprocessing import TimeoutError
from contextlib import closing

from sqlalchemy.orm import Session

from cloud_logger import default_config
from travel.avia.price_index.lib.settings import settings

# It is crucial to set up logging before all other local modules imported.
# Due to some loggers will be initialized during importing
if os.getenv('DEPLOY_STAGE_ID'):
    dictConfig(default_config(settings.LOG_ROOT))
    sentry_dsn = os.getenv('SENTRY_DSN')
    if sentry_dsn:
        from raven.conf import setup_logging as raven_setup_logging
        from raven.handlers.logging import SentryHandler

        sentry_handler = SentryHandler(sentry_dsn)
        sentry_handler.setLevel(logging.ERROR)
        raven_setup_logging(sentry_handler)
else:
    logging.basicConfig(
        level=logging.INFO if not os.getenv('DEBUG') else logging.DEBUG,
        stream=sys.stdout,
    )

from travel.avia.price_index.lib.national_version_provider import national_version_provider
from travel.avia.price_index.lib.rates_provider import rates_provider
from travel.avia.price_index.lib.currency_provider import currency_provider
from travel.avia.price_index.lib import result_operation
from travel.avia.price_index.lib.search_results_queue_consumer.consumer import (
    SearchResultsConsumer,
    DEFAULT_DC_ENDPOINTS,
)
from travel.avia.price_index.models.result import Result
from travel.avia.price_index.views.helpers import master_wrapper

logger = logging.getLogger(__name__)


def _parse_arguments():
    parser = ArgumentParser()

    parser.add_argument('--consumer-id', type=str, default='/avia/development/default/search-results-consumer')
    parser.add_argument('--topic-id', type=str, default='/avia/testing/search-results-queue')
    parser.add_argument('--read-timeout', default=10, type=int)
    parser.add_argument('--indexing-wait-timeout', default=60, type=int)
    parser.add_argument('--indexing-pool-size', default=16, type=int)
    parser.add_argument('--proto-parsing-pool-size', default=8, type=int)
    parser.add_argument('--enable-indexing', type=lambda x: str(x).lower() == 'true', default=False)

    parser.add_argument(
        '-e',
        '--endpoints',
        type=lambda s: s.split(','),
        default=','.join(DEFAULT_DC_ENDPOINTS),
        help='logbroker endpoint to read from. Usage: -e sas.logbroker.yandex.net,vla.logbroker.yandex.net',
    )

    return parser.parse_args()


def main():
    token = os.getenv('LOGBROKER_TOKEN')
    args = _parse_arguments()

    national_version_provider.fetch()
    rates_provider.fetch()
    currency_provider.fetch()

    with SearchResultsConsumer(
        args.consumer_id,
        args.topic_id,
        token=token,
        timeout=args.read_timeout,
        endpoints=args.endpoints,
    ) as consumer:
        run_indexing(consumer, args)


def run_indexing(consumer, args):
    # type: (SearchResultsConsumer, Namespace) -> None
    with closing(Pool(processes=args.indexing_pool_size)) as pool:
        for batch in consumer.read_batch(proto_parsing_pool_size=args.proto_parsing_pool_size):
            try:
                pool.map_async(_index_single, batch).wait(timeout=args.indexing_wait_timeout)
            except Exception:
                logger.exception('Unexpected error during indexing batch')


@master_wrapper
def _index_single(session, search_result):
    # type: (Session, Result) -> None
    try:
        if not search_result:
            return
        if not search_result.is_empty():
            logger.info('Indexing for qid: %s', search_result.qid)
            result_operation.on_conflict.save(session, search_result)
            logger.info('Indexed for qid: %s', search_result.qid)
        else:
            logger.info('Removing for qid: %s', search_result.qid)
            result_operation.delete.save(session, search_result)
            logger.info('Removed for qid: %s', search_result.qid)
    except TimeoutError:
        logger.exception('Failed to index search result because of timeout')
    except:
        logger.exception('Unexpected error during search result indexing: %s', search_result.qid)


if __name__ == '__main__':
    main()
