import argparse
import logging
import logging.handlers
import os
import time

from .component import Component
from .db import Database
from .deblock import Deblock
from .model import Log, StorageRevisionCollection
from .rpc import RPC
from .utils import human_secs
from .web import Web
from .job import JobManager
from .evoq.manager import EvoqManager
from .hotbackup import HotBackup
from .stats import Stats

from infra.qyp.qdm.src.lib import yp_client
from infra.qyp.qdm.src.lib import vmproxy_client
from infra.qyp.qdm.src.lib import qnotifier_client
from infra.qyp.qdm.src.lib import mds_client

import gevent
import gevent.event
import gevent.pywsgi
import gevent.socket
import gevent.pool
import gevent.queue


MAX_DELETE_ITERATIONS = 20
DEBLOCK_POOL_SIZE = 30


class QDM(Component):
    def __init__(self, args):
        super(QDM, self).__init__(logname='qdm')

        self.args = args
        self.active = gevent.event.Event()
        self.hostname = gevent.socket.gethostname()
        self.log = logging.getLogger('qdm')
        self.mds_client = None

    def initialize(self):
        dbobj = Database(
            host=self.args.db_host, port=self.args.db_port,
            user=self.args.db_user, password=self.args.db_password,
            dbname=self.args.db_name
        )
        db_deblock = Deblock(keepalive=None, logger=dbobj.log.getChild('deblock'))
        self.db = db_deblock.make_proxy(dbobj, put_deblock='deblock')

        # Check greenlets
        assert gevent.spawn(lambda: True).get(), 'Greenlets are not working correctly!'

        self.mds_client = mds_client.MdsClient(self.args.cli_tvm_secret)
        self.rpc = RPC(
            self, self.db, self.active, self.args.rpc_host, self.args.rpc_port, self.args.cli_tvm_secret
        )
        self.rpc.start()  # start RPC asap, but it will wait for .active flag

        self.db.connect(autocommit=True)
        self.db.migrate()

        dbtobj = Database(
            host=self.args.db_host, port=self.args.db_port,
            user=self.args.db_user, password=self.args.db_password,
            dbname=self.args.db_name,
            logname='dbt'
        )
        dbt_deblock = Deblock(keepalive=None, logger=dbtobj.log.getChild('deblock'))
        self.dbt = dbt_deblock.make_proxy(dbtobj, put_deblock='deblock')
        self.dbt.connect(autocommit=False)

        self.web = Web(self, self.db, self.hostname, self.args.vmproxy_oauth_token)
        if self.args.enable_ssl:
            self.web.enable_ssl(self.args.ssl_key, self.args.ssl_cert)
        self.web.listen()

        qnotifier_cli = qnotifier_client.QNotifierClient(self.args.srv_tvm_secret)
        vmproxy_cli = vmproxy_client.VmproxyClient(self.args.vmproxy_oauth_token)
        yp_cli = yp_client.YpClient(self.args.yp_oauth_token)

        self.evoq_mngr = EvoqManager(
            db=self.db,
            dbt=self.dbt,
            qnotifier_client=qnotifier_cli,
            vmproxy_client=vmproxy_cli,
            yp_client=yp_cli,
            max_active_jobs=self.args.evoq_max_active_jobs,
            disable_vm_leaving=self.args.disable_vm_leaving,
            parent=self
        )
        self.evoq_mngr.start()

        self.hotbackup = HotBackup(
            self.db, self.dbt,
            self.args.srv_tvm_secret, self.args.yp_oauth_token, self.args.vmproxy_oauth_token,
            parent=self
        )
        self.hotbackup.start()

        self.stats = Stats(self.db)

        # Each job is (run_if_fail, run_if_ok)
        self.job_mngr = JobManager(self.db, self.dbt, parent=self)
        self.job_mngr.register_job('block_cleaner', self._clean_orphan, 3600, 86400)
        self.job_mngr.register_job('data_cleaner', self._clean_data, 3600, 86400)
        self.job_mngr.register_job('session_cleaner', self._clean_sessions, 3600, 3600)
        self.job_mngr.register_job('evoq_planner', self.evoq_mngr.scheduler, 600, 1200)  # run planner every 10 mins
        self.job_mngr.register_job('evoq_ticker', self.evoq_mngr.ticker, 300, 310)       # run ticker every 5mins
        # self.job_mngr.register_job('stat_worker', self.stats.worker, 300, 600)
        self.job_mngr.register_job('vmsync', self.hotbackup.vmsync, 3600, 3600)                    # every 1hr

        if not self.args.disable_hot_backups:
            self.job_mngr.register_job('hotbackup_scheduler', self.hotbackup.scheduler, 1800, 300)    # every 5 min
        else:
            self.log.info('hotbackups DISABLED')

        self.active.set()

    def run_forever(self):
        try:
            while True:
                gevent.sleep(1)
        except KeyboardInterrupt:
            self.log.info('Got KeyboardInterrupt')
            self.stop().join()

    def stop(self):
        self.log.info('Requested stop signal, stopping everything')
        ret = super(QDM, self).stop()
        self.log.info('Component stopped')
        return ret

    def _log_db(self, who, log, severity, message, args):
        assert severity in ('debug', 'info', 'warning', 'error')

        logger = getattr(log, severity)
        logger(message, *args)

        log_record = Log(self.db, who, time.time(), severity, message, args)
        log_record.save()

    def _get_orphan_blocks(self, count_only=False):
        if count_only:
            return self.db.query_one_col(
                'SELECT COUNT(sb.id) FROM storage_block sb '
                '   LEFT JOIN storage_data sd ON sd.block_id = sb.id '
                'WHERE sd.block_id IS NULL'
            )
        else:
            return self.db.query(
                'SELECT sb.id, sb.hash, sb.mds_key, sb.mds_storage FROM storage_block sb '
                '   LEFT JOIN storage_data sd ON sd.block_id = sb.id '
                'WHERE sd.block_id IS NULL '
                'LIMIT 10000'
            )

    def _clean_orphan(self):
        """
        Try to remove all orphan blocks or at least MAX_DELETE_ITERATIONS * 10k
        """
        orphan_blocks_number = self._get_orphan_blocks(count_only=True)
        if not orphan_blocks_number:
            return

        with self.dbt:
            self.dbt.execute('DELETE FROM storage_block_removed WHERE remove_success = false')

        deblocker_pool = gevent.queue.Queue()
        for _ in range(DEBLOCK_POOL_SIZE):
            deblocker_pool.put(Deblock(keepalive=None))
        iterations_number = MAX_DELETE_ITERATIONS
        while orphan_blocks_number and iterations_number:
            self._clean_orphan_once(deblocker_pool)
            orphan_blocks_number = self._get_orphan_blocks(count_only=True)
            iterations_number -= 1

        for _ in range(DEBLOCK_POOL_SIZE):
            deblocker_pool.get().stop()

    def _clean_orphan_once(self, deblocker_pool):
        """
        Remove only 10k orphan blocks at once
        """
        log = self.log

        with self.dbt:
            orphan_blocks = self._get_orphan_blocks()

            cnt_drop_db = 0
            cnt_drop_mds = 0

            blocks_to_drop_by_mds_key = {}
            blocks_by_id = {}

            for block_id, block_hash, block_mds_key, block_mds_storage in orphan_blocks:
                blocks_to_drop_by_mds_key.setdefault(block_mds_key, []).append(block_id)
                blocks_by_id[block_id] = (block_mds_key, block_mds_storage)

            drop_ids = []

            for block_id, block_mds_key in self.dbt.query(
                'SELECT DISTINCT ON (sb.mds_key) sb.id, sb.mds_key FROM storage_block sb '
                '   JOIN storage_data sd ON sd.block_id = sb.id '
                'WHERE sb.mds_key IN (%S)',
                (list(blocks_to_drop_by_mds_key), )
            ):
                # These blocks exist in other active datas. Just drop block from db, do not remove from mds
                for drop_block_id in blocks_to_drop_by_mds_key[block_mds_key]:
                    self._log_db(
                        'block_cleaner', log, 'debug', 'drop block %r (id %d) from db only',
                        (blocks_by_id[drop_block_id], drop_block_id)
                    )
                    cnt_drop_db += 1
                    drop_ids.append(drop_block_id)
                    blocks_by_id.pop(drop_block_id)

                blocks_to_drop_by_mds_key.pop(block_mds_key)

            if drop_ids:
                self.dbt.execute('DELETE FROM storage_block WHERE id IN (%S)', (drop_ids, ))

            drop_ids[:] = []

        for mds_key, drop_block_ids in blocks_to_drop_by_mds_key.items():
            for drop_block_id in drop_block_ids:
                cnt_drop_mds += 1
                drop_ids.append(drop_block_id)

        self._log_db(
            'block_cleaner', log, 'info',
            'Duplicate blocks removed %d, blocks to remove from mds %d, blocks to remove from db %d',
            (cnt_drop_db, cnt_drop_mds, len(drop_ids))
        )

        values = [(k, v[0], v[1]) for k, v in blocks_by_id.items()]
        with self.dbt:
            self.dbt.execute('INSERT INTO storage_block_removed ('
                             '    id, mds_key, mds_storage'
                             ') VALUES %S', (values, ))

        def drop_block_mds(key):
            deblock_obj = deblocker_pool.get(block=True)
            op_result = deblock_obj.apply(self.mds_client.drop_block_mds, key)
            deblocker_pool.put(deblock_obj)
            return key, op_result

        pool = gevent.pool.Pool(DEBLOCK_POOL_SIZE)
        blocks_to_delete = []
        cannot_remove = []
        for mds_key, result in pool.imap_unordered(drop_block_mds, blocks_to_drop_by_mds_key.keys()):
            if result:
                blocks_to_delete.extend(blocks_to_drop_by_mds_key[mds_key])
            else:
                cannot_remove.extend(blocks_to_drop_by_mds_key[mds_key])
        self._log_db(
            'block_cleaner', log, 'info',
            'Cannot remove %d, blocks to remove from db %d',
            (len(cannot_remove), len(blocks_to_delete))
        )

        with self.dbt:
            self.dbt.execute('DELETE FROM storage_block_removed WHERE id IN (%S)', (cannot_remove, ))
            if blocks_to_delete:
                self.dbt.execute('UPDATE storage_block_removed '
                                 'SET remove_success = true WHERE id IN (%S)', (blocks_to_delete, ))
                self.dbt.execute('DELETE FROM storage_block WHERE id IN (%S)', (blocks_to_delete, ))

    def _clean_data(self):
        log = self.log.getChild('data_cleaner')

        log.debug('Data cleaner run')

        cnt_rev_archived = 0
        cnt_rev_removed = 0
        cnt_datas_removed = 0
        cnt_new_orphan_blocks = 0

        orphan_blocks = self._get_orphan_blocks(count_only=True)
        log.debug('Found %d orphan blocks before cleaner run', orphan_blocks)

        with self.dbt:
            # Find and mark vm's with too much active revisions and mark old revisions as archive
            # Normally, this should be made each time new revision created, but to clean bd from any
            # trash we do this here also
            for vm_id, act_cnt in self.dbt.query(
                'SELECT act.vm_id, act.cnt FROM ( '
                '   SELECT vm_id, count(state) as cnt '
                '   FROM storage_revision '
                '   WHERE state = %s '
                '   GROUP BY vm_id '
                ') as act '
                'WHERE act.cnt > %s',
                ('active', StorageRevisionCollection.MAX_ACTIVE_REVS)
            ):
                with self.dbt:
                    revs = StorageRevisionCollection(self.dbt, vm_id)
                    archived = revs.archive_old_revisions()

                    self._log_db(
                        'data_cleaner', log, 'info',
                        'Marked %d obsolete revisions for vm %r as archived',
                        (archived, vm_id)
                    )
                    cnt_rev_archived += archived
                    self.dbt.commit()

        with self.dbt:
            # Drop old draft revisions, which were created > 7 days ago and still in draft state
            # This could happen if dl/ul job was scheduled but was never run or does not finished in 7 days
            # Also drop revisions already marked as archive
            for vm_id, rev_id, key, state, create_ts in self.dbt.query(
                'SELECT vm_id, rev_id, key, state, create_ts '
                'FROM storage_revision '
                'WHERE (state = %s AND create_ts < %s) or state = %s '
                'LIMIT 10000',
                ('draft', time.time() - 86400 * 7, 'archive')
            ):
                with self.dbt:
                    if state == 'archive':
                        self._log_db(
                            'data_cleaner', log, 'info',
                            'Drop revision %r (rev %r) for vm %r, marked as archived',
                            (key, rev_id, vm_id)
                        )
                    else:
                        self._log_db(
                            'data_cleaner', log, 'info',
                            'Drop revision %r (rev %r) for vm %r, too old (%s ago) and still in draft state',
                            (key, rev_id, vm_id, human_secs(int(time.time() - create_ts)))
                        )

                    storage_datas_count = self.dbt.execute(
                        'DELETE FROM storage_data WHERE vm_id = %s AND rev_id = %s',
                        (
                            vm_id, rev_id
                        )
                    ).rowcount

                    if storage_datas_count:
                        self._log_db(
                            'data_cleaner', log, 'info',
                            'Dropped %d storage datas for revision %r (rev %r) for vm %r',
                            (storage_datas_count, key, rev_id, vm_id)
                        )
                        cnt_datas_removed += storage_datas_count

                    # Break FK constraint from session to storage_revision if we have some
                    self.dbt.execute(
                        'UPDATE session SET rev_id = NULL WHERE vm_id = %s AND rev_id = %s',
                        (vm_id, rev_id)
                    )

                    self.dbt.execute(
                        'DELETE FROM revision_user WHERE vm_id = %s AND rev_id = %s',
                        (vm_id, rev_id)
                    )

                    self.dbt.execute(
                        'DELETE FROM revision_group WHERE vm_id = %s AND rev_id = %s',
                        (vm_id, rev_id)
                    )

                    self.dbt.execute(
                        'DELETE FROM storage_revision WHERE key = %s', (key, )
                    )
                    cnt_rev_removed += 1
                    self.dbt.commit()

            new_orphan_blocks = self._get_orphan_blocks(count_only=True)
            cnt_new_orphan_blocks = new_orphan_blocks - orphan_blocks

            self._log_db(
                'data_cleaner', log, 'info',
                'Rev cleaner finished: archived revs %d dropped revs %d, datas %d. New orphan blocks %d',
                (cnt_rev_archived, cnt_rev_removed, cnt_datas_removed, cnt_new_orphan_blocks)
            )

        with self.dbt:
            # Next step is to drop data from storage_data
            # Phase 1: find orphan storage blocks, but with mds_key used in other storage blocks
            orphan_blocks = self._get_orphan_blocks()

            cnt_drop_db = 0

            blocks_to_drop_by_mds_key = {}
            blocks_by_id = {}

            drop_ids = []

            for block_id, block_hash, block_mds_key, _ in orphan_blocks:
                blocks_to_drop_by_mds_key.setdefault(block_mds_key, []).append(block_id)
                blocks_by_id[block_id] = (block_hash, block_mds_key)

            # Find datas in storage_data searching by mds_key for our orphan blocks
            for block_id, block_mds_key in self.dbt.query(
                'SELECT DISTINCT ON (sb.mds_key) sb.id, sb.mds_key FROM storage_block sb '
                '   JOIN storage_data sd ON (sd.block_id = sb.id) '
                'WHERE sb.mds_key IN (%S)',
                (list(blocks_to_drop_by_mds_key), )
            ):
                # Since this mds key exist in other datas -- it is safe to just drop them from db
                # Removal from MDS is not needed
                for drop_block_id in blocks_to_drop_by_mds_key[block_mds_key]:
                    self._log_db(
                        'data_cleaner', log, 'debug', 'Drop block %r (id %d) from db only -- exist in other datas',
                        (blocks_by_id[drop_block_id], drop_block_id)
                    )
                    cnt_drop_db += 1
                    drop_ids.append(drop_block_id)

            if drop_ids:
                self.dbt.execute('DELETE FROM storage_block WHERE id IN (%S)', (drop_ids, ))

    def _clean_sessions(self):
        log = self.log.getChild('session_cleaner')

        log.debug('Session cleaner run')

        cnt_archived_new = 0
        cnt_archived_active = 0
        cnt_archived_revisions = 0

        now = int(time.time())

        with self.dbt:
            for key, type, vm_id, rev_id, node_id, state, state_ts, modify_ts in self.db.query(
                'SELECT key, type, vm_id, rev_id, node_id, state, state_ts, modify_ts FROM session '
                'WHERE state != %s AND modify_ts < %s ORDER BY modify_ts asc',
                ('archive', now - 86400 * 7)
            ):
                self._log_db(
                    'session_cleaner', log, 'info',
                    'Mark %s session %s as archived (vm %s, node %s) -- it has %s state and last modified %s ago',
                    (type, key, vm_id, node_id, state, human_secs(time.time() - modify_ts))
                )

                self.dbt.execute('UPDATE session SET state = %s WHERE key = %s', ('archive', key))
                self.dbt.commit()

                if state == 'new':
                    cnt_archived_new += 1
                else:
                    cnt_archived_active += 1

            for key, vm_id, rev_id, state, modify_ts in self.db.query(
                'SELECT s.key, s.vm_id, s.rev_id, s.state, s.modify_ts FROM SESSION AS s '
                'JOIN storage_revision AS sr ON (s.vm_id = sr.vm_id AND s.rev_id = sr.rev_id) '
                'WHERE sr.state = %s and s.modify_ts < %s',
                ('draft', now - 60 * 60 * 6)
            ):
                self._log_db(
                    'session_cleaner', log, 'info',
                    'Mark revision (vm %s, rev %s) as archived -- its session last modified %s ago',
                    (vm_id, rev_id, human_secs(time.time() - modify_ts))
                )

                if state == 'active':
                    self.dbt.execute('UPDATE session SET state = %s WHERE key = %s', ('archive', key))
                    cnt_archived_active += 1
                self.dbt.execute(
                    'UPDATE storage_revision SET state = %s WHERE vm_id = %s and rev_id = %s',
                    ('archive', vm_id, rev_id)
                )
                cnt_archived_revisions += 1

                self.dbt.commit()

            self._log_db(
                'session_cleaner', log, 'info',
                'Finished: marked %d new and %d active sessions as archived, mark %d draft revisions as archived',
                (cnt_archived_new, cnt_archived_active, cnt_archived_revisions)
            )


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--db-host', required=True)
    parser.add_argument('--db-port', required=True)
    parser.add_argument('--db-name', required=True)
    parser.add_argument('--db-user', required=True)
    parser.add_argument('--db-password', required=True)
    parser.add_argument('--logfile', default='-')
    parser.add_argument('--rpc-host', default='localhost')
    parser.add_argument('--rpc-port', type=int, default=2183)
    parser.add_argument('--cli-tvm-secret', required=True)
    parser.add_argument('--srv-tvm-secret', required=True)
    parser.add_argument('--yp-oauth-token', required=True)
    parser.add_argument('--vmproxy-oauth-token', required=True)
    parser.add_argument('--enable-ssl', action='store_true')
    parser.add_argument('--ssl-key')
    parser.add_argument('--ssl-cert')
    parser.add_argument('--evoq-max-active-jobs', required=True, type=int)
    parser.add_argument('--disable-hot-backups', action='store_true', default=False)
    parser.add_argument('--disable-vm-leaving', action='store_true', default=False)

    args = parser.parse_args()

    for attr in dir(args):
        if attr.startswith('_'):
            continue

        value = getattr(args, attr)

        if isinstance(value, str) and value.startswith('env:'):
            setattr(args, attr, os.environ.get(value.split(':', 1)[1]))

    return args


def main():
    args = parse_args()

    logging.getLogger('urllib3').setLevel(logging.ERROR)
    if args.logfile != '-':
        handler = logging.handlers.TimedRotatingFileHandler(
            args.logfile,
            when='midnight',
            backupCount=7,
        )

        logging.basicConfig(
            level=logging.DEBUG,
            format='%(asctime)s %(levelname)-7s [%(name)-24s]  %(message)s',
            handlers=[handler]
        )
    else:
        logging.basicConfig(
            level=logging.DEBUG,
            format='%(asctime)s %(levelname)-7s [%(name)-24s]  %(message)s'
        )

    qdm = QDM(args)
    qdm.initialize()
    qdm.start()
    qdm.run_forever()
