package ru.yandex.qe.dispenser.ws.quota.request.unbalanced.formula;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import com.google.common.collect.ImmutableMap;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.qe.dispenser.api.util.EnumUtils;
import ru.yandex.qe.dispenser.api.v1.DiAmount;
import ru.yandex.qe.dispenser.api.v1.DiUnit;
import ru.yandex.qe.dispenser.domain.dao.segment.SegmentUtils;
import ru.yandex.qe.dispenser.ws.bot.Provider;
import ru.yandex.qe.dispenser.ws.quota.request.unbalanced.QuotaChangeRequestUnbalancedContext;
import ru.yandex.qe.dispenser.ws.quota.request.unbalanced.QuotaChangeRequestUnbalancedResult;
import ru.yandex.qe.dispenser.ws.quota.request.unbalanced.SegmentKey;

import static ru.yandex.bolts.collection.Tuple2.tuple;

/**
 * Yp formula for calculate unbalance for quota request;
 * See DISPENSER-4564.
 *
 * @author Ruslan Kadriev <aqru@yandex-team.ru>
 */
@Component
public class YpUnbalancedFormula implements UnbalancedFormula {
    private static final Logger LOG = LoggerFactory.getLogger(YpUnbalancedFormula.class);

    private static final Provider provider = Provider.YP;
    // See https://st.yandex-team.ru/DISPENSER-4564#61a4b944b401a1031867e67b
    private static final Map<Resource, Tuple2<BigInteger, DiUnit>> RESOURCES_BY_CPU = new EnumMap<>(Map.of(
            Resource.RAM, tuple(BigInteger.valueOf(4473924266L), DiUnit.BYTE),
            Resource.SSD, tuple(BigInteger.valueOf(50394282939L), DiUnit.BYTE),
            Resource.HDD, tuple(BigInteger.valueOf(91625968981L), DiUnit.BYTE),
            Resource.IO_HDD, tuple(BigInteger.valueOf(1310720L), DiUnit.BINARY_BPS),
            Resource.IO_SSD, tuple(BigInteger.valueOf(26214400L), DiUnit.BINARY_BPS)
    ));

    private static final Map<Resource, DiUnit> MIN_UNIT_BY_RESOURCE = new EnumMap<>(Map.of(
            Resource.RAM, DiUnit.GIBIBYTE,
            Resource.SSD, DiUnit.GIBIBYTE,
            Resource.HDD, DiUnit.GIBIBYTE,
            Resource.IO_HDD, DiUnit.MIBPS,
            Resource.IO_SSD, DiUnit.MIBPS
    ));

    private final String locationSegmentationKey;

    public YpUnbalancedFormula(@Value("${dispenser.location.segmentation.key}") String locationSegmentationKey) {
        this.locationSegmentationKey = locationSegmentationKey;
    }

    @Override
    public @NotNull QuotaChangeRequestUnbalancedResult calculate(
            @NotNull QuotaChangeRequestUnbalancedContext quotaChangeRequestUnbalancedContext) {
        Map<Long, List<QuotaChangeRequestUnbalancedContext.Change>> changesByOrderId =
                quotaChangeRequestUnbalancedContext.getChanges().stream()
                        .filter(c -> Resource.byKey(c.getResourceKey()).isPresent())
                        .collect(Collectors.groupingBy(QuotaChangeRequestUnbalancedContext.Change::getOrderId));

        Result result = changesByOrderId.entrySet().stream()
                .map(e -> calculateForOrder(e.getValue(), e.getKey()))
                .reduce(Result.builder(), (b, r) -> b.addChanges(r.getChanges()).unbalanced(r.isUnbalanced()),
                        Result.Builder::combine).build();

        return QuotaChangeRequestUnbalancedResult.builder()
                .requestId(quotaChangeRequestUnbalancedContext.getRequestId())
                .providerKey(quotaChangeRequestUnbalancedContext.getProviderKey())
                .unbalanced(result.isUnbalanced())
                .changes(result.getChanges())
                .build();
    }

