# -*- coding: utf-8 -*-
import os
import time
import logging
import requests

from OpenSSL import crypto
import ydb
from concurrent.futures import TimeoutError


logger = logging.getLogger('crls')

Crls = {
    'yandex_ca': [
        (crypto.FILETYPE_ASN1, 'http://crls.yandex.net/certum/l4.crl'),
        (crypto.FILETYPE_PEM, 'http://crls.yandex.net/combinedcrl'),
        (crypto.FILETYPE_ASN1, 'http://crls.yandex.net/certum/ycasha2.crl')
    ],
    'yandex_internal_root_ca': [
        (crypto.FILETYPE_ASN1, 'http://crls.yandex.ru/YandexInternalRootCA/YandexInternalRootCA.crl')
    ],
    'yandex_internal_ca': [
        (crypto.FILETYPE_ASN1, 'http://crls.yandex.ru/YandexInternalCA/YandexInternalCA.crl'),
    ],
    'yandex_rcca': [
        (crypto.FILETYPE_ASN1, 'http://crls.yandex.ru/YandexRCCA/YandexRCCA.crl')
    ],
    'yandex_clca': [
        (crypto.FILETYPE_ASN1, 'http://crls.yandex.ru/YandexCLCA/YandexCLCA.crl')
    ]
}

UploadQuery = '''
PRAGMA TablePathPrefix("{}");

DECLARE $crls AS "List<Struct<
    serial: String,
    ca: String,
    updated_at: Uint64>>";

REPLACE INTO crls
SELECT
    serial, ca, updated_at
FROM AS_TABLE($crls);
'''


DeleteQuery = '''
PRAGMA TablePathPrefix("{}");

DECLARE $updatedAt AS Uint64;

DELETE FROM crls WHERE updated_at < $updatedAt;'''


class Crl(object):
    __slots__ = ('serial', 'ca', 'updated_at')

    def __init__(self, serial, ca):
        self.serial = serial.encode('latin-1')
        self.ca = ca.encode('latin-1')
        self.updated_at = int(time.time())


def get_reworked(ca):
    if ca not in Crls or not Crls[ca]:
        logger.error('unknown ca: %s' % ca)
        return

    logger.info('update "%s" ca' % ca)
    session = requests.Session()
    for filetype, url in Crls[ca]:
        logger.info('fetch %s' % url)
        resp = session.get(url)
        crl = crypto.load_crl(filetype, resp.content)
        for rvk in crl.get_revoked():
            yield Crl(rvk.get_serial(), ca)


def create_tables(session, path):
    session.create_table(
        os.path.join(path, 'crls'),
        ydb.TableDescription()
        .with_column(ydb.Column('serial', ydb.OptionalType(ydb.DataType.String)))
        .with_column(ydb.Column('ca', ydb.OptionalType(ydb.DataType.String)))
        .with_column(ydb.Column('updated_at', ydb.OptionalType(ydb.DataType.Uint64)))
        .with_primary_keys('serial', 'ca')
    )


def upload(session, path, crls):
    query = UploadQuery.format(path)
    prepared_query = session.prepare(query)
    session.transaction(ydb.SerializableReadWrite()).execute(
        prepared_query, {
            '$crls': crls,
        },
        commit_tx=True
    )


def delete_old(session, path):
    query = DeleteQuery.format(path)
    prepared_query = session.prepare(query)
    session.transaction(ydb.SerializableReadWrite()).execute(
        prepared_query, {
            '$updatedAt': int(time.time()) - 3600,
        },
        commit_tx=True
    )


def run(endpoint, database, path, auth_token):
    connection_params = ydb.ConnectionParams(endpoint, database=database, auth_token=auth_token)
    try:
        driver = ydb.Driver(connection_params)
        driver.wait(timeout=5)
    except TimeoutError:
        raise RuntimeError('Connect failed to YDB')
    session = driver.table_client.session().create()

    create_tables(session, path)
    logger.info('start update')
    batch = []
    updated = 0
    for ca in Crls:
        for crl in get_reworked(ca):
            updated += 1
            batch.append(crl)
            if len(batch) > 100:
                upload(session, database, batch)
                del batch[:]

    if batch:
        upload(session, database, batch)

    logger.info('updated %d crls' % updated)

    logger.info('start cleanup')
    delete_old(session, database)
    logger.info('done')
