# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals
"""
In order to run in dev environment:

QLOUD_PROJECT=1 QLOUD_DATACENTER=sas YENV_TYPE={env-type} \
DJANGO_SETTINGS_MODULE="travel.rasp.train_api.docker.local_settings" \
Y_PYTHON_ENTRY_POINT="travel.rasp.train_api.app:script" bin/app/app archive_order_logs

"""

import os
import logging
import uuid

import yt.wrapper
from concurrent.futures import ProcessPoolExecutor
from django.conf import settings
from enum import IntEnum
from pymongo import ReturnDocument

from common.db.mongo import ConnectionProxy
from travel.rasp.train_api.helpers.ydb import ydb_provider
from travel.rasp.train_api.train_purchase.utils.logs import YDB_ORDER_LOGS_TABLE_NAME

log = logging.getLogger(__name__)

YT_LOGS_PATH = '//home/rasp/logs/train-order-logs/order_logs_chunk_{}'
MAX_WORKERS = int(os.getenv('LOGS_MAX_WORKERS', 2))
YDB_MAX_ROWS = 1000
SCHEMA = [
    {'name': 'order_uid', 'type': 'string'},
    {'name': 'created_utc', 'type': 'string'},
    {'name': 'name', 'type': 'string'},
    {'name': 'levelname', 'type': 'string'},
    {'name': 'message', 'type': 'string'},
    {'name': 'created', 'type': 'string'},
    {'name': 'process', 'type': 'int64'},
    {'name': 'context', 'type': 'string'},
    {'name': 'exception', 'type': 'string'},
]

YDB_INITIAL_LOGS_QUERY = '''
DECLARE $order_uid AS "Utf8";
PRAGMA TablePathPrefix("{0}");

SELECT *
FROM {1}
WHERE order_uid = $order_uid
ORDER BY created_utc LIMIT {2};
'''

YDB_PAGINATED_LOGS_QUERY = '''
DECLARE $order_uid AS "Utf8";
DECLARE $created_utc AS "Utf8";
PRAGMA TablePathPrefix("{0}");

SELECT *
FROM {1}
WHERE
    order_uid = $order_uid AND
    created_utc > $created_utc
ORDER BY created_utc LIMIT {2};
'''

YDB_CLEAR_LOGS_QUERY = '''
DECLARE $order_uid AS "Utf8";
PRAGMA TablePathPrefix("{0}");
DELETE FROM {1}
WHERE order_uid = $order_uid;
'''


class YdbContext(object):
    def __init__(self, provider):
        self._provider = provider
        self._session = None
        self.initial_select_query = None
        self.paginated_select_query = None
        self.clear_query = None

    def prepare_queries(self):
        self.initial_select_query = self.session.prepare(
            YDB_INITIAL_LOGS_QUERY.format(settings.YDB_DATABASE, YDB_ORDER_LOGS_TABLE_NAME, YDB_MAX_ROWS))
        self.paginated_select_query = self.session.prepare(
            YDB_PAGINATED_LOGS_QUERY.format(settings.YDB_DATABASE, YDB_ORDER_LOGS_TABLE_NAME, YDB_MAX_ROWS))
        self.clear_query = self.session.prepare(
            YDB_CLEAR_LOGS_QUERY.format(settings.YDB_DATABASE, YDB_ORDER_LOGS_TABLE_NAME))

    @property
    def session(self):
        if self._session is None:
            self._session = self._provider.get_session()
        return self._session


class LogMigrationStatus(IntEnum):
    IN_YDB = 2
    FAILED = 5
    IN_YT = 6


def migrate_order(order, orders_collection, ydb_context, yt_logs_path):
    log.info('Migrating logs for order %s', order['uid'])

    try:
        yt.wrapper.write_table(yt.wrapper.TablePath(yt_logs_path, append=True), [
            {item['name']: getattr(row, item['name']) for item in SCHEMA}
            for row in get_logs_to_archive(order['uid'], ydb_context)
        ])
        orders_collection.update_one(
            {'uid': order['uid']},
            {'$set': {'migrate_logs_status': LogMigrationStatus.IN_YT, 'migrate_process_uid': None}}
        )
        clean_logs(order['uid'], ydb_context)
        log.info('Migrating logs for order %s. Completed.', order['uid'])
        return True
    except Exception:
        log.exception('Migrating logs for order %s. Failed.', order['uid'])
        orders_collection.update_one(
            {'uid': order['uid']},
            {'$set': {'migrate_logs_status': LogMigrationStatus.FAILED, 'migrate_process_uid': None}}
        )
        return False


def get_orders_to_archive(orders_collection, process_uuid):
    # archiving orders up to the point when we switch to Orchestrator
    while True:
        order = orders_collection.find_one_and_update(
            {'migrate_logs_status': {'$ne': LogMigrationStatus.IN_YT}, 'migrate_process_uid': None},
            {'$set': {'migrate_process_uid': process_uuid}},
            sort=[('uid', 1)],
            return_document=ReturnDocument.AFTER
        )
        if order is None:
            break
        log.info('order: %s', order['uid'])
        yield order


def get_logs_to_archive(order_uid, ydb_context):
    last_created_utc = None
    while True:
        if last_created_utc is None:
            query_params = {'$order_uid': order_uid}
            result_sets = ydb_context.session.transaction().execute(
                ydb_context.initial_select_query, query_params, commit_tx=True
            )
        else:
            query_params = {'$order_uid': order_uid, '$created_utc': last_created_utc}
            result_sets = ydb_context.session.transaction().execute(
                ydb_context.paginated_select_query, query_params, commit_tx=True
            )

        counter = 0
        for row in result_sets[0].rows:
            last_created_utc = row.created_utc
            counter += 1
            yield row

        if counter < YDB_MAX_ROWS:
            raise StopIteration


def clean_logs(order_uid, ydb_context):
    ydb_context.session.transaction().execute(
        ydb_context.clear_query, {'$order_uid': order_uid}, commit_tx=True
    )


def migrate_all(n):
    yt.wrapper.config.set_proxy('hahn')
    yt_logs_path = YT_LOGS_PATH.format(n)
    orders_collection = ConnectionProxy('train_purchase').train_order
    process_uuid = uuid.uuid4().hex
    try:
        yt.wrapper.create_table(yt_logs_path, attributes={'schema': SCHEMA})
    except Exception:
        log.info('table already exists')
    ydb_context = YdbContext(ydb_provider)
    ydb_context.prepare_queries()
    log.info('Start logs migration')

    migrated_count = 0
    for order in get_orders_to_archive(orders_collection, process_uuid):
        if migrate_order(order, orders_collection, ydb_context, yt_logs_path):
            migrated_count += 1
    log.info('%s orders migrated', migrated_count)
    log.info('Migration complete')


def migrate_concurrent():
    executor = ProcessPoolExecutor(max_workers=MAX_WORKERS)
    futures = []
    for n in range(0, MAX_WORKERS):
        futures.append(executor.submit(migrate_all, n))
    for f in futures:
        f.result()
    executor.shutdown()


def run():
    migrate_concurrent()
