package ru.yandex.direct.grid.processing.service.statistics.service

import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.stereotype.Service
import ru.yandex.direct.core.entity.metrika.utils.AttributionConverter
import ru.yandex.direct.currency.CurrencyCode
import ru.yandex.direct.dbutil.model.ClientId
import ru.yandex.direct.grid.model.campaign.GdCampaignAttributionModel
import ru.yandex.direct.grid.model.campaign.GdCampaignAttributionModel.toSource
import ru.yandex.direct.grid.processing.context.container.GridGraphQLContext
import ru.yandex.direct.grid.processing.model.client.GdClient
import ru.yandex.direct.grid.processing.model.statistics.GdCampaignStatisticsPeriod
import ru.yandex.direct.grid.processing.model.statistics.GdCampaignStatisticsValueHolder
import ru.yandex.direct.grid.processing.model.statistics.metrika.GdMetrikaStatisticsColumnValues
import ru.yandex.direct.grid.processing.model.statistics.metrika.GdMetrikaStatisticsContainer
import ru.yandex.direct.grid.processing.model.statistics.metrika.GdMetrikaStatisticsDimension
import ru.yandex.direct.grid.processing.model.statistics.metrika.GdMetrikaStatisticsItem
import ru.yandex.direct.grid.processing.model.statistics.metrika.GdMetrikaStatisticsMeta
import ru.yandex.direct.grid.processing.model.statistics.metrika.GdMetrikaStatisticsPayload
import ru.yandex.direct.grid.processing.service.statistics.validation.EndToEndAnalyticsValidationService
import ru.yandex.direct.grid.processing.util.StatHelper.getDayPeriod
import ru.yandex.direct.metrika.client.MetrikaClient
import ru.yandex.direct.metrika.client.internal.Attribution
import ru.yandex.direct.metrika.client.internal.Dimension
import ru.yandex.direct.metrika.client.internal.MetrikaByTimeStatisticsParams
import ru.yandex.direct.metrika.client.internal.MetrikaSourcesParams
import ru.yandex.direct.metrika.client.model.response.Counter
import ru.yandex.direct.metrika.client.model.response.CounterGoal
import ru.yandex.direct.metrika.client.model.response.statistics.StatisticsResponseRow
import java.math.BigDecimal
import java.time.Instant
import java.time.LocalDate

private const val ECOMMERCE_FEATURE_NAME = "ecommerce"
const val EXPENSES_FEATURE_NAME = "expenses"

