# -*- coding: utf-8 -*-
import os
import datetime
import collections
from retrying import retry
from intranet.yandex_directory.src.yandex_directory.directory_logging.logger import log
from yt.wrapper.errors import YtError
from dateutil.relativedelta import relativedelta

from intranet.yandex_directory.src import settings
from intranet.yandex_directory.src import utils
from intranet.yandex_directory.src.yandex_directory.core.models import OrganizationModel
from intranet.yandex_directory.src.yandex_directory.core.models.license import TrackerLicenseLogModel
from intranet.yandex_directory.src.yandex_directory.core.models.service import (
    TrackerBillingStatusModel,
    TrackerBillingErrorModel,
)
from intranet.yandex_directory.src.yandex_directory.common.yt import utils as yt_utils
from intranet.yandex_directory.src.yandex_directory.common.utils import (
    utcnow,
    ensure_date,
    _split_in_batches,
    get_days_in_month,
)
from intranet.yandex_directory.src.yandex_directory.common.db import (
    get_shard_numbers,
    get_main_connection,
    get_meta_connection,
)
from intranet.yandex_directory.src.yandex_directory import app
from intranet.yandex_directory.src.yandex_directory.core.models.organization import (
    OrganizationBillingConsumedInfoModel,
    get_price_and_product_id_for_service,
    OrganizationLicenseConsumedInfoModel,
    organization_type,
)
from intranet.yandex_directory.src.yandex_directory.core.models import (
    ServiceModel,
    PromocodeModel,
)
from intranet.yandex_directory.src.yandex_directory.core.utils import (
    objects_map_by_id,
)
from intranet.yandex_directory.src.yandex_directory.core.models.service import (
    OrganizationServiceModel,
    AuthorizationError,
    ServiceNotFound,
)
from intranet.yandex_directory.src.yandex_directory.core.features import is_feature_enabled, DISABLE_BILLING_TRACKER

BATCH_SIZE = 20000


def save_organizations_consumed_products_to_yt():
    """
    Функция save_organizations_consumed_products_to_yt должна посчитать потребленные
    продукты платными организациями и записать данные в модель OrganizationBillingConsumedInfoModel для всех шардов.
    Далее, нужно проверить не были ли уже сохранены эти данные в нужные кластера YT,
    и если не были для каких-то кластеров, сохраняем их с созданием таблицы транзакционно
    """
    yesterday_date = _get_yesterday_date_string()

    with log.name_and_fields('billing', for_date=yesterday_date):
        # проверим, лежат ли уже данные в каких-то кластерах в YT
        need_to_save_into_clusters = get_yt_clusters_without_billing_data(for_date=yesterday_date)

        # если таблицы есть во всех кластерах, ничего сохранять не нужно
        # т.к. создаются и записываются они транзакционно
        if not need_to_save_into_clusters:
            log.info('Billing data is already saved to all YT clusters')
            return

        # для всех шардов посчитаем потребленные услуги в платных организациях
        for shard in get_shard_numbers():
            with get_main_connection(for_write=True, shard=shard) as main_connection:
                OrganizationLicenseConsumedInfoModel(main_connection).save_user_service_licenses()
                OrganizationBillingConsumedInfoModel(main_connection).calculate(
                    rewrite=True,
                )

        with log.fields(clusters=','.join(need_to_save_into_clusters)):
            log.info('Saving billing data to the YT for all organizations')

            data_to_write, billing_status_ids, orgs_by_shard = _get_tracker_billing_data(
                for_date=yesterday_date,
            )
            has_error = False
            for cluster in need_to_save_into_clusters:
                yt_client = yt_utils.billing_yt_clients[cluster]
                try:
                    _save_billing_data_to_yt_cluster_from_data(
                        cluster=cluster,
                        yt_client=yt_client,
                        data_to_write=data_to_write,
                        for_date=yesterday_date,
                    )
                except Exception:
                    has_error = True
                    # Если не получилось сложить данные в один из кластеров YT, то запишем ошибку.
                    # Данные должны записаться при следующей попытке запуска функции
                    log.trace().error('Can not save billing data to the YT cluster')
                else:
                    log.info('All billing data has been saved')
            if not has_error:
                _mark_billing_info_as_sended(billing_status_ids)
                _set_next_payment(orgs_by_shard)
                log.info('Set next date payments')


