# -*- coding: utf-8 -*-
import os
import time
import logging
import json

import ydb
from concurrent.futures import TimeoutError
from http import requests_retry_session


logger = logging.getLogger('tvm_tags')

TVM_RESOURCE_ID = 47
ENV_TAG_CATEGORY = 1
BATCH_SIZE = 100

UploadQuery = '''
PRAGMA TablePathPrefix("{}");

DECLARE $tvmApps AS "List<Struct<
    tvm_id: Uint64,
    service_id: Uint64,
    resource_id: Uint64,
    tags: Json?,
    updated_at: Uint64>>";

REPLACE INTO tvm_tags
SELECT
    tvm_id, service_id, resource_id, tags, updated_at
FROM AS_TABLE($tvmApps);
'''


DeleteQuery = '''
PRAGMA TablePathPrefix("{}");

DECLARE $updatedAt AS Uint64;

DELETE FROM tvm_tags WHERE updated_at < $updatedAt;'''


class TvmApp(object):
    __slots__ = ('tvm_id', 'service_id', 'resource_id', 'tags', 'updated_at')

    def __init__(self, tvm_id, service_id, resource_id, tags):
        self.tvm_id = tvm_id
        self.service_id = service_id
        self.resource_id = resource_id
        self.tags = json.dumps(tags) if tags else None
        self.updated_at = int(time.time())


class TvmIterator(object):
    def __init__(self, oauth_token):
        self.session = requests_retry_session()
        self.session.headers.update({'Authorization': 'OAuth ' + oauth_token})
        self.url = 'https://abc-back.yandex-team.ru/api/v4/resources/consumers/?state=granted&fields=id,service,resource,tags&type=%d' % TVM_RESOURCE_ID

    def next(self):
        if not self.url:
            raise StopIteration()

        resp = self.session.get(self.url).json()
        self.url = resp.get('next', None)
        return resp.get('results', [])

    # Python 3 compatibility
    def __next__(self):
        return self.next()

    def __iter__(self):
        return self


def create_tables(session, path):
    session.create_table(
        os.path.join(path, 'tvm_tags'),
        ydb.TableDescription()
        .with_column(ydb.Column('tvm_id', ydb.OptionalType(ydb.DataType.Uint64)))
        .with_column(ydb.Column('service_id', ydb.OptionalType(ydb.DataType.Uint64)))
        .with_column(ydb.Column('resource_id', ydb.OptionalType(ydb.DataType.Uint64)))
        .with_column(ydb.Column('tags', ydb.OptionalType(ydb.DataType.Json)))
        .with_column(ydb.Column('updated_at', ydb.OptionalType(ydb.DataType.Uint64)))
        .with_primary_keys('tvm_id')
    )


def upload(session, path, tvmApps):
    query = UploadQuery.format(path)
    prepared_query = session.prepare(query)
    session.transaction(ydb.SerializableReadWrite()).execute(
        prepared_query, {
            '$tvmApps': tvmApps,
        },
        commit_tx=True
    )


def delete_old(session, path):
    query = DeleteQuery.format(path)
    prepared_query = session.prepare(query)
    session.transaction(ydb.SerializableReadWrite()).execute(
        prepared_query, {
            '$updatedAt': int(time.time()) - 3600,
        },
        commit_tx=True
    )


def run(endpoint, database, path, auth_token):
    connection_params = ydb.ConnectionParams(endpoint, database=database, auth_token=auth_token)
    try:
        driver = ydb.Driver(connection_params)
        driver.wait(timeout=5)
    except TimeoutError:
        raise RuntimeError('Connect failed to YDB')
    session = driver.table_client.session().create()

    create_tables(session, path)
    logger.info('start update')
    batch = []
    updated = 0
    for tvms in TvmIterator(auth_token):
        for tvm in tvms:
            resource_id = tvm.get('id', None)
            if not resource_id:
                logger.warn('skip tvm resource w/o resource_id: %r' % tvm)
                continue
            resource_id = int(resource_id)

            service = tvm.get('service', {})
            if not service:
                logger.warn('skip tvm resource w/o service: %r' % tvm)
                continue

            service_id = service.get('id')
            if not service_id:
                logger.warn('skip tvm resource w/o service: %r' % tvm)
                continue
            service_id = int(service_id)

            resource = tvm.get('resource', {})
            if not resource:
                logger.warn('skip tvm resource w/o tvm_id: %r' % tvm)
                continue

            tvm_id = resource.get('external_id', None)
            if not tvm_id:
                logger.warn('skip tvm resource w/o tvm_id: %r' % tvm)
                continue
            tvm_id = int(tvm_id)
            tags = []
            for tag in tvm.get('tags', []):
                category = tag.get('category', {})
                if not category:
                    # logger.warn('skip tag w/o category: %r' % tvm)
                    continue

                category_id = category.get('id', None)
                if not category_id:
                    # logger.warn('skip tag w/o category: %r' % tvm)
                    continue

                if int(category_id) != ENV_TAG_CATEGORY:
                    logger.warn('skip tag w/ wrong category: %r' % tvm)
                    continue

                name = tag.get('name', {})
                tag_name = name.get('en', '') if name else ''
                tags.append(dict(
                    id=int(tag.get('id', 0)),
                    name=tag_name.encode('latin-1')
                ))

            tvm_app = TvmApp(
                tvm_id=tvm_id,
                service_id=service_id,
                resource_id=resource_id,
                tags=tags
            )

            batch.append(tvm_app)

            updated += 1
            if len(batch) > BATCH_SIZE:
                upload(session, database, batch)
                del batch[:]

    if batch:
        upload(session, database, batch)

    logger.info('updated %d tmv resources' % updated)

    logger.info('start cleanup')
    delete_old(session, database)
    logger.info('done')
