package ru.yandex.intranet.d.dao.aggregates

import com.fasterxml.jackson.core.type.TypeReference
import com.yandex.ydb.table.query.Params
import com.yandex.ydb.table.result.ResultSetReader
import com.yandex.ydb.table.values.ListValue
import com.yandex.ydb.table.values.PrimitiveValue
import com.yandex.ydb.table.values.StructValue
import kotlinx.coroutines.reactor.awaitSingle
import kotlinx.coroutines.reactor.awaitSingleOrNull
import org.springframework.beans.factory.annotation.Qualifier
import org.springframework.stereotype.Component
import ru.yandex.intranet.d.dao.DaoPagination
import ru.yandex.intranet.d.dao.DaoReader
import ru.yandex.intranet.d.dao.JsonFieldHelper
import ru.yandex.intranet.d.dao.WithTx
import ru.yandex.intranet.d.datasource.impl.YdbQuerySource
import ru.yandex.intranet.d.datasource.model.YdbTxSession
import ru.yandex.intranet.d.model.TenantId
import ru.yandex.intranet.d.model.aggregates.ServiceAggregateKey
import ru.yandex.intranet.d.model.aggregates.ServiceAggregateKeyWithEpoch
import ru.yandex.intranet.d.model.aggregates.ServiceAggregateKeyWithEpochPage
import ru.yandex.intranet.d.model.aggregates.ServiceAggregateUsageModel
import ru.yandex.intranet.d.model.usage.ServiceUsageAmounts
import ru.yandex.intranet.d.util.ObjectMapperHolder

