# -*- coding: utf-8 -*-
import enum

import mpfs.engine.process

from mpfs.common.util import chunks2
from mpfs.dao.shard_endpoint import ShardEndpoint, ShardType
from mpfs.metastorage.postgres.query_executer import PGQueryExecuter
from mpfs.dao.session import Session


error_log = mpfs.engine.process.get_error_log()
default_log = mpfs.engine.process.get_default_log()


class MigrationMode(enum.Enum):
    VALIDATE = 'validate'
    SKIP_SHARD_SWITCH = 'skip-switch'
    REAL = 'real'


def get_current_user_shard_endpoint(uid):
    is_user_in_postgres = mpfs.engine.process.usrctl().is_user_in_postgres(uid)

    if is_user_in_postgres:
        shard_id = PGQueryExecuter().get_shard_id(uid)
        src_shard_endpoint = ShardEndpoint(ShardType.POSTGRES, shard_id)
    else:
        shard_id = mpfs.engine.process.usrctl().info(uid)['shard']
        src_shard_endpoint = ShardEndpoint(ShardType.MONGO, shard_id)

    return src_shard_endpoint


def reset_caches():
    PGQueryExecuter().reset_cache()
    Session.clear_cache()


class BaseCollectionMigration(object):
    name = None
    dao_cls = None
    shard_client_cls = None

    def __init__(self, copy_chunk_size=20, mode=MigrationMode.REAL):
        self._copy_chunk_size = copy_chunk_size
        self.dao = self.dao_cls()
        self.mode = mode

    def run(self, uid, dst_shard_endpoint):
        reset_caches()

        try:
            src_shard_endpoint = self.get_src_shard_endpoint(uid)
            if src_shard_endpoint == dst_shard_endpoint:
                raise RuntimeError('Same source and destination is not allowed')

            default_log.info('[Migration]: Migration "%s" from shard "%s" to "%s"',
                             self.__class__.__name__, src_shard_endpoint.get_name(), dst_shard_endpoint.get_name())

            self.clear_destination_data(uid, dst_shard_endpoint)
            self._migration(uid, dst_shard_endpoint)

            if self.mode != MigrationMode.VALIDATE:
                count_check_passed = self.check_migrated_count(uid, src_shard_endpoint, dst_shard_endpoint)
                if not count_check_passed:
                    default_log.info('[Migration]: Migration `%s` failed because migrated count doesn\'t match',
                                     self.__class__.__name__)
                    if src_shard_endpoint.is_mongo():
                        self._print_current_status_of_db(uid, dst_shard_endpoint)
                    return False

        except Exception as exc:
            error_log.error('[Migration]: Error occurred (uid "%s"): %s', uid, exc, exc_info=True)
            return False

        return True

    def transfer_data(self, uid, src_shard_client, dst_shard_client):
        cursor = src_shard_client.find_by_uid(uid)
        for chunk in chunks2(cursor, self._copy_chunk_size):
            dst_shard_client.insert_chunk(chunk)

    def _migration(self, uid, dst_shard_endpoint):
        src_shard_endpoint = self.get_src_shard_endpoint(uid)

        src_shard_client = self.shard_client_cls(self.dao, src_shard_endpoint, mode=self.mode)
        dst_shard_client = self.shard_client_cls(self.dao, dst_shard_endpoint, mode=self.mode)

        count = src_shard_client.count_by_uid(uid)
        default_log.info('[Migration]: Migrating %s records', count)

        src_shard_client.fix_before_migration(uid)
        dst_shard_client.prepare_for_insert(uid)
        self.transfer_data(uid, src_shard_client, dst_shard_client)

    @staticmethod
    def get_src_shard_endpoint(uid):
        return get_current_user_shard_endpoint(uid)

    @classmethod
    def get_pg_tables(cls):
        return [cls.dao_cls.dao_item_cls.postgres_table_obj]

    def check_parent_is_alive(self):
        # TODO Либо придумать реализацию, либо оторвать.
        return True

    def clear_destination_data(self, uid, shard_endpoint):
        shard_client = self.shard_client_cls(self.dao, shard_endpoint, mode=self.mode)
        shard_client.remove_by_uid(uid)

    def check_migrated_count(self, uid, src_shard_endpoint, dst_shard_endpoint):
        src_shard_client = self.shard_client_cls(self.dao, src_shard_endpoint, mode=self.mode)
        dst_shard_client = self.shard_client_cls(self.dao, dst_shard_endpoint, mode=self.mode)

        src_count = src_shard_client.count_by_uid(uid)
        dst_count = dst_shard_client.count_by_uid(uid)

        default_log.info('[Migration]: Migration count for `%s` (src shard = %d, dst shard = %d)',
                         self.__class__.__name__, src_count, dst_count)
        return src_count == dst_count

    def _print_current_status_of_db(self, uid, dst_shard_endpoint):
        pass


class BaseCommonCollectionMigration(BaseCollectionMigration):
    def __init__(self, copy_chunk_size=1000, mode=MigrationMode.REAL):
        super(BaseCommonCollectionMigration, self).__init__(copy_chunk_size, mode)

    def run(self, skip=0, limit=0, validate=True, **kwargs):
        reset_caches()

        try:
            default_log.info('[CommonMigration]: Migration "%s" from common mongo to pg"', self.__class__.__name__)

            mongo_before_count = self.dao.get_mongo_impl().count()

            if self.mode == MigrationMode.REAL:
                if validate:
                    assert self.dao.get_pg_impl().count() == 0
                default_log.info('[CommonMigration]: Migrating %s records', mongo_before_count)

                cursor = self.dao.get_mongo_impl().find(skip=skip, limit=limit)
                for chunk in chunks2(cursor, self._copy_chunk_size):
                    self.dao.get_pg_impl().insert(chunk)

            mongo_after_count = self.dao.get_mongo_impl().count()
            assert mongo_before_count == mongo_after_count

            pg_after_count = self.dao.get_pg_impl().count()
            if validate or self.mode == MigrationMode.VALIDATE:
                assert mongo_after_count == pg_after_count

            default_log.info('[CommonMigration]: Migration count for `%s` (src shard = %d, dst shard = %d)',
                             self.__class__.__name__, mongo_before_count, pg_after_count)
        except Exception as exc:
            error_log.error('[CommonMigration]: Error occurred: %s', exc, exc_info=True)
            return False

        return True
