package ru.yandex.webmaster3.storage.notifications.dao;

import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.curator.shaded.com.google.common.collect.Lists;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.springframework.stereotype.Component;

import ru.yandex.webmaster3.storage.notifications.NotificationChannel;
import ru.yandex.webmaster3.storage.notifications.ProgressInfo;
import ru.yandex.webmaster3.storage.notifications.SendByChannelResult;
import ru.yandex.webmaster3.storage.util.ydb.AbstractYDao;
import ru.yandex.webmaster3.storage.util.ydb.querybuilder.Select;
import ru.yandex.webmaster3.storage.util.ydb.querybuilder.typesafe.DataMapper;
import ru.yandex.webmaster3.storage.util.ydb.querybuilder.typesafe.Field;
import ru.yandex.webmaster3.storage.util.ydb.querybuilder.typesafe.Fields;

/**
 * @author kravchenko99
 * @date 11/3/21
 */
@Component
public class NotificationProgressYDao extends AbstractYDao {
    private static final int READ_BATCH_SIZE = 1000 / NotificationChannel.values().length;
    private static final Duration TTL = Duration.standardDays(7);

    public NotificationProgressYDao() {
        super(PREFIX_NOTIFICATION, "notification_progress");
    }

    public void saveResult(UUID notificationId, String targetId, NotificationChannel channel,
                           SendByChannelResult result) {
        upsert(
                F.NOTIFICATION_ID.value(notificationId),
                F.TARGET_ID.value(targetId),
                F.CHANNEL.value(channel),
                F.STATUS.value(result),
                F.TTL_DATE.value(DateTime.now())
        ).execute();
    }

    public Map<String, List<ProgressInfo>> getProgressInfo(UUID notificationId, List<String> targetIds) {
        DateTime ttlThreshold = DateTime.now().minus(TTL);

        List<Select<Pair<String, ProgressInfo>>> statements = Lists.partition(targetIds, READ_BATCH_SIZE).stream()
                .map(splitTargetIds -> {
                    return select(TARGET_ID_PROGRESS_INFO_MAPPER)
                            .where(F.NOTIFICATION_ID.eq(notificationId))
                            .and(F.TARGET_ID.in(splitTargetIds))
                            .and(F.TTL_DATE.gte(ttlThreshold))
                            .getStatement();
                }).toList();


        return asyncExecute(statements, TARGET_ID_PROGRESS_INFO_MAPPER).stream().collect(Collectors.groupingBy(
                Pair::getLeft,
                Collectors.mapping(Pair::getRight, Collectors.toList())));
    }

    private static final DataMapper<ProgressInfo> MAPPER = DataMapper.create(F.CHANNEL, F.STATUS, ProgressInfo::new);
    private static final DataMapper<Pair<String, ProgressInfo>> TARGET_ID_PROGRESS_INFO_MAPPER =
            DataMapper.create(F.TARGET_ID, MAPPER, Pair::of);

    private static class F {
        static final Field<UUID> NOTIFICATION_ID = Fields.uuidField("notification_id");
        static final Field<String> TARGET_ID = Fields.stringField("target_id");
        static final Field<NotificationChannel> CHANNEL = Fields.intEnumField("channel", NotificationChannel.R);
        static final Field<SendByChannelResult> STATUS = Fields.intEnumField("status", SendByChannelResult.R);
        static final Field<DateTime> TTL_DATE = Fields.jodaDateTimeField("ttl_date");
    }
}
