#!/skynet/python/bin/python

import os
import csv
import time
import random
import socket
import logging
import argparse
import threading
from itertools import chain
from cStringIO import StringIO

import pkg_resources
pkg_resources.require('requests')
pkg_resources.require('skynet-heartbeat-server-service')

import requests  # noqa

from ya.skynet.services.heartbeatserver.bulldozer import helper  # noqa


# queries to initialize DB:
#   on each host:
#       CREATE TABLE IF NOT EXISTS portoshell_connections (
#           uuid FixedString(36),
#           timestamp DateTime,
#           host String,
#           transport String,
#           user String,
#           acc_user String,
#           acc_host String,
#           slot_type String,
#           slot_info String,
#           streaming Nullable(UInt8),
#           timeout Nullable(Float64),
#           interactive_cmd Nullable(UInt8),
#           api_mode Nullable(UInt8),
#           extra_files Array(String),
#           command Nullable(String),
#           auth_type Nullable(String),
#           auth_bits Nullable(UInt32),
#           watch_parent Nullable(UInt8),
#           width Nullable(UInt32),
#           height Nullable(UInt32),
#           forward_agent Nullable(UInt8),
#           reverse_forwardings Nullable(UInt8),
#           direct_forwardings Nullable(UInt8),
#           subsystem Nullable(String)
#       ) ENGINE = ReplicatedMergeTree('/table_portoshell_connections', '{replica}')
#       PARTITION BY toYYYYMM(timestamp) ORDER BY (uuid, timestamp)


def parse_hosts_from_url(database_uri):
    clickhouse_prefix = 'clickhouse://'
    if not database_uri.startswith(clickhouse_prefix):
        raise Exception("Invalid database URI")
    return [host.strip() for host in database_uri[len(clickhouse_prefix):].split(',')]


def make_urls(hosts, db):
    make_url = lambda host: 'https://%(host)s/?database=%(db)s' % {'host': host, 'db': db}
    dc = socket.getfqdn()[:3]
    main_hosts = [make_url(host) for host in hosts if host.startswith(dc)]
    secondary_hosts = [make_url(host) for host in hosts if not host.startswith(dc)]
    return main_hosts, secondary_hosts


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--db', help='db name', required=True)
    parser.add_argument('-c', '--count-threshold', help='count threshold for sending the batch', type=int, default=1000)
    parser.add_argument('-t', '--time-threshold', help='time threshold for sending the batch', type=int, default=60)
    parser.add_argument('-v', '--verbose', help='increase logging level', action='store_true', default=False)
    parser.add_argument('--credentials', help='path to file with colon-separated clickhouse credentials')
    return parser.parse_args()


class ReportCollection(object):
    def __init__(self, hosts, db, count_threshold, time_threshold, user, password):
        self.op_lock = threading.Lock()
        self.queue = 0
        self.reports = None
        self.reinit_queue()
        self.primary_urls, self.secondary_urls = make_urls(hosts, db)
        self.auth = {
            'X-ClickHouse-User': user,
            'X-ClickHouse-Key': password,
        }
        self.count_threshold = count_threshold
        self.time_threshold = time_threshold
        self.last_time = time.time()

    def reinit_queue(self):
        new_reports = StringIO()
        new_reports_writer = csv.writer(new_reports, 'clickhouse')

        reports, self.reports, self.reports_writer = self.reports, new_reports, new_reports_writer
        queue, self.queue = self.queue, 0

        return queue, reports

    def watcher(self):
        while True:
            time.sleep(self.time_threshold)
            with self.op_lock:
                if self.queue and time.time() - self.last_time > self.time_threshold:
                    self.send_report()
                    self.last_time = time.time()

    def add(self, report):
        def clean_str(s):
            return str(s).replace('\t', ' ').replace('\n', ' ').replace('\r', '').replace('"', "'")

        uuid = report['uuid']
        row = [
            uuid,
            int(report['timestamp']),
            clean_str(report['host']),
            clean_str(report['transport']),
            clean_str(report['user']),
            clean_str(report['acc_user']),
            clean_str(report['acc_host']),
            clean_str(report['slot_type']),
            clean_str(report['slot_info']),
        ]

        def get(val, typ, deflt=None):
            res = report.get(val)
            if res is None:
                return r'\N' if deflt is None else clean_str(deflt)
            elif typ is list:
                return repr([clean_str(s) for s in res])
            elif typ is bool:
                return int(bool(res))
            else:
                return typ(res)

        row.extend([
            get('streaming', bool),
            get('timeout', float),
            get('interactive_cmd', bool),
            get('api_mode', bool),
            get('extra_files', list, []),
            get('command', clean_str),
            get('auth_type', str),
            get('auth_bits', int),
            get('watch_parent', bool),
            get('width', int),
            get('height', int),
            get('forward_agent', bool),
            get('reverse_forwardings', bool),
            get('direct_forwardings', bool),
            get('subsystem', clean_str),
        ])
        self.reports_writer.writerow(row)

        with self.op_lock:
            self.queue += 1
            if self.queue > self.count_threshold or time.time() - self.last_time > self.time_threshold:
                self.send_report()
                self.last_time = time.time()

    def make_query(self):
        queue, reports = self.reinit_queue()

        if not reports:
            logging.debug("no reports to send")
            return None, None

        header = "INSERT INTO portoshell_connections (uuid, timestamp, host, transport, user, acc_user, acc_host, slot_type, slot_info, " + \
                 "streaming, timeout, interactive_cmd, api_mode, extra_files, command, auth_type, auth_bits, watch_parent, " + \
                 "width, height, forward_agent, reverse_forwardings, direct_forwardings, subsystem) FORMAT TabSeparated"
        reports = reports.getvalue()

        return header, reports

    def send_report(self):
        query, data = self.make_query()
        if not query:
            return

        logging.info("sending connections: %d bytes", len(data))
        urls = chain(
            random.sample(self.primary_urls, len(self.primary_urls)),
            random.sample(self.secondary_urls, len(self.secondary_urls)),
        )

        for url in urls:
            try:
                req = requests.post(
                    url=url,
                    headers=self.auth,
                    params={"query": query},
                    data=data,
                    verify=False,
                )
                req.raise_for_status()
            except Exception:
                logging.exception("failed to send report to %r", url)
            else:
                logging.debug("sent report to %r", url)
                break


def main():
    args = parse_args()
    hosts = parse_hosts_from_url(os.environ['DATABASE_URI'])

    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)

    if args.credentials:
        user, password = open(args.credentials).read().strip().split(':', 1)
    else:
        user = os.environ['CLICKHOUSE_USER']
        password = os.environ['CLICKHOUSE_PASSWORD']

    csv.register_dialect('clickhouse', delimiter='\t', lineterminator='\n', doublequote=True, quoting=csv.QUOTE_NONE)

    com = helper.Communicator().ready()
    collection = ReportCollection(hosts, args.db, args.count_threshold, args.time_threshold, user, password)

    t = threading.Thread(target=collection.watcher)
    t.daemon = True
    t.start()

    for host, _, data in com.read():
        try:
            logging.debug("got report from %s: %s", host, data['report'])
            collection.add(data['report'])
            com.ready()
        except (KeyError, TypeError, ValueError) as ex:
            com.discard(repr(ex))
