diff --git a/docs/changelog/128099.yaml b/docs/changelog/128099.yaml new file mode 100644 index 0000000000000..1f26cb00bd75d --- /dev/null +++ b/docs/changelog/128099.yaml @@ -0,0 +1,5 @@ +pr: 128099 +summary: Remove first `FlowControlHandler` from HTTP pipeline +area: Network +type: enhancement +issues: [] diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java index 668780fc90665..1ba47d21c8627 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java @@ -15,19 +15,23 @@ import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.HttpObject; import io.netty.handler.codec.http.HttpRequest; +import io.netty.util.ReferenceCounted; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ContextPreservingActionListener; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.core.Nullable; import org.elasticsearch.http.netty4.internal.HttpValidator; import org.elasticsearch.transport.Transports; +import java.util.ArrayDeque; + public class Netty4HttpHeaderValidator extends ChannelDuplexHandler { private final HttpValidator validator; private final ThreadContext threadContext; - private State state; + private State state = State.PASSING; + private final ArrayDeque buffer = new ArrayDeque<>(); public Netty4HttpHeaderValidator(HttpValidator validator, ThreadContext threadContext) { this.validator = validator; @@ -36,80 +40,125 @@ public Netty4HttpHeaderValidator(HttpValidator validator, ThreadContext threadCo @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (state == State.VALIDATING || buffer.size() > 0) { + // there's already some buffered messages that need to be processed before this one, so queue this one up behind them + buffer.offerLast(msg); + return; + } + assert msg instanceof HttpObject; - var httpObject = (HttpObject) msg; + final var httpObject = (HttpObject) msg; if (httpObject.decoderResult().isFailure()) { ctx.fireChannelRead(httpObject); // pass-through for decoding failures + } else if (msg instanceof HttpRequest httpRequest) { + validate(ctx, httpRequest); + } else if (state == State.PASSING) { + assert msg instanceof HttpContent; + ctx.fireChannelRead(msg); } else { - if (msg instanceof HttpRequest request) { - validate(ctx, request); - } else { - assert msg instanceof HttpContent; - var content = (HttpContent) msg; - if (state == State.DROPPING) { - content.release(); - ctx.read(); - } else { - assert state == State.PASSING : "unexpected content before validation completed"; - ctx.fireChannelRead(content); - } - } + assert state == State.DROPPING : state; + assert msg instanceof HttpContent; + final var httpContent = (HttpContent) msg; + httpContent.release(); + ctx.read(); } } @Override - public void read(ChannelHandlerContext ctx) throws Exception { - // until validation is completed we can ignore read calls, - // once validation is finished HttpRequest will be fired and downstream can read from there - if (state != State.VALIDATING) { - ctx.read(); - } + public void channelReadComplete(ChannelHandlerContext ctx) { + if (buffer.size() == 0) { + ctx.fireChannelReadComplete(); + } // else we're buffering messages so will manage the read-complete messages ourselves } - void validate(ChannelHandlerContext ctx, HttpRequest request) { - assert Transports.assertDefaultThreadContext(threadContext); - state = State.VALIDATING; - ActionListener.run( - // this prevents thread-context changes to propagate to the validation listener - // atm, the validation listener submits to the event loop executor, which doesn't know about the ES thread-context, - // so this is just a defensive play, in case the code inside the listener changes to not use the event loop executor - ActionListener.assertOnce( - new ContextPreservingActionListener( - threadContext.wrapRestorable(threadContext.newStoredContext()), - new ActionListener<>() { - @Override - public void onResponse(Void unused) { - handleValidationResult(ctx, request, null); - } - - @Override - public void onFailure(Exception e) { - handleValidationResult(ctx, request, e); - } + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + assert ctx.channel().eventLoop().inEventLoop(); + if (state != State.VALIDATING) { + if (buffer.size() > 0) { + final var message = buffer.pollFirst(); + if (message instanceof HttpRequest httpRequest) { + if (httpRequest.decoderResult().isFailure()) { + ctx.fireChannelRead(message); // pass-through for decoding failures + ctx.fireChannelReadComplete(); // downstream will have to call read() again when it's ready + } else { + validate(ctx, httpRequest); } - ) - ), - listener -> { - // this prevents thread-context changes to propagate beyond the validation, as netty worker threads are reused - try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) { - validator.validate(request, ctx.channel(), listener); + } else { + assert message instanceof HttpContent; + assert state == State.PASSING : state; // DROPPING releases any buffered chunks up-front + ctx.fireChannelRead(message); + ctx.fireChannelReadComplete(); // downstream will have to call read() again when it's ready } + } else { + ctx.read(); } - ); + } } - void handleValidationResult(ChannelHandlerContext ctx, HttpRequest request, @Nullable Exception validationError) { - assert Transports.assertDefaultThreadContext(threadContext); - // Always explicitly dispatch back to the event loop to prevent reentrancy concerns if we are still on event loop - ctx.channel().eventLoop().execute(() -> { - if (validationError != null) { - request.setDecoderResult(DecoderResult.failure(validationError)); - state = State.DROPPING; - } else { - state = State.PASSING; + void validate(ChannelHandlerContext ctx, HttpRequest httpRequest) { + final var validationResultListener = new ValidationResultListener(ctx, httpRequest); + SubscribableListener.newForked(validationResultListener::doValidate) + .addListener( + validationResultListener, + // dispatch back to event loop unless validation completed already in which case we can just continue on this thread + // straight away, avoiding the need to buffer any subsequent messages + ctx.channel().eventLoop(), + null + ); + } + + private class ValidationResultListener implements ActionListener { + + private final ChannelHandlerContext ctx; + private final HttpRequest httpRequest; + + ValidationResultListener(ChannelHandlerContext ctx, HttpRequest httpRequest) { + this.ctx = ctx; + this.httpRequest = httpRequest; + } + + void doValidate(ActionListener listener) { + assert Transports.assertDefaultThreadContext(threadContext); + assert ctx.channel().eventLoop().inEventLoop(); + assert state == State.PASSING || state == State.DROPPING : state; + state = State.VALIDATING; + try (var ignore = threadContext.newEmptyContext()) { + validator.validate( + httpRequest, + ctx.channel(), + new ContextPreservingActionListener<>(threadContext::newEmptyContext, listener) + ); } - ctx.fireChannelRead(request); - }); + } + + @Override + public void onResponse(Void unused) { + assert Transports.assertDefaultThreadContext(threadContext); + assert ctx.channel().eventLoop().inEventLoop(); + assert state == State.VALIDATING : state; + state = State.PASSING; + fireChannelRead(); + } + + @Override + public void onFailure(Exception e) { + assert Transports.assertDefaultThreadContext(threadContext); + assert ctx.channel().eventLoop().inEventLoop(); + assert state == State.VALIDATING : state; + httpRequest.setDecoderResult(DecoderResult.failure(e)); + state = State.DROPPING; + while (buffer.isEmpty() == false && buffer.peekFirst() instanceof HttpRequest == false) { + assert buffer.peekFirst() instanceof HttpContent; + ((ReferenceCounted) buffer.pollFirst()).release(); + } + fireChannelRead(); + } + + private void fireChannelRead() { + ctx.fireChannelRead(httpRequest); + ctx.fireChannelReadComplete(); // downstream needs to read() again + } } private enum State { diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index c8f2d75d18a6f..254576d225ce4 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -371,7 +371,6 @@ protected HttpMessage createMessage(String[] initialLine) throws Exception { ch.pipeline().addLast("decoder", decoder); // parses the HTTP bytes request into HTTP message pieces // from this point in pipeline every handler must call ctx or channel #read() when ready to process next HTTP part - ch.pipeline().addLast(new FlowControlHandler()); if (Assertions.ENABLED) { // missing reads are hard to catch, but we can detect absence of reads within interval long missingReadIntervalMs = 10_000; diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java index d29894a149a4f..1803549b19305 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java @@ -23,20 +23,27 @@ import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.LastHttpContent; -import io.netty.handler.flow.FlowControlHandler; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.http.netty4.internal.HttpValidator; import org.elasticsearch.test.ESTestCase; +import java.util.ArrayDeque; +import java.util.Objects; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; +import static org.hamcrest.Matchers.instanceOf; + public class Netty4HttpHeaderValidatorTests extends ESTestCase { private EmbeddedChannel channel; private BlockingQueue validatorRequestQueue; + private HttpValidator httpValidator = (httpRequest, channel, listener) -> validatorRequestQueue.add( + new ValidationRequest(httpRequest, channel, listener) + ); @Override public void setUp() throws Exception { @@ -44,7 +51,7 @@ public void setUp() throws Exception { validatorRequestQueue = new LinkedBlockingQueue<>(); channel = new EmbeddedChannel( new Netty4HttpHeaderValidator( - (httpRequest, channel, listener) -> validatorRequestQueue.add(new ValidationRequest(httpRequest, channel, listener)), + (httpRequest, channel, listener) -> httpValidator.validate(httpRequest, channel, listener), new ThreadContext(Settings.EMPTY) ) ); @@ -70,12 +77,42 @@ public void testValidatorReceiveHttpRequest() { } public void testDecoderFailurePassThrough() { - for (var i = 0; i < 1000; i++) { - var httpRequest = newHttpRequest(); - httpRequest.setDecoderResult(DecoderResult.failure(new Exception("bad"))); - channel.writeInbound(httpRequest); - assertEquals(httpRequest, channel.readInbound()); + // send a valid request so that the buffer is nonempty + final var validRequest = newHttpRequest(); + channel.writeInbound(validRequest); + channel.writeInbound(newLastHttpContent()); + + // follow it with an invalid request which should be buffered + final var invalidHttpRequest1 = newHttpRequest(); + invalidHttpRequest1.setDecoderResult(DecoderResult.failure(new Exception("simulated decoder failure 1"))); + channel.writeInbound(invalidHttpRequest1); + + // handle the first request + if (randomBoolean()) { + Objects.requireNonNull(validatorRequestQueue.poll()).listener().onResponse(null); + channel.runPendingTasks(); + assertSame(validRequest, channel.readInbound()); + channel.read(); + asInstanceOf(LastHttpContent.class, channel.readInbound()).release(); + } else { + Objects.requireNonNull(validatorRequestQueue.poll()).listener().onFailure(new Exception("simulated validation failure")); + channel.runPendingTasks(); + assertSame(validRequest, channel.readInbound()); } + + // handle the second request, which is read from the buffer and passed on without validation + assertNull(channel.readInbound()); + channel.read(); + assertSame(invalidHttpRequest1, channel.readInbound()); + + // send another invalid request which is passed straight through + final var invalidHttpRequest2 = newHttpRequest(); + invalidHttpRequest2.setDecoderResult(DecoderResult.failure(new Exception("simulated decoder failure 2"))); + channel.writeInbound(invalidHttpRequest2); + if (randomBoolean()) { + channel.read(); // optional read + } + assertSame(invalidHttpRequest2, channel.readInbound()); } /** @@ -121,10 +158,8 @@ public void testMixedValidationResults() { } public void testIgnoreReadWhenValidating() { - channel.pipeline().addFirst(new FlowControlHandler()); // catch all inbound messages - channel.writeInbound(newHttpRequest()); - channel.writeInbound(newLastHttpContent()); // should hold by flow-control-handler + channel.writeInbound(newLastHttpContent()); assertNull("nothing should pass yet", channel.readInbound()); channel.read(); @@ -143,8 +178,7 @@ public void testIgnoreReadWhenValidating() { asInstanceOf(LastHttpContent.class, channel.readInbound()).release(); } - public void testWithFlowControlAndAggregator() { - channel.pipeline().addFirst(new FlowControlHandler()); + public void testWithAggregator() { channel.pipeline().addLast(new Netty4HttpAggregator(8192, (req) -> true, new HttpRequestDecoder())); channel.writeInbound(newHttpRequest()); @@ -162,5 +196,134 @@ public void testWithFlowControlAndAggregator() { asInstanceOf(FullHttpRequest.class, channel.readInbound()).release(); } + public void testBufferPipelinedRequestsWhenValidating() { + final var expectedChunks = new ArrayDeque(); + expectedChunks.addLast(newHttpContent()); + + // write one full request and one incomplete request received all at once + channel.writeInbound(newHttpRequest()); + channel.writeInbound(newLastHttpContent()); + channel.writeInbound(newHttpRequest()); + channel.writeInbound(expectedChunks.peekLast()); + assertNull("nothing should pass yet", channel.readInbound()); + + if (randomBoolean()) { + channel.read(); + } + var validationRequest = validatorRequestQueue.poll(); + assertNotNull(validationRequest); + + channel.read(); + assertNull("should ignore read while validating", channel.readInbound()); + + validationRequest.listener().onResponse(null); + channel.runPendingTasks(); + assertTrue("http request should pass", channel.readInbound() instanceof HttpRequest); + assertNull("content should not pass yet, need explicit read", channel.readInbound()); + + channel.read(); + asInstanceOf(LastHttpContent.class, channel.readInbound()).release(); + + // should have started to validate the next request + channel.read(); + assertNull("should ignore read while validating", channel.readInbound()); + Objects.requireNonNull(validatorRequestQueue.poll()).listener().onResponse(null); + + channel.runPendingTasks(); + assertThat("next http request should pass", channel.readInbound(), instanceOf(HttpRequest.class)); + + // another chunk received and is buffered, nothing is sent downstream + expectedChunks.addLast(newHttpContent()); + channel.writeInbound(expectedChunks.peekLast()); + assertNull(channel.readInbound()); + assertFalse(channel.hasPendingTasks()); + + // the first chunk is now emitted on request + channel.read(); + var nextChunk = asInstanceOf(HttpContent.class, channel.readInbound()); + assertSame(nextChunk, expectedChunks.pollFirst()); + nextChunk.release(); + assertNull(channel.readInbound()); + assertFalse(channel.hasPendingTasks()); + + // and the second chunk + channel.read(); + nextChunk = asInstanceOf(HttpContent.class, channel.readInbound()); + assertSame(nextChunk, expectedChunks.pollFirst()); + nextChunk.release(); + assertNull(channel.readInbound()); + assertFalse(channel.hasPendingTasks()); + + // buffer is now drained, no more chunks available + if (randomBoolean()) { + channel.read(); // optional read + } + assertNull(channel.readInbound()); + assertTrue(expectedChunks.isEmpty()); + assertFalse(channel.hasPendingTasks()); + + // subsequent chunks are passed straight through without another read() + expectedChunks.addLast(newHttpContent()); + channel.writeInbound(expectedChunks.peekLast()); + nextChunk = asInstanceOf(HttpContent.class, channel.readInbound()); + assertSame(nextChunk, expectedChunks.pollFirst()); + nextChunk.release(); + assertNull(channel.readInbound()); + assertFalse(channel.hasPendingTasks()); + } + + public void testDropChunksOnValidationFailure() { + // write an incomplete request which will be marked as invalid + channel.writeInbound(newHttpRequest()); + channel.writeInbound(newHttpContent()); + assertNull("nothing should pass yet", channel.readInbound()); + + var validationRequest = validatorRequestQueue.poll(); + assertNotNull(validationRequest); + validationRequest.listener().onFailure(new Exception("simulated validation failure")); + + // failed request is passed downstream + channel.runPendingTasks(); + var inboundRequest = asInstanceOf(HttpRequest.class, channel.readInbound()); + assertTrue(inboundRequest.decoderResult().isFailure()); + assertEquals("simulated validation failure", inboundRequest.decoderResult().cause().getMessage()); + + // chunk is not emitted (the buffer is now drained) + assertNull(channel.readInbound()); + if (randomBoolean()) { + channel.read(); + assertNull(channel.readInbound()); + } + + // next chunk is also not emitted (it is released on receipt, not buffered) + channel.writeInbound(newLastHttpContent()); + assertNull(channel.readInbound()); + if (randomBoolean()) { + channel.read(); + assertNull(channel.readInbound()); + } + assertFalse(channel.hasPendingTasks()); + + // next request triggers validation again + final var nextRequest = newHttpRequest(); + channel.writeInbound(nextRequest); + Objects.requireNonNull(validatorRequestQueue.poll()).listener().onResponse(null); + channel.runPendingTasks(); + + if (randomBoolean()) { + channel.read(); // optional read + } + assertSame(nextRequest, channel.readInbound()); + assertFalse(channel.hasPendingTasks()); + } + + public void testInlineValidationDoesNotFork() { + httpValidator = (httpRequest, channel, listener) -> listener.onResponse(null); + final var httpRequest = newHttpRequest(); + channel.writeInbound(httpRequest); + assertFalse(channel.hasPendingTasks()); + assertSame(httpRequest, channel.readInbound()); + } + record ValidationRequest(HttpRequest request, Channel channel, ActionListener listener) {} } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java index 222d88642050a..4c4d66753f362 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java @@ -63,6 +63,7 @@ import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; @@ -120,6 +121,7 @@ import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.iterableWithSize; @@ -976,7 +978,7 @@ public void testMultipleValidationsOnTheSameChannel() throws InterruptedExceptio final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { @Override public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { - assertThat(okURIs.contains(request.uri()), is(true)); + assertThat(request.uri(), in(okURIs)); // assert validated request is dispatched okURIs.remove(request.uri()); channel.sendResponse(new RestResponse(OK, RestResponse.TEXT_CONTENT_TYPE, new BytesArray("dispatch OK"))); @@ -985,7 +987,7 @@ public void dispatchRequest(final RestRequest request, final RestChannel channel @Override public void dispatchBadRequest(final RestChannel channel, final ThreadContext threadContext, final Throwable cause) { // assert unvalidated request is NOT dispatched - assertThat(nokURIs.contains(channel.request().uri()), is(true)); + assertThat(channel.request().uri(), in(nokURIs)); nokURIs.remove(channel.request().uri()); try { channel.sendResponse(new RestResponse(channel, (Exception) ((ElasticsearchWrapperException) cause).getCause())); @@ -1000,9 +1002,11 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th assertThat(channelSetOnce.get(), is(channel)); // some requests are validated while others are not if (httpPreRequest.uri().contains("X-Auth=OK")) { - validationListener.onResponse(null); + randomFrom(EsExecutors.DIRECT_EXECUTOR_SERVICE, channel.eventLoop()).execute(() -> validationListener.onResponse(null)); } else if (httpPreRequest.uri().contains("X-Auth=NOK")) { - validationListener.onFailure(new ElasticsearchSecurityException("Boom", UNAUTHORIZED)); + randomFrom(EsExecutors.DIRECT_EXECUTOR_SERVICE, channel.eventLoop()).execute( + () -> validationListener.onFailure(new ElasticsearchSecurityException("Boom", UNAUTHORIZED)) + ); } else { throw new AssertionError("Unrecognized URI"); }