package ru.yandex.chemodan.app.webdav.servlet;

import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.net.SocketException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.channels.ClosedChannelException;
import java.util.Set;

import javax.net.ssl.SSLException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.transform.TransformerException;

import org.apache.jackrabbit.webdav.AbstractLocatorFactory;
import org.apache.jackrabbit.webdav.DavException;
import org.apache.jackrabbit.webdav.DavLocatorFactory;
import org.apache.jackrabbit.webdav.DavResource;
import org.apache.jackrabbit.webdav.DavResourceFactory;
import org.apache.jackrabbit.webdav.DavResourceLocator;
import org.apache.jackrabbit.webdav.DavSessionProvider;
import org.apache.jackrabbit.webdav.WebdavRequest;
import org.apache.jackrabbit.webdav.WebdavResponse;
import org.apache.jackrabbit.webdav.WebdavResponseImpl;
import org.apache.jackrabbit.webdav.server.AbstractWebdavServlet;
import org.apache.jackrabbit.webdav.xml.DomUtil;
import org.apache.jackrabbit.webdav.xml.XmlSerializable;
import org.eclipse.jetty.io.EofException;
import org.jetbrains.annotations.NotNull;
import org.w3c.dom.Document;
import org.xml.sax.SAXException;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.Try;
import ru.yandex.chemodan.app.webdav.auth.AuthInfo;
import ru.yandex.chemodan.app.webdav.filter.AuthenticationFilter;
import ru.yandex.chemodan.app.webdav.log.WebdavApiTskvLogger;
import ru.yandex.chemodan.app.webdav.repository.MpfsResource;
import ru.yandex.chemodan.app.webdav.repository.MpfsResourceManager;
import ru.yandex.chemodan.app.webdav.repository.SimpleLocatorFactory;
import ru.yandex.chemodan.log.DiskLog4jRequestLog;
import ru.yandex.chemodan.mpfs.MpfsResponseParserUtils;
import ru.yandex.chemodan.util.exception.PermanentHttpFailureException;
import ru.yandex.inside.mulca.MulcaException;
import ru.yandex.inside.passport.tvm2.TvmHeaders;
import ru.yandex.inside.passport.tvm2.UserTicketHolder;
import ru.yandex.misc.ExceptionUtils;
import ru.yandex.misc.cache.tl.TlCache;
import ru.yandex.misc.io.IoFunction0V;
import ru.yandex.misc.io.RuntimeIoException;
import ru.yandex.misc.io.http.HttpException;
import ru.yandex.misc.io.http.HttpStatus;
import ru.yandex.misc.lang.StringUtils;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.thread.ThreadUtils;

/**
 * @author tolmalevø
 */
public class WebDavServlet extends AbstractWebdavServlet {
    private static final Logger logger = LoggerFactory.getLogger(WebDavServlet.class);

    private static final Set<Integer> NOT_ALLOWED_STATUS_CODES = Cf.set(HttpStatus.SC_502_BAD_GATEWAY,
            HttpStatus.SC_503_SERVICE_UNAVAILABLE, HttpStatus.SC_504_GATEWAY_TIMEOUT);

    private static final Set<Integer> WHITELIST_CODES_FROM_MPFS = Cf.set(
            280, //CHEMODAN-68401
            276, //CHEMODAN-66648
            253, //Request per second limit exceeded for user,
            72,  //FILE_EXISTS
            62   //COPY_PARENT_NOT_FOUND
    );

    private final MpfsResourceManager mpfsResourceManager;
    private final MapF<String, ListF<DavMethodHandler>> handlers;

    public WebDavServlet(MpfsResourceManager mpfsResourceManager, ListF<DavMethodHandler> handlers) {
        this.mpfsResourceManager = mpfsResourceManager;
        this.handlers = handlers
                .groupBy(DavMethodHandler::method)
                .mapValues(list -> list.sortedByDesc(DavMethodHandler::order));
    }

    @Override
    public DavResourceFactory getResourceFactory() {
        return mpfsResourceManager;
    }

    @Override
    protected void service(HttpServletRequest request, HttpServletResponse response)
            throws IOException
    {
        try {
            doService(request, response);
        } catch (Throwable e) {
            ExceptionUtils.throwIfUnrecoverable(e);
            if (hasEofInCause(e)) {
                request.setAttribute(DiskLog4jRequestLog.CLIENT_DISCONNECTED, "true");
                completeIfAsync(request);
            } else {
                throw e;
            }
        }
    }

