from __future__ import print_function

import os
import logging
import argparse

import yt.wrapper as yt

from crypta.graph.mrcc import lib
from crypta.lib.python.yql.client import create_yql_client, get_yt_client

# add console handler to mrcc logger
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
console.setFormatter(logging.Formatter('[%(levelname)s] %(asctime)s :: %(message)s'))
lib.logger.addHandler(console)
lib.logger.setLevel(logging.DEBUG)


def make_arg_parser():
    parser = argparse.ArgumentParser(description='Find connected components on MR')
    parser.add_argument(
        '-s', '--source', help='YT path to source table', required=True)
    parser.add_argument(
        '-d', '--destination', help='YT path to destination table (if not set, equal to source path)')
    parser.add_argument(
        '-u', help='Name of u-column in source table (default: %(default)s)', default='u')
    parser.add_argument(
        '-v', help='Name of v-column in source table (default: %(default)s)', default='v')
    parser.add_argument(
        '-cid', help='Name of connected component id column in destination'
        ' table (default: %(default)s)', default='component_id')
    parser.add_argument(
        '--max_iter', type=int, help='Maximum number of iterations '
        'to get converged graph (default: %(default)s)', default=15)
    parser.add_argument(
        '--yql', help='Find connected component via YQL (default via YT)',
        action='store_true', dest='yql', default=False)
    parser.add_argument(
        '--vdata', help='Keep v_data column in destination table (run on yql)',
        action='store_true', dest='vdata', default=False)
    return parser


def find_connected_components(MrccClass, client, args):
    mrcc = MrccClass(client)
    return mrcc.find_connected_components(
        source=args.source,
        destination=args.destination or None,
        u_name=args.u,
        v_name=args.v,
        component_id=args.cid,
        max_iter=args.max_iter,
    )


def main():
    args = make_arg_parser().parse_args()

    assert os.getenv('YT_PROXY'), 'Please, configure YT_PROXY environment variable'
    yt_proxy = os.getenv('YT_PROXY')
    yt_pool = os.getenv('YT_POOL')

    if args.yql or args.vdata:
        assert os.getenv('YQL_TOKEN'), 'Please, configure YQL_TOKEN environment variable'
        converged = find_connected_components(
            (lib.ExtendedMRConnectedComponentsYQL, lib.VDataMRConnectedComponentsYQL)[args.vdata],
            create_yql_client(yt_proxy=yt_proxy, token=os.getenv('YQL_TOKEN'), pool=yt_pool), args)
    else:
        assert os.getenv('YT_TOKEN'), 'Please, configure YT_TOKEN environment variable'
        token = os.getenv('YT_TOKEN')
        with yt.Transaction() as transaction:
            client_config = {
                'proxy': yt_proxy,
                'transaction_id': str(transaction.transaction_id),
                'token': token,
                'pool': yt_pool,
            }
            client = get_yt_client(yt_proxy, token, transaction.transaction_id)
            for key, value in client_config.items():
                setattr(client, key, value)
            converged = find_connected_components(lib.ExtendedMRConnectedComponentsYT, client, args)

    assert converged, 'Not converged'


if __name__ == '__main__':
    main()