def _mark_billing_info_as_sended(billing_status_ids):
    for shard, ids in billing_status_ids.items():
        with get_main_connection(for_write=True, shard=shard) as main_connection:
            TrackerBillingStatusModel(main_connection).update(
                filter_data={'id__in': ids},
                update_data={'payment_status': True}
            )

def _get_orgs_metadata(main_connection, orgs_ids):
    if not orgs_ids:
        return {}
    orgs_metadata = OrganizationModel(main_connection).filter(
        id__in=orgs_ids,
    ).fields(
            'billing_info.balance',
            'billing_info.client_id',
            'billing_info.first_debt_act_date',
            'billing_info.person_type',
            'subscription_plan',
            'organization_type',
    )
    return objects_map_by_id(orgs_metadata)


def get_price_for_users(users_count, promocode_id=None):
    """
    Возвращает количество денег за месяц за указанное
    количество пользователей с учетом промокода
    """
    with get_meta_connection() as meta_connection:
        org_pricing_data = get_org_pricing_data(
            meta_connection=meta_connection,
            licenses_count=users_count,
            promocode_id=promocode_id,
        )
        all_prices = app.billing_client.get_products_price()
        all_prices[settings.PRODUCT_ID_FREE] = 0

        total = 0
        for product_id, product_users_count in org_pricing_data.items():
            total += all_prices[product_id] * product_users_count
        total_with_discount = total

        if promocode_id:
            org_pricing_data_without_promo = get_org_pricing_data(
                meta_connection=meta_connection,
                licenses_count=users_count,
                promocode_id=None,
            )

            total = 0
            for product_id, product_users_count in org_pricing_data_without_promo.items():
                total += all_prices[product_id] * product_users_count

    return total, total_with_discount


def get_product_map_for_promocode(meta_connection, promocode_id):
    """
    Ожидается что у промокода в базе для каждого сервиса
    лежит dict product_id соответствующий каждой категории
    например 'tracker': {1: 123, 100: 321, 250: 444}

    и в результате эти данные преобразуются в
    BILLING_PRODUCT_IDS_FOR_TRACKER = RangeDict({
        range(1, 101): 123,
        range(101, 251): 321,
        range(251, 9999999999): 444,
    })
    """
    promocode = PromocodeModel(meta_connection).get(promocode_id)
    if promocode:
        product_ids = promocode.get('product_ids', {}).get('tracker', {})
        if product_ids:
            product_ids = {
                int(category): product_id for
                category, product_id in
                product_ids.items()
            }
            cat_1 = product_ids.get(1, settings.TRACKER_PRODUCT_ID_1)
            cat_2 = product_ids.get(100, settings.TRACKER_PRODUCT_ID_100)
            cat_3 = product_ids.get(250, settings.TRACKER_PRODUCT_ID_250)
            return utils.RangeDict({
                range(1, 101): cat_1,
                range(101, 251): cat_2,
                range(251, 9999999999): cat_3,
            })


def get_org_pricing_data(meta_connection, licenses_count, promocode_id=None):
    """
    Возвращает данные о product_id и количестве лицензий для
    данного product_id (категории)

    {
        product_id: количество лицензий,
        ...
    }
    """
    if licenses_count <= settings.TRACKER_FREE_LICENSES:
        return {
            settings.PRODUCT_ID_FREE: licenses_count
        }
    product_range_map = None
    if promocode_id:
        product_range_map = get_product_map_for_promocode(
            meta_connection=meta_connection,
            promocode_id=promocode_id,
        )

    if not product_range_map:
        product_range_map = settings.BILLING_PRODUCT_IDS_FOR_TRACKER
    return collections.Counter([
        product_range_map[i]
        for i in range(1, licenses_count + 1)
    ])


