#!/usr/bin/env python

import sys
import traceback
import time
import ydb
from concurrent.futures import TimeoutError


def build_records(items, ctype, version):
    for item in items:
        yield {
            'type': ctype,
            'version': version,
            'id': item['id'],
            'data': item['data'],
            'metadata': item['metadata']
        }


def split_to_chunks(iterable, length):
    chunk = []
    for item in iterable:
        chunk.append(item)
        if len(chunk) >= length:
            yield chunk
            chunk = []
    if chunk:
        yield chunk


class Retrier(object):
    def __init__(self, config, max_retries=10):
        self.config = config
        self.max_retries = max_retries

    def connect(self):
        connection_params = ydb.DriverConfig(self.config['endpoint'], self.config['database'], auth_token=self.config['auth_token'])
        self.driver = ydb.Driver(connection_params)
        self.driver.wait(timeout=5)
        self.session = self.driver.table_client.session().create()

    def disconnect(self):
        if self.driver is not None:
            self.driver.stop()
        self.driver = None

    def run(self, action):
        retries = 0
        while True:
            try:
                if self.driver is None:
                    self.connect()
                action(self.session)
                return
            except (TimeoutError, ydb.Aborted, ydb.Unavailable, ydb.Overloaded, ydb.Timeout, ydb.ConnectionError):
                print >>sys.stderr, traceback.format_exc()
                self.disconnect()
                retries += 1
                if retries > self.max_retries:
                    raise

    def __enter__(self):
        self.connect()
        return self

    def __exit__(self, type, value, tb):
        self.disconnect()


def upload_data(config, items, ctype, version, items_per_transaction=10):
    with Retrier(config) as retrier:
        query = """
        declare $type as String;
        declare $id as String;
        declare $data as String;
        declare $metadata as String;
        declare $version as String;
        upsert into [{table}] (type, id, data, metadata, version) values ($type, $id, $data, $metadata, $version)
        """.format(table=config['table_items'])
        items_commited = 0
        for chunk in split_to_chunks(build_records(items, ctype, version), items_per_transaction):
            def commit_chunk(session):
                prepared_query = session.prepare(query)
                tx = session.transaction().begin()
                for item in chunk:
                    parameters = {'$'+k: v for k, v in item.iteritems()}
                    tx.execute(prepared_query, parameters)
                tx.commit()
            retrier.run(commit_chunk)
            items_commited += len(chunk)
            print >>sys.stderr, "%d items commited" % items_commited

        def commit_version_data(session):
            query = """
            declare $type as String;
            declare $timestamp as Int32;
            declare $version as String;
            upsert into [{table}] (type, timestamp, version) values ($type, $timestamp, $version)
            """.format(table=config['table_versions'])
            prepared_query = session.prepare(query)
            parameters = {
                "$type": ctype,
                "$timestamp": int(time.time()),
                "$version": version
            }
            tx = session.transaction().begin()
            tx.execute(prepared_query, parameters)
            tx.commit()
        retrier.run(commit_version_data)
