package ru.yandex.calendar.logic.staff.dao;

import java.sql.PreparedStatement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.SneakyThrows;
import lombok.Value;
import one.util.streamex.StreamEx;
import org.intellij.lang.annotations.Language;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.calendar.micro.yt.entity.UserIdWIthDepartmentSlugs;
import ru.yandex.calendar.micro.yt.entity.YtUser;
import ru.yandex.calendar.micro.yt.entity.YtUserWithDepartmentIds;
import ru.yandex.inside.passport.PassportUid;
import ru.yandex.mail.cerberus.GroupId;
import ru.yandex.mail.cerberus.Uid;
import ru.yandex.mail.cerberus.yt.data.YtUserInfo;

public class UsersDao extends StaffDao<YtUserWithDepartmentIds> {

    @Override
    protected int[] insert(List<YtUserWithDepartmentIds> users) {
        return getJdbcTemplate().batchUpdate(
                "INSERT INTO staff_users (uid, login, info, groups) VALUES (?, ?, (?)::JSONB, ?) " +
                        "ON CONFLICT (uid) DO UPDATE SET " +
                        "login=excluded.login, info=excluded.info, groups=excluded.groups",
                new BatchPreparedStatementSetter() {

                    @SneakyThrows
                    public void setValues(PreparedStatement ps, int i) {
                        ps.setLong(1, users.get(i).getUser().getUid().getValue());
                        ps.setString(2, users.get(i).getUser().getLogin());
                        ps.setString(3, objectMapper.writeValueAsString(users.get(i).getUser().getInfo()));
                        ps.setArray(4, ps.getConnection().createArrayOf("BIGINT",
                                users.get(i).getDepartmentIds()
                                        .stream().map(GroupId::getValue).toArray()
                        ));
                    }

                    @Override
                    public int getBatchSize() {
                        return users.size();
                    }
                });
    }

    @Override
    public int update(List<YtUserWithDepartmentIds> entities) {
        var existing = existingIds(StreamEx.of(entities).map(user -> user.getUser().getUid().getValue()).toImmutableList());
        insert(StreamEx.of(entities).filter(user -> existing.contains(user.getUser().getUid().getValue())).toImmutableList());
        return Arrays.stream(insert(StreamEx.of(entities).filter(user -> !existing.contains(user.getUser().getUid().getValue())).toImmutableList())).sum();
    }

    private List<Long> existingIds(List<Long> ids) {
        if (ids.isEmpty()) {
            return Cf.list();
        }
        String inSql = String.join(",", Collections.nCopies(ids.size(), "?"));
        return getJdbcTemplate().query(
                String.format(
                        "SELECT uid FROM staff_users " +
                                "WHERE uid IN (%s)", inSql),
                (rs, rowNum) -> {
                    return rs.getLong("uid");
                },
                ids.toArray()
        );
    }


    @Override
    public List<YtUserWithDepartmentIds> getAll(int limit, int offset) throws JsonProcessingException {
        @Value
        class UserArgs {
            long uid;
            String login;
            String info;
            Long[] groups;
        }
        var args = getJdbcTemplate().query(
                "SELECT * FROM staff_users ORDER BY uid LIMIT (?) OFFSET (?)",
                (rs, rowNum) -> {
                    long uid = rs.getLong("uid");
                    String login = rs.getString("login");
                    String info = rs.getString("info");
                    var groups = (Long[]) rs.getArray("groups").getArray();
                    return new UserArgs(uid, login, info, groups);
                },
                limit, offset
        );
        List<YtUserWithDepartmentIds> users = new ArrayList<>();
        for (var arg : args) {
            YtUserInfo info = objectMapper.readValue(arg.info, YtUserInfo.class);
            var groups = Arrays.stream(arg.groups)
                    .map(GroupId::new).collect(Collectors.toUnmodifiableSet());
            users.add(new YtUserWithDepartmentIds(new YtUser(new Uid(arg.uid), arg.login, info), groups));
        }
        return users;
    }


    @Override
    protected int count() {
        return getJdbcTemplate().query(
                "SELECT COUNT(*) FROM staff_users",
                (rs, rowNum) -> {
                    return rs.getInt("count");
                }
        ).get(0);
    }

    public List<UserIdWIthDepartmentSlugs> findUserDepartmentsByUids(List<PassportUid> uids) {
        String inSql = String.join(",", Collections.nCopies(uids.size(), "?"));
        @Language("SQL")
        String queryTemplate = "SELECT su.uid as uid, array_agg(sg.info->>'url') as group_slugs FROM staff_users su " +
                "JOIN staff_groups sg ON sg.id = ANY(su.groups) " +
                "WHERE su.uid IN (%s) GROUP BY su.uid";
        String query = String.format(queryTemplate, inSql);

        return getJdbcTemplate().query(query,
                (rs, rowNum) -> {
                    long uid = rs.getLong("uid");
                    String[] groups = (String[]) rs.getArray("group_slugs").getArray();
                    return new UserIdWIthDepartmentSlugs(uid, List.of(groups));
                }, uids.toArray());
    }
}
