package ru.yandex.direct.oneshot.oneshots.video

import okhttp3.OkHttpClient
import okhttp3.Request
import one.util.streamex.StreamEx
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.stereotype.Component
import ru.yandex.direct.canvas.client.CanvasClient
import ru.yandex.direct.canvas.client.model.exception.CanvasClientException
import ru.yandex.direct.canvas.client.model.video.AdditionResponse
import ru.yandex.direct.canvas.client.model.video.VideoUploadResponse
import ru.yandex.direct.core.entity.banner.model.Banner
import ru.yandex.direct.core.entity.banner.model.BannerWithBannerImage
import ru.yandex.direct.core.entity.banner.model.BannerWithCreative
import ru.yandex.direct.core.entity.banner.model.BannerWithSystemFields
import ru.yandex.direct.core.entity.banner.repository.BannerTypedRepository
import ru.yandex.direct.core.entity.banner.service.BannersUpdateOperation
import ru.yandex.direct.core.entity.banner.service.BannersUpdateOperationFactory
import ru.yandex.direct.core.entity.banner.service.moderation.ModerationMode
import ru.yandex.direct.core.entity.client.service.ClientService
import ru.yandex.direct.dbutil.model.ClientId
import ru.yandex.direct.dbutil.sharding.ShardHelper
import ru.yandex.direct.dbutil.sharding.ShardKey
import ru.yandex.direct.model.ModelChanges
import ru.yandex.direct.oneshot.worker.def.Approvers
import ru.yandex.direct.oneshot.worker.def.Multilaunch
import ru.yandex.direct.oneshot.worker.def.PausedStatusOnFail
import ru.yandex.direct.oneshot.worker.def.SimpleOneshot
import ru.yandex.direct.result.MassResult
import ru.yandex.direct.utils.JsonUtils
import ru.yandex.direct.validation.builder.Constraint
import ru.yandex.direct.validation.builder.ItemValidationBuilder
import ru.yandex.direct.validation.constraint.CommonConstraints
import ru.yandex.direct.validation.constraint.NumberConstraints
import ru.yandex.direct.validation.defect.CommonDefects
import ru.yandex.direct.validation.result.Defect
import ru.yandex.direct.validation.result.ValidationResult
import ru.yandex.direct.ytwrapper.client.YtProvider
import ru.yandex.direct.ytwrapper.model.YtCluster
import ru.yandex.direct.ytwrapper.model.YtTable
import ru.yandex.inside.yt.kosher.impl.ytree.`object`.annotation.YTreeField
import ru.yandex.inside.yt.kosher.impl.ytree.`object`.annotation.YTreeObject
import ru.yandex.inside.yt.kosher.tables.YTableEntryTypes
import java.util.concurrent.TimeUnit
import java.util.function.Consumer
import java.util.function.Function


data class Param(
    val ytCluster: YtCluster,
    val tablePath: String,
    val chunkSize: Int
)

data class State(
    val lastRow: Long = 0L
)

@YTreeObject
data class InputTableRow(
    @YTreeField(key = "CampaignId") val campaignId: Long,
    @YTreeField(key = "BannerId") val bannerId: Long,
    @YTreeField(key = "ImageHash") val imageHash: String,
    @YTreeField(key = "VideoLink") val videoLink: String
)

data class BannerInfo(
    val campaignId: Long,
    val bannerId: Long,
    val imageHash: String,
    val videoLink: String
)

/**
 * Ваншот для замены дефолтного видеодополнения на 3D видео сгенеронное по картинке баннера
 * На вход получает таблицу в YT @see InputTableRow
 * Подробнее можно почитать тут: https://st.yandex-team.ru/DIRECT-151833
 */
