from pathlib import Path
import argparse
import logging
import random
import socket
import subprocess
import os
import datetime


logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler('ppchouse-copy.log')
fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)


class FormatInfo:
    def __init__(self, date_col, time_col):
        self.date_col = date_col
        self.time_col = time_col


class TableInfo:
    def __init__(self, name, date, fmt, custom_sql=None, custom_select_table=None, custom_insert_table=None):
        self.name = name
        self.date = date
        self.fmt = fmt
        self.custom_sql = custom_sql
        self.custom_select_table = custom_select_table
        self.custom_insert_table = custom_insert_table

    def select_count(self, date, hour):
        hour_str = str(hour).rjust(2, '0')
        from_date = "'{} {}:00:00'".format(date, hour_str)
        to_date = "'{} {}:59:59'".format(date, hour_str)

        select_count_query = "SELECT count() FROM {} WHERE {} = '{}'".format(self.select_table(), self.fmt.date_col, date)
        select_count_query += " AND {} BETWEEN {} AND {}".format(self.fmt.time_col, from_date, to_date)
        return select_count_query

    def select(self, date, hour):
        hour_str = str(hour).rjust(2, '0')
        from_date = "'{} {}:00:00'".format(date, hour_str)
        to_date = "'{} {}:59:59'".format(date, hour_str)

        if self.custom_sql is None:
            select_query = "SELECT *"
            select_query += " FROM {} WHERE {} = '{}'".format(self.select_table(), self.fmt.date_col, date)
            select_query += " AND {} BETWEEN {} AND {}".format(self.fmt.time_col, from_date, to_date)
            return select_query
        else:
            return self.custom_sql.format(
                date=date,
                from_date=from_date,
                to_date=to_date
            )

    def insert(self):
        insert_query = "INSERT INTO {} FORMAT Native".format(self.insert_table())
        return insert_query

    def select_table(self):
        if self.custom_select_table is not None:
            return self.custom_select_table
        return self.name + "_lr"

    def insert_table(self):
        if self.custom_insert_table is not None:
            return self.custom_insert_table
        return self.name + "_lr"


format_default = FormatInfo('log_date', 'log_time')
format_date_datetime = FormatInfo('date', 'datetime')
format_timestamp = FormatInfo('date', 'timestamp')
format_ess_logic_objects = FormatInfo('binlog_date', 'binlog_time')

date_0_3 = '2019-09-01'
date_1_3 = '2018-09-01'
date_2_3 = '2017-09-01'
date_3_3 = '2016-09-01'

