# -*- coding: utf-8 -*-
import os
import argparse
import datetime
import time
import sys
import requests
import json
import yt.wrapper as yt
import itertools
import logging
import random

logging.getLogger("Yt").setLevel(logging.DEBUG)
rnd = random.SystemRandom()


def get_arg_parser():
    parser = argparse.ArgumentParser("UGC CH Pusher")

    parser.add_argument(
        "--input-dates", type=argparse.FileType("r"), required=True
    )
    parser.add_argument("--config", type=argparse.FileType("r"), required=True)

    return parser.parse_args()


def get_host(config):
    if "hosts" in config:
        return rnd.choice(config["hosts"])
    return config["host"]


def get_clickhouse_data(query, config, auth=None):
    retries = 0
    while retries <= 5:
        host = get_host(config)
        try:
            return requests.post(
                host,
                auth=auth,
                data=query,
                timeout=1500,
                verify=False,
            )
        except requests.exceptions.ConnectionError as e:
            retries += 1
            tosleep = 300 * retries
            print(
                "retrying cause of connection error: {}. Sleeping for {} secs".format(
                    e, tosleep
                )
            )
            if retries >= 5:
                raise
            time.sleep(tosleep)


def wrap_string(value):
    try:
        value = (
            (value or "")
            .decode("utf-8")
            .replace("\\", "\\\\")
            .replace("'", "\\'")
        )
        return "'{}'".format(value)
    except UnicodeEncodeError:
        return "''"


def isoformat_date(date_int):
    date_int = date_int if date_int > 0 else 0
    date_int = datetime.date.fromtimestamp(date_int).isoformat()
    return wrap_string(date_int)


def gen_query(dates, config):
    for date in dates:
        table = config["yt_path"] + "/" + date
        ytc = config["yt_client"]
        schema = {
            x["name"]: x["type"] for x in ytc.get_attribute(table, "schema")
        }
        for i, row in enumerate(ytc.read_table(table, format="json")):
            if config.get("skip") and i < config["skip"]:
                continue
            processed_row = []
            for key in config["rows"]:
                value = row[key]
                if key == config["key_column"] and config.get(
                    "isoformat_key_column"
                ):
                    value = isoformat_date(value)
                elif schema[key] == "string":
                    value = wrap_string(value)
                else:
                    value = str(value or 0)
                processed_row.append(value)

            yield "({})".format(",".join(processed_row))


def main():
    args = get_arg_parser()
    dates = json.load(args.input_dates)
    config = json.load(args.config)
    ytc = yt.YtClient(proxy=config["yt_proxy"])
    config["yt_client"] = ytc
    bulk_size = config.get("bulk_size", 50000)

    rows = "({})".format(", ".join(config["rows"]))
    auth = (os.environ["CH_USER"], os.environ["CH_PASSWORD"])
    insert_query = u"""insert into {ch_table} {rows} Values \n{values}""".format(
        rows=rows, values="{values}", ch_table=config["ch_table"]
    )
    delete_query = u"""alter table {ch_table} delete where toDate("{key_column}") in ({dates})""".format(
        ch_table=config["ch_table"],
        key_column=config["key_column"],
        dates="{dates}",
    )

    to_delete = delete_query.format(
        dates=",".join(["'{date}'".format(date=date) for date in dates])
    )
    if (config.get("mode") or "replace") == "replace":
        print("replace mode")
        answer = get_clickhouse_data(to_delete, config, auth)
        if answer.status_code != 200:
            raise Exception(
                "http error {}: {}".format(answer.status_code, answer.text)
            )
    else:
        print("append mode")

    query_gen = gen_query(dates, config)

    count = 0

    text = ", \n".join(itertools.islice(query_gen, bulk_size))

    while text:
        text = insert_query.format(values=text)
        answer = get_clickhouse_data(text, config, auth)
        if answer.status_code != 200:
            raise Exception(
                "http error {}: {}".format(answer.status_code, answer.text)
            )
        count += 1
        text = ", \n".join(itertools.islice(query_gen, bulk_size))
        print(count)


if __name__ == "__main__":
    main()