@Component
@Multilaunch
@PausedStatusOnFail
@Approvers("buhter", "ssdmitriev")
class Add3DVideoAdditionsOneshot @Autowired constructor(
    private val ytProvider: YtProvider,
    private val bannerTypedRepository: BannerTypedRepository,
    private val shardHelper: ShardHelper,
    private val canvasClient: CanvasClient,
    private val bannersUpdateOperationFactory: BannersUpdateOperationFactory,
    private val clientService: ClientService
) : SimpleOneshot<Param, State?> {

    companion object {
        private val logger = LoggerFactory.getLogger(Add3DVideoAdditionsOneshot::class.java)

        const val HTTP_TIMEOUT = 60L // sec
    }

    private val okHttpClient = OkHttpClient.Builder()
        .connectTimeout(HTTP_TIMEOUT, TimeUnit.SECONDS)
        .readTimeout(HTTP_TIMEOUT, TimeUnit.SECONDS)
        .build()

    override fun validate(inputData: Param): ValidationResult<Param, Defect<*>>? {
        val vb = ItemValidationBuilder.of(inputData, Defect::class.java)

        vb.item(inputData.ytCluster, "ytCluster")
            .check(CommonConstraints.notNull())

        if (vb.result.hasAnyErrors()) return vb.result

        vb.item(inputData.tablePath, "tablePath")
            .check(CommonConstraints.notNull())
            .check(
                Constraint.fromPredicate(
                    { tableName -> ytProvider.getOperator(inputData.ytCluster).exists(YtTable(tableName)) },
                    CommonDefects.objectNotFound()
                )
            )

        vb.item(inputData.chunkSize, "chunkSize")
            .check(CommonConstraints.notNull())
            .check(NumberConstraints.greaterThan(0))

        return vb.result
    }

    override fun execute(inputData: Param, prevState: State?): State? {
        val startRow = prevState?.lastRow ?: 0L
        val lastRow = startRow + inputData.chunkSize
        logger.info("Start from row=$startRow, to row=$lastRow (excluded)")

        val banners = readInputTable(inputData, startRow, lastRow)
        if (banners.isEmpty()) {
            logger.info("Got empty banners, let's finish")
            return null
        }

        logger.info("banners from yt table: ${JsonUtils.toJson(banners)}")

        val groupedBannersByClientId = getBannerWithDefaultVideoAdditionGroupedByClientId(banners)
        logger.info("banners with default videoAdditions: ${JsonUtils.toJson(groupedBannersByClientId)}")

        groupedBannersByClientId
            .forEach { (clientId, clientBanners) ->
                val creativeIdByBannerId = createBannersVideoAddition(clientId, clientBanners)
                logger.info("Created videoAdditions: ${JsonUtils.toJson(creativeIdByBannerId)}")

                val bannersUpdateOperation = createUpdateOperation(clientId, creativeIdByBannerId)

                val result = bannersUpdateOperation.prepareAndApply()
                if (result.validationResult != null && result.validationResult.hasAnyErrors()) {
                    logger.error("banners validationErrors: ${result.validationResult.flattenErrors()}")
                }

                loggingUpdatedBanners(result, clientBanners)
            }

        if (banners.size < inputData.chunkSize) {
            logger.info("Last iteration finished")
            return null
        }
        return State(lastRow)
    }

    private fun loggingUpdatedBanners(result: MassResult<Long>, clientBanners: List<BannerInfo>) {
        val successfullyUpdatedIds = getSuccessfullyUpdatedIds(result)

        val updatedBanners = clientBanners
            .filter { successfullyUpdatedIds.contains(it.bannerId) }
        logger.error("updated banners: ${JsonUtils.toJson(updatedBanners)}")

        val campaignIds = updatedBanners
            .map { it.campaignId }
            .toSet()
        logger.error("campaignIds of updated banners: $campaignIds")
    }

    private fun createUpdateOperation(
        clientId: Long,
        creativeIdByBannerId: Map<Long, Long>
    ): BannersUpdateOperation<BannerWithSystemFields> {
        val client = clientService.getClient(ClientId.fromLong(clientId))!!

        return bannersUpdateOperationFactory
            .createPartialUpdateOperation(
                ModerationMode.DEFAULT,
                getModelChanges(creativeIdByBannerId),
                client.chiefUid, ClientId.fromLong(clientId)
            )
    }

    private fun createBannersVideoAddition(clientId: Long, banners: List<BannerInfo>): Map<Long, Long> {
        return StreamEx.of(banners)
            .mapToEntry(BannerInfo::bannerId, BannerInfo::videoLink)
            .mapValues { uploadVideoByUrl(clientId, it) }
            .nonNullValues()
            .mapValues { createCreative(clientId, it!!.presetId, it.id) }
            .nonNullValues()
            .mapValues { it!!.creativeId }
            .toMap()
    }

    private fun getBannerWithDefaultVideoAdditionGroupedByClientId(banners: List<BannerInfo>): Map<Long, List<BannerInfo>> {
        val bannerIdsWithDefaultVideoAddition = getBannersIdsWithDefaultVideoAddition(banners)

        val campaignIds = banners.map(BannerInfo::campaignId)
        val clientIdByCampaignId = shardHelper.getClientIdsByCampaignIds(campaignIds)

        return StreamEx.of(banners)
            .mapToEntry(BannerInfo::campaignId, Function.identity())
            .filterKeys(clientIdByCampaignId::containsKey)
            .mapKeys { clientIdByCampaignId[it]!! }
            .filterValues { bannerIdsWithDefaultVideoAddition.contains(it.bannerId) }
            .grouping()
    }

    private fun uploadVideoByUrl(clientId: Long, url: String): VideoUploadResponse? {
        // скачиваем файл сами, т.к. урлы с видео у нас из внутренней сети,
        // а загрузка по урлу в канвасе работает для урлов из внешней сети
        val downloadRequest = Request.Builder().url(url).build()
        val downloadResponse = okHttpClient.newCall(downloadRequest).execute()
        if (!downloadResponse.isSuccessful) {
            logger.info("Failed download video by url: $url")
            return null
        }

        return try {
            val response = canvasClient.createVideoFromFile(
                clientId, downloadResponse.body().bytes(), "3d_video", null, null
            )
            logger.info("Uploaded video id: {}", response.id)
            return response
        } catch (e: CanvasClientException) {
            logger.error("Failure to upload video from file, url = $url", e)
            null
        }
    }

    private fun createCreative(clientId: Long, presetId: Long, videoId: String): AdditionResponse? {
        return try {
            val response = canvasClient.createDefaultAddition(clientId, presetId, videoId)
            logger.info("Created creative id: {}", response.creativeId)
            response
        } catch (e: CanvasClientException) {
            logger.error("Failure to create creative for videoId=$videoId", e)
            null
        }
    }

    private fun getBannersIdsWithDefaultVideoAddition(banners: List<BannerInfo>): Set<Long> {
        val bannerIds = banners.map(BannerInfo::bannerId)
        val imageHashByBannerId = banners.associate { it.bannerId to it.imageHash }

        return shardHelper.groupByShard(bannerIds, ShardKey.BID)
            .stream()
            .mapKeyValue { shard, ids -> bannerTypedRepository.getTyped(shard, ids) }
            .flatMap { it.stream() }
            .peek {
                val bannerId = it.id

                if (it is BannerWithCreative && it.creativeId != null) {
                    logger.info("Banner with id=${bannerId} has videoAddition=${it.creativeId}")
                } else if (it !is BannerWithCreative) {
                    logger.warn("Banner with id=${bannerId} has unexpected type ${it.javaClass}")
                }

                if (it is BannerWithBannerImage && it.imageHash != imageHashByBannerId[bannerId]) {
                    logger.info("Banner with id=${bannerId} has another image=${it.imageHash}, expected=${imageHashByBannerId[bannerId]}")
                } else if (it !is BannerWithBannerImage) {
                    logger.warn("Banner with id=${bannerId} has unexpected type ${it.javaClass}")
                }
            }
            .filter {
                val bannerId = it.id
                // оставляем баннеры с дефолтным видеодополнением и с картинкой для которой генерировали 3д видео
                it is BannerWithCreative && it.creativeId == null
                        && it is BannerWithBannerImage && it.imageHash == imageHashByBannerId[bannerId]
            }
            .map(Banner::getId)
            .toSet()
    }

    private fun getModelChanges(creativeIdByBannerId: Map<Long, Long>): List<ModelChanges<BannerWithSystemFields>> {
        return creativeIdByBannerId
            .map { (bannerId, creativeId) ->
                ModelChanges(bannerId, BannerWithCreative::class.java)
                    .process(creativeId, BannerWithCreative.CREATIVE_ID)
                    .castModel(BannerWithSystemFields::class.java)
            }
    }

    private fun getSuccessfullyUpdatedIds(massResult: MassResult<Long>): Set<Long> {
        return if (massResult.result == null) {
            setOf()
        } else StreamEx.of(massResult.result)
            .nonNull()
            .filter { it.isSuccessful }
            .map { it.result }
            .toSet()
    }

    private fun readInputTable(inputData: Param, startRow: Long, lastRow: Long): List<BannerInfo> {
        val entryType = YTableEntryTypes.yson(InputTableRow::class.java)
        val ytTable = YtTable(inputData.tablePath)

        val items = mutableListOf<BannerInfo>()
        ytProvider.get(inputData.ytCluster).tables()
            .read(ytTable.ypath().withRange(startRow, lastRow), entryType,
                Consumer { row ->
                    items.add(
                        BannerInfo(
                            campaignId = row.campaignId,
                            bannerId = row.bannerId,
                            imageHash = row.imageHash,
                            videoLink = row.videoLink
                        )
                    )
                })
        return items
    }

}