/**
 * Service aggregate usage DAO.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
@Component
class ServiceAggregateUsageDao(private val ydbQuerySource: YdbQuerySource,
                               @Qualifier("ydbJsonObjectMapper") private val objectMapper: ObjectMapperHolder) {
    private val exactFieldHelper: JsonFieldHelper<ServiceUsageAmounts> = JsonFieldHelper(objectMapper,
        object : TypeReference<ServiceUsageAmounts>() {})

    suspend fun getById(session: YdbTxSession, id: ServiceAggregateKey): ServiceAggregateUsageModel? {
        val query = ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getById")
        val params = toKeyParams(id)
        return DaoReader.toModel(session.executeDataQueryRetryable(query, params).awaitSingle(), this::toModel)
    }

    suspend fun getByIds(session: YdbTxSession,
                         ids: Collection<ServiceAggregateKey>): List<ServiceAggregateUsageModel> {
        if (ids.isEmpty()) {
            return listOf()
        }
        return ids.distinct().chunked(1000).map {
            val query = ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getByIds")
            val params = toKeyListParams(it)
            DaoReader.toModels(session.executeDataQueryRetryable(query, params).awaitSingle(), this::toModel)
        }.flatten()
    }

    suspend fun upsertOneRetryable(session: YdbTxSession, value: ServiceAggregateUsageModel) {
        val query = ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.upsertOne")
        val params = toUpsertOneParams(value)
        session.executeDataQueryRetryable(query, params).awaitSingleOrNull()
    }

    suspend fun upsertManyRetryable(session: YdbTxSession, values: Collection<ServiceAggregateUsageModel>) {
        if (values.isEmpty()) {
            return
        }
        val query = ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.upsertMany")
        val params = toUpsertManyParams(values)
        session.executeDataQueryRetryable(query, params).awaitSingleOrNull()
    }

    suspend fun getByService(session: YdbTxSession, tenantId: TenantId,
                             serviceId: Long, perPage: Long): List<ServiceAggregateUsageModel> {
        return DaoPagination.getAllPages(session,
            { ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getByServiceFirstPage") },
            { ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getByServiceNextPage") },
            { limit -> toGetByServiceFirstPageParams(tenantId, serviceId, limit) },
            { limit, lastOnPreviousPage -> toGetByServiceNextPageParams(tenantId, serviceId, limit,
                lastOnPreviousPage.key.providerId, lastOnPreviousPage.key.resourceId)},
            this::toModel,
            perPage
        )
    }

    suspend fun getByServiceAndProvider(session: YdbTxSession, tenantId: TenantId,
                                        serviceId: Long, providerId: String, perPage: Long): List<ServiceAggregateUsageModel> {
        return DaoPagination.getAllPages(session,
            { ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getByServiceAndProviderFirstPage") },
            { ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getByServiceAndProviderNextPage") },
            { limit -> toGetByServiceAndProviderFirstPageParams(tenantId, serviceId, providerId, limit) },
            { limit, lastOnPreviousPage -> toGetByServiceAndProviderNextPageParams(tenantId, serviceId, providerId,
                limit, lastOnPreviousPage.key.resourceId)},
            this::toModel,
            perPage
        )
    }

    suspend fun getByServices(session: YdbTxSession, tenantId: TenantId,
                              serviceIds: Collection<Long>, perPage: Long): List<ServiceAggregateUsageModel> {
        if (serviceIds.isEmpty()) {
            return listOf()
        }
        val sortedServiceIds = serviceIds.distinct().sorted()
        return DaoPagination.getAllPages(session,
            { ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getByServicesFirstPage") },
            { ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getByServicesNextPage") },
            { ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getByServicesLastPage") },
            { limit -> toGetByServicesFirstPageParams(tenantId, sortedServiceIds, limit) },
            { limit, lastOnPreviousPage -> toGetByServicesNextPageParams(tenantId,
                serviceIdsTail(sortedServiceIds, lastOnPreviousPage.key.serviceId), limit,
                lastOnPreviousPage.key.serviceId, lastOnPreviousPage.key.providerId,
                lastOnPreviousPage.key.resourceId)},
            { limit, lastOnPreviousPage -> toGetByServicesLastPageParams(tenantId, limit,
                lastOnPreviousPage.key.serviceId, lastOnPreviousPage.key.providerId,
                lastOnPreviousPage.key.resourceId)},
            { lastOnPreviousPage -> serviceIdsTail(sortedServiceIds, lastOnPreviousPage.key.serviceId).isEmpty() },
            this::toModel,
            perPage
        )
    }

    suspend fun getProviderIdsByService(session: YdbTxSession, tenantId: TenantId, serviceId: Long, perPage: Long): List<String> {
        return DaoPagination.getAllPages(session,
            { ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getByServiceGroupByProviderIdFirstPage") },
            { ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getByServiceGroupByProviderIdNextPage") },
            { limit -> toGetByServiceGroupByProviderIdFirstPageParams(tenantId, serviceId, limit) },
            { limit, lastProviderId -> toGetByServiceGroupByProviderIdNextPageParams(
                tenantId, serviceId, limit, lastProviderId)},
            fun (reader: ResultSetReader) = reader.getColumn("provider_id").utf8.toString(),
            perPage
        )
    }

    suspend fun getResourceIdsByServiceAndProvider(
        session: YdbTxSession, tenantId: TenantId, serviceId: Long, providerId: String, perPage: Long): List<String> {
        return DaoPagination.getAllPages(session,
            { ydbQuerySource.getQuery(
                "yql.queries.serviceAggregateUsage.getByServiceAndProviderGroupByResourceIdFirstPage") },
            { ydbQuerySource.getQuery(
                "yql.queries.serviceAggregateUsage.getByServiceAndProviderGroupByResourceIdNextPage") },
            { limit -> toGetByServiceAndProviderGroupByResourceIdFirstPageParams(
                tenantId, serviceId, providerId, limit) },
            { limit, lastResourceId -> toGetByServiceAndProviderGroupByResourceIdNextPageParams(
                tenantId, serviceId, providerId, limit, lastResourceId)},
            fun (reader: ResultSetReader) = reader.getColumn("resource_id").utf8.toString(),
            perPage
        )
    }

    suspend fun getKeysForOlderEpochsFirstPage(session: YdbTxSession,
                                               tenantId: TenantId,
                                               providerId: String,
                                               resourceId: String,
                                               currentEpoch: Long,
                                               limit: Long): WithTx<ServiceAggregateKeyWithEpochPage> {
        val query = ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getKeysForOlderEpochsFirstPage")
        val params = toGetKeysForOlderEpochsFirstPageParams(tenantId, providerId, resourceId, currentEpoch, limit)
        val page = DaoReader.toModelsWithTx(session.executeDataQueryRetryable(query, params).awaitSingle(),
            this::toKeyWithEpoch)
        return WithTx(ServiceAggregateKeyWithEpochPage(keys = page.value,
            nextFrom = if (page.value.size >= limit) { page.value.last() } else { null }), page.txId)
    }

    suspend fun getKeysForOlderEpochsNextPage(
        session: YdbTxSession, from: ServiceAggregateKeyWithEpoch, limit: Long): WithTx<ServiceAggregateKeyWithEpochPage> {
        val query = ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.getKeysForOlderEpochsNextPage")
        val params = toGetKeysForOlderEpochsNextPageParams(from, limit)
        val page = DaoReader.toModelsWithTx(session.executeDataQueryRetryable(query, params).awaitSingle(),
            this::toKeyWithEpoch)
        return WithTx(ServiceAggregateKeyWithEpochPage(keys = page.value,
            nextFrom = if (page.value.size >= limit) { page.value.last() } else { null }), page.txId)
    }

    suspend fun deleteByIdRetryable(session: YdbTxSession, id: ServiceAggregateKey) {
        val query = ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.deleteById")
        val params = toKeyParams(id)
        session.executeDataQueryRetryable(query, params).awaitSingleOrNull()
    }

    suspend fun deleteByIdsRetryable(session: YdbTxSession, ids: Collection<ServiceAggregateKey>) {
        if (ids.isEmpty()) {
            return
        }
        val query = ydbQuerySource.getQuery("yql.queries.serviceAggregateUsage.deleteByIds")
        val params = toKeyListParams(ids)
        session.executeDataQueryRetryable(query, params).awaitSingleOrNull()
    }

    suspend fun getKeysForOlderEpochsMultiResourceFirstPage(session: YdbTxSession,
                                                            tenantId: TenantId,
                                                            providerId: String,
                                                            resourceIds: Collection<String>,
                                                            currentEpoch: Long,
                                                            limit: Long): WithTx<ServiceAggregateKeyWithEpochPage> {
        val query = ydbQuerySource
            .getQuery("yql.queries.serviceAggregateUsage.getKeysForOlderEpochsMultiResourceFirstPage")
        val params = toGetKeysForOlderEpochsMultiResourceFirstPageParams(tenantId, providerId, resourceIds,
            currentEpoch, limit)
        val page = DaoReader.toModelsWithTx(session.executeDataQueryRetryable(query, params).awaitSingle(),
            this::toKeyWithEpoch)
        return WithTx(ServiceAggregateKeyWithEpochPage(keys = page.value,
            nextFrom = if (page.value.size >= limit) { page.value.last() } else { null }), page.txId)
    }

    suspend fun getKeysForOlderEpochsMultiResourceNextPage(
        session: YdbTxSession,
        from: ServiceAggregateKeyWithEpoch,
        resourceIds: Collection<String>,
        currentEpoch: Long,
        limit: Long
    ): WithTx<ServiceAggregateKeyWithEpochPage> {
        val sortedResourceIds = resourceIds.sortedDescending()
        val resourceIdsTail = sortedResourceIds.filter { it < from.key.resourceId }
        val (query, params) = if (resourceIdsTail.isNotEmpty()) {
            val nextQuery = ydbQuerySource
                .getQuery("yql.queries.serviceAggregateUsage.getKeysForOlderEpochsMultiResourceNextPage")
            val nextParams = toGetKeysForOlderEpochsMultiResourceNextPageParams(from, resourceIdsTail,
                currentEpoch, limit)
            Pair(nextQuery, nextParams)
        } else {
            val lastQuery = ydbQuerySource
                .getQuery("yql.queries.serviceAggregateUsage.getKeysForOlderEpochsMultiResourceLastPage")
            val lastParams = toGetKeysForOlderEpochsMultiResourceLastPageParams(from, limit)
            Pair(lastQuery, lastParams)
        }
        val page = DaoReader.toModelsWithTx(session.executeDataQueryRetryable(query, params).awaitSingle(),
            this::toKeyWithEpoch)
        return WithTx(ServiceAggregateKeyWithEpochPage(keys = page.value,
            nextFrom = if (page.value.size >= limit) { page.value.last() } else { null }), page.txId)
    }

    private fun toModel(reader: ResultSetReader) = ServiceAggregateUsageModel(
        key = toKey(reader),
        lastUpdate = reader.getColumn("last_update").timestamp,
        epoch = reader.getColumn("epoch").int64,
        exactAmounts = exactFieldHelper.read(reader.getColumn("exact_amounts"))!!
    )

    private fun toKey(reader: ResultSetReader) = ServiceAggregateKey(
        tenantId = TenantId(reader.getColumn("tenant_id").utf8),
        serviceId = reader.getColumn("service_id").int64,
        providerId = reader.getColumn("provider_id").utf8,
        resourceId = reader.getColumn("resource_id").utf8
    )

    private fun toKeyWithEpoch(reader: ResultSetReader) = ServiceAggregateKeyWithEpoch(
        key = toKey(reader),
        epoch = reader.getColumn("epoch").int64
    )

    private fun toKeyParams(id: ServiceAggregateKey) = Params.of(
        "\$id", toKeyStruct(id)
    )

    private fun toKeyListParams(ids: Collection<ServiceAggregateKey>) = Params.of(
        "\$ids", ListValue.of(*ids.map { toKeyStruct(it) }.toTypedArray())
    )

    private fun toKeyStruct(id: ServiceAggregateKey) = StructValue.of(mapOf(
        "tenant_id" to PrimitiveValue.utf8(id.tenantId.id),
        "service_id" to PrimitiveValue.int64(id.serviceId),
        "provider_id" to PrimitiveValue.utf8(id.providerId),
        "resource_id" to PrimitiveValue.utf8(id.resourceId)
    ))

    private fun toUpsertOneParams(value: ServiceAggregateUsageModel) = Params.of(
        "\$value", toUpsertStruct(value)
    )

    private fun toUpsertManyParams(values: Collection<ServiceAggregateUsageModel>) = Params.of(
        "\$values", ListValue.of(*values.map { toUpsertStruct(it) }.toTypedArray())
    )

    private fun toUpsertStruct(value: ServiceAggregateUsageModel) = StructValue.of(
        mapOf(
            "tenant_id" to PrimitiveValue.utf8(value.key.tenantId.id),
            "service_id" to PrimitiveValue.int64(value.key.serviceId),
            "provider_id" to PrimitiveValue.utf8(value.key.providerId),
            "resource_id" to PrimitiveValue.utf8(value.key.resourceId),
            "last_update" to PrimitiveValue.timestamp(value.lastUpdate),
            "epoch" to PrimitiveValue.int64(value.epoch),
            "exact_amounts" to exactFieldHelper.write(value.exactAmounts)
        )
    )

    private fun toGetByServiceFirstPageParams(tenantId: TenantId, serviceId: Long, limit: Long) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$service_id", PrimitiveValue.int64(serviceId),
            "\$limit", PrimitiveValue.uint64(limit))

    private fun toGetByServiceNextPageParams(tenantId: TenantId, serviceId: Long, limit: Long, fromProviderId: String,
                                             fromResourceId: String) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$service_id", PrimitiveValue.int64(serviceId),
            "\$limit", PrimitiveValue.uint64(limit),
            "\$from_provider_id", PrimitiveValue.utf8(fromProviderId),
            "\$from_resource_id", PrimitiveValue.utf8(fromResourceId))

    private fun toGetByServicesFirstPageParams(tenantId: TenantId, serviceIds: Collection<Long>, limit: Long) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$service_ids", ListValue.of(*serviceIds.map { PrimitiveValue.int64(it) }.toTypedArray()),
            "\$limit", PrimitiveValue.uint64(limit))

    private fun toGetByServicesNextPageParams(tenantId: TenantId, serviceIds: Collection<Long>, limit: Long,
                                              fromServiceId: Long, fromProviderId: String,
                                              fromResourceId: String) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$service_ids", ListValue.of(*serviceIds.map { PrimitiveValue.int64(it) }.toTypedArray()),
            "\$limit", PrimitiveValue.uint64(limit),
            "\$from_service_id", PrimitiveValue.int64(fromServiceId),
            "\$from_provider_id", PrimitiveValue.utf8(fromProviderId),
            "\$from_resource_id", PrimitiveValue.utf8(fromResourceId))

    private fun toGetByServicesLastPageParams(tenantId: TenantId, limit: Long, fromServiceId: Long,
                                              fromProviderId: String, fromResourceId: String) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$limit", PrimitiveValue.uint64(limit),
            "\$from_service_id", PrimitiveValue.int64(fromServiceId),
            "\$from_provider_id", PrimitiveValue.utf8(fromProviderId),
            "\$from_resource_id", PrimitiveValue.utf8(fromResourceId))

    private fun toGetByServiceGroupByProviderIdFirstPageParams(tenantId: TenantId, serviceId: Long, limit: Long) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$service_id", PrimitiveValue.int64(serviceId),
            "\$limit", PrimitiveValue.uint64(limit))

    private fun toGetByServiceGroupByProviderIdNextPageParams(tenantId: TenantId, serviceId: Long, limit: Long,
                                                              fromProviderId: String) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$service_id", PrimitiveValue.int64(serviceId),
            "\$limit", PrimitiveValue.uint64(limit),
            "\$from_provider_id", PrimitiveValue.utf8(fromProviderId))

    private fun toGetByServiceAndProviderGroupByResourceIdFirstPageParams(
        tenantId: TenantId, serviceId: Long, providerId: String, limit: Long) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$service_id", PrimitiveValue.int64(serviceId),
            "\$provider_id", PrimitiveValue.utf8(providerId),
            "\$limit", PrimitiveValue.uint64(limit))

    private fun toGetByServiceAndProviderGroupByResourceIdNextPageParams(
        tenantId: TenantId, serviceId: Long, providerId: String, limit: Long, fromResourceId: String) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$service_id", PrimitiveValue.int64(serviceId),
            "\$provider_id", PrimitiveValue.utf8(providerId),
            "\$limit", PrimitiveValue.uint64(limit),
            "\$from_resource_id", PrimitiveValue.utf8(fromResourceId))

    private fun serviceIdsTail(sortedServiceIds: List<Long>, currentServiceId: Long) = sortedServiceIds
        .filter { it > currentServiceId }

    private fun toGetKeysForOlderEpochsFirstPageParams(tenantId: TenantId, providerId: String, resourceId: String,
                                                       currentEpoch: Long, limit: Long) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$provider_id", PrimitiveValue.utf8(providerId),
            "\$resource_id", PrimitiveValue.utf8(resourceId),
            "\$current_epoch", PrimitiveValue.int64(currentEpoch),
            "\$limit", PrimitiveValue.uint64(limit)
        )

    private fun toGetKeysForOlderEpochsNextPageParams(from: ServiceAggregateKeyWithEpoch, limit: Long) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(from.key.tenantId.id),
            "\$provider_id", PrimitiveValue.utf8(from.key.providerId),
            "\$resource_id", PrimitiveValue.utf8(from.key.resourceId),
            "\$from_epoch", PrimitiveValue.int64(from.epoch),
            "\$from_service_id", PrimitiveValue.int64(from.key.serviceId),
            "\$limit", PrimitiveValue.uint64(limit)
        )

    private fun toGetKeysForOlderEpochsMultiResourceFirstPageParams(tenantId: TenantId, providerId: String,
                                                                    resourceIds: Collection<String>, currentEpoch: Long,
                                                                    limit: Long) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$provider_id", PrimitiveValue.utf8(providerId),
            "\$resource_ids", ListValue.of(*resourceIds.map { PrimitiveValue.utf8(it) }.toTypedArray()),
            "\$current_epoch", PrimitiveValue.int64(currentEpoch),
            "\$limit", PrimitiveValue.uint64(limit)
        )

    private fun toGetKeysForOlderEpochsMultiResourceNextPageParams(from: ServiceAggregateKeyWithEpoch,
                                                                   resourceIds: Collection<String>, currentEpoch: Long,
                                                                   limit: Long) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(from.key.tenantId.id),
            "\$provider_id", PrimitiveValue.utf8(from.key.providerId),
            "\$resource_ids", ListValue.of(*resourceIds.map { PrimitiveValue.utf8(it) }.toTypedArray()),
            "\$from_resource_id", PrimitiveValue.utf8(from.key.resourceId),
            "\$current_epoch", PrimitiveValue.int64(currentEpoch),
            "\$from_epoch", PrimitiveValue.int64(from.epoch),
            "\$from_service_id", PrimitiveValue.int64(from.key.serviceId),
            "\$limit", PrimitiveValue.uint64(limit)
        )

    private fun toGetKeysForOlderEpochsMultiResourceLastPageParams(from: ServiceAggregateKeyWithEpoch,
                                                                   limit: Long) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(from.key.tenantId.id),
            "\$provider_id", PrimitiveValue.utf8(from.key.providerId),
            "\$from_resource_id", PrimitiveValue.utf8(from.key.resourceId),
            "\$from_epoch", PrimitiveValue.int64(from.epoch),
            "\$from_service_id", PrimitiveValue.int64(from.key.serviceId),
            "\$limit", PrimitiveValue.uint64(limit)
        )

    private fun toGetByServiceAndProviderFirstPageParams(tenantId: TenantId, serviceId: Long, providerId: String,
                                                         limit: Long) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$service_id", PrimitiveValue.int64(serviceId),
            "\$provider_id", PrimitiveValue.utf8(providerId),
            "\$limit", PrimitiveValue.uint64(limit))

    private fun toGetByServiceAndProviderNextPageParams(tenantId: TenantId, serviceId: Long, providerId: String,
                                                        limit: Long, fromResourceId: String) = Params
        .of("\$tenant_id", PrimitiveValue.utf8(tenantId.id),
            "\$service_id", PrimitiveValue.int64(serviceId),
            "\$limit", PrimitiveValue.uint64(limit),
            "\$provider_id", PrimitiveValue.utf8(providerId),
            "\$from_resource_id", PrimitiveValue.utf8(fromResourceId))

}
