package ru.yandex.intranet.d.web.errors;

import java.util.LinkedHashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;

import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.web.error.ErrorAttributeOptions;
import org.springframework.boot.web.reactive.error.DefaultErrorAttributes;
import org.springframework.context.MessageSource;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.validation.BindException;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebInputException;

import ru.yandex.intranet.d.datasource.model.BadSessionException;
import ru.yandex.intranet.d.datasource.model.SessionPoolDepletedException;
import ru.yandex.intranet.d.datasource.model.TransactionLocksInvalidatedException;
import ru.yandex.intranet.d.datasource.model.TransportUnavailableException;
import ru.yandex.intranet.d.i18n.Locales;

/**
 * Error attributes provider.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
@Component
public class YaErrorAttributes extends DefaultErrorAttributes {

    private final MessageSource messages;

    public YaErrorAttributes(@Qualifier("messageSource") MessageSource messages) {
        this.messages = messages;
    }

    @Override
    public Map<String, Object> getErrorAttributes(ServerRequest request, ErrorAttributeOptions options) {
        Map<String, Object> errorAttributes = new LinkedHashMap<>();
        Throwable error = getError(request);
        MergedAnnotation<ResponseStatus> responseStatusAnnotation = MergedAnnotations
                .from(error.getClass(), MergedAnnotations.SearchStrategy.TYPE_HIERARCHY).get(ResponseStatus.class);
        HttpStatus errorStatus = httpStatusFromThrowable(error, responseStatusAnnotation);
        errorAttributes.put("status", errorStatus.value());
        errorAttributes.put("error", errorStatus.getReasonPhrase());
        errorAttributes.put("message", messageFromThrowable(error, responseStatusAnnotation, request));
        return errorAttributes;
    }

    private HttpStatus httpStatusFromThrowable(Throwable error, MergedAnnotation<ResponseStatus> annotation) {
        if (error instanceof ResponseStatusException) {
            return ((ResponseStatusException) error).getStatus();
        }
        Optional<HttpStatus> statusO = annotation.getValue("code", HttpStatus.class);
        if (statusO.isPresent()) {
            return statusO.get();
        }
        if (TransactionLocksInvalidatedException.isTransactionLocksInvalidated(error)) {
            return HttpStatus.TOO_MANY_REQUESTS;
        }
        if (SessionPoolDepletedException.isSessionPoolDepleted(error)) {
            return HttpStatus.TOO_MANY_REQUESTS;
        }
        if (TransportUnavailableException.isTransportUnavailable(error)) {
            return HttpStatus.SERVICE_UNAVAILABLE;
        }
        if (BadSessionException.isBadSession(error)) {
            return HttpStatus.SERVICE_UNAVAILABLE;
        }
        return HttpStatus.INTERNAL_SERVER_ERROR;
    }

    private String messageFromThrowable(Throwable throwable, MergedAnnotation<ResponseStatus> annotation,
                                        ServerRequest request) {
        Locale locale = Optional.ofNullable(request.exchange().getLocaleContext().getLocale())
                .orElse(Locales.ENGLISH);
        if (throwable instanceof ServerWebInputException || throwable instanceof BindException) {
            return messages.getMessage("errors.bad.request", null, locale);
        }
        if (throwable instanceof ResponseStatusException) {
            return ((ResponseStatusException) throwable).getReason();
        }
        String reasonAnnotation = annotation.getValue("reason", String.class).orElse("");
        if (StringUtils.hasText(reasonAnnotation)) {
            return reasonAnnotation;
        }
        if (TransactionLocksInvalidatedException.isTransactionLocksInvalidated(throwable)) {
            return messages.getMessage("errors.too.many.requests.transaction.locks.invalidated", null, locale);
        }
        if (SessionPoolDepletedException.isSessionPoolDepleted(throwable)) {
            return messages.getMessage("errors.too.many.requests", null, locale);
        }
        if (TransportUnavailableException.isTransportUnavailable(throwable)) {
            return messages.getMessage("errors.unavailable", null, locale);
        }
        if (BadSessionException.isBadSession(throwable)) {
            return messages.getMessage("errors.unavailable", null, locale);
        }
        return messages.getMessage("errors.unexpected.service.error", null, locale);
    }

}
