package ru.yandex.infra.auth;

import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Function;
import java.util.stream.Collectors;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;

import ru.yandex.infra.auth.staff.StaffApi;
import ru.yandex.infra.auth.staff.StaffGroup;
import ru.yandex.infra.auth.yp.YpAbcRoleGroup;
import ru.yandex.infra.auth.yp.YpClients;
import ru.yandex.infra.controller.RepeatedTask;
import ru.yandex.infra.controller.metrics.GaugeRegistry;
import ru.yandex.infra.controller.metrics.GolovanableGauge;
import ru.yandex.infra.controller.metrics.NamespacedGaugeRegistry;

import static org.slf4j.LoggerFactory.getLogger;

public class GroupsAndUsersCacheImpl implements GroupsAndUsersCache {
    private static final Logger LOG = getLogger(GroupsAndUsersCacheImpl.class);

    public static final String YP_DEPARTMENT_GROUP_PREFIX = "staff:department:";
    public static final String YP_ABC_SERVICE_GROUP_PREFIX = "abc:service:";

    static final String METRIC_STAFF_GROUPS_LOAD_TIME_MS = "staff_groups_load_time_ms";
    static final String METRIC_USERS_COUNT = "users_count";
    static final String METRIC_ALL_GROUPS_COUNT = "all_groups_count";
    static final String METRIC_STAFF_GROUPS_COUNT = "staff_groups_count";

    private volatile Integer metricUsers;
    private volatile Integer metricAllGroups;
    private volatile Integer metricStaffGroups;
    private volatile Long metricLastStaffGroupsLoadTimeMilliseconds;

    private final StaffApi staffApi;
    private final YpClients masterYpClients;
    private final ScheduledExecutorService executor;
    private final Duration cycleInterval;
    private final RepeatedTask mainLoop;

    private Long lastSnapshotTimestampMillis;
    private Set<String> ypUsers;
    private Map<String, String> staffToYpGroup;
    private Map<String, Long> departmentUrlToStaffGroup;

    public GroupsAndUsersCacheImpl(StaffApi staffApi,
                                   YpClients masterYpClients,
                                   Duration updateRate,
                                   Duration mainLoopTimeout,
                                   GaugeRegistry gaugeRegistry) {
        this.staffApi = staffApi;
        this.masterYpClients = masterYpClients;
        this.cycleInterval = updateRate;
        this.executor = Executors.newSingleThreadScheduledExecutor(runnable -> new Thread(runnable, "staff"));

        gaugeRegistry.add(METRIC_STAFF_GROUPS_LOAD_TIME_MS, new GolovanableGauge<>(() -> metricLastStaffGroupsLoadTimeMilliseconds, "axxx"));
        gaugeRegistry.add(METRIC_STAFF_GROUPS_COUNT, new GolovanableGauge<>(() -> metricStaffGroups, "axxx"));

        GaugeRegistry ypRegistry = new NamespacedGaugeRegistry(gaugeRegistry, "yp");
        ypRegistry.add(METRIC_USERS_COUNT, new GolovanableGauge<>(() -> metricUsers, "axxx"));
        ypRegistry.add(METRIC_ALL_GROUPS_COUNT, new GolovanableGauge<>(() -> metricAllGroups, "axxx"));

        NamespacedGaugeRegistry registry = new NamespacedGaugeRegistry(gaugeRegistry, "staff");

        mainLoop = new RepeatedTask(this::mainLoop,
                cycleInterval,
                mainLoopTimeout,
                executor,
                Optional.of(registry),
                LOG,
                false);
    }

    public void start() {
        if (cycleInterval.isZero()) {
            LOG.warn("Staff groups cache and YP users cache is disabled by config option 'main.groups_and_users_cache_rate' = 0");
        } else {
            mainLoop.start();
        }
    }

    public void shutdown() {
        mainLoop.stop();
        executor.shutdown();
    }

    static class Groups {
        List<StaffGroup> staffDepartmentGroups;
        List<StaffGroup> staffAbcGroups;
        List<StaffGroup> staffAbcRoleGroups;
        Set<String> ypAllGroupIds;
        Map<String, String> ypAbcScopeGroups;
    }

