import argparse
import logging
import os

import yt.wrapper as yt


from travel.avia.analytics.price_changes.python.lib.answer_calculator import AnswerCalculator


_KEY_COLUMNS = [
    # query
    "from_settlement_id",
    "from_airport_id",
    "to_settlement_id",
    "to_airport_id",
    "forward_date",
    "backward_date",
    "adults",
    "children",
    "infants",
    "class_id",

    # partner
    "partner_id",
    "vendor_id",

    "forward_route",
    "backward_rote",
]


def create_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('-src-table')
    parser.add_argument('-dst-table')
    parser.add_argument('--threshold', default=1, type=int)
    parser.add_argument('--yt-proxy', default='hahn')
    parser.add_argument('--yt-token')

    return parser


def parse_args():
    parser = create_parser()
    return parser.parse_args()


def create_yt_client(args):
    proxy = args.yt_proxy
    token = args.yt_token
    if not token:
        token = os.getenv('YT_TOKEN')
        if not token:
            token_file = os.path.join(os.path.expanduser('~'), '.yt', 'token')
            if not os.path.exists(token_file) or not os.path.isfile(token_file):
                raise ValueError('Can not find YT token')

            with open(token_file, 'rb') as inp:
                token = inp.read().strip()

    if not token:
        raise ValueError('Can not find YT token')

    return yt.YtClient(proxy=proxy, token=token)


def create_dst_table(ytc, dst_table, src_table):
    src_schema = ytc.get_attribute(src_table, 'schema')
    schema = [
        {
            key: r[key]
            for key in ['type', 'name']
        }
        for r in src_schema
    ]

    schema.append({
        'type': 'int64',
        'name': 'answer',
    })

    logging.info('Creating table: %s', dst_table)
    logging.info('Schema: %s', schema)
    ytc.create('table', dst_table, attributes={
        'schema': schema,
        'optimize_for': 'scan',
    })


def process(ytc, src_table, dst_table, threhold):
    with ytc.Transaction():
        if ytc.exists(dst_table):
            ytc.remove(dst_table)

        create_dst_table(ytc, dst_table, src_table)

        ytc.run_map_reduce(
            mapper=None,
            reducer=AnswerCalculator(threhold=threhold),
            reduce_by=_KEY_COLUMNS,
            sort_by=_KEY_COLUMNS + ['unixtime'],
            source_table=src_table,
            destination_table=dst_table,
        )


def main():
    logging.info('Start')
    args = parse_args()
    ytc = create_yt_client(args)

    process(ytc, args.src_table, args.dst_table, args.threshold)
    logging.info('End')


if __name__ == '__main__':
    logging.basicConfig(
        hander=logging.StreamHandler(),
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    )
    main()