def populate_pricing_data(org_id, org_pricing_data, client_id, for_date, promocode_id):
    """
    Возвращает данные в формате пригодном для записи в табличку
    биллинга в YT
    """
    license_multiplier = get_days_in_month(for_date)

    if not isinstance(for_date, str):
        for_date = for_date.strftime('%Y-%m-%d')
    return [
        {
            'client_id': client_id,
            'product_id': product_id,
            'quantity': quantity * license_multiplier,  # DIR-9622 так как биллинг потом поделит на это число
            'date': for_date,
            'promocode_id': promocode_id,
            'service': 'tracker',
            'org_id': org_id,
        }
        for product_id, quantity in org_pricing_data.items()
        if product_id != settings.PRODUCT_ID_FREE
    ]


def get_promocodes_days(main_connection, org_id, to_date, from_date=None):
    """
    Возвращает id промокодов и количество дней которые они
    использовались за отчетный период
    """
    to_date = ensure_date(to_date)
    if from_date:
        from_date = ensure_date(from_date)
    else:
        from_date = to_date - relativedelta(months=1)

    promocodes = OrganizationBillingConsumedInfoModel(main_connection).filter(
        org_id=org_id,
        for_date__gte=ensure_date(from_date),
        for_date__lt=ensure_date(to_date),
        service='tracker',
    ).scalar('promocode_id')
    promocodes = [promocode for promocode in promocodes if promocode]

    return collections.Counter(promocodes)

def get_promocode_id(main_connection, org_id, to_date, from_date=None):
    """
    Возвращает id промокода, который использовался больше всего дней
    за отчетный период или None
    """
    promocodes_by_days = get_promocodes_days(
        main_connection=main_connection,
        org_id=org_id,
        to_date=to_date,
        from_date=from_date,
    )

    if promocodes_by_days:
        return promocodes_by_days.most_common(1)[0][0]

def _set_next_payment(orgs_by_shard):
    for shard, orgs_data in orgs_by_shard.items():
        with get_main_connection(for_write=True, shard=shard) as main_connection:
            for org_id, payment_dates in orgs_data.items():
                last_payment = ensure_date(sorted(payment_dates)[-1])
                future_paid_date = TrackerBillingStatusModel(main_connection).filter(
                    org_id=org_id,
                    payment_status=False,
                    payment_date__gt=last_payment,
                ).one()
                if not future_paid_date:
                    TrackerBillingStatusModel(main_connection).create(
                        org_id=org_id,
                        payment_date=last_payment + relativedelta(months=1),
                    )


