from infra.rtc.docker_registry.docker_torrents.torrent_config import TorrentsConfig
from infra.rtc.docker_registry.docker_torrents.exceptions import DublicateMapping
from psycopg2.errors import UniqueViolation
from psycopg2 import pool as pgpool
import time


class TorrentDatabase:

    def __init__(self, config: TorrentsConfig):
        self.config = config
        connect_timeout = config.db.get('connect_timeout', 1)
        self.connect_timeout = int(connect_timeout) if connect_timeout > 1 else 1
        self.pg_pool = pgpool.ThreadedConnectionPool(
            1,
            int(config.db['max_pool_size']),
            database=str(config.db['database']),
            user=str(config.db['username']),
            password=str(config.db['password']),
            host=str(config.db['host']),
            port=int(config.db['port']),
            target_session_attrs='read-write',
            connect_timeout=self.connect_timeout
        )
        self.mapping_table = str(config.db['table_name'])
        self.old_mapping_table = str(config.db['old_table_name'])
        self.mds_table = str(config.db['mds_table'])
        self.mfs_table = str(config.db['mfs_table'])
        self.blob_path_template = str(config.db['blob_path_template'])
        self.blob_path_digest_index = int(config.db['blob_path_digest_index'])
        self.logger = config.logger
        self.init_db()

    def init_db(self) -> None:
        conn = self.pg_pool.getconn()
        try:
            with conn.cursor() as cursor:
                cursor.execute("""CREATE TABLE IF NOT EXISTS {table} (
                digest varchar(128) PRIMARY KEY,
                rbtorrent_id varchar(64),
                added BIGINT,
                brew_lock BIGINT,
                brewed BIGINT);""".format(table=self.mapping_table))
                cursor.execute(
                    """CREATE UNIQUE INDEX IF NOT EXISTS {table}_rbtorrent_idx ON {table}
                    USING btree (rbtorrent_id);""".format(
                        table=self.mapping_table
                    ))
                cursor.execute(
                    """CREATE INDEX IF NOT EXISTS {table}_added_sort ON {table}
                    USING btree (added);""".format(
                        table=self.mapping_table
                    ))
                cursor.execute(
                    """CREATE INDEX IF NOT EXISTS {table}_brew_lock_compare ON {table}
                    USING btree (brew_lock);""".format(
                        table=self.mapping_table
                    ))
                cursor.execute(
                    """CREATE INDEX IF NOT EXISTS {table}_brew_lock_null ON {table}(brew_lock)
                    WHERE brew_lock IS NOT NULL;""".format(
                        table=self.mapping_table
                    ))
                cursor.execute(
                    """CREATE INDEX IF NOT EXISTS {table}_brewed_null ON {table}(brewed)
                    WHERE brewed IS NOT NULL;""".format(
                        table=self.mapping_table
                    ))

            conn.commit()
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def search_digest_mapping(self, layers_digest: list) -> dict:
        self.logger.debug('Searching digests {} in torrent database'.format(layers_digest))
        conn = self.pg_pool.getconn()
        try:
            with conn.cursor() as cursor:
                cursor.execute(
                    """SELECT digest, rbtorrent_id FROM {table} WHERE digest = ANY(%s);""".format(
                        table=self.mapping_table
                    ),
                    (layers_digest, )
                )
                mapping = dict()
                for row in cursor.fetchall():
                    if row[0] is not None:
                        mapping[str(row[0])] = row[1]
                return mapping
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def search_old_mds_key_mapping(self, mds_keys: list) -> dict:
        self.logger.debug('Searching mds_keys: {}'.format(mds_keys))
        conn = self.pg_pool.getconn()
        try:
            with conn.cursor() as cursor:
                cursor.execute(
                    """SELECT mds_key, rbtorent_id FROM {table} WHERE mds_key = ANY(%s);""".format(
                        table=self.old_mapping_table
                    ),
                    (mds_keys,)
                )
                mapping = dict()
                for row in cursor.fetchall():
                    if row[0] is not None and row[1] is not None:
                        mapping[str(row[0])] = str(row[1])
                return mapping
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def get_queued(self, offset: int = 0, limit: int = 1000) -> (list, int):
        conn = self.pg_pool.getconn()
        try:
            with conn.cursor() as cursor:
                cursor.execute(
                    """SELECT digest FROM {table}
                    WHERE brew_lock IS NULL AND brewed IS NULL ORDER BY added OFFSET %s LIMIT %s;""".format(
                        table=self.mapping_table
                    ), (offset, limit))
                result = list()
                for row in cursor.fetchall():
                    result.append([str(row[0])])
                cursor.execute(
                    """SELECT COUNT(*) FROM {table}
                    WHERE brew_lock IS NULL AND brewed IS NULL;""".format(
                        table=self.mapping_table
                    ))
                return result, int(cursor.fetchone()[0])
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def get_brewing(self, offset: int = 0, limit: int = 1000) -> (list, int):
        conn = self.pg_pool.getconn()
        try:
            with conn.cursor() as cursor:
                cursor.execute(
                    """SELECT digest, brew_lock FROM {table}
                    WHERE brew_lock IS NOT NULL AND brewed IS NULL ORDER BY added OFFSET %s LIMIT %s;""".format(
                        table=self.mapping_table
                    ), (offset, limit))
                result = list()
                for row in cursor.fetchall():
                    result.append([str(row[0]), int(row[1])])
                cursor.execute(
                    """SELECT COUNT(*) FROM {table}
                    WHERE brew_lock IS NOT NULL AND brewed IS NULL;""".format(
                        table=self.mapping_table
                    ))
                return result, int(cursor.fetchone()[0])
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def pop_queue(self, lock_timeout=300, count=10) -> list:
        conn = self.pg_pool.getconn()
        current_time = int(time.time())
        try:
            result = list()
            with conn.cursor() as cursor:
                cursor.execute(
                    """
                    CREATE TEMPORARY TABLE brew_tmp ON COMMIT DROP AS SELECT digest FROM {table}
                    WHERE brewed IS NULL AND (brew_lock is NULL OR brew_lock <= %s) ORDER by added LIMIT %s;
                    UPDATE {table} SET brew_lock = %s WHERE digest in (select digest FROM brew_tmp);
                    SELECT digest FROM brew_tmp;
                    """.format(
                        table=self.mapping_table
                    ),
                    (current_time - lock_timeout, count, current_time))

                for row in cursor.fetchall():
                    result.append(str(row[0]))
                conn.commit()
                return result
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def put_queue_brew(self, digest: str) -> None:
        self.logger.debug('Adding to queue digest: {}'.format(digest))
        conn = self.pg_pool.getconn()
        current_time = int(time.time())
        try:
            with conn.cursor() as cursor:
                cursor.execute(
                    """INSERT INTO {table} (digest, added, rbtorrent_id, brew_lock, brewed)
                    VALUES (%s, %s, %s, %s, %s);""".format(table=self.mapping_table),
                    (digest, current_time, None, None, None)
                )
                conn.commit()
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def new_mapping(self, digest: str, rbtorrent_id: str) -> None:
        self.logger.debug('Put digest: {} -> rbtorrent_id: {} from old data'.format(
            digest, rbtorrent_id
        ))
        conn = self.pg_pool.getconn()
        current_time = int(time.time())
        try:
            with conn.cursor() as cursor:
                cursor.execute(
                    """INSERT INTO {table} (digest, rbtorrent_id, added, brew_lock, brewed)
                    VALUES (%s, %s, %s, %s, %s)""".format(
                        table=self.mapping_table
                    ),
                    (digest, rbtorrent_id, current_time, current_time, current_time)
                )
                conn.commit()
        except UniqueViolation:
            raise DublicateMapping(digest)
        except Exception as error:
            raise error
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def update_mapping(self, digest: str, rbtorrent_id: str) -> None:
        self.logger.debug('Update digest: {} -> rbtorrent_id: {}'.format(digest, rbtorrent_id))
        conn = self.pg_pool.getconn()
        current_time = int(time.time())
        try:
            with conn.cursor() as cursor:
                cursor.execute(
                    """UPDATE {table} SET rbtorrent_id = %s, brewed = %s WHERE digest = %s;""".format(
                        table=self.mapping_table
                    ),
                    (rbtorrent_id, current_time, digest)
                )
                conn.commit()
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def get_mds_keys(self, digests) -> dict:
        result = dict()
        paths = list()
        for digest in digests:
            paths.append(self.blob_path_template.format(short_digest=digest[0:2], full_digest=digest))
        conn = self.pg_pool.getconn()
        try:
            with conn.cursor() as cursor:
                cursor.execute(
                    """SELECT {mfs_table}.path, {mds_table}.mdsfileinfo::json->'key', {mds_table}.deleted
                    FROM {mfs_table} LEFT JOIN {mds_table} ON {mfs_table}.key = mds.key
                    WHERE {mfs_table}.path = ANY(%s)""".format(mfs_table=self.mfs_table, mds_table=self.mds_table),
                    (paths,))
                for row in cursor.fetchall():
                    digest = row[0].split('/')[self.blob_path_digest_index]
                    if row[2] == 't':
                        result[digest] = None
                    else:
                        result[digest] = row[1]
                return result
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def drop_mapping(self, digest: str) -> None:
        self.logger.debug('Removing digest: {}'.format(digest))
        conn = self.pg_pool.getconn()
        try:
            with conn.cursor() as cursor:
                cursor.execute(
                    """DELETE FROM {table} WHERE digest = %s;""".format(
                        table=self.mapping_table
                    ),
                    (digest, )
                )
                conn.commit()
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def repo_has_role(self, repo: str, role: str) -> bool:
        self.logger.debug('Checking {} role for {} on all users'.format(role, repo))
        repo_path = repo.split('/')
        prefix = ''
        if len(repo_path) > 1:
            prefix = '{}/'.format(repo_path[0])
        conn = self.pg_pool.getconn()
        try:
            with conn.cursor() as cursor:
                cursor.execute(
                    """SELECT roles.id FROM roles
                    LEFT JOIN repos ON repos.id = roles.repoid
                    WHERE roles.role = %s AND (repos.name = %s OR repos.name = %s);""",
                    (role, repo, prefix)
                )
                allowed = len(cursor.fetchall()) > 0
                return allowed
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)

    def check_role(self, repo: str, user: str, role: str) -> bool:
        self.logger.debug('Checking {} role for {} on user {}'.format(role, repo, user))
        repo_path = repo.split('/')
        prefix = ''
        if len(repo_path) > 1:
            prefix = '{}/'.format(repo_path[0])
        conn = self.pg_pool.getconn()
        try:
            with conn.cursor() as cursor:
                cursor.execute(
                    """SELECT roles.id FROM roles
                    LEFT JOIN users ON users.id = roles.userid
                    LEFT JOIN repos ON repos.id = roles.repoid
                    WHERE roles.role = %s AND users.name = %s AND (repos.name = %s OR repos.name = %s);""",
                    (role, user, repo, prefix)
                )
                allowed = len(cursor.fetchall()) > 0
                return allowed
        finally:
            if conn is not None:
                self.pg_pool.putconn(conn)
