import logging
from copy import deepcopy
from datetime import datetime, timedelta

import clickhouse.client as clickhouse

logger = logging.getLogger(__name__)


class ClickhouseUploader(object):
    def __init__(self, token, hosts):
        self.token = token
        self.hosts = hosts

    def get_connection(self, host):
        return clickhouse.connect(host=host, port=8443, username="statadhoc", password=self.token, ssl=True)

    def write_packages(self, lines, package_cache):
        for host in self.hosts:
            print("Working with {}".format(host))
            conn = self.get_connection(host)

            self._write_data(conn, deepcopy(lines))
            self._write_versions(conn, deepcopy(package_cache))

    def write_dogma(self, dogma_commits, dogma_alive):
        for host in self.hosts:
            print("Working with {}".format(host))
            conn = self.get_connection(host)

            self._write_commits(conn, deepcopy(dogma_commits))
            self._write_alive(conn, deepcopy(dogma_alive))

    def _write_data(self, connection, lines):
        columns = [
            "fielddate",
            "file_version",
            "is_deprecated",
            "last_commit_time",
            "name",
            "project",
            "repo_version",
            "url",
            "vcs_type",
            "version",
            "is_devdependency",
            "lock_version",
            "path",
            "lock_variant"
        ]

        data = [[x[col] for col in columns] for x in lines]
        for row in data:
            for field in (2, 10):
                row[field] = int(row[field])  # boolean columns are ints in CH

        logger.info("Inserting %s rows into version.dependencies", len(data))
        cur = connection.cursor()
        cur.execute(
            "INSERT INTO version.dependencies ({})".format(", ".join(columns)),
            data=data,
        )

    def _write_versions(self, connection, versions):
        columns = ["package_name", "version", "release_date"]

        for lib in versions:
            for vers in versions[lib]:
                versions[lib][vers] = versions[lib][vers][:19].replace("T", " ")

        cur = connection.cursor()
        cur.execute("SELECT package_name, version FROM version.releases")
        table_data = cur.fetchall()

        for row in table_data:
            package, version = row
            try:
                del versions[package][version]
            except KeyError:
                pass

        data = []
        for lib in versions:
            for vers in versions[lib]:
                data.append([lib, vers, versions[lib][vers]])

        logger.info("Inserting %s rows into version.releases", len(data))
        cur.execute(
            "INSERT INTO version.releases ({})".format(", ".join(columns)),
            data=data,
        )

    def _write_commits(self, connection, commits):
        columns = [
            "commit_time",
            "committer",
            "repo_vcs_name",
            "repo_vcs_type",
            "branch_name",
        ]

        cur = connection.cursor()
        cur.execute(
            "SELECT {} FROM version.commits WHERE commit_time>toDateTime('{}')".format(
                ",".join(columns),
                (datetime.now() - timedelta(days=5)).strftime("%Y-%m-%dT%H:%M:%S"),
            )
        )
        ch_table_data = {tuple(row) for row in cur.fetchall()}

        data = list(commits - ch_table_data)
        data.sort(key=lambda x: x[0])

        logger.info("Inserting %s rows into version.commits", len(data))
        BATCH_SIZE = 10000
        for i in range(0, len(data), BATCH_SIZE):
            logger.info("Starting from record number %s, %s percent", i, i * 100 / len(data))
            cur.execute(
                "INSERT INTO version.commits ({})".format(", ".join(columns)),
                data=data[i : i + BATCH_SIZE],
            )

    def _write_alive(self, connection, alive):
        columns = [
            "fielddate",
            "repo_vcs_name",
            "repo_vcs_type",
        ]

        fielddate = None
        for row in alive:
            fielddate = row[0]  # assuming alive data has only one date
            break

        cur = connection.cursor()
        cur.execute(
            "SELECT {} FROM version.alive WHERE fielddate=toDate('{}')".format(
                ",".join(columns),
                fielddate,
            )
        )
        ch_table_data = {tuple(row) for row in cur.fetchall()}

        data = list(alive - ch_table_data)

        logger.info("Inserting %s rows into version.alive", len(data))
        BATCH_SIZE = 10000
        for i in range(0, len(data), BATCH_SIZE):
            logger.info("Starting from record number %s, %s percent", i, i * 100 / len(data))
            cur.execute(
                "INSERT INTO version.alive ({})".format(", ".join(columns)),
                data=data[i : i + BATCH_SIZE],
            )
