# -*- coding: utf-8 -*-
import os
import time
import logging

import ydb
from concurrent.futures import TimeoutError
from http import requests_retry_session


logger = logging.getLogger('ssh-keys')


UploadQuery = '''
PRAGMA TablePathPrefix("{}");

DECLARE $staffKeys AS "List<Struct<
    fingerprint: String,
    login: String,
    updated_at: Uint64>>";

REPLACE INTO ssh_keys
SELECT
    fingerprint, login, updated_at
FROM AS_TABLE($staffKeys);
'''


DeleteQuery = '''
PRAGMA TablePathPrefix("{}");

DECLARE $updatedAt AS Uint64;

DELETE FROM ssh_keys WHERE updated_at < $updatedAt;'''


class SshKey(object):
    __slots__ = ('fingerprint', 'login', 'updated_at')

    def __init__(self, fingerprint, login):
        self.fingerprint = fingerprint.encode('latin-1')
        self.login = login.encode('latin-1')
        self.updated_at = int(time.time())


class KeysIterator(object):
    def __init__(self, oauth_token):
        self.session = requests_retry_session()
        self.session.headers.update({'Authorization': 'OAuth ' + oauth_token})
        self.url = 'https://staff-api.yandex-team.ru/v3/persons?_fields=keys.fingerprint,login&official.is_dismissed=false'

    def next(self):
        if not self.url:
            raise StopIteration()

        resp = self.session.get(self.url).json()
        self.url = resp.get('links', {}).get('next', None)
        keys = []
        for res in resp.get('result', []):
            login = res.get('login', None)
            if not login:
                continue
            for key in res.get('keys', []):
                fp = key.get('fingerprint', None)
                if not fp:
                    continue
                keys.append(SshKey(fp, login))
        return keys

    # Python 3 compatibility
    def __next__(self):
        return self.next()

    def __iter__(self):
        return self


def create_tables(session, path):
    session.create_table(
        os.path.join(path, 'ssh_keys'),
        ydb.TableDescription()
        .with_column(ydb.Column('fingerprint', ydb.OptionalType(ydb.DataType.String)))
        .with_column(ydb.Column('login', ydb.OptionalType(ydb.DataType.String)))
        .with_column(ydb.Column('updated_at', ydb.OptionalType(ydb.DataType.Uint64)))
        .with_primary_keys('fingerprint', 'login')
    )


def upload(session, path, keys):
    query = UploadQuery.format(path)
    prepared_query = session.prepare(query)
    session.transaction(ydb.SerializableReadWrite()).execute(
        prepared_query, {
            '$staffKeys': keys,
        },
        commit_tx=True
    )


def delete_old(session, path):
    query = DeleteQuery.format(path)
    prepared_query = session.prepare(query)
    session.transaction(ydb.SerializableReadWrite()).execute(
        prepared_query, {
            '$updatedAt': int(time.time()) - 86400,
        },
        commit_tx=True
    )


def run(endpoint, database, path, auth_token):
    connection_params = ydb.ConnectionParams(endpoint, database=database, auth_token=auth_token)
    try:
        driver = ydb.Driver(connection_params)
        driver.wait(timeout=5)
    except TimeoutError:
        raise RuntimeError('Connect failed to YDB')
    session = driver.table_client.session().create()

    create_tables(session, path)
    logger.info('start update')
    batch = []
    updated = 0
    for keys in KeysIterator(auth_token):
        updated += len(keys)
        batch.extend(keys)
        if len(batch) > 100:
            upload(session, database, batch)
            del batch[:]

    if batch:
        upload(session, database, batch)

    logger.info('updated %d keys' % updated)

    logger.info('start cleanup')
    delete_old(session, database)
    logger.info('done')
