Http Servlet request lose params from POST body after read it once

前端 未结 13 2111
一个人的身影
一个人的身影 2020-11-22 14:56

I\'m trying to accessing two http request parameters in a Java Servlet filter, nothing new here, but was surprised to find that the parameters have already been consumed! Be

13条回答
  •  抹茶落季
    2020-11-22 15:44

    So this is basically Lathy's answer BUT updated for newer requirements for ServletInputStream.

    Namely (for ServletInputStream), one has to implement:

    public abstract boolean isFinished();
    
    public abstract boolean isReady();
    
    public abstract void setReadListener(ReadListener var1);
    

    This is the edited Lathy's object

    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import javax.servlet.ServletInputStream;
    import javax.servlet.http.HttpServletRequest;
    import javax.servlet.http.HttpServletRequestWrapper;
    
    public class RequestWrapper extends HttpServletRequestWrapper {
    
        private String _body;
    
        public RequestWrapper(HttpServletRequest request) throws IOException {
            super(request);
            _body = "";
            BufferedReader bufferedReader = request.getReader();
            String line;
            while ((line = bufferedReader.readLine()) != null){
                _body += line;
            }
        }
    
        @Override
        public ServletInputStream getInputStream() throws IOException {
    
            CustomServletInputStream kid = new CustomServletInputStream(_body.getBytes());
            return kid;
        }
    
        @Override
        public BufferedReader getReader() throws IOException {
            return new BufferedReader(new InputStreamReader(this.getInputStream()));
        }
    }
    

    and somewhere (??) I found this (which is a first-class class that deals with the "extra" methods.

    import javax.servlet.ReadListener;
    import javax.servlet.ServletInputStream;
    import java.io.IOException;
    import java.io.UnsupportedEncodingException;
    
    public class CustomServletInputStream extends ServletInputStream {
    
        private byte[] myBytes;
    
        private int lastIndexRetrieved = -1;
        private ReadListener readListener = null;
    
        public CustomServletInputStream(String s) {
            try {
                this.myBytes = s.getBytes("UTF-8");
            } catch (UnsupportedEncodingException ex) {
                throw new IllegalStateException("JVM did not support UTF-8", ex);
            }
        }
    
        public CustomServletInputStream(byte[] inputBytes) {
            this.myBytes = inputBytes;
        }
    
        @Override
        public boolean isFinished() {
            return (lastIndexRetrieved == myBytes.length - 1);
        }
    
        @Override
        public boolean isReady() {
            // This implementation will never block
            // We also never need to call the readListener from this method, as this method will never return false
            return isFinished();
        }
    
        @Override
        public void setReadListener(ReadListener readListener) {
            this.readListener = readListener;
            if (!isFinished()) {
                try {
                    readListener.onDataAvailable();
                } catch (IOException e) {
                    readListener.onError(e);
                }
            } else {
                try {
                    readListener.onAllDataRead();
                } catch (IOException e) {
                    readListener.onError(e);
                }
            }
        }
    
        @Override
        public int read() throws IOException {
            int i;
            if (!isFinished()) {
                i = myBytes[lastIndexRetrieved + 1];
                lastIndexRetrieved++;
                if (isFinished() && (readListener != null)) {
                    try {
                        readListener.onAllDataRead();
                    } catch (IOException ex) {
                        readListener.onError(ex);
                        throw ex;
                    }
                }
                return i;
            } else {
                return -1;
            }
        }
    };
    

    Ultimately, I was just trying to log the requests. And the above frankensteined together pieces helped me create the below.

    import java.io.IOException;
    import java.io.UnsupportedEncodingException;
    import java.security.Principal;
    import java.util.Enumeration;
    import java.util.LinkedHashMap;
    import java.util.Map;
    
    import javax.servlet.FilterChain;
    import javax.servlet.ServletException;
    import javax.servlet.http.HttpServletRequest;
    import javax.servlet.http.HttpServletResponse;
    
    import org.apache.commons.io.IOUtils;
    
    //one or the other based on spring version
    //import org.springframework.boot.autoconfigure.web.ErrorAttributes;
    import org.springframework.boot.web.servlet.error.ErrorAttributes;
    
    import org.springframework.core.Ordered;
    import org.springframework.http.HttpStatus;
    import org.springframework.stereotype.Component;
    import org.springframework.web.context.request.ServletRequestAttributes;
    import org.springframework.web.context.request.WebRequest;
    import org.springframework.web.filter.OncePerRequestFilter;
    
    
    /**
     * A filter which logs web requests that lead to an error in the system.
     */
    @Component
    public class LogRequestFilter extends OncePerRequestFilter implements Ordered {
    
        // I tried apache.commons and slf4g loggers.  (one or the other in these next 2 lines of declaration */
        //private final static org.apache.commons.logging.Log logger = org.apache.commons.logging.LogFactory.getLog(LogRequestFilter.class);
        private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(LogRequestFilter.class);
    
        // put filter at the end of all other filters to make sure we are processing after all others
        private int order = Ordered.LOWEST_PRECEDENCE - 8;
        private ErrorAttributes errorAttributes;
    
        @Override
        public int getOrder() {
            return order;
        }
    
        @Override
        protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
                throws ServletException, IOException {
    
            String temp = ""; /* for a breakpoint, remove for production/real code */
    
            /* change to true for easy way to comment out this code, remove this if-check for production/real code */
            if (false) {
                filterChain.doFilter(request, response);
                return;
            }
    
            /* make a "copy" to avoid issues with body-can-only-read-once issues */
            RequestWrapper reqWrapper = new RequestWrapper(request);
    
            int status = HttpStatus.INTERNAL_SERVER_ERROR.value();
            // pass through filter chain to do the actual request handling
            filterChain.doFilter(reqWrapper, response);
            status = response.getStatus();
    
            try {
                Map traceMap = getTrace(reqWrapper, status);
                // body can only be read after the actual request handling was done!
                this.getBodyFromTheRequestCopy(reqWrapper, traceMap);
    
                /* now do something with all the pieces of information gatherered */
                this.logTrace(reqWrapper, traceMap);
            } catch (Exception ex) {
                logger.error("LogRequestFilter FAILED: " + ex.getMessage(), ex);
            }
        }
    
        private void getBodyFromTheRequestCopy(RequestWrapper rw, Map trace) {
            try {
                if (rw != null) {
                    byte[] buf = IOUtils.toByteArray(rw.getInputStream());
                    //byte[] buf = rw.getInputStream();
                    if (buf.length > 0) {
                        String payloadSlimmed;
                        try {
                            String payload = new String(buf, 0, buf.length, rw.getCharacterEncoding());
                            payloadSlimmed = payload.trim().replaceAll(" +", " ");
                        } catch (UnsupportedEncodingException ex) {
                            payloadSlimmed = "[unknown]";
                        }
    
                        trace.put("body", payloadSlimmed);
                    }
                }
            } catch (IOException ioex) {
                trace.put("body", "EXCEPTION: " + ioex.getMessage());
            }
        }
    
        private void logTrace(HttpServletRequest request, Map trace) {
            Object method = trace.get("method");
            Object path = trace.get("path");
            Object statusCode = trace.get("statusCode");
    
            logger.info(String.format("%s %s produced an status code '%s'. Trace: '%s'", method, path, statusCode,
                    trace));
        }
    
        protected Map getTrace(HttpServletRequest request, int status) {
            Throwable exception = (Throwable) request.getAttribute("javax.servlet.error.exception");
    
            Principal principal = request.getUserPrincipal();
    
            Map trace = new LinkedHashMap();
            trace.put("method", request.getMethod());
            trace.put("path", request.getRequestURI());
            if (null != principal) {
                trace.put("principal", principal.getName());
            }
            trace.put("query", request.getQueryString());
            trace.put("statusCode", status);
    
            Enumeration headerNames = request.getHeaderNames();
            while (headerNames.hasMoreElements()) {
                String key = (String) headerNames.nextElement();
                String value = request.getHeader(key);
                trace.put("header:" + key, value);
            }
    
            if (exception != null && this.errorAttributes != null) {
                trace.put("error", this.errorAttributes
                        .getErrorAttributes((WebRequest) new ServletRequestAttributes(request), true));
            }
    
            return trace;
        }
    }
    

    Please take this code with a grain of salt.

    The MOST important "test" is if a POST works with a payload. This is what will expose "double read" issues.

    pseudo example code

    import org.springframework.web.bind.annotation.*;
    
    @RestController
    @RequestMapping("myroute")
    public class MyController {
        @RequestMapping(method = RequestMethod.POST, produces = "application/json")
        @ResponseBody
        public String getSomethingExample(@RequestBody MyCustomObject input) {
    
            String returnValue = "";
    
            return returnValue;
        }
    }
    

    You can replace "MyCustomObject" with plain ole "Object" if you just want to test.

    This answer is frankensteined from several different SOF posts and examples..but it took a while to pull it all together so I hope it helps a future reader.

    Please upvote Lathy's answer before mine. I could have not gotten this far without it.

    Below is one/some of the exceptions I got while working this out.

    getReader() has already been called for this request

    Looks like some of the places I "borrowed" from are here:

    http://slackspace.de/articles/log-request-body-with-spring-boot/

    https://github.com/c0nscience/spring-web-logging/blob/master/src/main/java/org/zalando/springframework/web/logging/LoggingFilter.java

    https://howtodoinjava.com/servlets/httpservletrequestwrapper-example-read-request-body/

    https://www.oodlestechnologies.com/blogs/How-to-create-duplicate-object-of-httpServletRequest-object

    https://github.com/c0nscience/spring-web-logging/blob/master/src/main/java/org/zalando/springframework/web/logging/LoggingFilter.java

提交回复
热议问题