    private Result calculateForOrder(
            List<QuotaChangeRequestUnbalancedContext.Change> changes, long orderId) {
        Map<Location, List<QuotaChangeRequestUnbalancedContext.Change>> changesByLocation = changes.stream()
                .map(c -> tuple(getLocationFromChange(c), c))
                .filter(t -> t.get1().isPresent())
                .collect(Collectors.groupingBy(t -> t.get1().orElseThrow(), Collectors.mapping(Tuple2::get2,
                        Collectors.toList())));

        return changesByLocation.entrySet().stream()
                .map(e -> calculateForDC(e.getValue(), orderId, e.getKey()))
                .reduce(Result.builder(), (b, r) -> b.addChanges(r.getChanges())
                        .unbalanced(r.isUnbalanced()), Result.Builder::combine)
                .build();
    }

    private Result calculateForDC(
            List<QuotaChangeRequestUnbalancedContext.Change> changes, long orderId, Location location) {
        Map<Resource, BigInteger> sumByResourceKey = new HashMap<>();
        BigInteger sumCpu = BigInteger.ZERO;
        for (QuotaChangeRequestUnbalancedContext.Change change : changes) {
            Resource resource = Resource.byKey(change.getResourceKey()).orElseThrow();
            if (resource.equals(Resource.CPU)) {
                sumCpu = sumCpu.add(convert(change, DiUnit.CORES, LOG).toBigIntegerExact());
            } else if (RESOURCES_BY_CPU.containsKey(resource)) {
                BigInteger amount = convert(change, RESOURCES_BY_CPU.get(resource).get2(), LOG).toBigIntegerExact();
                sumByResourceKey.compute(resource, (k, v) -> v == null ? amount : v.add(amount));
            }
        }

        Map<Resource, BigInteger> maxByResourceKey = new HashMap<>();
        BigInteger minCpu = BigInteger.ZERO;
        for (Map.Entry<Resource, BigInteger> entry : sumByResourceKey.entrySet()) {
            Resource k = entry.getKey();
            BigInteger v = entry.getValue();
            maxByResourceKey.put(k, convert(sumCpu.multiply(RESOURCES_BY_CPU.get(k).get1()),
                    RESOURCES_BY_CPU.get(k).get2(), MIN_UNIT_BY_RESOURCE.get(k), LOG)
                    .setScale(0, RoundingMode.FLOOR).toBigIntegerExact());
            minCpu = minCpu.max(new BigDecimal(v).divide(new BigDecimal(RESOURCES_BY_CPU.get(k).get1()), MATH_CONTEXT)
                    .setScale(0, RoundingMode.UP).toBigIntegerExact());
        }

        boolean unbalanced = !sumByResourceKey.entrySet().stream()
                .allMatch(e -> convert(e.getValue(), RESOURCES_BY_CPU.get(e.getKey()).get2(),
                        MIN_UNIT_BY_RESOURCE.get(e.getKey()), LOG).toBigIntegerExact()
                        .compareTo(maxByResourceKey.get(e.getKey())) <= 0);

        Set<SegmentKey> segmentKeys = Set.of(SegmentKey.builder()
                .segmentKey(location.getKey())
                .segmentationKey(locationSegmentationKey)
                .build());

        List<QuotaChangeRequestUnbalancedResult.Change> resultChanges = new ArrayList<>();
        maxByResourceKey.forEach((key, value) -> resultChanges.add(QuotaChangeRequestUnbalancedResult.Change.builder()
                .resourceKey(QuotaChangeRequestUnbalancedResult.ResourceKey.builder()
                        .orderId(orderId)
                        .resourceKey(key.getKey())
                        .segmentKeys(segmentKeys)
                        .build())
                .recommendedMax(DiAmount.of(value.longValue(), MIN_UNIT_BY_RESOURCE.get(key)))
                .build()));
        resultChanges.add(QuotaChangeRequestUnbalancedResult.Change.builder()
                .resourceKey(QuotaChangeRequestUnbalancedResult.ResourceKey.builder()
                        .orderId(orderId)
                        .resourceKey(Resource.CPU.getKey())
                        .segmentKeys(segmentKeys)
                        .build())
                .recommendedMin(DiAmount.of(minCpu.longValue(), DiUnit.CORES))
                .build());

        return Result.builder()
                .changes(resultChanges)
                .unbalanced(unbalanced)
                .build();
    }

