import argparse
import os
import re
import requests
import time
import traceback
import urllib3
import json
import yt.wrapper as yt
import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
import kikimr.public.sdk.python.persqueue.auth as auth
from kikimr.public.sdk.python.persqueue.errors import SessionFailureResult
from yt.logger import LOGGER as LOGGER

QUERY_SELECT_PARTITIONS = "SELECT partition FROM system.parts WHERE database = '{db}' AND table = '{table}' GROUP BY partition ORDER BY partition"
QUERY_MOVE_PARTITION = "ALTER TABLE {db}.{dst_table} REPLACE PARTITION {partition} FROM {db}.{src_table}"
REPLICA_ALREADY_EXISTS_PATTERN = "Replica ([^ ]+) already exists"
ENGINE_PATTERN = re.compile("ReplicatedMergeTree\\( *\\'([^\\']+)\\' *, *\\'([^\\']*)\\' *(, *date *)?(?:, *\\(([a-zA-Z_]+ *(?:, *[a-zA-Z_]+ *)*)\\) *)?(, *[0-9]+ *)?\\)")


def request(config, query, data=None, encoding=None):
    url = 'https://{host}:8443/?&query={query}'.format(
        host=config['ch_address'],
        query=query)
    headers = {
        'X-ClickHouse-User': config['ch_username'],
        'X-ClickHouse-Key': config['ch_password'],
    }
    if encoding is not None:
        headers["Content-Encoding"] = encoding

    res = requests.post(url, data=data, headers=headers, verify=False)
    if res.status_code != 200:
        LOGGER.error("Bad response: " + res.text)
        raise Exception("Clickhouse responsed with code " + str(res.status_code) + ": " + res.text)
    return res.text


def fix_ddl(query):
    match = ENGINE_PATTERN.search(query)
    if match and match.group(3):
        return ENGINE_PATTERN.sub("ReplicatedMergeTree('{path}', '{replica}') PARTITION BY toYYYYMM(date) ORDER BY ({columns}) SETTINGS parts_to_throw_insert = 100000".format(
            path=match.group(1), replica=match.group(2), columns=match.group(4)), query)
    else:
        if query.find("parts_to_throw_insert") == -1:
            return query + " SETTINGS parts_to_throw_insert = 100000"
        else:
            return query


def tmp_table_spec(tableSpec):
    result = re.sub("ReplicatedMergeTree\\( *'[^']+' *, *'[^\']+' *,", "MergeTree(", tableSpec)
    return re.sub("ReplicatedMergeTree\\( *'[^']+' *, *'[^\']+' *\\)", "MergeTree()", result)


def get_row_count(config, db_name, table_name):
    LOGGER.info("Checking that table " + db_name + "." + table_name + " is not empty")
    row_count = request(config, "SELECT count() FROM " + db_name + "." + table_name)
    try:
        LOGGER.info("Sucessfully imported " + str(int(row_count)) + " rows into table " + table_name)
        return int(row_count)
    except ValueError:
        raise Exception("Could not get count of rows in imported table " + db_name + "." + table_name)