    WebdavRequest toWebdavRequest(HttpServletRequest request) {
        return new WebDavRequestSupport(request, getLocatorFactory(), isCreateAbsoluteURI()) {
            public DavResourceLocator getDestinationLocator() throws DavException {
                String destination = request.getHeader(HEADER_DESTINATION);

                Try<URI> parsed = Try.tryCatchException(() -> new URI(destination))
                        .recoverCatchException(e -> UriParser.parseLikeAJetty(destination));

                return getHrefLocator(parsed
                        .mapCatchException(uri -> excludeUserInfo(uri))
                        .mapCatchException(URI::toString).getOrElse(destination));
            }
        };
    }

    private URI excludeUserInfo(URI uri) {
        try {
            return new URI(uri.getScheme(), null, uri.getHost(), uri.getPort(),
                    uri.getPath(), uri.getQuery(), uri.getFragment());
        } catch (URISyntaxException ignore) {
            return uri;
        }
    }

    private void doService(HttpServletRequest request, HttpServletResponse response)
            throws IOException
    {
        //TODO: "smart logging". "Yandex-Cloud-Activity". "Yandex-Cloud-Mobile-Activity"
        WebdavRequest webdavRequest = toWebdavRequest(request);

        // hack to return application/xml instead of text/xml
        // TODO: remove after tests fix
        WebdavResponse webdavResponse = toWebdavResponse(request, response);

        addSecurityResponseHeaders(webdavResponse);

        try {
            if (isPing(request)) {
                String timeout = request.getParameter("delay");
                if (StringUtils.isNotEmpty(timeout)) {
                    logger.warn("get ping with delay {}", timeout);
                    Integer to = Integer.valueOf(timeout);
                    if (to > 0) {
                        ThreadUtils.sleep(to);
                    }
                }
                webdavResponse.setStatus(HttpStatus.SC_200_OK);
                ByteArrayOutputStream out = new ByteArrayOutputStream();
                out.write("pong".getBytes());
                out.writeTo(response.getOutputStream());
                return;
            }

            checkAuth(request);

            Option<String> tvmUserTicketO = Cf.list(
                    webdavRequest.getHeader(TvmHeaders.USER_TICKET),
                    (String) webdavRequest.getAttribute(TvmHeaders.USER_TICKET)
            ).flatMap(Option::ofNullable).firstO();

            MpfsResource resource = getResource(webdavRequest, webdavResponse);

            //TODO: maybe we don't need it
            if (!isPreconditionValid(webdavRequest, resource)) {
                webdavResponse.sendError(HttpServletResponse.SC_PRECONDITION_FAILED);
            }

            Option<DavMethodHandler> handler = findHandler(webdavRequest, resource);
            if (handler.isPresent()) {
                TlCache.Handle handle = TlCache.push();
                try {
                    IoFunction0V handleAction = () -> handler.get().handle(webdavRequest, webdavResponse, resource);

                    UserTicketHolder.withUserTicketO(tvmUserTicketO, handleAction);

                } catch (RuntimeException e) {
                    if (e.getCause() instanceof DavException) {
                        throw (DavException) e.getCause();
                    } else {
                        throw e;
                    }
                } finally {
                    handle.popSafely();
                }
            } else {
                webdavResponse.sendError(HttpStatus.SC_405_METHOD_NOT_ALLOWED);
            }

        } catch (HttpException e) {
            if (e.getStatusCode().isPresent() && isBad5xx(e.getStatusCode().get())) {
                logger.error("Caught HttpException error: {}", e);
            } else {
                logger.debug("Caught HttpException error: {}", e);
            }

            if (!e.getStatusCode().isSome(499)) {
                int statusCode = fetchStatusCode(e.getStatusCode());

                webdavResponse.setStatus(statusCode);
                Option<String> title = getTitleFromMpfsResponse(statusCode, e.getResponseBody());
                if (title.isPresent()) {
                    logger.debug("Write title to output: {}", title.get());
                    writeContentWithLength(webdavResponse, title.get());
                    completeIfAsync(request);
                    return;
                } else {
                    logger.debug("No title found in json. Write nothing");
                }
                completeIfAsync(request);
            } else {
                request.setAttribute(DiskLog4jRequestLog.CLIENT_DISCONNECTED, "true");
                completeIfAsync(request);
            }
        } catch (MulcaException e) {
            logger.error("Caught mulca error: {}", e);
            webdavResponse.sendError(e.getStatusCode());
            return;
        } catch (RuntimeIoException e) {
            if (e.getCause() instanceof EofException
                    || e.getCause() instanceof ClosedChannelException)
            {
                logger.debug("Caught eof exception: {}", e);
                request.setAttribute(DiskLog4jRequestLog.CLIENT_DISCONNECTED, "true");
                completeIfAsync(request);
            } else if (e.getCause() instanceof SSLException
                    || e.getCause() instanceof SocketException)
            {
                logger.error("Caught client socket exception: {}", e);
                webdavResponse.sendError(HttpStatus.SC_499_CLIENT_CLOSED_REQUEST);
                return;
            } else {
                logger.error("Caught RuntimeIoException error: {}", e);
                webdavResponse.sendError(HttpStatus.SC_500_INTERNAL_SERVER_ERROR);
            }
        } catch (UnsupportedOperationException e) {
            logger.error("Caught UnsupportedOperationException error: {}", e);
            webdavResponse.sendError(HttpStatus.SC_405_METHOD_NOT_ALLOWED);
        } catch (DavException e) {
            if (isBad5xx(e.getErrorCode())) {
                logger.error("Caught DavException error: {}", e);
            } else {
                logger.debug("Caught DavException error: {}", e);
            }

            if (e.getErrorCode() == HttpServletResponse.SC_UNAUTHORIZED) {
                sendUnauthorized(webdavRequest, webdavResponse, e);
            } else {
                webdavResponse.sendError(e);
            }
        } catch (PermanentHttpFailureException e) {
            logger.debug("Caught PermanentHttpFailureException error: {}", e);
            int statusCode = fetchStatusCode(e.getStatusCode());
            webdavResponse.setStatus(statusCode);
            Option<String> title = getTitleFromMpfsResponse(statusCode, e.responseBody);
            if (title.isPresent()) {
                logger.debug("Write title to output: {}", title.get());
                writeContentWithLength(webdavResponse, title.get());
                completeIfAsync(request);
                return;
            } else {
                logger.debug("No title found in json. Write nothing");
            }
            completeIfAsync(request);
        } catch (Throwable e) {
            ExceptionUtils.throwIfUnrecoverable(e);
            if (hasEofInCause(e)) {
                request.setAttribute(DiskLog4jRequestLog.CLIENT_DISCONNECTED, "true");
                completeIfAsync(request);
            } else {
                throw e;
            }
        }

        try {
            WebdavApiTskvLogger.log(request, response);
        } catch (Exception e) {
            logger.error("Failed to log event: {}", e);
        }
    }

