# -*- coding: utf-8 -*-
from collections import defaultdict
from itertools import chain

from intranet.yandex_directory.src.yandex_directory import app
from intranet.yandex_directory.src.yandex_directory.core.models.service import (
    OrganizationServiceModel,
    trial_status,
)
from intranet.yandex_directory.src.yandex_directory.core.models.organization import (
    OrganizationModel,
    subscription_plan,
    organization_type,
    vip_reason,
)


def _paid_organizations(main_connection):
    return OrganizationModel(main_connection).\
        filter(subscription_plan=subscription_plan.paid).\
        fields('id').\
        scalar()


def _partner_organizations(main_connection):
    return OrganizationModel(main_connection).\
        filter(organization_type=organization_type.partner_types).\
        fields('id').\
        scalar()


def _organizations_with_paid_services_with_licenses(main_connection):
    return OrganizationServiceModel(main_connection). \
        filter(trial_status=trial_status.expired, enabled=True, has_user_licenses=True). \
        fields('org_id'). \
        scalar()


def _organizations_with_trial_services(main_connection):
    return OrganizationServiceModel(main_connection). \
        filter(trial_status=trial_status.in_progress). \
        fields('org_id'). \
        scalar()


def _organizations_with_many_users(main_connection):
    return OrganizationModel(main_connection). \
        filter(user_count__gt=app.config['VIP_MANY_USERS_COUNT']). \
        fields('id'). \
        scalar()


def get_current_vip_reasons(main_connection):
    organizations = OrganizationModel(main_connection).fields('id', 'vip').all()
    return {o['id']: set(o['vip']) for o in organizations}


def calculated_vip_reasons(main_connection):
    calculator = {
        vip_reason.partner: [
            _partner_organizations,
        ],
        vip_reason.paid: [
            _organizations_with_paid_services_with_licenses,
            _paid_organizations
        ],
        vip_reason.trial_service: [
            _organizations_with_trial_services,
        ],
        vip_reason.many_users: [
            _organizations_with_many_users
        ],
    }
    reasons = defaultdict(set)
    for vip, calc_funcs in list(calculator.items()):
        org_ids = chain(*[
            func(main_connection) for func in calc_funcs
        ])
        for org_id in org_ids:
            reasons[org_id].add(vip)

    return reasons


def update_vip_reasons_for_all_orgs(main_connection):
    with main_connection.begin_nested():
        current_vip_reasons = get_current_vip_reasons(main_connection)
        new_vip_reasons =  calculated_vip_reasons(main_connection)

        for org_id, current_vip_reason in list(current_vip_reasons.items()):
            new_vip_reason = new_vip_reasons[org_id]
            if new_vip_reason == (current_vip_reason - {vip_reason.whitelist}):
                continue

            if vip_reason.whitelist in current_vip_reason:
                new_vip_reason.add(vip_reason.whitelist)

            OrganizationModel(main_connection).update_vip_reasons(org_id, new_vip_reason)