def import_table(task_spec, config):

    encoding = None

    compression = task_spec["dataCompressionMode"]

    yt_cluster, yt_path = task_spec["ytTable"].split(":", 1)
    yt_config = {}

    if compression != "NONE":
        yt_config = {"proxy": {"content_encoding": "identity"}}
        encoding = "gzip"

    client = yt.YtClient(proxy=yt_cluster, token=config["yt_token"], config=yt_config)
    db_name, table_name = task_spec["chTable"].split(".", 1)

    tmp_table_name = "imprt_" + table_name
    task_spec["started"] = int(time.time())
    LOGGER.info("Creating table " + table_name)
    request(config, "DROP TABLE IF EXISTS " + db_name + "." + table_name + " NO DELAY")
    request(config, "DROP TABLE IF EXISTS " + db_name + "." + tmp_table_name + " NO DELAY")

    time.sleep(1)

    create_spec = fix_ddl(task_spec["chCreateSpec"])
    request(config, "CREATE TABLE " + db_name + "." + tmp_table_name + " " + tmp_table_spec(create_spec))
    insert_query = "INSERT INTO " + db_name + "." + tmp_table_name + " " + task_spec["chInsertSpec"]
    count = 0
    for line in client.read_table(yt_path, raw=False, format=yt.DsvFormat()):
        data = line["data"]
        count += 1
        request(config, insert_query, data=data, encoding=encoding)
        LOGGER.info("Processed line " + str(count))

    LOGGER.info("Optimizing table " + tmp_table_name)
    request(config, "OPTIMIZE TABLE " + db_name + "." + tmp_table_name)
    row_count = get_row_count(config, db_name, tmp_table_name)

    LOGGER.info("Creating result table and moving partitions")
    try:
        request(config, "CREATE TABLE " + db_name + "." + table_name + " " + create_spec)
    except Exception as e:
        match = re.search(REPLICA_ALREADY_EXISTS_PATTERN, str(e))
        if match is None:
            raise e
        zk_path = match.group(1)
        suffix = '/replicas/' + config['ch_address']
        if not zk_path.endswith(suffix):
            raise e
        zk_path = zk_path[:-len(suffix)]
        # drop bad replica
        request(config, "SYSTEM DROP REPLICA '{replica}' FROM ZKPATH '{zk_path}'".format(replica=config['ch_address'], zk_path=zk_path))
        # retry create
        request(config, "CREATE TABLE " + db_name + "." + table_name + " " + task_spec["chCreateSpec"])

    partitions = request(config, QUERY_SELECT_PARTITIONS.format(db=db_name, table=tmp_table_name)).split("\n")
    for partition in partitions:
        if len(partition) > 0:
            if partition != 'tuple()':
                partition = "'" + partition + "'"
            request(config, QUERY_MOVE_PARTITION.format(db=db_name, src_table=tmp_table_name, dst_table=table_name, partition=partition))

    new_row_count = get_row_count(config, db_name, table_name)
    task_spec["row_count"] = new_row_count
    if row_count != new_row_count:
        raise Exception("Problem occured when moving partition from table " + tmp_table_name + " to " + table_name +
                        ". Old row count = " + str(row_count) + ", new row count = " + str(new_row_count))

    LOGGER.info("Clearing temp table")
    request(config, "DROP TABLE " + db_name + "." + tmp_table_name)
    task_spec["finished"] = int(time.time())
    LOGGER.info("Finished successfully")


def pick_and_run_task(config):
    locke_client = yt.YtClient(proxy="locke", token=config["yt_token"])
    base_path = config["yt_tasks_path"]
    # check for operation id
    operation_id = os.environ["YT_OPERATION_ID"]
    # get attribute from root
    root_path = "/".join(base_path.split("/")[:-1])
    locke_operation_id = locke_client.get(root_path, attributes=["clickhouseUploaderOperationId"]).attributes["clickhouseUploaderOperationId"]
    if locke_operation_id.split(":")[2] != operation_id:
        raise Exception("Current operation id " + operation_id + " differs from operation, stored in locke: " + locke_operation_id)

    with locke_client.Transaction():
        for path in locke_client.list(config["yt_tasks_path"]):
            task_path = base_path + "/" + path
            spec = locke_client.get(task_path, attributes=["spec"]).attributes["spec"]
            if not spec["done"]:
                LOGGER.info("Found spec to process: " + str(spec))
                locke_client.lock(task_path)
                import_table(spec, config)
                spec["done"] = True
                locke_client.set(task_path + "/@spec", spec)
                return spec

    return None


