# -*- coding: utf-8 -*-

from intranet.yandex_directory.src.yandex_directory.common.commands.base import AllShardsCommand, Option
from intranet.yandex_directory.src.yandex_directory.common.db import (
    get_meta_connection,
    get_main_connection,
)


def slice(ids, at_once):
    cursor = 0
    while True:
        part = ids[cursor:cursor + at_once]
        cursor = cursor + at_once
        if not part:
            break
        yield part


def get_dismissed(main_connection, ids):
    if not ids:
        return set()
    return {
            obj[0]
            for obj in main_connection.execute('''
                SELECT 
                    id
                FROM users
                WHERE 
                    id IN ({ids_formatted}) 
                    AND is_dismissed = TRUE 
                FOR UPDATE
            '''.format(
            ids_formatted=', '.join(map(str, ids))
            ))
        }


def set_dismissed(meta_connection, ids, dismissed):
    if not ids:
        return
    meta_connection.execute('''
        UPDATE users
        SET 
            is_dismissed = {dismissed} 
        WHERE 
            id IN ({ids_formatted})
    '''.format(
        ids_formatted=', '.join(map(str, ids)),
        dismissed='TRUE' if dismissed else 'FALSE'
    ))


class Command(AllShardsCommand):
    """
    Установить признак is_dismissed в метабазе.
    Одноразовая команда. Удали меня.
    """
    name = 'set_dismissed_in_metabase'
    PROCESS_AT_ONCE = 10000
    option_list = [
        Option('--silent', '-s', dest='silent', action='store_true', help='Silent mode'),
    ]
    silent = None
    need_writable_database_connections = []  # none db to write by default

    def run(self, silent=False):
        self.silent = silent
        self.do_print('Shard {}:'.format(self.shard))
        for unprocessed_ids in self.still_got_ids():
            self.do_print('Got {ids} IDs to process on shard {shard}'.format(
                ids=len(unprocessed_ids), shard=self.shard
            ))
            self.proccess_ids(unprocessed_ids, self.shard)
            self.do_print('Done with shard {}.'.format(self.shard))

    def still_got_ids(self):
        while True:
            result = self.meta_connection.execute('''
                SELECT users.id FROM users
                JOIN organizations ON users.org_id = organizations.id
                WHERE organizations.shard = {shard}
                AND users.is_dismissed IS NULL
                LIMIT {limit}
            '''.format(limit=self.PROCESS_AT_ONCE, shard=self.shard)).fetchall()
            if not result:
                raise StopIteration
            yield [row[0] for row in result]

    def proccess_ids(self, ids, shard):
        with get_main_connection(shard=shard, for_write=True) as main_connection, get_meta_connection(for_write=True) as meta_connection:
            dismissed = get_dismissed(main_connection, ids)
            active = set(ids) - dismissed
            set_dismissed(meta_connection, ids=dismissed, dismissed=True)
            set_dismissed(meta_connection, ids=active, dismissed=False)
        return dismissed, active

    def do_print(self, msg):
        if not self.silent:
            print(msg)
