import argparse
import logging

import yt.wrapper as yt

from crypta.dmp.adobe.bin.common.python import config_fields
from crypta.dmp.common.data.python import bindings
from crypta.lib.python.logging import logging_helpers
from crypta.lib.python import yaml_config
from crypta.lib.python.yql import executer


logger = logging.getLogger(__name__)


QUERY_TEMPLATE = """
PRAGMA SimpleColumns;

INSERT INTO `{filtered_table}` WITH TRUNCATE
SELECT
    src.*
FROM `{src_table}` as src
INNER JOIN `{geo_table}` as geo
ON geo.{geo_table_yandexuid_field} == src.{src_table_yandexuid_field}
WHERE NOT ListHas(AsList({country_ids}), geo.{geo_table_country_id_field});
"""


def parse_args():
    parser = argparse.ArgumentParser(description="Filter EU yandexuids and yandexuids with no geo")
    parser.add_argument("--config", help="Config file path", required=True, type=yaml_config.load)
    return parser.parse_args()


def main():
    logging_helpers.configure_stdout_logger(logging.getLogger())

    args = parse_args()
    logger.info("args: %s", args)
    config = args.config

    yt_proxy = config[config_fields.YT_PROXY]
    yt_pool = config[config_fields.YT_POOL]
    yt.config.set_proxy(yt_proxy)
    yt.config["pool"] = yt_pool

    yt_tmp_dir = config[config_fields.YT_TMP_DIR]
    yql_executer = executer.get_executer(yt_proxy, yt_pool, yt_tmp_dir)

    with yt.Transaction() as tx, yt.TempTable(yt_tmp_dir) as yql_query_dst_table:
        bindings_yandexuid_field = config[config_fields.BINDINGS_YANDEXUID_FIELD]

        query = QUERY_TEMPLATE.format(src_table_yandexuid_field=bindings_yandexuid_field,
                                      geo_table_yandexuid_field=config[config_fields.GEO_TABLE_YANDEXUID_FIELD],
                                      geo_table_country_id_field=config[config_fields.GEO_TABLE_COUNTRY_ID_FIELD],
                                      country_ids=",".join([str(x) for x in config[config_fields.EU_COUNTRY_IDS]]),
                                      src_table=config[config_fields.YANDEXUID_BINDINGS_TABLE],
                                      geo_table=config[config_fields.YANDEXUID_GEO_TABLE],
                                      filtered_table=yql_query_dst_table)

        logger.info("Running YQL query:\n%s", query)
        yql_executer(query, transaction=tx.transaction_id, syntax_version=1)

        filtered_table = yt.TablePath(config[config_fields.FILTERED_YANDEXUID_BINDINGS_TABLE], schema=bindings.get_yandexuid_schema(), attributes={"optimize_for": "scan"})
        yt.create("map_node", yt.ypath_dirname(filtered_table), recursive=True, ignore_existing=True)
        yt.run_sort(yql_query_dst_table, filtered_table, sort_by=bindings_yandexuid_field)

    logger.info("Completed successfully")