    @Override
    public @NotNull String getProviderKey() {
        return provider.getServiceKey();
    }

    private enum Resource implements EnumUtils.StringKey {
        SSD("ssd_segmented"),
        HDD("hdd_segmented"),
        CPU("cpu_segmented"),
        RAM("ram_segmented"),
        IO_HDD("io_hdd"),
        IO_SSD("io_ssd"),
        ;

        private static Map<String, Resource> resourceByKey;
        private final String key;

        Resource(String key) {
            this.key = key;
        }

        public String getKey() {
            return key;
        }

        public static Optional<Resource> byKey(String key) {
            if (Resource.resourceByKey == null) {
                Resource.resourceByKey = ImmutableMap.copyOf(EnumUtils.prepareKeysMap(Resource.values()));
            }

            return Optional.ofNullable(Resource.resourceByKey.get(key));
        }
    }

    private enum Location implements EnumUtils.StringKey {
        VLA("VLA"),
        SAS("SAS"),
        MAN("MAN"),
        MYT("MYT"),
        IVA("IVA"),
        ;

        private static Map<String, Location> locationByKey;
        private final String key;

        Location(String key) {
            this.key = key;
        }

        public String getKey() {
            return key;
        }

        public static Optional<Location> byKey(String key) {
            if (locationByKey == null) {
                locationByKey = ImmutableMap.copyOf(EnumUtils.prepareKeysMap(Location.values()));
            }

            return Optional.ofNullable(locationByKey.get(key));
        }
    }

    private Optional<Location> getLocationFromChange(QuotaChangeRequestUnbalancedContext.Change change) {
        return getLocationSegment(change)
                .flatMap(segment -> Location.byKey(segment.getPublicKey()));
    }

    private Optional<ru.yandex.qe.dispenser.domain.Segment> getLocationSegment(QuotaChangeRequestUnbalancedContext.Change change) {
        return SegmentUtils.getSegmentBySegmentationKey(change.getSegments(), locationSegmentationKey);
    }

    private static class Result {
        private final List<QuotaChangeRequestUnbalancedResult.Change> changes;
        private final boolean unbalanced;

        private Result(List<QuotaChangeRequestUnbalancedResult.Change> changes, boolean unbalanced) {
            this.changes = changes;
            this.unbalanced = unbalanced;
        }

        public List<QuotaChangeRequestUnbalancedResult.Change> getChanges() {
            return changes;
        }

        public boolean isUnbalanced() {
            return unbalanced;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            Result result = (Result) o;
            return unbalanced == result.unbalanced && Objects.equals(changes, result.changes);
        }

        @Override
        public int hashCode() {
            return Objects.hash(changes, unbalanced);
        }

        @Override
        public String toString() {
            return "Result{" +
                    "changes=" + changes +
                    ", unbalanced=" + unbalanced +
                    '}';
        }

        private static Builder builder() {
            return new Builder();
        }

        private static class Builder {
            private List<QuotaChangeRequestUnbalancedResult.Change> changes = new ArrayList<>();
            private boolean unbalanced = false;

            private Builder changes(List<QuotaChangeRequestUnbalancedResult.Change> changes) {
                this.changes = changes;
                return this;
            }

            private Builder addChanges(List<QuotaChangeRequestUnbalancedResult.Change> changes) {
                if (this.changes == null) {
                    this.changes = new ArrayList<>();
                }
                this.changes.addAll(changes);
                return this;
            }

            private Builder unbalanced(Boolean unbalanced) {
                if (!this.unbalanced) {
                    this.unbalanced = unbalanced;
                }

                return this;
            }

            private Builder combine(Builder b) {
                addChanges(b.changes)
                        .unbalanced(b.unbalanced);

                return this;
            }

            private Result build() {
                return new Result(
                        Objects.requireNonNull(changes),
                        unbalanced
                );
            }
        }
    }
}