table_list = [
    TableInfo('balance',                 '2015-08-27',       format_default),
    TableInfo(
        'binlog_queries',
        date_1_3,
        format_date_datetime,
        custom_select_table='binlog_queries_mergetree',
        custom_insert_table='binlog_queries_v2_lr',
        custom_sql=r"""
SELECT
    date,
    addHours(datetime, 3) as datetime,
    reqid,
    method,
    service,
    concat('production:', source) as source,
    gtid,
    gtid_src,
    gtid_scn,
    query_seq_num,
    query
FROM (
    SELECT * FROM binlog_queries_mergetree WHERE date = '{date}' AND datetime BETWEEN {from_date} AND {to_date}
)"""),
    TableInfo('binlog_queries_v2',           date_1_3,       format_date_datetime),
    TableInfo('binlog_rows_v2',              date_3_3,       format_date_datetime),
    TableInfo(
        'binlog_rows',
        date_3_3,
        format_date_datetime,
        custom_select_table='binlog_rows_mergetree',
        custom_insert_table='binlog_rows_v2_lr',
        custom_sql=r"""
SELECT
    date,
    addHours(datetime, 3) as datetime,
    reqid,
    method,
    service,
    concat('production:', source) as source,
    db,
    table,
    CAST(toInt8(operation) - 1, 'Enum8(\'INSERT\' = 0, \'UPDATE\' = 1, \'DELETE\' = 2)') as operation,
    '' as gtid,
    gtid_src,
    gtid_scn,
    query_seq_num,
    toUInt32(0) as change_seq_num,
    primary_key,
    primary_key_schema,
    row.name,
    row.value,
    arrayMap(x -> CAST(toInt8(x), 'Enum8(\'FALSE\' = 0, \'TRUE\' = 1)'), row.is_null) as \`row.is_null\`
FROM (
    SELECT * FROM binlog_rows_mergetree WHERE date = '{date}' AND datetime BETWEEN {from_date} AND {to_date}
)"""),
    TableInfo('bsexport_data',               date_2_3,       format_default),
    TableInfo(
        'bsexport_data_old',
        date_2_3,
        format_default,
        custom_select_table='bsexport_data_v3_mergetree',
        custom_insert_table='bsexport_data_lr',
        custom_sql=r"""
SELECT *, '' as par_type FROM bsexport_data_v3_mergetree WHERE log_date = '{date}' AND log_time BETWEEN {from_date} AND {to_date}
"""),  # up to 2019-04-21 (exclusive)
    TableInfo('bsexport_prices',             date_2_3,       format_default),
    TableInfo(
        'bsexport_prices_old',
        date_2_3,
        format_default,
        custom_select_table='bsexport_prices_v3_mergetree',
        custom_insert_table='bsexport_prices_lr',
        custom_sql=r"""
SELECT *, '' as par_type FROM bsexport_prices_v3_mergetree WHERE log_date = '{date}' AND log_time BETWEEN {from_date} AND {to_date}
"""),  # up to 2019-05-07 14 (inclusive)
    TableInfo('bsimport_active_orders',      date_1_3,       format_default),
    TableInfo('campaign_balance',            '2005-07-15',   format_default),
    TableInfo('dbshards_ids',                date_3_3,       format_default),
    TableInfo('ess_logic_objects',           date_1_3,       format_ess_logic_objects),
    TableInfo('ess_moderation',              date_3_3,       format_default),
    TableInfo('java_threaddump',             date_1_3,       format_default),
    TableInfo('mails',                       date_1_3,       format_default),
    TableInfo('messages',                    date_1_3,       format_default),
    TableInfo('metrics',                     '2019-12-11',   format_default),
    TableInfo('moderate',                    date_3_3,       format_default),
    TableInfo('mysql_audit',                 date_3_3,       format_default),
    TableInfo('mysql_ptkill',                date_3_3,       format_default),
    TableInfo('nginx_access',                date_2_3,       format_default),
    TableInfo('ppclog_api',                  date_3_3,       format_default),
    TableInfo('ppclog_cmd',                  date_3_3,       format_default),
    TableInfo('ppclog_price',                date_1_3,       format_default),
    TableInfo('trace',                       date_1_3,       format_default),
]

tables = {
    it.name: it for it in table_list
}


class Credentials:
    clickhouse_client = ['clickhouse-client']

    def __init__(self, user=None, password=None, port=None, database=None, secure=False):
        self.user = user
        self.password = password
        self.port = port
        self.database = database
        self.secure = secure

    def with_host(self, host):
        result = Credentials.clickhouse_client.copy()
        result += ['--max_block_size', '1024']
        result += ['--max_insert_block_size', '131072']
        result += ['--insert_quorum', '0']
        result += ['-h', host]
        if self.user is not None:
            result += ['-u', self.user]
        if self.port is not None:
            result += ['--port', self.port]
        if self.password is not None:
            result += ['--password', self.password]
        if self.database is not None:
            result += ['-d', self.database]
        if self.secure:
            result += ['--secure']
        return result


if os.getenv("DEBUG", "1") == "1":
    logger.info("Starting in debug mode")

    read_host = "ppchouse04e.yandex.ru"
    read_credentials = Credentials(user="readonly")

    write_hosts = ["man-ivwlnuip3uudyaqs.db.yandex.net"]  # devtest
    write_credentials = Credentials(user="direct_test", password="", port="9440", database="directlogs", secure=True)