@Service
class EndToEndAnalyticsService
@Autowired constructor(
    private val metrikaClient: MetrikaClient,
    private val endToEndAnalyticsValidationService: EndToEndAnalyticsValidationService,
) {
    /**
     * Получить данные по сквозной аналитике из Метрики
     */
    fun getEndToEndStatistics(
        input: GdMetrikaStatisticsContainer,
        context: GridGraphQLContext,
        client: GdClient
    ): GdMetrikaStatisticsPayload {
        validateStatisticsQuery(input, context)
        val counter = metrikaClient.getCounter(input.filter.counterId)
        if (!counter.features.contains(EXPENSES_FEATURE_NAME)) {
            logger.warn("Cannot load expense statistics for counter ${counter.id}: " +
                "feature $EXPENSES_FEATURE_NAME is not enabled.")
            return GdMetrikaStatisticsPayload().withRowset(emptyList())
        }
        val params = getMetrikaByTimeStatisticsParams(input, client, counter)
        val endToEndStatistics = metrikaClient.getEndToEndStatistics(params)
        val currency = CurrencyCode.parse(endToEndStatistics.currencyCode)
        val endToEndStatisticsRows = endToEndStatistics.rowset
            .filter { checkRowByAdvChannelIds(it, input.filter.advChannelIds) }
        return GdMetrikaStatisticsPayload()
            .withMeta(GdMetrikaStatisticsMeta().withCurrency(currency))
            .withRowset(endToEndStatisticsRows.map { convertToGdStatsItem(it) })
    }

    /**
     * Получить данные по рекламным системам и источникам трафика из Метрики
     */
    fun getMetrikaMarketingStatistics(
        input: GdMetrikaStatisticsContainer,
        context: GridGraphQLContext,
        client: GdClient
    ): GdMetrikaStatisticsPayload {
        validateStatisticsQuery(input, context)
        val sources = metrikaClient.getAvailableSources(getMetrikaSourcesParams(input))
        val rowIds = sources.items
            .asSequence()
            .map { listOf(it.category.id, it.channel.id) }
            .plusElement(listOf(ORGANIC_CATEGORY.id))
            .plusElement(listOf(AD_CATEGORY.id, DIRECT_CHANNEL.id)) // explicitly add Yandex.Direct and
            .plusElement(listOf(AD_CATEGORY.id, DIRECT_UNDETERMINED_CHANNEL.id)) // Yandex.Direct: Undetermined
            .plusElement(emptyList()) // for total statistics
            .toSet() // to avoid duplicates
        val params = getMetrikaByTimeStatisticsParams(input, client, rowIds = rowIds.toList())
        val trafficSourceStatistics = metrikaClient.getTrafficSourceStatistics(params)
        val yandexDirectStatisticsRows = extractYandexDirectStatistics(trafficSourceStatistics.rowset)
        val otherTrafficStatisticsRows = extractOtherTrafficStatistics(trafficSourceStatistics.rowset)
        val currency = CurrencyCode.parse(trafficSourceStatistics.currencyCode)
        val marketingStatisticsRows = trafficSourceStatistics.rowset
            .asSequence()
            .filter { !checkRowByAdvChannelIds(it, DIRECT_CHANNELS_IDS) } // remove Direct statistics
            .filter { !checkRowHasTotalStats(it) } // remove total statistics
            .plus(yandexDirectStatisticsRows)
            .plus(otherTrafficStatisticsRows)
            .filter { checkRowByAdvChannelIds(it, input.filter.advChannelIds) }
            .toList()
        return GdMetrikaStatisticsPayload()
            .withMeta(GdMetrikaStatisticsMeta().withCurrency(currency))
            .withRowset(marketingStatisticsRows.map { convertToGdStatsItem(it) })
    }

    private fun extractYandexDirectStatistics(
        rows: Collection<StatisticsResponseRow>
    ): Collection<StatisticsResponseRow> {
        return rows
            .filter { checkRowByAdvChannelIds(it, setOf(DIRECT_CHANNEL.id, DIRECT_UNDETERMINED_CHANNEL.id)) }
            .groupBy { it.period }
            .map { (period, stats) -> sumMetrics(listOf(AD_CATEGORY, DIRECT_CHANNEL), period, stats) }
    }

    private fun extractOtherTrafficStatistics(
        rows: Collection<StatisticsResponseRow>
    ): Collection<StatisticsResponseRow> {
        return rows
            .groupBy { it.period }
            .map { (period, stats) ->
                val sumTraffic = sumMetrics(
                    dimensions = emptyList(),
                    rows = stats.filter { !checkRowHasTotalStats(it) }
                )
                val totalStats = sumMetrics(
                    dimensions = emptyList(),
                    rows = stats.filter { checkRowHasTotalStats(it) }
                )
                val otherClicks = totalStats.clicks?.minus(sumTraffic.clicks ?: 0L)
                val otherGoalVisits = totalStats.goalVisits?.minus(sumTraffic.goalVisits ?: 0L)
                StatisticsResponseRow(
                    dimensions = listOf(OTHER_CATEGORY),
                    period = period,
                    clicks = otherClicks,
                    conversionRate = calculateConversionRate(otherClicks, otherGoalVisits),
                    goalVisits = otherGoalVisits,
                    expenses = totalStats.expenses?.minus(sumTraffic.expenses ?: BigDecimal.ZERO),
                    revenue = totalStats.revenue?.minus(sumTraffic.revenue ?: BigDecimal.ZERO)
                )
            }
    }

    private fun sumMetrics(
        dimensions: List<Dimension>,
        period: String? = null,
        rows: Collection<StatisticsResponseRow>
    ): StatisticsResponseRow {
        val clicks = rows.map { it.clicks }.sumNotNull()
        val goalVisits = rows.map { it.goalVisits }.sumNotNull()
        return StatisticsResponseRow(
            dimensions = dimensions,
            period = period,
            clicks = clicks,
            conversionRate = calculateConversionRate(clicks, goalVisits),
            goalVisits = goalVisits,
            expenses = rows.map { it.expenses }.sumNotNull(),
            revenue = rows.map { it.revenue }.sumNotNull()
        )
    }

    private fun Iterable<Long?>.sumNotNull(): Long? {
        var sum: Long? = null
        for (num in this) {
            sum = sum?.plus(num ?: 0) ?: num
        }
        return sum
    }

    private fun Iterable<BigDecimal?>.sumNotNull(): BigDecimal? {
        var sum: BigDecimal? = null
        for (num in this) {
            sum = sum?.plus(num ?: BigDecimal.ZERO) ?: num
        }
        return sum
    }

    private fun calculateConversionRate(clicks: Long?, goalVisits: Long?) =
        if (clicks != null && goalVisits != null) goalVisits * 100.0 / clicks else null

    private fun validateStatisticsQuery(
        input: GdMetrikaStatisticsContainer,
        context: GridGraphQLContext
    ) {
        endToEndAnalyticsValidationService.validateInput(input)
        val clientId = ClientId.fromLong(context.queriedClient.id)
        endToEndAnalyticsValidationService.validateClientRightForCounter(input, clientId)
    }

    private fun getMetrikaSourcesParams(input: GdMetrikaStatisticsContainer): MetrikaSourcesParams {
        val (dateFrom, dateTo) = input.filter.period.extractDates()
        return MetrikaSourcesParams(
            counterId = input.filter.counterId,
            attribution = input.attributionModel.toMetrikaAttribution(),
            categories = CATEGORIES_FOR_SEARCH,
            channels = CHANNELS_FOR_SEARCH,
            dateFrom = dateFrom,
            dateTo = dateTo,
            limit = SOURCES_LIMIT
        )
    }

    private fun getMetrikaByTimeStatisticsParams(
        input: GdMetrikaStatisticsContainer,
        client: GdClient,
        counter: Counter? = null,
        rowIds: List<List<String>>? = null
    ): MetrikaByTimeStatisticsParams {

        val counterId = input.filter.counterId
        val (dateFrom, dateTo) = input.filter.period.extractDates()
        val attribution = input.attributionModel.toMetrikaAttribution()
        var goalId = input.filter.goalId
        var skipGoalData = false
        if (goalId != null && !isGoalValidForStat(goalId, counterId)) {
            // запрашиваемая цель невалидна -- не пытаемся получить статистику по ней
            skipGoalData = true
            goalId = null
        }

        val currencyCode = client.info.workCurrency
        val counterInfo = counter ?: metrikaClient.getCounter(counterId)
        val withRevenue = counterInfo.features.contains(ECOMMERCE_FEATURE_NAME)

        return MetrikaByTimeStatisticsParams(
            chiefLogin = client.chiefLogin,
            counterId = counterId,
            attribution = attribution,
            currencyCode = if (currencyCode != CurrencyCode.YND_FIXED) currencyCode else null,
            goalId = goalId,
            skipGoalData = skipGoalData,
            withRevenue = withRevenue,
            withConversionRate = goalId == null,
            dateFrom = dateFrom,
            dateTo = dateTo,
            rowIds = rowIds
        )
    }

    private fun isGoalValidForStat(
        goalId: Long?,
        counterId: Long
    ): Boolean {
        if (goalId == null) {
            return false
        }
        // Проверяем, что цель есть и не скрыта
        // Не уношу эту логику в ядро, так как там и так всё запутано. А знание о скрытых целях пока нужно только тут
        val counterIdInt = counterId.toInt()
        val goals = metrikaClient.getMassCountersGoalsFromMetrika(setOf(counterIdInt))[counterIdInt] ?: emptyList()
        val foundGoal = goals.find { it.id == goalId.toInt() }

        if (foundGoal == null) {
            logger.info("Skip statistics for goal $goalId. It is not found on counter $counterId")
            return false
        }
        if (foundGoal.status == CounterGoal.Status.HIDDEN) {
            logger.info("Skip statistics for goal $goalId. It is hidden")
            return false
        }

        return true
    }

    companion object {
        val CATEGORIES_FOR_SEARCH = listOf("ad", "social", "messenger")
        val CHANNELS_FOR_SEARCH = listOf("recommend.zen_yandex")
        val SOURCES_LIMIT = 8L

        val ORGANIC_CATEGORY = Dimension("organic", "Переходы из поисковых систем")
        val OTHER_CATEGORY = Dimension("internal:other", "Другие источники")
        val AD_CATEGORY = Dimension("ad", "Переходы по рекламе")

        val DIRECT_CHANNEL = Dimension("ad.Яндекс: Директ", "Яндекс: Директ")
        val DIRECT_UNDETERMINED_CHANNEL = Dimension("ad.Яндекс.Директ: Не определено", "Яндекс.Директ: Не определено")
        val DIRECT_CHANNELS_IDS = setOf(DIRECT_CHANNEL.id, DIRECT_UNDETERMINED_CHANNEL.id)

        private val logger: Logger = LoggerFactory.getLogger(EndToEndAnalyticsService::class.java)

        fun convertToGdStatsItem(row: StatisticsResponseRow) =
            GdMetrikaStatisticsItem()
                .withCategory(row.dimensions.first().let { dim ->
                    GdMetrikaStatisticsDimension().withId(dim.id).withName(dim.name)
                })
                .withAdvChannel(row.dimensions.last().let { dim ->
                    GdMetrikaStatisticsDimension().withId(dim.id).withName(dim.name)
                })
                .withPeriod(row.period)
                .withColumnValues(
                    GdMetrikaStatisticsColumnValues()
                        .withClicks(GdCampaignStatisticsValueHolder().withValue(row.clicks))
                        .withGoalVisits(GdCampaignStatisticsValueHolder().withValue(row.goalVisits))
                        .withGoalReaches(GdCampaignStatisticsValueHolder().withValue(row.goalVisits))
                        .withExpenses(GdCampaignStatisticsValueHolder().withValue(row.expenses))
                        .withIncome(GdCampaignStatisticsValueHolder().withValue(row.revenue))
                )

        private fun checkRowHasTotalStats(row: StatisticsResponseRow): Boolean {
            return row.dimensions.isEmpty()
        }

        private fun checkRowByAdvChannelIds(
            row: StatisticsResponseRow,
            advChannelIds: Set<String>?
        ): Boolean {
            return if (advChannelIds == null) {
                true
            } else if (row.dimensions.isEmpty()) {
                false
            } else {
                row.dimensions.last().id in advChannelIds
            }
        }
    }

    private fun GdCampaignStatisticsPeriod.extractDates(): Pair<LocalDate, LocalDate> {
        if (this.preset != null) {
            val dayPeriod = getDayPeriod(this.preset, Instant.now(), null)
            return Pair(dayPeriod.left, dayPeriod.right)
        }

        return Pair(from, to)
    }

    private fun GdCampaignAttributionModel.toMetrikaAttribution(): Attribution? =
        AttributionConverter.coreToMetrika(toSource(this))
}