    static Option<String> getTitleFromMpfsResponse(int status, Option<String> responseBody) {
        return MpfsResponseParserUtils.getTitleFromMpfsResponse(status, responseBody, WHITELIST_CODES_FROM_MPFS);
    }

    private boolean isBad5xx(int statusCode) {
        return HttpStatus.is5xx(statusCode) && statusCode != 507;
    }

    private void writeContentWithLength(WebdavResponse webdavResponse, String content) throws IOException {
        webdavResponse.setContentLength(content.length());
        PrintWriter writer = webdavResponse.getWriter();
        writer.write(content);
        writer.flush();
    }

    @NotNull
    WebdavResponseImpl toWebdavResponse(HttpServletRequest request, HttpServletResponse response) {
        return new WebdavResponseImpl(response, false) {
            public void sendXmlResponse(XmlSerializable serializable, int status) throws IOException {
                response.setStatus(status);

                if (serializable != null) {
                    ByteArrayOutputStream out = new ByteArrayOutputStream();
                    try {
                        Document doc = DomUtil.createDocument();
                        doc.appendChild(serializable.toXml(doc));

                        // JCR-2636: Need to use an explicit OutputStreamWriter
                        // instead of relying on the built-in UTF-8 serialization
                        // to avoid problems with surrogate pairs on Sun JRE 1.5.
                        Writer writer = new OutputStreamWriter(out, "UTF-8");
                        DomUtil.transformDocument(doc, writer);
                        writer.flush();

                        response.setContentType("application/xml; charset=UTF-8");
                        response.setContentLength(out.size());
                        out.writeTo(response.getOutputStream());
                    } catch (ParserConfigurationException | TransformerException | SAXException e) {
                        logger.error(e);
                        throw ExceptionUtils.translate(e);
                    }
                }
            }

            @Override
            public void sendError(int i) throws IOException {
                super.sendError(i);
                completeIfAsync(request);
            }

            @Override
            public void sendError(int i, String s) throws IOException {
                super.sendError(i, s);
                completeIfAsync(request);
            }

            @Override
            public void sendError(DavException exception) throws IOException {
                super.sendError(exception);
                completeIfAsync(request);
            }
        };
    }