else:
    logger.info("Starting in prod mode")

    read_host = "localhost"
    read_credentials = Credentials(user="readonly")

    # write_hosts = json.loads(urlopen('http://ppchouse-cloud.direct.yandex.net:8085/ready-write-hosts').read().decode('utf-8'))

    write_hosts = [  # новые шарды
        "man-0xov3wf3itthcp8t.db.yandex.net",
        "man-3q18dzy2i48qtnbn.db.yandex.net",
        "man-ckj9l6wzifymqhxy.db.yandex.net",
        "man-cqthbfipnqk095t8.db.yandex.net",
        "man-eiubqw234dn5ijov.db.yandex.net",
        "man-gxs08tkudps8ik5g.db.yandex.net",
        "man-imcwkqus9itz0uys.db.yandex.net",
        "man-kjm21785uche67gr.db.yandex.net",
        "man-t6b03tvcv8qvsr25.db.yandex.net",
        "sas-61sdcnvg3kqmkr2m.db.yandex.net",
        "sas-7wykxlvz6z49324f.db.yandex.net",
        "sas-8fq9xph50lex4y0y.db.yandex.net",
        "sas-aqrbglxze1rpq8oz.db.yandex.net",
        "sas-b6u66j7zfdgrv9tz.db.yandex.net",
        "sas-cu1prb05ld22aukf.db.yandex.net",
        "sas-i1nukp4xhfk9s6gb.db.yandex.net",
        "sas-o93fllsu85wbrw7t.db.yandex.net",
        "sas-ridm8s4esntt6cru.db.yandex.net",
        "sas-ve8qj4wu2sln7273.db.yandex.net",
        "sas-xo9borevmruwmft5.db.yandex.net",
        "vla-5j37e3ccqdg63tr5.db.yandex.net",
        "vla-6777b0pvb3dq4ykw.db.yandex.net",
        "vla-9jgmqxbnlc3i1ret.db.yandex.net",
        "vla-9kfuw1e9e7f11uxe.db.yandex.net",
        "vla-gmggbwzfqipna2am.db.yandex.net",
        "vla-jppkh0o0vs4enmj8.db.yandex.net",
        "vla-qszhwbcnfv1gzs8y.db.yandex.net",
        "vla-s5hklok72e2a70jy.db.yandex.net",
        "vla-sgmnsu8h9q2qyjyf.db.yandex.net",
        "vla-vimjkqp3riodj1r0.db.yandex.net"
    ]

    write_credentials = Credentials(user=os.getenv("CLOUD_USER"), password=os.getenv("CLOUD_PASSWORD"), port="9440", database="directdb", secure=True)


distination_dc_by_ppchouse_host = {
    'f': 'vla',
    'e': 'vla',
    'k': 'man',
    'i': 'sas'
}


todo_dir = Path("todo")
done_dir = Path("done")
running_dir = Path("running")
failed_dir = Path("failed")


for dir in [todo_dir, done_dir, running_dir, failed_dir]:
    if not dir.exists():
        dir.mkdir()


files = []


def get_task():
    global files
    if files is None or len(files) == 0:
        files = list(map(str, todo_dir.iterdir()))

    if len(files) == 0:
        return None

    result = min(files)
    files.remove(result)
    return Path(result)


def get_host(dc):
    hosts = [host for host in write_hosts if host.startswith(dc)]
    return random.choice(hosts) if len(hosts) > 0 else None


def move_task(task, target_dir):
    try:
        target = target_dir / task.name
        task.replace(target)
        return True, target
    except FileNotFoundError:
        return False, None


def check_host_ok(host, timeout=5):
    credentials = write_credentials.with_host(host)
    args = credentials + ['-q', 'select 1']
    logger.info("Running health check for host {} with command: '{}'".format(host, ' '.join(args)))
    try:
        output = subprocess.check_output(args, stderr=subprocess.STDOUT, timeout=timeout)
        if not output.startswith(b'1'):
            logger.warning("Health check for host {} returned {} instead of 1".format(host, output))
            return False
    except TimeoutError:
        logger.exception("Health check for host {} timed out after {} seconds".format(host, timeout))
        return False
    except Exception:
        logger.exception("Health check for host {} failed with exception".format(host))
        return False

    return True


