From d491227d87c82212f85e03968712ddd5593178c0 Mon Sep 17 00:00:00 2001 From: Arturo Bernal Date: Sun, 12 Apr 2026 13:49:32 +0200 Subject: [PATCH] Enforce max message size during permessage-deflate inflation on the server path to prevent decompression-bomb DoS --- .../websocket/PerMessageDeflateExtension.java | 19 +++++++++ .../core5/websocket/WebSocketExtension.java | 16 ++++++++ .../core5/websocket/WebSocketFrameReader.java | 9 +++-- .../WebSocketH2ServerExchangeHandler.java | 8 ++++ .../PerMessageDeflateExtensionTest.java | 40 +++++++++++++++++++ 5 files changed, 88 insertions(+), 4 deletions(-) diff --git a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/PerMessageDeflateExtension.java b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/PerMessageDeflateExtension.java index d3d000673f..6c1068f6bf 100644 --- a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/PerMessageDeflateExtension.java +++ b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/PerMessageDeflateExtension.java @@ -33,6 +33,8 @@ import java.util.zip.Deflater; import java.util.zip.Inflater; +import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException; + public final class PerMessageDeflateExtension implements WebSocketExtension { private static final byte[] TAIL = new byte[]{0x00, 0x00, (byte) 0xFF, (byte) 0xFF}; @@ -76,6 +78,14 @@ public boolean usesRsv1() { @Override public ByteBuffer decode(final WebSocketFrameType type, final boolean fin, final ByteBuffer payload) throws WebSocketException { + return decode(type, fin, payload, 0L); + } + + @Override + public ByteBuffer decode(final WebSocketFrameType type, + final boolean fin, + final ByteBuffer payload, + final long maxOutputSize) throws WebSocketException { if (!isDataFrame(type) && type != WebSocketFrameType.CONTINUATION) { throw new WebSocketException("Unsupported frame type for permessage-deflate: " + type); } @@ -94,14 +104,23 @@ public ByteBuffer decode(final WebSocketFrameType type, final boolean fin, final inflater.setInput(withTail); final ByteArrayOutputStream out = new ByteArrayOutputStream(Math.max(128, input.length)); final byte[] buffer = new byte[Math.min(16384, Math.max(1024, input.length * 2))]; + long produced = 0L; try { while (!inflater.needsInput()) { final int count = inflater.inflate(buffer); if (count == 0 && inflater.needsInput()) { break; } + // Enforce the decoded size cap during inflation, not after, so a small + // compressed payload cannot expand into a huge buffer before we react. + if (maxOutputSize > 0L && produced + count > maxOutputSize) { + throw new WebSocketProtocolException(1009, "Message too big"); + } out.write(buffer, 0, count); + produced += count; } + } catch (final WebSocketProtocolException wspe) { + throw wspe; } catch (final Exception ex) { throw new WebSocketException("Unable to inflate payload", ex); } diff --git a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/WebSocketExtension.java b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/WebSocketExtension.java index 73988651cd..4364a5008f 100644 --- a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/WebSocketExtension.java +++ b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/WebSocketExtension.java @@ -51,6 +51,22 @@ default ByteBuffer decode( return payload; } + /** + * Decode a frame payload, aborting as soon as the produced output exceeds + * {@code maxOutputSize}. A non-positive limit means no limit. Implementations + * that may expand input (e.g. permessage-deflate) MUST honour the limit during + * the expansion step, not only after it, to prevent decompression-bomb attacks. + * + * @since 5.7 + */ + default ByteBuffer decode( + final WebSocketFrameType type, + final boolean fin, + final ByteBuffer payload, + final long maxOutputSize) throws WebSocketException { + return decode(type, fin, payload); + } + default ByteBuffer encode( final WebSocketFrameType type, final boolean fin, diff --git a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/WebSocketFrameReader.java b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/WebSocketFrameReader.java index af14ff2518..702cd74cce 100644 --- a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/WebSocketFrameReader.java +++ b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/WebSocketFrameReader.java @@ -112,11 +112,12 @@ WebSocketFrame readFrame() throws IOException { payload[i] = (byte) (payload[i] ^ maskKey[i % 4]); } ByteBuffer data = ByteBuffer.wrap(payload); + final long maxOutputSize = config.getMaxMessageSize(); if (rsv1 && rsv1Extension != null) { - data = rsv1Extension.decode(type, fin, data); + data = rsv1Extension.decode(type, fin, data, maxOutputSize); continuationCompressed = !fin && (type == WebSocketFrameType.TEXT || type == WebSocketFrameType.BINARY); } else if (type == WebSocketFrameType.CONTINUATION && continuationCompressed && rsv1Extension != null) { - data = rsv1Extension.decode(type, fin, data); + data = rsv1Extension.decode(type, fin, data, maxOutputSize); if (fin) { continuationCompressed = false; } @@ -124,10 +125,10 @@ WebSocketFrame readFrame() throws IOException { continuationCompressed = false; } if (rsv2 && rsv2Extension != null) { - data = rsv2Extension.decode(type, fin, data); + data = rsv2Extension.decode(type, fin, data, maxOutputSize); } if (rsv3 && rsv3Extension != null) { - data = rsv3Extension.decode(type, fin, data); + data = rsv3Extension.decode(type, fin, data, maxOutputSize); } return new WebSocketFrame(fin, false, false, false, type, data); } diff --git a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/server/WebSocketH2ServerExchangeHandler.java b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/server/WebSocketH2ServerExchangeHandler.java index dad8631510..64cb417143 100644 --- a/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/server/WebSocketH2ServerExchangeHandler.java +++ b/httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/server/WebSocketH2ServerExchangeHandler.java @@ -61,6 +61,7 @@ import org.apache.hc.core5.websocket.WebSocketHandler; import org.apache.hc.core5.websocket.WebSocketHandshake; import org.apache.hc.core5.websocket.WebSocketSession; +import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException; final class WebSocketH2ServerExchangeHandler implements AsyncServerExchangeHandler { @@ -160,6 +161,13 @@ public void handleRequest( try { handler.onOpen(session); new WebSocketServerProcessor(session, handler, config.getMaxMessageSize()).process(); + } catch (final WebSocketProtocolException ex) { + handler.onError(session, ex); + try { + session.close(ex.closeCode, ex.getMessage()); + } catch (final IOException ignore) { + // ignore + } } catch (final Exception ex) { handler.onError(session, ex); try { diff --git a/httpclient5-websocket/src/test/java/org/apache/hc/core5/websocket/PerMessageDeflateExtensionTest.java b/httpclient5-websocket/src/test/java/org/apache/hc/core5/websocket/PerMessageDeflateExtensionTest.java index 3e1fcc3997..20fd8b019b 100644 --- a/httpclient5-websocket/src/test/java/org/apache/hc/core5/websocket/PerMessageDeflateExtensionTest.java +++ b/httpclient5-websocket/src/test/java/org/apache/hc/core5/websocket/PerMessageDeflateExtensionTest.java @@ -26,13 +26,17 @@ */ package org.apache.hc.core5.websocket; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.zip.Deflater; +import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException; import org.junit.jupiter.api.Test; class PerMessageDeflateExtensionTest { @@ -56,6 +60,42 @@ void decodesFragmentedMessage() throws Exception { assertEquals("fragmented message", WebSocketSession.decodeText(ByteBuffer.wrap(joined.toByteArray()))); } + @Test + void decodeWithinLimitSucceeds() throws Exception { + final byte[] plain = "hello world hello world hello world".getBytes(StandardCharsets.UTF_8); + final byte[] compressed = deflateWithSyncFlush(plain); + + final PerMessageDeflateExtension ext = new PerMessageDeflateExtension(); + final ByteBuffer out = ext.decode(WebSocketFrameType.TEXT, true, ByteBuffer.wrap(compressed), plain.length + 16L); + + assertArrayEquals(plain, toBytes(out)); + } + + @Test + void decodeInflationBombIsRejectedDuringInflate() { + final byte[] plain = new byte[64 * 1024]; + Arrays.fill(plain, (byte) 'A'); + final byte[] compressed = deflateWithSyncFlush(plain); + + final PerMessageDeflateExtension ext = new PerMessageDeflateExtension(); + final WebSocketProtocolException ex = assertThrows(WebSocketProtocolException.class, + () -> ext.decode(WebSocketFrameType.BINARY, true, ByteBuffer.wrap(compressed), 1024L)); + assertEquals(1009, ex.closeCode); + assertEquals("Message too big", ex.getMessage()); + } + + @Test + void decodeZeroLimitMeansUnlimited() throws Exception { + final byte[] plain = new byte[8 * 1024]; + Arrays.fill(plain, (byte) 'B'); + final byte[] compressed = deflateWithSyncFlush(plain); + + final PerMessageDeflateExtension ext = new PerMessageDeflateExtension(); + final ByteBuffer out = ext.decode(WebSocketFrameType.BINARY, true, ByteBuffer.wrap(compressed), 0L); + + assertArrayEquals(plain, toBytes(out)); + } + private static byte[] deflateWithSyncFlush(final byte[] input) { final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true); deflater.setInput(input);