# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import json
import logging
import time
from datetime import datetime

from django.conf import settings
from django.core.serializers.json import DjangoJSONEncoder
from enum import IntEnum

from common.db.mongo import ConnectionProxy
from common.utils.iterrecipes import chunker
from travel.rasp.train_api.helpers.ydb import ydb_provider
from travel.rasp.train_api.train_purchase.core.models import TrainOrder
from travel.rasp.train_api.train_purchase.utils.logs import YDB_ORDER_LOGS_TABLE_NAME, YdbLogRecord

log = logging.getLogger(__name__)
database = ConnectionProxy('train_purchase')


class MigrateLogsStatus(IntEnum):
    STARTED = 1
    DONE = 2
    FAILED = 3


class LogsMigration(object):
    ydb_session = None
    ydb_database = settings.YDB_DATABASE
    ydb_table_name = YDB_ORDER_LOGS_TABLE_NAME
    logs_created_utc_ydb_query = None
    logs_write_ydb_query = None
    objects_in_step = 500

    def find_orders_to_migrate(self):
        # стали писать логи в ydb 18.12.21
        return TrainOrder.objects.filter(
            migrate_logs_status__exists=False, reserved_to__gte=datetime(2019, 1, 1)
        )

    def find_orders_by_migrate_status(self, status):
        return TrainOrder.objects.filter(migrate_logs_status=status)

    def get_logs_from_mongo(self, order_uid):
        query = {
            'filter': {'context.order_uid': order_uid},
        }
        return database.order_logs.find(**query)

    def init_ydb_session(self):
        self.ydb_session = ydb_provider.get_session()
        self.logs_created_utc_ydb_query = self.ydb_session.prepare("""
DECLARE $order_uid AS "Utf8";
PRAGMA TablePathPrefix("{0}");

SELECT created_utc
FROM {1}
WHERE order_uid = $order_uid
ORDER BY created_utc;
""".format(self.ydb_database, self.ydb_table_name))
        self.logs_write_ydb_query = self.ydb_session.prepare("""
DECLARE $logsData AS "List<Struct<
    order_uid: Utf8,
    created_utc: Utf8,
    name: Utf8?,
    levelname: Utf8?,
    message: Utf8?,
    created: Utf8?,
    process: Int32?,
    context: Json,
    exception: Json?>>";

PRAGMA TablePathPrefix("{0}");
REPLACE INTO {1}
SELECT
    order_uid,
    created_utc,
    name,
    levelname,
    message,
    created,
    [process],
    context,
    exception
FROM AS_TABLE($logsData);
""".format(self.ydb_database, self.ydb_table_name))

    def get_logs_created_utc_from_ydb(self, order_uid):
        query_params = {'$order_uid': order_uid}
        result_sets = self.ydb_session.transaction().execute(
            self.logs_created_utc_ydb_query, query_params, commit_tx=True,
        )
        data = [row.get('created_utc') for row in result_sets[0].rows]
        return data

    def write_logs_ydb(self, rows):
        data = []
        for row in rows:
            context = row.get('context')
            exception = row.get('exception')
            data.append(YdbLogRecord(
                order_uid=context['order_uid'],
                created_utc=row['created_utc'],
                name=row['name'],
                levelname=row['levelname'],
                message=row.get('message'),
                created=row['created'],
                process=row['process'],
                context=json.dumps(context, ensure_ascii=False, cls=DjangoJSONEncoder) if context else '',
                exception=json.dumps(exception, ensure_ascii=False, cls=DjangoJSONEncoder) if exception else None
            ))
        query_params = {'$logsData': data}
        self.ydb_session.transaction().execute(
            self.logs_write_ydb_query, query_params, commit_tx=True,
        )

    def migrate_one(self, order):
        try:
            mongo_logs = list(self.get_logs_from_mongo(order.uid))
            ydb_logs_created_utc = set(self.get_logs_created_utc_from_ydb(order.uid))
            if len(mongo_logs) != len(ydb_logs_created_utc):
                mongo_logs = [row for row in mongo_logs if row['created_utc'] not in ydb_logs_created_utc]
                if mongo_logs:
                    order.update(set__migrate_logs_status=MigrateLogsStatus.STARTED)
                    log.info('Migrating. order={}'.format(order.uid))
                    for chunk in chunker(mongo_logs, 4000):
                        self.write_logs_ydb(chunk)
                        time.sleep(0.05)
            order.update(set__migrate_logs_status=MigrateLogsStatus.DONE)
            return True
        except Exception:
            log.exception('Failed. order={}'.format(order.uid))
            order.update(set__migrate_logs_status=MigrateLogsStatus.FAILED)
            return False

    def migrate_all(self):
        log.info('Start migrate logs')
        log.info(self.get_string_stats())
        migrated_count = self.objects_in_step
        while migrated_count > 0:
            migrated_count = 0
            for order in self.find_orders_to_migrate()[:self.objects_in_step]:
                if self.migrate_one(order):
                    migrated_count += 1
            log.info('Step. migrated={0}'.format(migrated_count))
        log.info('Done migrate logs')
        log.info(self.get_string_stats())

    def get_string_stats(self):
        to_migrate = self.find_orders_to_migrate().count()
        started = self.find_orders_by_migrate_status(MigrateLogsStatus.STARTED).count()
        failed = self.find_orders_by_migrate_status(MigrateLogsStatus.FAILED).count()
        done = self.find_orders_by_migrate_status(MigrateLogsStatus.DONE).count()
        return 'to_migrate={0}, started={1}, failed={2}, done={3}'.format(to_migrate, started, failed, done)


def run():
    migration = LogsMigration()
    migration.init_ydb_session()
    migration.migrate_all()