def _get_tracker_billing_data(for_date):
    """
    Выводит данные для отправки YT для каждой организации
    у которой сегодня наступил день оплаты и нет условий для которых
    использование трекера бесплатно
    """
    orgs_billing_data = []
    billing_status_ids = collections.defaultdict(list)
    orgs_by_shard = collections.defaultdict(lambda: collections.defaultdict(list))
    with get_meta_connection() as meta_connection:
        for shard in get_shard_numbers():
            with get_main_connection(shard=shard, for_write=True) as main_connection:
                orgs_to_process = TrackerBillingStatusModel(main_connection).filter(
                    payment_date__lte=for_date,
                    payment_status=False,
                )
                orgs_metadata = _get_orgs_metadata(
                    main_connection,
                    [row['org_id'] for row in orgs_to_process]
                )
                for org_billing_data in orgs_to_process:
                    billing_status_ids[shard].append(org_billing_data['id'])
                    org_id = org_billing_data['org_id']
                    payment_date = org_billing_data['payment_date']
                    orgs_by_shard[shard][org_id].append(payment_date.strftime('%Y-%m-%d'))
                    org_metadata = orgs_metadata.get(org_id)

                    try:
                        tracker_service = OrganizationServiceModel(main_connection).get_by_slug(
                            org_id, 'tracker',
                            fields=['*']
                        )
                    except (AuthorizationError, ServiceNotFound):
                        pass
                    else:
                        if tracker_service['trial_expires'] and tracker_service['trial_expires'] >= payment_date:
                            # сервис еще в триале, не будем отгружать данные
                            continue

                    base_licenses = calculate_base_licenses(
                        main_connection=main_connection,
                        org_id=org_id,
                        for_date=payment_date - relativedelta(months=1),
                    )
                    max_period_licenses_count = calculate_licensed_users(
                        meta_connection=meta_connection,
                        org_id=org_id,
                        to_date=payment_date,
                        base_licenses=base_licenses,
                    )
                    if (
                        max_period_licenses_count <= settings.TRACKER_FREE_LICENSES or
                        (org_metadata and org_metadata['organization_type'] in organization_type.free_types) or
                        is_feature_enabled(meta_connection, org_id, DISABLE_BILLING_TRACKER)
                    ):
                        log.info('Skipping org_id: {}, free type or free licenses count'.format(org_id))
                        continue

                    if org_metadata and org_metadata['billing_info']:
                        client_id = org_metadata['billing_info'].get('client_id')
                    else:
                        client_id = org_billing_data['client_id']

                    if not client_id:
                        # пишем в таблицу ошибок
                        TrackerBillingErrorModel(main_connection).create(
                            org_id=org_id,
                            payment_date=payment_date,
                            error='no_client_id',
                        )

                        # если почему-то не нашли client_id для организации, запишем её в лог и кинем ошибку
                        log.trace().error('Organization {} has no information about billing client'.format(org_id))
                        continue

                    promocode_id = get_promocode_id(
                        main_connection=main_connection,
                        org_id=org_id,
                        to_date=payment_date,
                    )

                    org_pricing_data = get_org_pricing_data(
                        meta_connection=meta_connection,
                        licenses_count=max_period_licenses_count,
                        promocode_id=promocode_id,
                    )

                    rows_for_billing = populate_pricing_data(
                        org_id=org_id,
                        org_pricing_data=org_pricing_data,
                        client_id=client_id,
                        for_date=for_date,
                        promocode_id=promocode_id,

                    )
                    orgs_billing_data.extend(rows_for_billing)
    return orgs_billing_data, billing_status_ids, orgs_by_shard


def _get_yesterday_date_string():
    return (utcnow().date() - datetime.timedelta(days=1)).strftime('%Y-%m-%d')


@retry(stop_max_attempt_number=3, wait_incrementing_increment=50, retry_on_exception=lambda x: isinstance(x, YtError))
def _save_billing_data_to_yt_cluster_from_data(cluster, yt_client, data_to_write, for_date):
    """
    Сохраняет данные потребленных продуктов со всех шардов в один кластер YT.
    В YT создание таблицы и запись данных производится транзакционно

    Args:
        cluster (str) - строковое название кластера YT
        yt_client (YtClient) - инстанс клиента YT
        for_date (str) - дата для записи
        data_to_write (list) - данные для записи
    """
    schema = [
        {'name': 'client_id', 'type': 'int64'},
        {'name': 'product_id', 'type': 'int64'},
        {'name': 'quantity', 'type': 'int64'},
        {'name': 'date', 'type': 'string'},
        {'name': 'org_id', 'type': 'int64'},
        {'name': 'service', 'type': 'string'},
        {'name': 'promocode_id', 'type': 'string'},
    ]
    table_path = _get_table_path(for_date=for_date)

    with log.fields(yt_cluster=cluster):
        # данные целиком записываются в YT-транзакции для всех шардов
        with yt_client.Transaction():
            yt_utils.create_table(
                table=table_path,
                schema=schema,
                client=yt_client,
            )
            for batch in _split_in_batches(data_to_write, BATCH_SIZE):
                yt_utils.append_rows_to_table(
                    table=table_path,
                    rows_data=batch,
                    client=yt_client
                )
                log.info('Billing data has been saved to the YT from one shard')
        log.info('Billing data has been saved to the YT cluster')


