# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

"""
Запуск в проде/тестинге:
    Y_PYTHON_ENTRY_POINT="travel.avia.ticket_daemon_api.scripts.examine_cache:main" /app/app

Запуск на дев-машинке:
    ./tools/examine-cache.sh
"""

import argparse
import json
import logging
import sys
import zlib
from datetime import datetime

import django
django.setup()

import six
from django.conf import settings
import ydb

from travel.avia.library.python.avia_data.models import AmadeusMerchant
from travel.avia.library.python.common.models.geo import Settlement, Station
from travel.avia.library.python.common.models.partner import DohopVendor, Partner
from travel.avia.library.python.common.models_utils.geo import Point
from travel.avia.library.python.ticket_daemon.ydb.django import utils as django_ydb_utils
from travel.avia.ticket_daemon_api.jsonrpc.lib.ydb.cache import RESULTS_TABLE_NAME, REDIRECT_DATA_TABLE_NAME, ServiceCache
from travel.avia.ticket_daemon_api.jsonrpc.query import Query

logger = logging.getLogger(__name__)

YDB_COLUMNS = (
    'created_at',
    'date_backward',
    'date_forward',
    'expires_at',
    'klass',
    'lang',
    'meta',
    'national_version',
    'partner_code',
    'passengers',
    'point_from',
    'point_to',
    'variants',
)
YDB_ZIPPED_COLUMNS = (
    'redirect_data',
    'variants',
)


def select_cache_entries(where, limit):
    query = """
    PRAGMA TablePathPrefix("{path}");

    SELECT {columns}
    FROM {results}
    WHERE {where}
    ORDER BY created_at DESC
    LIMIT {limit};

    SELECT redirect_data, partner_code
    FROM {redirect_data}
    WHERE {where}
    ORDER BY created_at DESC
    LIMIT {limit};
    """.format(
        path=ServiceCache.DATABASE,
        results=RESULTS_TABLE_NAME,
        redirect_data=REDIRECT_DATA_TABLE_NAME,
        columns=', '.join(YDB_COLUMNS),
        where=where,
        limit=limit,
    )

    def callee(session):
        result_sets = session.transaction(ydb.SerializableReadWrite()).execute(query, commit_tx=True)
        variants = result_sets[0].rows
        redirect_data_by_partner = {row['partner_code']: row['redirect_data'] for row in result_sets[1].rows}
        for v in variants:
            v['redirect_data'] = redirect_data_by_partner.get(v['partner_code'])
        return variants

    with django_ydb_utils.get_session_pool() as session_pool:
        return session_pool.retry_operation_sync(callee)


def parse_point(value):
    try:
        return Point.get_by_key(value)
    except Exception:
        pass

    value = value.lower()
    try:
        return Settlement.objects.get(iata__iexact=value)
    except Exception:
        pass

    try:
        return Settlement.objects.get(sirena_id__iexact=value)
    except Exception:
        pass

    try:
        return Station.objects.filter(code_set__system__code='iata').distinct().get(code_set__code__iexact=value)
    except Exception:
        pass

    try:
        return Station.objects.filter(code_set__system__code='sirena').distinct().get(code_set__code__iexact=value)
    except Exception:
        pass

    raise ValueError('invalid point value: {}'.format(value))


def parse_partner(code):
    if code.startswith('dohop_'):
        dohop_id = int(code[len('dohop_'):])
        try:
            return DohopVendor.objects.get(dohop_id=dohop_id)
        except DohopVendor.DoesNotExist:
            pass

    if code.startswith('amadeus_'):
        try:
            return AmadeusMerchant.objects.get(code=code)
        except AmadeusMerchant.DoesNotExist:
            pass

    try:
        return Partner.objects.get(code=code)
    except Partner.DoesNotExist:
        pass

    raise ValueError('unknown partner code "{}", possible codes {}'.format(code, sorted(get_partner_codes())))


def get_partner_codes():
    return Partner.objects.filter(t_type__code='plane').values_list('code', flat=True)


def parse_date(dt):
    return datetime.strptime(dt, '%Y-%m-%d').date() if dt else None


def build_query(args):
    query_params = {
        'point_from': args.point_from,
        'point_to': args.point_to,
        'date_forward': args.date_forward,
        'date_backward': args.date_backward,
        'klass': args.service_class,
        'passengers': {
            'adults': args.adults,
            'children': args.children,
            'infants': args.children,
        },
        'national_version': args.national_version,
        'service': 'ticket',
        'lang': 'ru',
        'base_qid': 'examine-cache-qid',  # prevent init_multiple_queries
    }

    q = Query(**query_params)
    q = Query.from_key(q.key(), service=q.service, lang=q.lang)

    return q


def handle_verbose(verbose):
    if verbose:
        logger.addHandler(logging.StreamHandler(sys.stdout))
        logger.setLevel(logging.DEBUG)


def select():
    select_parser = argparse.ArgumentParser()
    select_parser.add_argument('-v', '--verbose', action='store_true')

    select_parser.add_argument('--where', dest='where', type=six.text_type, required=True)
    select_parser.add_argument('--limit', dest='limit', type=int, choices=range(1, 11), default=10)

    args = select_parser.parse_args()
    handle_verbose(args.verbose)

    logger.info('Start examine cache')
    result = select_cache_entries(args.where, args.limit)

    return result


def default():
    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--verbose', action='store_true')

    parser.add_argument('-f', '--point_from', dest='point_from', type=parse_point, required=True,
                        help='point_key or IATA/Sirena code')
    parser.add_argument('-t', '--point_to', dest='point_to', type=parse_point, required=True,
                        help='point_key or IATA/Sirena code')
    parser.add_argument('-d', '--dateForward', dest='date_forward', type=parse_date, required=True,
                        help='date in format 2021-05-28')
    parser.add_argument('-b', '--dateBackward', dest='date_backward', type=parse_date, default=None,
                        help='date in format 2021-05-28')
    parser.add_argument('-q', '--class', dest='service_class', type=str, default='economy',
                        choices=('economy', 'business'))
    parser.add_argument('-a', '--adults', dest='adults', type=int, default=1)
    parser.add_argument('-c', '--children', dest='children', type=int, default=0)
    parser.add_argument('-i', '--infants', dest='infants', type=int, default=0)
    parser.add_argument('-p', '--partner', dest='partner', type=parse_partner, required=True)

    parser.add_argument('-n', '--national-version', dest='national_version', type=str, default='ru',
                        choices=settings.AVIA_NATIONAL_VERSIONS)

    args = parser.parse_args()
    handle_verbose(args.verbose)

    logger.info('Start examine cache')
    logger.info('%s -> %s %s - %s %s',
                args.point_from, args.point_to, args.date_forward, args.date_backward, args.partner)

    logger.info('build query')
    q = build_query(args)
    logger.info('query is %s', q)

    logger.info('get result from cache')
    result = ServiceCache.get(q, args.partner.code, columns=YDB_COLUMNS)

    return result


def main():
    if len(sys.argv) > 1 and sys.argv[1].lower() == 'select':
        del sys.argv[1]
        result = select()
    else:
        result = default()

    logger.info('raw result: %s', result)
    for row in result:
        for field in YDB_ZIPPED_COLUMNS:
            if field in row:
                row[field] = json.loads(zlib.decompress(row[field]))
    logger.info('unzipped result: %s', result)
    print(json.dumps(result, indent=2))

    logger.info('Done examine cache')