def log_result(config, spec, success, error):
    dc, shard = config["yt_tasks_path"].split("/")[-1].split("-", 1)
    log_record = {
        "dc": dc,
        "shard": shard,
        "ch_address": config["ch_address"],
        "yt_table": spec["ytTable"],
        "table": spec["chTable"],
        "started": spec["started"],
        "finished": spec["finished"],
        "row_count": spec["row_count"],
        "success": success,
        "error": error
    }
    # first write into lb
    config["lb_max_seq_no"] = config["lb_max_seq_no"] + 1
    lb_response = config["lb_producer"].write(config["lb_max_seq_no"], json.dumps(log_record))
    lb_write_result = lb_response.result(timeout=10)
    if not lb_write_result.HasField("ack"):
        raise RuntimeError("Message write failed with error {}".format(lb_write_result))

    if config["yt_log_table"] is None or spec is None:
        return
    cluster, path = config["yt_log_table"].split(":", 1)
    log_client = yt.YtClient(proxy=cluster, token=config["yt_token"])
    log_client.write_table(yt.TablePath(path, append=True), [log_record], format=yt.JsonFormat())


def create_lb_producer(config):
    job_id = os.environ["YT_JOB_ID"]
    token = os.environ["YT_SECURE_VAULT_LB_OAUTH_TOKEN"]
    api = pqlib.PQStreamingAPI("vla.logbroker.yandex.net", 2135)
    LOGGER.info("Starting PqLib")
    api_start_future = api.start()
    result = api_start_future.result(timeout=10)
    LOGGER.info(" Api started with result: {}".format(result))
    credentials_provider = auth.OAuthTokenCredentialsProvider(token)
    configurator = pqlib.ProducerConfigurator(config["lb_log_topic"], 'clickhouse-uploader' + job_id)
    producer = api.create_retrying_producer(configurator, credentials_provider=credentials_provider)
    LOGGER.info("Starting Producer")
    start_future = producer.start()  # Also available with producer.start_future()
    start_result = start_future.result(timeout=10)
    max_seq_no = None

    if not isinstance(start_result, SessionFailureResult):
        if start_result.HasField("init"):
            LOGGER.info("Producer start result was: {}".format(start_result))
            max_seq_no = start_result.init.max_seq_no
        else:
            raise RuntimeError("Unexpected producer start result from server: {}.".format(start_result))
    else:
        raise RuntimeError("Error occurred on start of producer: {}.".format(start_result))
    LOGGER.info("Producer started")
    config["lb_producer"] = producer
    config["lb_max_seq_no"] = max_seq_no


def main():
    urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
    parser = argparse.ArgumentParser(description='Upload table from YT to Clockhouse')
    parser.add_argument('--ch_address', help='Clickhouse address (without scheme and port)')
    parser.add_argument('--ch_username', help='Clickhouse username')
    parser.add_argument('--yt_tasks_path', help="Base path for import tasks")
    parser.add_argument('--yt_log_table', help="Table for logging")
    parser.add_argument('--lb_log_topic', help="Logbroker topic for logging")
    args = parser.parse_args()

    config = {
        "yt_token": os.environ["YT_SECURE_VAULT_YT_TOKEN"],
        "yt_tasks_path": args.yt_tasks_path,
        "ch_password": os.environ["YT_SECURE_VAULT_CH_PASSWORD"],
        "ch_address": args.ch_address,
        "ch_username": args.ch_username,
        "yt_log_table": args.yt_log_table,
        "lb_log_topic": args.lb_log_topic
    }
    create_lb_producer(config)
    time.sleep(10)

    spec = None
    while True:
        try:
            spec = None
            spec = pick_and_run_task(config)
            if spec is None:
                LOGGER.info("No new tasks found")
                time.sleep(60)
            else:
                LOGGER.info("Successfully imported spec " + str(spec))
                log_result(config, spec, True, None)
        except Exception as e:
            LOGGER.error("Error occured when import table: " + str(e))
            error = traceback.format_exc()
            LOGGER.error(error)
            log_result(config, spec, False, error)
            time.sleep(60)


if __name__ == "__main__":
    main()