    MpfsResource getResource(WebdavRequest request, WebdavResponse response) throws DavException {
        return (MpfsResource) getResourceFactory().createResource(request.getRequestLocator(), request, response);
    }

    private int fetchStatusCode(Option<Integer> statusCode) {
        return statusCode
                .filter(code -> !NOT_ALLOWED_STATUS_CODES.contains(code))
                .getOrElse(HttpStatus.SC_500_INTERNAL_SERVER_ERROR);
    }

    private static boolean hasEofInCause(Throwable throwable) {
        while (throwable != null) {
            if (throwable instanceof EOFException
                    || throwable instanceof ClosedChannelException)
            {
                return true;
            }
            throwable = throwable.getCause();
        }
        return false;
    }

    private void completeIfAsync(HttpServletRequest request) {
        if (request.isAsyncStarted()) {
            request.getAsyncContext().complete();
        }
    }

    Option<DavMethodHandler> findHandler(WebdavRequest request, MpfsResource resource) {
        return handlers
                .getO(request.getMethod())
                .flatMapO(handlersSet -> handlersSet.find(h -> h.matches(request, resource)));
    }

    private void addSecurityResponseHeaders(HttpServletResponse response) {
        response.addHeader("X-Frame-Options", "SAMEORIGIN");
        response.addHeader("X-XSS-Protection", "1; mode=block");
        response.addHeader("X-Content-Type-Options", "nosniff");
    }

    private void checkAuth(HttpServletRequest req) throws DavException {
        String queryString = req.getQueryString() == null ? "" : req.getQueryString();
        if (isAuthNeeded(req, req.getMethod(), req.getPathInfo(), queryString)) {
            AuthInfo authInfo = getAuthInfo(req);
            if (authInfo == null) {
                throw new DavException(HttpStatus.SC_401_UNAUTHORIZED);
            }
            if (authInfo.authType == AuthInfo.AuthType.BANNED) {
                throw new DavException(HttpStatus.SC_402_PAYMENT_REQUIRED);
            }
            if (!authInfo.isAuthorized()) {
                throw authInfo.buildUnauthorizedException();
            }
        }
    }

    private boolean isOurClient(HttpServletRequest req) {
        return Option.ofNullable(getAuthInfo(req)).filter(AuthInfo::isOurClient).isPresent();
    }

    private boolean isPing(HttpServletRequest req) {
        return req.getMethod().equals("GET") && req.getPathInfo().equals("/ping");
    }

    private boolean isAuthNeeded(HttpServletRequest req, String method, String path, String queryString) {
        path = StringUtils.removeStart(path, "/");

        if (method.equals("GET")) {
            return !path.equals("") && !queryString.equals("userinfo") && !path.startsWith("share/");
        } else {
            return !method.equals("PROPFIND") || !path.startsWith("share/") || !isOurClient(req);
        }
    }

    @Override
    protected boolean isPreconditionValid(WebdavRequest request, DavResource resource) {
        //TODO: implement
        return true;
    }

    private final AbstractLocatorFactory locatorFactory = new SimpleLocatorFactory("");

    @Override
    public DavLocatorFactory getLocatorFactory() {
        return locatorFactory;
    }

    @Override
    public String getAuthenticateHeaderValue() {
        return "Basic realm=\"Yandex.Disk\"";
    }

    private AuthInfo getAuthInfo(HttpServletRequest request) {
        return AuthenticationFilter.getAuthInfo(request);
    }

    @Override
    public DavSessionProvider getDavSessionProvider() {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setDavSessionProvider(DavSessionProvider davSessionProvider) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setResourceFactory(DavResourceFactory resourceFactory) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setLocatorFactory(DavLocatorFactory locatorFactory) {
        throw new UnsupportedOperationException();
    }
}