def run_copy(table, date, hour, host):
    count_command = read_credentials.with_host(read_host) + ['-q', '"{}"'.format(table.select_count(date, hour))]

    count_output = None
    try:
        logger.info("Selecting count from {} with command {}".format(table.select_table(), ' '.join(count_command)))
        count_command = ['/bin/bash', '-c', ' '.join(count_command)]

        count_output = subprocess.check_output(count_command, stderr=subprocess.STDOUT, universal_newlines=True)
    except subprocess.CalledProcessError as e:
        logger.exception("Count process exited with non-zero exit code, stdout+stderr: \"{}\"".format(e.output))
        return False
    except Exception:
        logger.exception("Unknown error while selecting count")
        return False

    try:
        count = int(count_output)
    except Exception:
        logger.exception("Could not parse count")
        return False

    if count == 0:
        logger.info("Count is zero")
        return True

    logger.info("Count is {}".format(count))

    select_query = table.select(date, hour)
    insert_query = table.insert()

    read_cmd = read_credentials.with_host(read_host) + ['-q', '"{}"'.format(select_query), '--format', 'Native']
    write_cmd = write_credentials.with_host(host) + ['-q', '"{}"'.format(insert_query)]

    try:
        command = ' '.join(read_cmd + [' | '] + write_cmd)
        command = ['/bin/bash', '-c', 'set -o pipefail; ' + command]

        logger.info("Starting copy with command \"{}\"".format(' '.join(command)))

        subprocess.check_output(command, stderr=subprocess.STDOUT)

        logger.info("Success")
    except subprocess.CalledProcessError as e:
        logger.exception("Process exited with non-zero exit code, stdout+stderr: \"{}\"".format(e.output))
        return False
    except Exception:
        logger.exception("Unknown error while copying data")
        return False

    return True


def run_task(task, args):
    try:
        success, task = move_task(task, running_dir)
        if not success:
            logger.warn("Task {} could not be moved to '{}'".format(task, running_dir))
            return

        table, date, hour = task.name.split('.')

        if table not in tables:
            logger.warning("Invalid table {}".format(table))
            success, task = move_task(task, failed_dir)
            return

        table = tables[table]

        logger.info("Running task: {} {} {}".format(table.name, date, hour))

        write_host = get_host(args.dc)
        if write_host is None:
            logger.warning("No available hosts for dc {}".format(args.dc))
            success, task = move_task(task, todo_dir)
            return

        if not check_host_ok(write_host):
            logger.warning("Host {} is not ok".format(write_host))
            success, task = move_task(task, todo_dir)
            return

        logger.info("Health check ok, starting copy:")

        if run_copy(table, date, hour, write_host):
            move_task(task, done_dir)
            return

        logger.warn("Task {} with host {} failed".format(task, write_host))

        move_task(task, failed_dir)
    except KeyboardInterrupt:
        logger.info("Run was interrupted")
        move_task(task, failed_dir)
        raise


def run(args):
    logger.info("Starting with args: {}".format(args))

    if args.task is not None:
        task = Path(args.task)
        if not task.exists():
            logger.info("No such task: {}".format(task))
            return

        run_task(task, args)
        return

    while True:
        task = get_task()

        if task is not None:
            run_task(task, args)
            continue

        logger.info("No tasks left")
        break


def run_gen(args):
    up_to_date = '2020-01-01'

    def daterange(start_date, end_date):
        for n in range(int((end_date - start_date).days)):
            yield start_date + datetime.timedelta(n)

    dest = Path('todo')
    if not dest.exists():
        dest.mkdir()

    result = []

    for table_name, table in tables.items():
        if args.table is not None and table.name != args.table:
            continue

        l_date = datetime.date(*list(map(int, table.date.split('-'))))
        r_date = datetime.date(*list(map(int, up_to_date.split('-'))))

        for date in daterange(l_date, r_date):
            for hour in range(24):
                result.append((table.name, date, hour))

    print("Generating {} tasks".format(len(result)))

    for table, date, hour in result:
        hour_str = str(hour).rjust(2, '0')
        task = dest / '{}.{}.{}'.format(table, date, hour_str)
        if task.exists():
            print("Task {} already exists!".format(task))
        else:
            task.touch()
    print("Done.")


def main():
    hostname = socket.gethostname()

    dc = None
    if 'ppchouse' in hostname:
        dc = distination_dc_by_ppchouse_host.get(hostname[-1:], None)

    parser = argparse.ArgumentParser()
    parser.add_argument("--task")

    dc_required = dc is None
    parser.add_argument("--dc", default=dc, required=dc_required)

    parser.add_argument("--gen", action='store_true')
    parser.add_argument("--table")

    args = parser.parse_args()

    if args.gen:
        run_gen(args)
    else:
        run(args)

if __name__ == "__main__":
    main()