@retry(stop_max_attempt_number=3, wait_incrementing_increment=50, retry_on_exception=lambda x: isinstance(x, YtError))
def _save_billing_data_to_yt_cluster(cluster, yt_client, for_date):
    """
    Сохраняет данные потребленных продуктов со всех шардов в один кластер YT.
    В YT создание таблицы и запись данных производится транзакционно

    Args:
        cluster (str) - строковое название кластера YT
        yt_client (YtClient) - инстанс клиента YT
        for_date (str) - дата в виде строки, для которой выгружаем данные
    """
    schema = [
        {'name': 'client_id', 'type': 'int64'},
        {'name': 'product_id', 'type': 'int64'},
        {'name': 'quantity', 'type': 'int64'},
        {'name': 'date', 'type': 'string'},
        {'name': 'org_id', 'type': 'int64'},
        {'name': 'service', 'type': 'string'},
        {'name': 'promocode_id', 'type': 'string'},
    ]
    table_path = _get_table_path(for_date=for_date)
    filter_data = {'for_date': for_date}

    with log.fields(yt_cluster=cluster):
        # данные целиком записываются в YT-транзакции для всех шардов
        with yt_client.Transaction():
            yt_utils.create_table(
                table=table_path,
                schema=schema,
                client=yt_client,
            )

            for shard in get_shard_numbers():
                # для каждого шарда считаем данные для Биллинга
                log.info('Saving billing data to the YT from one shard')
                with get_main_connection(shard=shard) as main_connection:
                    total_rows = OrganizationBillingConsumedInfoModel(main_connection).count(
                        filter_data=filter_data
                    )
                    for i in range(0, total_rows, BATCH_SIZE):
                        consumed_products_info = OrganizationBillingConsumedInfoModel(main_connection).find(
                            fields=[
                                'org_id',
                                'for_date',
                                'service',
                                'total_users_count',
                                'organization_billing_info.client_id',
                                'promocode_id',
                                'organization_type',
                            ],
                            filter_data=filter_data,
                            limit=BATCH_SIZE,
                            skip=i,
                        )
                        yt_utils.append_rows_to_table(
                            table=table_path,
                            rows_data=_prepare_consumed_products(consumed_products_info),
                            client=yt_client
                        )
                log.info('Billing data has been saved to the YT from one shard')
        log.info('Billing data has been saved to the YT cluster')


def _prepare_consumed_products(consumed_products_info):
    """
    Подготавливает данные модели OrganizationBillingConsumedInfoModel для отправки в YT
    """
    consumed_products = []
    for info in consumed_products_info:
        log_fields = {
            'org_id': info['org_id'],
            'promocode_id': info['promocode_id'],
            'organization_type': info['organization_type'],
        }
        with log.fields(**log_fields):
            if info['service'] != 'tracker':
                log.info('Don\'t save billing info for {}'.format(info['service']))
                continue
            product_id = get_price_and_product_id_for_service(
                info['total_users_count'],
                info['service'],
                promocode_id=info['promocode_id'],
            )['product_id']
            with get_meta_connection() as meta_connection:
                if info['organization_type'] in organization_type.free_types \
                        or is_feature_enabled(meta_connection, info['org_id'], DISABLE_BILLING_TRACKER) \
                        or product_id == app.config['PRODUCT_ID_FREE']:
                    # не откидываем данные в YT для таких организаций
                    log.info('Organization has free type or with free promocode')
                    continue

            if not info['organization_billing_info'] or not info['organization_billing_info'].get('client_id'):
                # если почему-то не нашли client_id для организации, запишем её в лог и кинем ошибку
                log.trace().error('Organization has no information about billing client')
                raise ValueError('Can not prepare consumed products')

            consumed_products.append({
                'client_id': info['organization_billing_info']['client_id'],
                'product_id': product_id,
                'quantity': info['total_users_count'],
                'date': info['for_date'].strftime('%Y-%m-%d'),
                'promocode_id': info['promocode_id'],
                'service': info['service'],
                'org_id': info['org_id'],
            })
    return consumed_products