    private CompletableFuture<?> loadStaffGroups(Groups snapshot) {
        long startTimeMillis = System.currentTimeMillis();

        return CompletableFuture.allOf(
                staffApi.getAllDepartmentGroups()
                        .thenAccept(result -> {
                            snapshot.staffDepartmentGroups = result;
                            departmentUrlToStaffGroup = result.stream()
                                    .collect(Collectors.toMap(StaffGroup::getUrl, StaffGroup::getId));
                        }),
                staffApi.getAllAbcServiceGroups()
                        .thenAccept(result -> snapshot.staffAbcGroups = result),
                staffApi.getAllAbcServiceRoleGroups()
                        .thenAccept(result -> snapshot.staffAbcRoleGroups = result)
                ).thenAccept(x -> {
                    metricLastStaffGroupsLoadTimeMilliseconds = System.currentTimeMillis() - startTimeMillis;
                    metricStaffGroups = snapshot.staffDepartmentGroups.size() +
                            snapshot.staffAbcGroups.size() +
                            snapshot.staffAbcRoleGroups.size();
                    LOG.info("Loaded {} department groups, {} abc groups, {} abc role groups in {} ms",
                            snapshot.staffDepartmentGroups.size(),
                            snapshot.staffAbcGroups.size(),
                            snapshot.staffAbcRoleGroups.size(),
                            metricLastStaffGroupsLoadTimeMilliseconds);
               });
    }

    @VisibleForTesting
    CompletableFuture<?> mainLoop() {
        Groups snapshot = new Groups();
        return CompletableFuture.allOf(
                loadStaffGroups(snapshot),
                loadCollectionWithTimeMeasure(masterYpClients.getGroupsClient().listAllIds(), "YP groups", Set::size)
                        .thenAccept(result -> {
                            metricAllGroups = result.size();
                            snapshot.ypAllGroupIds = result;
                        }),
                loadCollectionWithTimeMeasure(masterYpClients.getGroupsClient().getAbcRoleGroups(), "abc-scope groups", List::size)
                        .thenAccept(result -> {
                            snapshot.ypAbcScopeGroups = result.stream()
                                    .collect(Collectors.toMap(g -> getAbcRoleGroupKey(Long.toString(g.getServiceId()), g.getRoleScope()),
                                            YpAbcRoleGroup::getId));
                        }),
                loadCollectionWithTimeMeasure(masterYpClients.getUserRepository().listAllIds(), "YP users", Set::size)
                        .thenAccept(result -> {
                            metricUsers = result.size();
                            ypUsers = result;
                        })

            )
            .thenAccept(x -> {
                staffToYpGroup = mapStaffToYpGroups(snapshot);
                lastSnapshotTimestampMillis = System.currentTimeMillis();
            });
    }

    @VisibleForTesting
    static Map<String, String> mapStaffToYpGroups(Groups snapshot) {
        Map<String, String> result = new HashMap<>();

        snapshot.staffDepartmentGroups.forEach(group -> {
            String ypGroupName = YP_DEPARTMENT_GROUP_PREFIX + group.getDepartment().getId();
            if (snapshot.ypAllGroupIds.contains(ypGroupName)) {
                result.put(group.getId().toString(), ypGroupName);
                result.put(group.getUrl(), ypGroupName);
            }
        });

        snapshot.staffAbcGroups.forEach(group -> {
            String ypGroupName = YP_ABC_SERVICE_GROUP_PREFIX + group.getAbcService().getId();
            if (snapshot.ypAllGroupIds.contains(ypGroupName)) {
                result.put(group.getId().toString(), ypGroupName);
            }
        });

        snapshot.staffAbcRoleGroups.forEach(group -> {
            StaffGroup parent = group.getParent();
            if (parent != null) {
                StaffGroup.ObjectWithId abcService = parent.getAbcService();
                if (abcService != null) {
                    String key = getAbcRoleGroupKey(abcService.getId(), group.getRoleScope());
                    String ypGroupName = snapshot.ypAbcScopeGroups.get(key);
                    if (ypGroupName != null) {
                        result.put(group.getId().toString(), ypGroupName);
                    }
                }
            } else {
                LOG.warn("Group {} was skipped due to a missed 'parent' field", group);
            }
        });

        return result;
    }


    @VisibleForTesting
    static String getAbcRoleGroupKey(String serviceId, String roleScope) {
        return serviceId + ":" + roleScope;
    }

    public static <T> CompletableFuture<T> loadCollectionWithTimeMeasure(CompletableFuture<T> future,
                                                                         String objectType,
                                                                         Function<T, Integer> getResultCollectionCount) {
        long startTimeMillis = System.currentTimeMillis();
        return future.whenComplete((objects, error) -> {
            if (error != null) {
                LOG.error("Failed to load {}: {}", objectType, error.getMessage());
            } else {
                LOG.info("Loaded {} {} in {} ms", getResultCollectionCount.apply(objects), objectType, System.currentTimeMillis() - startTimeMillis);
            }
        });
    }

    @Override
    public Long getLastSnapshotTimestampMillis() {
        return lastSnapshotTimestampMillis;
    }

    @Override
    public Set<String> getYpUsers() {
        return ypUsers;
    }

    @Override
    public Map<String, String> getStaffToYpGroupMap() {
        return staffToYpGroup;
    }

    @Override
    public Map<String, Long> getDepartmentUrlToStaffGroupIdMap() {
        return departmentUrlToStaffGroup;
    }
}
