package ru.yandex.travel.api.endpoints.avia_country_restrictions;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.CompletableFuture;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import ru.yandex.travel.api.endpoints.avia_country_restrictions.exceptions.PointInfoNotFoundException;
import ru.yandex.travel.api.endpoints.avia_country_restrictions.req_rsp.v1.CountryRestrictionsByListReq;
import ru.yandex.travel.api.endpoints.avia_country_restrictions.req_rsp.v1.CountryRestrictionsByListRsp;
import ru.yandex.travel.api.endpoints.avia_country_restrictions.req_rsp.v1.CountryRestrictionsReq;
import ru.yandex.travel.api.endpoints.avia_country_restrictions.req_rsp.v1.CountryRestrictionsRsp;
import ru.yandex.travel.api.endpoints.avia_country_restrictions.req_rsp.v1.ResultsForDestination;
import ru.yandex.travel.api.models.avia.country_restrictions.v1.ExtendedMetricsAnswer;
import ru.yandex.travel.api.models.avia.country_restrictions.v1.Metric;
import ru.yandex.travel.api.services.avia.country_restrictions.CountryRestrictionsExtendedServiceV1;
import ru.yandex.travel.api.services.avia.country_restrictions.CountryRestrictionsServiceV1;
import ru.yandex.travel.api.services.geo.PointInfoHierarchyProvider;
import ru.yandex.travel.api.services.geo.model.PointInfo;
import ru.yandex.travel.api.services.geo.model.PointType;

@Component
@RequiredArgsConstructor
@Slf4j
public class AviaCountryRestrictionsImpl {
    private final CountryRestrictionsServiceV1 countryRestrictionsServiceV1;
    private final CountryRestrictionsExtendedServiceV1 countryRestrictionsExtendedServiceV1;
    private final PointInfoHierarchyProvider pointInfoHierarchyProvider;

    private static final String locale = "ru";
    private static final String domain = "ru";

    public CompletableFuture<CountryRestrictionsRsp> getDataV1(CountryRestrictionsReq req) {
        Map<String, Metric> restrictions = countryRestrictionsServiceV1.getByPointInfoHierarchy(getFilteredPoints(req));
        if (restrictions == null) {
            throw new PointInfoNotFoundException();
        }

        ResultsForDestination destination = new ResultsForDestination(restrictions);
        return CompletableFuture.completedFuture(new CountryRestrictionsRsp(destination));
    }

    public CompletableFuture<ExtendedMetricsAnswer> getExtendedV1(CountryRestrictionsReq req) {
        ExtendedMetricsAnswer answer = countryRestrictionsExtendedServiceV1.getByPointInfoHierarchy(
                getFilteredPoints(req)
        );
        if (answer == null) {
            throw new PointInfoNotFoundException();
        }

        return CompletableFuture.completedFuture(answer);
    }

    public CompletableFuture<CountryRestrictionsByListRsp> getDataByListV1(CountryRestrictionsByListReq req) {
        Map<String, ResultsForDestination> results = new TreeMap<>();
        for (String toPointKey:req.getToPointKeys()) {
            Map<String, Metric> restrictions =
                    countryRestrictionsServiceV1.getByPointInfoHierarchy(getFilteredPoints(toPointKey));
            if (restrictions == null) {
                throw new PointInfoNotFoundException();
            }
            results.put(toPointKey, new ResultsForDestination(restrictions));
        }

        return CompletableFuture.completedFuture(new CountryRestrictionsByListRsp(results));
    }

    private List<PointInfo> getFilteredPoints(CountryRestrictionsReq req) {
        List<PointInfo> toPointInfos = getPointInfoListByAnyInput(req.getToPointKey(), req.getToGeoId());
        List<PointInfo> filteredPoints = new ArrayList<>();
        for (PointInfo pointInfo : toPointInfos) {
            if (pointInfo.getType() != PointType.UNKNOWN) {
                filteredPoints.add(pointInfo);
            }
        }

        return filteredPoints;
    }

    private List<PointInfo> getFilteredPoints(String pointKey) {
        List<PointInfo> toPointInfos = pointInfoHierarchyProvider.getByPointKey(pointKey, locale, domain);
        List<PointInfo> filteredPoints = new ArrayList<>();
        for (PointInfo pointInfo : toPointInfos) {
            if (pointInfo.getType() != PointType.UNKNOWN) {
                filteredPoints.add(pointInfo);
            }
        }

        return filteredPoints;
    }

    private List<PointInfo> getPointInfoListByAnyInput(String pointKey, Long geoId) {
        if (pointKey != null) {
            return pointInfoHierarchyProvider.getByPointKey(pointKey, locale, domain);
        }

        if (geoId != null) {
            int intGeoId = geoId.intValue();
            return pointInfoHierarchyProvider.getByGeoId(intGeoId, locale, domain);
        }

        throw new IllegalArgumentException();
    }
}