def get_yt_clusters_without_billing_data(for_date=None, check_empty_tables=False):
    """
    Возвращает список кластеров YT, в которых нет таблицы с биллинговыми данными для даты for_date

    Args:
        for_date (str) - дата в виде строки
        check_empty_tables (bool) - проверять, пустые ли таблицы. Если пустые, считается что их нет
    """
    if for_date is None:
        for_date = _get_yesterday_date_string()

    return yt_utils.get_yt_clusters_without_table(
        table=_get_table_path(for_date),
        check_empty_tables=check_empty_tables,
        yt_clients=list(yt_utils.billing_yt_clients.values()),
    )


def _get_table_path(for_date):
    """
    Возвращает полный путь до YT-таблицы, сформированный из настройки BILLING_YT_TABLES_PATH и даты for_date

    Args:
        for_date (str) - дата в виде строки
    """
    table_path = os.path.join(
        app.config['BILLING_YT_TABLES_PATH'],
        for_date,
    )
    return table_path


def calculate_base_licenses(main_connection, org_id, for_date):
    """
    Возвращает базу количества лицензий (список uid пользователей) - у которых были лицензии в организации
    на определенную дату

    """
    tracker_id = ServiceModel(main_connection).get_by_slug('tracker')['id']
    return OrganizationLicenseConsumedInfoModel(main_connection).filter(
        org_id=org_id,
        for_date=ensure_date(for_date),
        service_id=tracker_id,
    ).scalar('user_id')


def calculate_licensed_users(meta_connection, org_id, to_date, base_licenses=None, from_date=None):
    """
    Вычисляет количество лицензий используемое в организации
    за срок >= to_date - 1 месяц and < to_date
    логика расчета в DIR-9487
    """
    base_licenses = set(base_licenses) if base_licenses else set()

    license_value = len(base_licenses) or 0
    value_for_billing = license_value

    license_not_billing_for = datetime.timedelta(minutes=30)

    to_date = ensure_date(to_date)
    if from_date:
        from_date = ensure_date(from_date)
    else:
        from_date = to_date - relativedelta(months=1)

    log_data = TrackerLicenseLogModel(meta_connection).filter(
        org_id=org_id,
        created_at__gte=from_date,
        created_at__lt=to_date,
    ).order_by('created_at')

    current_users_add_event = {}
    for log_entry in log_data:
        action = log_entry['action']
        uid = log_entry['uid']
        action_date = log_entry['created_at']
        if action == TrackerLicenseLogModel.ADD_ACTION:
            if ensure_date(action_date) == from_date and uid in base_licenses:
                continue

            license_value+=1
            if license_value > value_for_billing:
                value_for_billing+=1
            current_users_add_event[uid] = action_date

        elif action == TrackerLicenseLogModel.DELETE_ACTION:
            license_value -= 1
            add_license_date = None

            if uid in current_users_add_event:
                add_license_date = current_users_add_event[uid]
            else:
                add_license_entry = TrackerLicenseLogModel(meta_connection).filter(
                    org_id=org_id,
                    uid=uid,
                    action=TrackerLicenseLogModel.ADD_ACTION,
                    created_at__lt=action_date,
                ).order_by('-created_at').one()
                if add_license_entry:
                    add_license_date = add_license_entry['created_at']
            if add_license_date and (action_date - add_license_date) <  license_not_billing_for:
                value_for_billing-=1

            if ensure_date(action_date) == from_date and uid in base_licenses:
                base_licenses.discard(uid)
    return value_for_billing
