# -*- coding: utf-8 -*-
import tqdm
import time

from itertools import repeat, islice

from intranet.yandex_directory.src.yandex_directory.directory_logging.logger import log
    
from intranet.yandex_directory.src.yandex_directory.common.db import (
    get_main_connection,
    get_meta_connection,
    get_shards_by_org_type,
)
from intranet.yandex_directory.src.yandex_directory.core.models.service import (
    MAILLIST_SERVICE_SLUG,
)
from intranet.yandex_directory.src.yandex_directory.core.models.service import enable_service


MAILLIST_SERVICE_ID = 246


def get_progress():
    """Считаем долю организаций, в которых уже новый сервис рассылок.
    """
    shards = get_shards_by_org_type(organization_type=None)

    total = 0
    enabled = 0
    
    for shard in shards:
        with get_main_connection(shard=shard) as main_connection:
            total += main_connection.execute('SELECT COUNT(*) FROM organizations').fetchall()[0][0]
            enabled += main_connection.execute('SELECT COUNT(*) from organization_services WHERE service_id = 246').fetchall()[0][0]

    return float(enabled) / total, total - enabled
    

def enable_maillist_for_organizations(org_num, max_org_size=100, delay=0.3, dry_run=True):
    # сервис bigml может синкать 250 организаций в минуту. Если больше - у них накапливаются события в очередях.
    # Будем включать сервис с задержкой
    org_ids = iterate_over_all_orgs(max_org_size)
    if org_num:
        org_ids = islice(org_ids, 0, org_num)

    org_ids = tqdm.tqdm(org_ids)

    try:
        for shard, org_id in org_ids:
            if delay:
                time.sleep(delay)

            with get_main_connection(shard=shard, for_write=True) as main_connection, \
                 get_meta_connection() as meta_connection:
                try:
                    enable_service(
                        meta_connection,
                        main_connection,
                        org_id,
                        MAILLIST_SERVICE_SLUG,
                    )
                except Exception as exc:
                    print('ERROR', org_id, exc)
                    log.trace().error('Unable to enable service')
    except KeyboardInterrupt:
        print('\nGraceful exit')
    finally:
        progress, not_enabled = get_progress()
        print('\nCurrently migrated: {0:.2f}%'.format(progress * 100))
        print('Organization left: {}'.format(not_enabled))




def iterate_over_org_ids(shard, max_org_size, chunk_size=10):
    query = """
        SELECT id
        FROM organizations
        WHERE id > %(max_org_id)s
        AND user_count < %(max_org_size)s
        AND NOT EXISTS (
            SELECT 1
            FROM organization_services
            JOIN services
                ON services.id = organization_services.service_id
            WHERE organization_services.org_id = organizations.id
              AND services.slug = %(service_slug)s
              AND organization_services.enabled
        )
        AND NOT EXISTS (
            SELECT 1 
            FROM departments 
            WHERE label is not null 
            AND uid is null
            AND org_id = organizations.id
        )
        AND NOT EXISTS (
            SELECT 1 
            FROM groups 
            WHERE label is not null 
            AND uid is null
            AND org_id = organizations.id
        )
        ORDER BY id
        LIMIT %(chunk_size)s
    """
    max_org_id = 0
    
    while True:
        with get_main_connection(shard=shard, no_transaction=False) as main_connection:
            params = dict(
                max_org_id=max_org_id,
                chunk_size=chunk_size,
                max_org_size=max_org_size,
                service_slug=MAILLIST_SERVICE_SLUG,
            )
            rows = main_connection.execute(
                query,
                params,
            ).fetchall()
            if not rows:
                return
         
            for row in rows:
                org_id = row[0]
                if org_id != 34:
                    yield org_id
                max_org_id = org_id


def interleave(iterables):
   iterators = list(map(iter, iterables))
   to_remove = []
   
   while iterators:
      for iterator in iterators:
         try:
            yield next(iterator)
         except StopIteration:
            to_remove.append(iterator)
      if to_remove:
         iterators = [i
                      for i in iterators
                      if i not in to_remove]
         to_remove = []


def iterate_over_all_orgs(max_org_size):
    shards = get_shards_by_org_type(organization_type=None)
   
    iterators = [
        list(zip(
            repeat(shard),
            iterate_over_org_ids(shard, max_org_size)
        ))
        for shard in shards
    ]
    return interleave(iterators)


# Пример:
# Если включать вот так, то у нас не будут накапливаться события
enable_maillist_for_organizations(1000000, dry_run=False, delay=0.3)
enable_maillist_for_organizations(1000000, dry_run=True, delay=0, max_org_size=10000)
