from intranet.yandex_directory.src.yandex_directory.common.commands.base import (
    Option,
    BaseCommand,
)
from intranet.yandex_directory.src.yandex_directory.common.db import get_meta_connection, get_main_connection, get_shard, get_shard_numbers
from intranet.yandex_directory.src.yandex_directory.core.models.service import (
    disable_licensed_services_by_trial,
)
from intranet.yandex_directory.src.yandex_directory.core.utils import only_attrs


class Command(BaseCommand):
    name = 'disable-service-trial-expired'
    option_list = (
        Option('--org-id', '-o', dest='org_id', type=int, required=False, help='Organization id.'),
    )

    def run(self, org_id):
        if org_id:
            shard = self._get_shard(org_id)
            self._process_org(shard, org_id)
        else:
            for shard in get_shard_numbers():
                with get_main_connection(shard) as main_connection:
                    query = """
                            select distinct os.org_id
                            from organization_services os
                                     left join user_service_licenses usl
                                               on os.org_id = usl.org_id
                                                   and os.service_id = usl.service_id
                            where os.resource_id is not null
                              and os.trial_status = 'expired'
                              and os.trial_expires is not null
                              and os.enabled = true
                              and usl.service_id is null
                    """
                    data = main_connection.execute(query).fetchall()
                    org_ids = only_attrs(data, 'org_id')

                n = 1
                orgs_count = len(org_ids)
                for org_id in org_ids:
                    print("\r shard: {shard} | {current}/{total}".format(shard=shard, current=n, total=orgs_count),
                          end='')
                    n += 1
                    self._process_org(shard, org_id)
                print()

    @staticmethod
    def _process_org(shard, org_id):
        with get_meta_connection(True) as meta_connection:
            with get_main_connection(shard, True) as main_connection:
                disable_licensed_services_by_trial(meta_connection, main_connection, org_id)

    @staticmethod
    def _get_shard(org_id):
        with get_meta_connection() as meta_connection:
            return get_shard(meta_connection, org_id)
