# -*- coding: utf-8 -*-
import os
import argparse
import datetime
import time
import math
import sys
import requests
import json
import yt.wrapper as yt
import itertools
import logging
import random
from collections import defaultdict

# from concurrent.futures import ThreadPoolExecutor

logging.getLogger("Yt").setLevel(logging.DEBUG)
rnd = random.SystemRandom()


def get_arg_parser():
    parser = argparse.ArgumentParser("UGC CH Pusher")

    parser.add_argument("--input-dates", type=argparse.FileType("r"), required=True)
    parser.add_argument("--config", type=argparse.FileType("r"), required=True)
    parser.add_argument("--mod", type=int, required=False)
    parser.add_argument("--snapshot")

    return parser.parse_args()


def get_host(config):
    if "hosts" in config:
        return rnd.choice(config["hosts"])
    return config["host"]


def get_clickhouse_data(query, config, auth=None):
    retries = 0
    while retries <= 5:
        host = get_host(config)
        try:
            return requests.post(
                host, auth=auth, data=query, timeout=1500, verify=False
            )
        except requests.exceptions.ConnectionError as e:
            retries += 1
            tosleep = 300 * retries
            print(
                "retrying cause of connection error: {}. Sleeping for {} secs".format(
                    e, tosleep
                )
            )
            if retries >= 5:
                raise
            time.sleep(tosleep)


def wrap_string(value):
    try:
        if isinstance(value, bytes):
            value = value.decode("utf8")
        value = (value or "").replace("\\", "\\\\").replace("'", "\\'")
        return "'{}'".format(value)
    except UnicodeEncodeError:
        return "''"


def isoformat_date(date_int):
    date_int = date_int if date_int > 0 else 0
    date_int = datetime.date.fromtimestamp(date_int).isoformat()
    return wrap_string(date_int)


def gen_query(ytc, tp, schema, config):
    rows = []
    for i, row in enumerate(ytc.read_table(tp, format="json")):
        processed_row = []
        for key in config["rows"]:
            value = row[key]
            if key == config["key_column"] and config.get("isoformat_key_column"):
                value = isoformat_date(value)
            elif schema[key] == "string":
                value = wrap_string(value)
            else:
                value = str(value or 0)
            processed_row.append(value)
        rows.append("({})".format(",".join(processed_row)))
    return ", \n".join(rows)


def main():
    args = get_arg_parser()
    dates = json.load(args.input_dates)
    config = json.load(args.config)
    if args.snapshot:
        try:
            with open(args.snapshot) as f:
                obj = json.loads(f.read())
                already_pushed = defaultdict(set, {k: set(v) for k, v in obj.items()})
        except FileNotFoundError:
            already_pushed = defaultdict(set)
    else:
        already_pushed = defaultdict(set)
    ytc = yt.YtClient(proxy=config["yt_proxy"])
    config["yt_client"] = ytc
    bulk_size = config.get("bulk_size", 50000)

    rows = "({})".format(", ".join(config["rows"]))
    auth = (os.environ["CH_USER"], os.environ["CH_PASSWORD"])
    insert_query = u"""insert into {ch_table} {rows} Values \n{values}""".format(
        rows=rows, values="{values}", ch_table=config["ch_table"]
    )
    delete_query = u"""alter table {ch_table} delete where toDate("{key_column}") in ({dates})""".format(
        ch_table=config["ch_table"], key_column=config["key_column"], dates="{dates}"
    )

    to_delete = delete_query.format(
        dates=",".join(["'{date}'".format(date=date) for date in dates])
    )
    if (config.get("mode") or "replace") == "replace":
        print("replace mode")
        if (args.mod is None or args.mod == 0) and not already_pushed:
            print("deleting data")
            answer = get_clickhouse_data(to_delete, config, auth)
            if answer.status_code != 200:
                raise Exception(
                    "http error {}: {}".format(answer.status_code, answer.text)
                )
    else:
        print("append mode")

    count = 0

    def pusher(text, i):
        answer = get_clickhouse_data(text, config, auth)
        if answer.status_code != 200:
            raise Exception("http error {}: {}".format(answer.status_code, answer.text))
        print(i)

    max_workers = config.get("max_workers") or 10

    for date in dates:
        table = config["yt_path"] + "/" + date
        print(f"pushing {table}")
        row_count = ytc.get_attribute(table, "row_count")
        schema = {x["name"]: x["type"] for x in ytc.get_attribute(table, "schema")}
        chunks = math.ceil(row_count // bulk_size)

        chunk_ids = list(range(chunks))
        if args.mod is not None:
            chunk_ids = [x for x in chunk_ids if x % 10 == args.mod]
        print(f"preparing to push {len(chunk_ids)} chunks: {chunk_ids}")

        def pusher(i):
            start_index = i * bulk_size
            end_index = min(bulk_size * (i + 1), row_count)
            rows = end_index - start_index
            tp = ytc.TablePath(table, start_index=start_index, end_index=end_index)
            t1 = datetime.datetime.now()
            text = insert_query.format(values=gen_query(ytc, tp, schema, config))
            t2 = datetime.datetime.now()
            answer = get_clickhouse_data(text, config, auth)
            t3 = datetime.datetime.now()
            if answer.status_code != 200:
                raise Exception(
                    "http error {}: {}".format(answer.status_code, answer.text)
                )
            rows_per_sec_read = rows / (t2 - t1).total_seconds()
            rows_per_sec_push = rows / (t3 - t2).total_seconds()
            print(f"pushed chunk {i} (rows {start_index} - {end_index}), read: {rows_per_sec_read:.2f} rows/sec, push: {rows_per_sec_push:.2f} rows/sec")
            return i

        for i in chunk_ids:
            if i in already_pushed[date]:
                print("{} is already pushed, skipping".format(i))
                continue
            chunk_id = pusher(i)
            already_pushed[date].add(chunk_id)
            if args.snapshot:
                with open(args.snapshot, "w") as f:
                    f.write(json.dumps(
                        {k: sorted(v) for k, v in already_pushed.items()}
                    ))


if __name__ == "__main__":
    main()
