diff --git a/.bob/skills/merge-upstream-tag/SKILL.md b/.bob/skills/merge-upstream-tag/SKILL.md new file mode 100644 index 00000000000..3286cc6bcd9 --- /dev/null +++ b/.bob/skills/merge-upstream-tag/SKILL.md @@ -0,0 +1,180 @@ +--- +name: merge-upstream-tag +description: Use when the user wants to merge an upstream version tag into the current branch-specific fork branch (e.g. "merge netty-4.1.133.Final into dse-netty-4.1.133"), preserving existing branch-specific changes. +--- + +# Merge Upstream Version Tag into Fork Branch + +This skill forward-ports a DSE netty fork branch to a new upstream release +tag. It fetches all tags from the upstream remote, then merges the target +version tag into the current branch while preserving all branch-specific +commits. + +## Prerequisites + +Confirm with the user before starting: +1. The **upstream remote name** — check with `git remote -v`. In this repo + it is `upstream` (pointing to `github.com/netty/netty`). +2. The **version tag** to merge in (e.g. `netty-4.1.133.Final`). +3. The **current branch** name — use `git branch --show-current`. + +## Steps + +### 1. Verify the current state + +```bash +git status +git branch --show-current +git remote -v +``` + +Make sure the working tree is clean (no uncommitted changes). If it is not, +ask the user to stash or commit before continuing. + +### 2. Fetch all tags from the upstream remote + +```bash +git fetch upstream --tags +``` + +This pulls every tag from the upstream remote (e.g. `github.com/netty/netty`) +without modifying any local branches. + +### 3. Confirm the tag exists + +```bash +git tag -l "" +``` + +Replace `` with the target tag (e.g. `netty-4.1.133.Final`). If it is +not listed, the fetch in step 2 may have failed or the tag name is wrong — +verify and retry. + +### 4. Merge the tag into the current branch + +```bash +git merge "" --no-ff -m "Merge tag '' into + +This forward-ports the DSE netty fork to the release of netty. + +[maven-release-plugin] copy for tag " +``` + +- `` — the upstream version tag (e.g. `netty-4.1.133.Final`). +- `` — the current branch name (e.g. `dse-netty-4.1.133`). +- `--no-ff` preserves the merge commit so the upstream history remains + traceable. + +The commit message format must follow the established convention used in +previous merges on this repository (e.g. commit `f22f5ae6aa` for +`netty-4.1.132.Final`). + +**Expect the merge to stop with conflicts every time.** This is normal — every +merge of a new upstream tag into a DSE fork branch will produce conflicts +because of DSE-specific patches and version numbers. Proceed directly to +step 5. + +### 5. Resolve all conflicts + +Run the following to see every conflicting file: + +```bash +git diff --name-only --diff-filter=U +``` + +Work through each file. Do **not** ask the user — resolve them autonomously +using the rules below. + +#### 5a. `pom.xml` conflicts — always keep the DSE version + +Every `pom.xml` file will conflict on the `` element because upstream +uses a plain `.Final` version while the DSE branch uses the `.1.dse` suffix. + +**Rule:** In every `pom.xml` conflict, always accept the **HEAD (DSE) side** +for the `` element. The resolved value must be the DSE version string +(e.g. `4.1.133.1.dse`), **never** the bare upstream value (e.g. `4.1.133.Final`). + +For all other content in the same `pom.xml` (dependency versions, plugin +config, etc.) apply the general rules from §5b below. + +After resolving each `pom.xml`, stage it: + +```bash +git add +``` + +#### 5b. Code file conflicts — always preserve DSE changes + +For every non-`pom.xml` conflict: + +1. **Read the conflicting file** using `read_file` to understand both sides + (`<<<<<<< HEAD` = local/DSE side, `>>>>>>> ` = upstream side). +2. Apply the following priority rules in order: + - **Keep the DSE (HEAD) change** whenever the conflict region contains code + that was introduced or modified by a DSE-specific patch (i.e. it does not + exist in the upstream tag at all, or it is a deliberate override of + upstream behaviour). DSE changes must never be silently dropped. + - **Keep the upstream change** only when the local side is byte-for-byte + identical to the previous upstream base and carries no DSE-specific intent + (e.g. an import that was moved, a purely mechanical refactor, or a + whitespace-only difference that the DSE branch never intentionally touched). + - **Merge both sides** when the upstream adds new logic _and_ the local side + also adds independent logic to the same region — integrate them carefully + so neither is lost. +3. After resolving, write the file back with `apply_diff` or `write_file` so + it contains no conflict markers, then stage it: + +```bash +git add +``` + +#### 5c. Complete the merge + +Once every conflicting file has been staged, finalize the merge commit: + +```bash +git merge --continue +``` + +If `git merge --continue` opens an editor, the pre-filled message is already +correct (it was passed in step 4); just accept it as-is. + +### 6. Verify pom.xml versions are consistent + +After the merge commit is created, double-check that all `pom.xml` files carry +the correct DSE version and not the upstream version: + +```bash +grep -r "" --include="pom.xml" . | grep -v "\.1\.dse" | grep -v "target/" +``` + +If any `pom.xml` still contains the bare upstream version (e.g. `4.1.133.Final` +instead of `4.1.133.1.dse`), amend them and stage + commit the fix: + +```bash +# Fix remaining pom.xml files that slipped through conflict resolution +mvn versions:set -DnewVersion= -DgenerateBackupPoms=false +git add -u +git commit -m "Update version to " +``` + +- `` — the full DSE version string, e.g. `4.1.133.1.dse`. +- `-DgenerateBackupPoms=false` avoids leaving `pom.xml.versionsBackup` files + behind. + +### 7. Verify the result + +```bash +git log --oneline --graph -10 +``` + +Confirm the merge commit is at the HEAD and that both parent lines of the +merge are visible. + +### 8. Push + +```bash +git push origin +``` + +Ask the user for confirmation before pushing. diff --git a/.github/workflows/README-build-and-publish.md b/.github/workflows/README-build-and-publish.md new file mode 100644 index 00000000000..4440763b116 --- /dev/null +++ b/.github/workflows/README-build-and-publish.md @@ -0,0 +1,233 @@ +# Build and Publish Workflow + +## Overview + +This GitHub Actions workflow (`build-and-publish.yml`) builds the Netty library across multiple platforms and publishes the artifacts to GitHub Packages. + +## Workflow Architecture + +The workflow consists of 4 stages that run in sequence: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Stage 1: Linux x86_64 Full Build │ +│ - Builds all Netty modules │ +│ - Uses Docker with CentOS 6 for compatibility │ +│ - Produces complete JAR artifacts │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Stage 2: macOS Intel x86_64 Native Libraries │ +│ - Builds native modules only: │ +│ • resolver-dns-native-macos │ +│ • transport-native-unix-common │ +│ • transport-native-kqueue │ +│ - Runs on GitHub-hosted Intel Mac │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Stage 3: macOS ARM aarch64 Native Libraries │ +│ - Builds same native modules as Stage 2 │ +│ - Runs on GitHub-hosted Apple Silicon Mac │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Stage 4: Merge and Publish │ +│ - Downloads all artifacts from previous stages │ +│ - Merges staging repositories │ +│ - Generates netty-all module │ +│ - Publishes to GitHub Packages │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Triggers + +The workflow can be triggered in two ways: + +1. **Manual Trigger**: Via the GitHub Actions UI (workflow_dispatch) +2. **Tag Push**: Automatically when pushing version tags: + - Tags starting with `v*` (e.g., `v4.1.100`) + - Tags starting with `netty-*` (e.g., `netty-4.1.100.Final`) + +## Prerequisites + +### Repository Configuration + +1. **GitHub Packages**: Ensure GitHub Packages is enabled for your repository +2. **Permissions**: The workflow requires the following permissions: + - `contents: read` - To checkout the repository + - `packages: write` - To publish to GitHub Packages + +### Required Files + +The workflow depends on these existing files: +- `docker/Dockerfile.centos6` - Docker image for Linux builds +- `.github/scripts/local_staging_install_release.sh` - Script to merge staging artifacts +- `.github/actions/thread-dump-jvms/action.yml` - Action for debugging cancelled jobs +- `Brewfile` - macOS dependencies (optional, continues on error) + +### Secrets + +The workflow uses the built-in `GITHUB_TOKEN` secret, which is automatically provided by GitHub Actions. No additional secrets need to be configured. + +## Usage + +### Manual Trigger + +1. Go to the **Actions** tab in your GitHub repository +2. Select **Build and Publish to GitHub Packages** workflow +3. Click **Run workflow** +4. Select the branch to run from +5. Click **Run workflow** button + +### Tag-based Trigger + +```bash +# Create and push a version tag +git tag v4.1.100.Final +git push origin v4.1.100.Final +``` + +The workflow will automatically start building and publishing. + +## Artifacts + +### Intermediate Artifacts + +Each build stage uploads its artifacts to GitHub Actions: +- `linux-x86_64-local-staging` - Linux build artifacts +- `macos-x86_64-local-staging` - Intel Mac native libraries +- `macos-aarch64-local-staging` - ARM Mac native libraries +- `merged-local-staging` - Final merged artifacts (for debugging) + +These artifacts are retained for 90 days (GitHub default) and can be downloaded from the workflow run page. + +### Published Artifacts + +Final artifacts are published to GitHub Packages Maven registry at: +``` +https://maven.pkg.github.com/OWNER/REPOSITORY +``` + +## Consuming Published Artifacts + +To use the published artifacts in your Maven project: + +### 1. Configure Maven Settings + +Add to your `~/.m2/settings.xml`: + +```xml + + + + github + YOUR_GITHUB_USERNAME + YOUR_GITHUB_TOKEN + + + +``` + +### 2. Add Repository to pom.xml + +```xml + + + github + https://maven.pkg.github.com/OWNER/REPOSITORY + + +``` + +### 3. Add Netty Dependencies + +```xml + + io.netty + netty-all + YOUR_VERSION + +``` + +## Build Times + +Approximate build times (may vary): +- **Stage 1 (Linux)**: 15-25 minutes +- **Stage 2 (macOS Intel)**: 10-15 minutes +- **Stage 3 (macOS ARM)**: 10-15 minutes +- **Stage 4 (Merge & Publish)**: 5-10 minutes +- **Total**: ~40-65 minutes + +## Troubleshooting + +### Build Failures + +1. **Check the logs**: Click on the failed job to see detailed logs +2. **Download artifacts**: Failed builds may still produce partial artifacts for debugging +3. **Thread dumps**: If a job is cancelled, thread dumps are automatically captured + +### Common Issues + +**Docker build fails on Linux**: +- Check if `docker/Dockerfile.centos6` exists and is valid +- Verify Docker daemon is accessible + +**macOS native build fails**: +- Check if Brewfile dependencies are correct +- Verify JDK 8 is properly installed +- Check native compilation toolchain (Xcode Command Line Tools) + +**Publishing fails**: +- Verify `GITHUB_TOKEN` has `packages:write` permission +- Check if GitHub Packages is enabled for the repository +- Ensure the repository URL in the workflow matches your repository + +### Re-running Failed Jobs + +You can re-run individual failed jobs without re-running the entire workflow: +1. Go to the workflow run page +2. Click on the failed job +3. Click **Re-run jobs** → **Re-run failed jobs** + +## Caching + +The workflow uses Maven repository caching to speed up builds: +- Linux builds cache: `~/.m2/repository` +- macOS Intel builds cache: `~/.m2/repository` (separate cache key) +- macOS ARM builds cache: `~/.m2/repository` (separate cache key) + +Caches are automatically invalidated when `pom.xml` files change. + +## Maintenance + +### Updating Dependencies + +- **Java Version**: Modify the `java-version` in the `setup-java` steps +- **Docker Image**: Update `docker/Dockerfile.centos6` +- **macOS Dependencies**: Update `Brewfile` + +### Adding New Platforms + +To add support for additional platforms: +1. Add a new job in the workflow (e.g., `build-windows-x64`) +2. Configure the appropriate runner (e.g., `runs-on: windows-latest`) +3. Add the job to the `needs` array in `publish-to-github-packages` +4. Update the artifact download and merge steps + +## Security Considerations + +- The workflow uses minimal permissions (read contents, write packages) +- Secrets are not exposed in logs +- Docker containers run with volume mounts but no privileged access +- All dependencies are cached and verified via checksums + +## Support + +For issues with this workflow: +1. Check the [GitHub Actions documentation](https://docs.github.com/en/actions) +2. Review the [Netty build documentation](BUILD-DATASTAX.md) +3. Open an issue in the repository with workflow run logs \ No newline at end of file diff --git a/.github/workflows/autoport-41.yml b/.github/workflows/autoport-41.yml new file mode 100644 index 00000000000..682ac53c370 --- /dev/null +++ b/.github/workflows/autoport-41.yml @@ -0,0 +1,130 @@ +name: Auto-port to 4.1 +on: + pull_request_target: + types: + - closed + - labeled + branches: + - '4.2' + - '5.0' + +jobs: + autoport: + name: "Auto-porting to 4.1" + concurrency: + group: port-41-${{ github.event.pull_request.number }} + cancel-in-progress: true + if: github.event.pull_request.merged && contains(github.event.pull_request.labels.*.name, 'needs-cherry-pick-4.1') + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + ssh-key: ${{ secrets.SSH_PRIVATE_KEY_PEM }} + ssh-known-hosts: ${{ secrets.SSH_KNOWN_HOSTS }} + fetch-depth: '0' # Cherry-pick needs full history + + - name: Setup git configuration + run: | + git config --global user.email "netty-project-bot@users.noreply.github.com" + git config --global user.name "Netty Project Bot" + + - name: Create auto-port PR branch and cherry-pick + id: cherry-pick + run: | + MERGE_COMMIT="${{ github.event.pull_request.merge_commit_sha }}" + echo "Auto-porting commit: $MERGE_COMMIT" + + PORT_BRANCH="auto-port-pr-${{ github.event.pull_request.number }}-to-4.1" + if [[ $(git branch --show-current) != '4.1' ]]; then + git fetch origin 4.1:4.1 + fi + git checkout -b "$PORT_BRANCH" 4.1 + + if git cherry-pick -x "$MERGE_COMMIT"; then + echo "Cherry-pick successful" + else + echo "Cherry-pick failed - conflicts detected" + git cherry-pick --abort + exit 1 + fi + echo "branch=$PORT_BRANCH" >> "$GITHUB_OUTPUT" + + - name: Push auto-port branch + id: push + if: steps.cherry-pick.outcome == 'success' + run: | + if ! git push origin "${{ steps.cherry-pick.outputs.branch }}"; then + echo "Auto-port branch push failed" + exit 1 + fi + + - name: Create pull request + id: create-pr + if: steps.cherry-pick.outcome == 'success' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + const { data: pr } = await github.rest.pulls.create({ + owner: context.repo.owner, + repo: context.repo.repo, + title: `Auto-port 4.1: ${context.payload.pull_request.title}`, + head: '${{ steps.cherry-pick.outputs.branch }}', + base: '4.1', + body: `Auto-port of #${context.payload.pull_request.number} to 4.1\n` + + `Cherry-picked commit: ${context.payload.pull_request.merge_commit_sha}\n\n---\n` + + `${context.payload.pull_request.body || ''}` + }); + console.log(`Created auto-port PR: ${pr.html_url}`); + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Auto-port PR for 4.1: #${pr.number}` + }); + + # Important: This script MUST run with the default GITHUB_TOKEN to avoid triggering other actions. + - name: Remove triggering label + if: steps.create-pr.outcome == 'success' + uses: actions/github-script@v8 + with: + script: | + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + name: 'needs-cherry-pick-4.1' + }); + + - name: Report cherry-pick conflicts + if: failure() && steps.cherry-pick.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\nGot conflicts when cherry-picking onto 4.1.` + }); + + - name: Report auto-port branch push failure + if: failure() && steps.push.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\n`+ + `I could cherry-pick onto 4.1 just fine, but pushing the new branch failed.` + }); + + - name: Remove branch on PR create failure + if: failure() && steps.cherry-pick.outputs.branch + run: | + git push -d origin "${{ steps.cherry-pick.outputs.branch }}" diff --git a/.github/workflows/autoport-42.yml b/.github/workflows/autoport-42.yml new file mode 100644 index 00000000000..15b27eafe67 --- /dev/null +++ b/.github/workflows/autoport-42.yml @@ -0,0 +1,130 @@ +name: Auto-port to 4.2 +on: + pull_request_target: + types: + - closed + - labeled + branches: + - '4.1' + - '5.0' + +jobs: + autoport: + name: "Auto-porting to 4.2" + concurrency: + group: port-42-${{ github.event.pull_request.number }} + cancel-in-progress: true + if: github.event.pull_request.merged && contains(github.event.pull_request.labels.*.name, 'needs-cherry-pick-4.2') + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + ssh-key: ${{ secrets.SSH_PRIVATE_KEY_PEM }} + ssh-known-hosts: ${{ secrets.SSH_KNOWN_HOSTS }} + fetch-depth: '0' # Cherry-pick needs full history + + - name: Setup git configuration + run: | + git config --global user.email "netty-project-bot@users.noreply.github.com" + git config --global user.name "Netty Project Bot" + + - name: Create auto-port PR branch and cherry-pick + id: cherry-pick + run: | + MERGE_COMMIT="${{ github.event.pull_request.merge_commit_sha }}" + echo "Auto-porting commit: $MERGE_COMMIT" + + PORT_BRANCH="auto-port-pr-${{ github.event.pull_request.number }}-to-4.2" + if [[ $(git branch --show-current) != '4.2' ]]; then + git fetch origin 4.2:4.2 + fi + git checkout -b "$PORT_BRANCH" 4.2 + + if git cherry-pick -x "$MERGE_COMMIT"; then + echo "Cherry-pick successful" + else + echo "Cherry-pick failed - conflicts detected" + git cherry-pick --abort + exit 1 + fi + echo "branch=$PORT_BRANCH" >> "$GITHUB_OUTPUT" + + - name: Push auto-port branch + id: push + if: steps.cherry-pick.outcome == 'success' + run: | + if ! git push origin "${{ steps.cherry-pick.outputs.branch }}"; then + echo "Auto-port branch push failed" + exit 1 + fi + + - name: Create pull request + id: create-pr + if: steps.cherry-pick.outcome == 'success' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + const { data: pr } = await github.rest.pulls.create({ + owner: context.repo.owner, + repo: context.repo.repo, + title: `Auto-port 4.2: ${context.payload.pull_request.title}`, + head: '${{ steps.cherry-pick.outputs.branch }}', + base: '4.2', + body: `Auto-port of #${context.payload.pull_request.number} to 4.2\n` + + `Cherry-picked commit: ${context.payload.pull_request.merge_commit_sha}\n\n---\n` + + `${context.payload.pull_request.body || ''}` + }); + console.log(`Created auto-port PR: ${pr.html_url}`); + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Auto-port PR for 4.2: #${pr.number}` + }); + + # Important: This script MUST run with the default GITHUB_TOKEN to avoid triggering other actions. + - name: Remove triggering label + if: steps.create-pr.outcome == 'success' + uses: actions/github-script@v8 + with: + script: | + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + name: 'needs-cherry-pick-4.2' + }); + + - name: Report cherry-pick conflicts + if: failure() && steps.cherry-pick.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\nGot conflicts when cherry-picking onto 4.2.` + }); + + - name: Report auto-port branch push failure + if: failure() && steps.push.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\n`+ + `I could cherry-pick onto 4.2 just fine, but pushing the new branch failed.` + }); + + - name: Remove branch on PR create failure + if: failure() && steps.cherry-pick.outputs.branch + run: | + git push -d origin "${{ steps.cherry-pick.outputs.branch }}" diff --git a/.github/workflows/autoport-50.yml b/.github/workflows/autoport-50.yml new file mode 100644 index 00000000000..2899d56e209 --- /dev/null +++ b/.github/workflows/autoport-50.yml @@ -0,0 +1,130 @@ +name: Auto-port to 5.0 +on: + pull_request_target: + types: + - closed + - labeled + branches: + - '4.1' + - '4.2' + +jobs: + autoport: + name: "Auto-porting to 5.0" + concurrency: + group: port-50-${{ github.event.pull_request.number }} + cancel-in-progress: true + if: github.event.pull_request.merged && contains(github.event.pull_request.labels.*.name, 'needs-cherry-pick-5.0') + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + ssh-key: ${{ secrets.SSH_PRIVATE_KEY_PEM }} + ssh-known-hosts: ${{ secrets.SSH_KNOWN_HOSTS }} + fetch-depth: '0' # Cherry-pick needs full history + + - name: Setup git configuration + run: | + git config --global user.email "netty-project-bot@users.noreply.github.com" + git config --global user.name "Netty Project Bot" + + - name: Create auto-port PR branch and cherry-pick + id: cherry-pick + run: | + MERGE_COMMIT="${{ github.event.pull_request.merge_commit_sha }}" + echo "Auto-porting commit: $MERGE_COMMIT" + + PORT_BRANCH="auto-port-pr-${{ github.event.pull_request.number }}-to-5.0" + if [[ $(git branch --show-current) != '5.0' ]]; then + git fetch origin 5.0:5.0 + fi + git checkout -b "$PORT_BRANCH" 5.0 + + if git cherry-pick -x "$MERGE_COMMIT"; then + echo "Cherry-pick successful" + else + echo "Cherry-pick failed - conflicts detected" + git cherry-pick --abort + exit 1 + fi + echo "branch=$PORT_BRANCH" >> "$GITHUB_OUTPUT" + + - name: Push auto-port branch + id: push + if: steps.cherry-pick.outcome == 'success' + run: | + if ! git push origin "${{ steps.cherry-pick.outputs.branch }}"; then + echo "Auto-port branch push failed" + exit 1 + fi + + - name: Create pull request + id: create-pr + if: steps.cherry-pick.outcome == 'success' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + const { data: pr } = await github.rest.pulls.create({ + owner: context.repo.owner, + repo: context.repo.repo, + title: `Auto-port 5.0: ${context.payload.pull_request.title}`, + head: '${{ steps.cherry-pick.outputs.branch }}', + base: '5.0', + body: `Auto-port of #${context.payload.pull_request.number} to 5.0\n` + + `Cherry-picked commit: ${context.payload.pull_request.merge_commit_sha}\n\n---\n` + + `${context.payload.pull_request.body || ''}` + }); + console.log(`Created auto-port PR: ${pr.html_url}`); + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Auto-port PR for 5.0: #${pr.number}` + }); + + # Important: This script MUST run with the default GITHUB_TOKEN to avoid triggering other actions. + - name: Remove triggering label + if: steps.create-pr.outcome == 'success' + uses: actions/github-script@v8 + with: + script: | + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + name: 'needs-cherry-pick-5.0' + }); + + - name: Report cherry-pick conflicts + if: failure() && steps.cherry-pick.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\nGot conflicts when cherry-picking onto 5.0.` + }); + + - name: Report auto-port branch push failure + if: failure() && steps.push.outcome == 'failure' + uses: actions/github-script@v8 + with: + github-token: '${{ secrets.PAT_TOKEN_READ_WRITE_PR }}' + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: `Could not create auto-port PR.\n`+ + `I could cherry-pick onto 5.0 just fine, but pushing the new branch failed.` + }); + + - name: Remove branch on PR create failure + if: failure() && steps.cherry-pick.outputs.branch + run: | + git push -d origin "${{ steps.cherry-pick.outputs.branch }}" diff --git a/.github/workflows/build-and-publish.yml b/.github/workflows/build-and-publish.yml new file mode 100644 index 00000000000..946bd3990e2 --- /dev/null +++ b/.github/workflows/build-and-publish.yml @@ -0,0 +1,306 @@ +name: Build and Publish DataStax Netty to GitHub Packages + +on: + # Allows manual trigger from the Actions tab + workflow_dispatch: + + # Trigger on version tags + push: + branches: + - dse-netty-4.1.135 + tags: + - '*.dse' + - 'dse-netty-*' + +permissions: + contents: read + packages: write + +env: + MAVEN_OPTS: -Dhttp.keepAlive=false -Dmaven.wagon.http.pool=false -Dmaven.wagon.http.retryhandler.count=5 -Dmaven.wagon.httpconnectionManager.ttlSeconds=240 + +# Cancel running jobs when a new push happens to the same branch/tag +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + # Stage 1: Build full Netty library on Linux x64 + build-linux-x64: + runs-on: ubuntu-latest + name: Build Linux x86_64 (Full) + + steps: + - uses: actions/checkout@v4 + + # Cache .m2/repository + - name: Cache local Maven repository + uses: actions/cache@v4 + continue-on-error: true + with: + path: ~/.m2/repository + key: cache-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + cache-maven-${{ hashFiles('**/pom.xml') }} + cache-maven- + + - name: Configure Maven settings for Docker + run: | + mkdir -p ~/.m2 + cat > ~/.m2/settings.xml << 'EOF' + + + + github + ${env.GITHUB_ACTOR} + ${env.GITHUB_TOKEN} + + + + EOF + env: + GITHUB_ACTOR: ${{ github.actor }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Create local staging directory + run: mkdir -p ~/local-staging + + - name: Build docker image + run: docker build -f docker/Dockerfile-netty-centos6 -t netty-centos6 . + + - name: Build and stage artifacts + run: | + docker run -t \ + -v ~/.m2:/root/.m2:Z \ + -v ~/local-staging:/root/local-staging:Z \ + -v $(pwd):/code:Z \ + -w /code \ + --entrypoint="" \ + netty-centos6 \ + bash -ic "./mvnw -B clean install -DskipTests=true ; ./mvnw -B deploy -DaltDeploymentRepository=local-staging::default::file:///root/local-staging -DskipTests=true" + + - name: Upload local staging directory + uses: actions/upload-artifact@v4 + with: + name: linux-x86_64-local-staging + path: ~/local-staging + if-no-files-found: error + include-hidden-files: true + + # Stage 2: Build macOS Intel x86_64 native libraries + build-macos-intel: + runs-on: macos-15-intel + name: Build macOS x86_64 (Native Libraries) + needs: [build-linux-x64] + + steps: + - uses: actions/checkout@v4 + + - name: Set up JDK 8 + uses: actions/setup-java@v4 + with: + distribution: 'zulu' + java-version: '8' + + # Cache .m2/repository + - name: Cache local Maven repository + uses: actions/cache@v4 + continue-on-error: true + with: + path: ~/.m2/repository + key: cache-maven-macos-intel-${{ hashFiles('**/pom.xml') }} + restore-keys: | + cache-maven-macos-intel-${{ hashFiles('**/pom.xml') }} + cache-maven- + + - name: Install tools via brew + run: brew bundle + continue-on-error: true + + - name: Create local staging directory + run: mkdir -p ~/local-staging + + - name: Build and stage native libraries + run: | + echo "$(pwd)" + ./mvnw -B -U \ + -pl resolver-dns-native-macos,transport-native-unix-common,transport-native-kqueue \ + deploy \ + -DskipTests \ + -DaltDeploymentRepository=local-staging::default::file:///$(pwd)/local-staging + find $(pwd)/local-staging + find ~/local-staging + + - name: Upload local staging directory + uses: actions/upload-artifact@v4 + with: + name: macos-x86_64-local-staging + path: ${{ github.workspace }}/local-staging + if-no-files-found: error + include-hidden-files: true + + # Stage 3: Build macOS ARM aarch64 native libraries + build-macos-arm: + runs-on: macos-15 + name: Build macOS aarch64 (Native Libraries) + needs: [build-linux-x64] + + steps: + - uses: actions/checkout@v4 + + - name: Set up JDK 8 + uses: actions/setup-java@v4 + with: + distribution: 'zulu' + java-version: '8' + + # Cache .m2/repository + - name: Cache local Maven repository + uses: actions/cache@v4 + continue-on-error: true + with: + path: ~/.m2/repository + key: cache-maven-macos-arm-${{ hashFiles('**/pom.xml') }} + restore-keys: | + cache-maven-macos-arm-${{ hashFiles('**/pom.xml') }} + cache-maven- + + - name: Install tools via brew + run: brew bundle + continue-on-error: true + + - name: Create local staging directory + run: mkdir -p ~/local-staging + + - name: Build and stage native libraries + run: | + echo "$(pwd)" + ./mvnw -B -Pmac-m1-cross-compile deploy \ + -pl resolver-dns-native-macos,transport-native-unix-common,transport-native-kqueue \ + -DskipTests \ + -DaltDeploymentRepository=local-staging::default::file:///$(pwd)/local-staging + find $(pwd)/local-staging + find ~/local-staging + + - name: Upload local staging directory + uses: actions/upload-artifact@v4 + with: + name: macos-aarch64-local-staging + path: ${{ github.workspace }}/local-staging + if-no-files-found: error + include-hidden-files: true + + # Stage 4: Merge artifacts and publish to GitHub Packages + publish-to-github-packages: + runs-on: ubuntu-latest + name: Merge and Publish to GitHub Packages + needs: [build-linux-x64, build-macos-intel, build-macos-arm] + + steps: + - uses: actions/checkout@v4 + + - name: Set up JDK 8 + uses: actions/setup-java@v4 + with: + distribution: 'zulu' + java-version: '8' + + # Cache .m2/repository + - name: Cache local Maven repository + uses: actions/cache@v4 + continue-on-error: true + with: + path: ~/.m2/repository + key: cache-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + cache-maven-${{ hashFiles('**/pom.xml') }} + cache-maven- + + # Configure Maven settings for GitHub Packages + - name: Configure Maven settings + uses: s4u/maven-settings-action@v3.0.0 + with: + servers: | + [{ + "id": "github", + "username": "${{ github.actor }}", + "password": "${{ secrets.GITHUB_TOKEN }}" + }, + { + "id": "central-portal-snapshots", + "username": "${{ github.actor }}", + "password": "${{ secrets.GITHUB_TOKEN }}" + }] + + # Setup environment variables + - name: Prepare environment variables + run: | + echo "LOCAL_STAGING_DIR=$HOME/local-staging" >> $GITHUB_ENV + + # Download all staging artifacts + - name: Download Linux x86_64 staging directory + uses: actions/download-artifact@v4 + with: + name: linux-x86_64-local-staging + path: ~/linux-x86_64-local-staging + + - name: Download macOS x86_64 staging directory + uses: actions/download-artifact@v4 + with: + name: macos-x86_64-local-staging + path: ~/macos-x86_64-local-staging + + - name: Download macOS aarch64 staging directory + uses: actions/download-artifact@v4 + with: + name: macos-aarch64-local-staging + path: ~/macos-aarch64-local-staging + + # Install artifacts to local Maven repository + - name: Copy build artifacts to local maven repository + run: | + bash ./.github/scripts/local_staging_install_release.sh \ + ~/.m2/repository \ + ~/linux-x86_64-local-staging \ + ~/macos-x86_64-local-staging \ + ~/macos-aarch64-local-staging + + # Generate netty-all and install to local Maven repository + - name: Generate netty-all + run: | + ./mvnw -B --file pom.xml -pl all \ + clean install \ + -DskipTests=true + + # Merge all staging repositories + - name: Merge staging repositories + run: | + bash ./.github/scripts/local_staging_install_release.sh \ + ~/local-staging \ + ~/linux-x86_64-local-staging \ + ~/macos-x86_64-local-staging \ + ~/macos-aarch64-local-staging + + # Copy netty-all from local Maven repository + if [ -d "$HOME/.m2/repository/io/netty/netty-all" ]; then + cp -r $HOME/.m2/repository/io/netty/netty-all $HOME/local-staging/io/netty/ + fi + + # Deploy to GitHub Packages + - name: Deploy to GitHub Packages + run: | + ./mvnw -B --file pom.xml \ + org.sonatype.plugins:nexus-staging-maven-plugin:deploy-staged \ + -DaltStagingDirectory=$HOME/local-staging \ + -DserverId=github \ + -DnexusUrl=https://maven.pkg.github.com/${{ github.repository }} \ + -DrepositoryId=github + + - name: Upload merged staging directory (for debugging) + uses: actions/upload-artifact@v4 + if: always() + with: + name: merged-local-staging + path: ~/local-staging + if-no-files-found: warn + include-hidden-files: true diff --git a/.github/workflows/ci-deploy.yml b/.github/workflows/ci-deploy.yml index 3f41bd26508..7a8e2abf467 100644 --- a/.github/workflows/ci-deploy.yml +++ b/.github/workflows/ci-deploy.yml @@ -90,7 +90,7 @@ jobs: matrix: include: - setup: macos-x86_64-java8 - os: macos-13 + os: macos-15-intel - setup: macos-aarch64-java8 os: macos-15 diff --git a/.github/workflows/ci-pr.yml b/.github/workflows/ci-pr.yml index 506ac787e17..f67ed7ed0fd 100644 --- a/.github/workflows/ci-pr.yml +++ b/.github/workflows/ci-pr.yml @@ -201,16 +201,35 @@ jobs: - setup: linux-x86_64-java11-adaptive docker-compose-build: "-f docker/docker-compose.yaml -f docker/docker-compose.centos-6.111.yaml build" docker-compose-run: "-f docker/docker-compose.yaml -f docker/docker-compose.centos-6.111.yaml run build-leak-adaptive" + - setup: linux-x86_64-java11-awslc + docker-compose-build: "-f docker/docker-compose.yaml -f docker/docker-compose.al2023.yaml build" + docker-compose-install-tcnative: "-f docker/docker-compose.yaml -f docker/docker-compose.al2023.yaml run install-tcnative" + docker-compose-update-tcnative-version: "-f docker/docker-compose.yaml -f docker/docker-compose.al2023.yaml run update-tcnative-version" + docker-compose-run: "-f docker/docker-compose.yaml -f docker/docker-compose.al2023.yaml run build" name: ${{ matrix.setup }} build needs: verify-pr + defaults: + run: + working-directory: netty steps: - uses: actions/checkout@v4 + with: + path: netty + + - uses: actions/checkout@v4 + if: ${{ endsWith(matrix.setup, '-awslc') }} + with: + repository: netty/netty-tcnative + ref: main + path: netty-tcnative + fetch-depth: 0 # Cache .m2/repository - name: Cache local Maven repository uses: actions/cache@v4 continue-on-error: true + if: ${{ !endsWith(matrix.setup, '-awslc') }} with: path: ~/.m2/repository key: cache-maven-${{ hashFiles('**/pom.xml') }} @@ -218,9 +237,28 @@ jobs: cache-maven-${{ hashFiles('**/pom.xml') }} cache-maven- + - name: Cache local Maven repository + uses: actions/cache@v4 + continue-on-error: true + if: ${{ endsWith(matrix.setup, '-awslc') }} + with: + path: ~/.m2-al2023/repository + key: cache-maven-al2023-${{ hashFiles('**/pom.xml') }} + restore-keys: | + cache-maven-al2023-${{ hashFiles('**/pom.xml') }} + cache-maven-al2023- + - name: Build docker image run: docker compose ${{ matrix.docker-compose-build }} + - name: Install custom netty-tcnative + if: ${{ endsWith(matrix.setup, '-awslc') }} + run: docker compose ${{ matrix.docker-compose-install-tcnative }} + + - name: Update netty-tcnative version + if: ${{ endsWith(matrix.setup, '-awslc') }} + run: docker compose ${{ matrix.docker-compose-update-tcnative-version }} + - name: Build project with leak detection run: docker compose ${{ matrix.docker-compose-run }} | tee build-leak.output @@ -231,7 +269,7 @@ jobs: run: ./.github/scripts/check_leak.sh build-leak.output - name: print JVM thread dumps when cancelled - uses: ./.github/actions/thread-dump-jvms + uses: ./netty/.github/actions/thread-dump-jvms if: ${{ cancelled() }} - name: Upload Test Results @@ -239,17 +277,17 @@ jobs: uses: actions/upload-artifact@v4 with: name: test-results-${{ matrix.setup }} - path: '**/target/surefire-reports/TEST-*.xml' + path: 'netty/**/target/surefire-reports/TEST-*.xml' - uses: actions/upload-artifact@v4 if: ${{ failure() || cancelled() }} with: name: build-${{ matrix.setup }}-target path: | - **/target/surefire-reports/ - **/target/autobahntestsuite-reports/ - **/hs_err*.log - **/core.* + netty/**/target/surefire-reports/ + netty/**/target/autobahntestsuite-reports/ + netty/**/hs_err*.log + netty/**/core.* build-pr-macos: strategy: @@ -257,7 +295,7 @@ jobs: matrix: include: - setup: macos-x86_64-java8-boringssl - os: macos-13 + os: macos-15-intel - setup: macos-aarch64-java8-boringssl os: macos-15 diff --git a/.github/workflows/ci-release-4.2.yml b/.github/workflows/ci-release-4.2.yml index 619fb0ac819..51709846987 100644 --- a/.github/workflows/ci-release-4.2.yml +++ b/.github/workflows/ci-release-4.2.yml @@ -185,7 +185,7 @@ jobs: matrix: include: - setup: macos-x86_64-java11 - os: macos-13 + os: macos-15-intel - setup: macos-aarch64-java11 os: macos-15 diff --git a/.github/workflows/ci-release.yml b/.github/workflows/ci-release.yml index 9f60f3b80e2..c797a64f127 100644 --- a/.github/workflows/ci-release.yml +++ b/.github/workflows/ci-release.yml @@ -185,7 +185,7 @@ jobs: matrix: include: - setup: macos-x86_64-java8 - os: macos-13 + os: macos-15-intel - setup: macos-aarch64-java8 os: macos-15 runs-on: ${{ matrix.os }} diff --git a/all/pom.xml b/all/pom.xml index d0a186dc417..defae37f331 100644 --- a/all/pom.xml +++ b/all/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-all diff --git a/bom/pom.xml b/bom/pom.xml index f74949ba8aa..92fe25caf21 100644 --- a/bom/pom.xml +++ b/bom/pom.xml @@ -25,7 +25,7 @@ io.netty netty-bom - 4.1.128.1.dse + 4.1.135.1.dse pom Netty/BOM @@ -49,7 +49,7 @@ https://github.com/netty/netty scm:git:git://github.com/netty/netty.git scm:git:ssh://git@github.com/netty/netty.git - netty-4.1.128.Final + netty-4.1.135.Final @@ -73,7 +73,7 @@ - 2.0.74.Final + 2.0.77.Final diff --git a/buffer/pom.xml b/buffer/pom.xml index 21dd16f77d9..ff905463a7b 100644 --- a/buffer/pom.xml +++ b/buffer/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-buffer diff --git a/buffer/src/main/java/io/netty/buffer/AdaptivePoolingAllocator.java b/buffer/src/main/java/io/netty/buffer/AdaptivePoolingAllocator.java index d4fba097831..7bae63d8a3d 100644 --- a/buffer/src/main/java/io/netty/buffer/AdaptivePoolingAllocator.java +++ b/buffer/src/main/java/io/netty/buffer/AdaptivePoolingAllocator.java @@ -18,22 +18,24 @@ import io.netty.util.ByteProcessor; import io.netty.util.CharsetUtil; import io.netty.util.IllegalReferenceCountException; -import io.netty.util.IntSupplier; +import io.netty.util.IntConsumer; import io.netty.util.NettyRuntime; +import io.netty.util.Recycler; import io.netty.util.Recycler.EnhancedHandle; import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.ConcurrentSkipListIntObjMultimap; +import io.netty.util.concurrent.ConcurrentSkipListIntObjMultimap.IntEntry; import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.FastThreadLocalThread; import io.netty.util.concurrent.MpscAtomicIntegerArrayQueue; import io.netty.util.concurrent.MpscIntQueue; -import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.MathUtil; import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.ReferenceCountUpdater; import io.netty.util.internal.SuppressJava6Requirement; import io.netty.util.internal.SystemPropertyUtil; import io.netty.util.internal.ThreadExecutorMap; -import io.netty.util.internal.ThreadLocalRandom; import io.netty.util.internal.UnstableApi; import java.io.IOException; @@ -47,8 +49,10 @@ import java.nio.channels.ScatteringByteChannel; import java.nio.charset.Charset; import java.util.Arrays; +import java.util.Iterator; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.concurrent.atomic.LongAdder; @@ -83,6 +87,16 @@ @SuppressJava6Requirement(reason = "Guarded by version check") @UnstableApi final class AdaptivePoolingAllocator implements AdaptiveByteBufAllocator.AdaptiveAllocatorApi { + private static final int LOW_MEM_THRESHOLD = 512 * 1024 * 1024; + private static final boolean IS_LOW_MEM = Runtime.getRuntime().maxMemory() <= LOW_MEM_THRESHOLD; + + /** + * Whether the IS_LOW_MEM setting should disable thread-local magazines. + * This can have fairly high performance overhead. + */ + private static final boolean DISABLE_THREAD_LOCAL_MAGAZINES_ON_LOW_MEM = SystemPropertyUtil.getBoolean( + "io.netty.allocator.disableThreadLocalMagazinesOnLowMemory", true); + /** * The 128 KiB minimum chunk size is chosen to encourage the system allocator to delegate to mmap for chunk * allocations. For instance, glibc will do this. @@ -90,11 +104,11 @@ final class AdaptivePoolingAllocator implements AdaptiveByteBufAllocator.Adaptiv * which is a much, much larger space. Chunks are also allocated in whole multiples of the minimum * chunk size, which itself is a whole multiple of popular page sizes like 4 KiB, 16 KiB, and 64 KiB. */ - private static final int MIN_CHUNK_SIZE = 128 * 1024; + static final int MIN_CHUNK_SIZE = 128 * 1024; private static final int EXPANSION_ATTEMPTS = 3; private static final int INITIAL_MAGAZINES = 1; private static final int RETIRE_CAPACITY = 256; - private static final int MAX_STRIPES = NettyRuntime.availableProcessors() * 2; + private static final int MAX_STRIPES = IS_LOW_MEM ? 1 : NettyRuntime.availableProcessors() * 2; private static final int BUFS_PER_CHUNK = 8; // For large buffers, aim to have about this many buffers per chunk. /** @@ -102,7 +116,9 @@ final class AdaptivePoolingAllocator implements AdaptiveByteBufAllocator.Adaptiv *

* This number is 8 MiB, and is derived from the limitations of internal histograms. */ - private static final int MAX_CHUNK_SIZE = 8 * 1024 * 1024; // 8 MiB. + private static final int MAX_CHUNK_SIZE = IS_LOW_MEM ? + 2 * 1024 * 1024 : // 2 MiB for systems with small heaps. + 8 * 1024 * 1024; // 8 MiB. private static final int MAX_POOLED_BUF_SIZE = MAX_CHUNK_SIZE / BUFS_PER_CHUNK; /** @@ -150,21 +166,9 @@ final class AdaptivePoolingAllocator implements AdaptiveByteBufAllocator.Adaptiv 16384, 16896, // 16384 + 512 }; - private static final ChunkReleasePredicate CHUNK_RELEASE_ALWAYS = new ChunkReleasePredicate() { - @Override - public boolean shouldReleaseChunk(int chunkSize) { - return true; - } - }; - private static final ChunkReleasePredicate CHUNK_RELEASE_NEVER = new ChunkReleasePredicate() { - @Override - public boolean shouldReleaseChunk(int chunkSize) { - return false; - } - }; private static final int SIZE_CLASSES_COUNT = SIZE_CLASSES.length; - private static final byte[] SIZE_INDEXES = new byte[(SIZE_CLASSES[SIZE_CLASSES_COUNT - 1] / 32) + 1]; + private static final byte[] SIZE_INDEXES = new byte[SIZE_CLASSES[SIZE_CLASSES_COUNT - 1] / 32 + 1]; static { if (MAGAZINE_BUFFER_QUEUE_CAPACITY < 2) { @@ -175,7 +179,7 @@ public boolean shouldReleaseChunk(int chunkSize) { for (int i = 0; i < SIZE_CLASSES_COUNT; i++) { int sizeClass = SIZE_CLASSES[i]; //noinspection ConstantValue - assert (sizeClass & 5) == 0 : "Size class must be a multiple of 32"; + assert (sizeClass & 31) == 0 : "Size class must be a multiple of 32"; int sizeIndex = sizeIndexOf(sizeClass); Arrays.fill(SIZE_INDEXES, lastIndex + 1, sizeIndex + 1, (byte) i); lastIndex = sizeIndex; @@ -193,8 +197,10 @@ public boolean shouldReleaseChunk(int chunkSize) { chunkRegistry = new ChunkRegistry(); sizeClassedMagazineGroups = createMagazineGroupSizeClasses(this, false); largeBufferMagazineGroup = new MagazineGroup( - this, chunkAllocator, new HistogramChunkControllerFactory(true), false); - threadLocalGroup = new FastThreadLocal() { + this, chunkAllocator, new BuddyChunkManagementStrategy(), false); + + boolean disableThreadLocalGroups = IS_LOW_MEM && DISABLE_THREAD_LOCAL_MAGAZINES_ON_LOW_MEM; + threadLocalGroup = disableThreadLocalGroups ? null : new FastThreadLocal() { @Override protected MagazineGroup[] initialValue() { if (useCacheForNonEventLoopThreads || ThreadExecutorMap.currentExecutor() != null) { @@ -220,7 +226,7 @@ private static MagazineGroup[] createMagazineGroupSizeClasses( for (int i = 0; i < SIZE_CLASSES.length; i++) { int segmentSize = SIZE_CLASSES[i]; groups[i] = new MagazineGroup(allocator, allocator.chunkAllocator, - new SizeClassChunkControllerFactory(segmentSize), isThreadLocal); + new SizeClassChunkManagementStrategy(segmentSize), isThreadLocal); } return groups; } @@ -245,7 +251,7 @@ private static MagazineGroup[] createMagazineGroupSizeClasses( * * @return A new multi-producer, multi-consumer queue. */ - private static Queue createSharedChunkQueue() { + private static Queue createSharedChunkQueue() { return PlatformDependent.newFixedMpmcQueue(CHUNK_REUSE_QUEUE); } @@ -259,13 +265,14 @@ private AdaptiveByteBuf allocate(int size, int maxCapacity, Thread currentThread if (size <= MAX_POOLED_BUF_SIZE) { final int index = sizeClassIndexOf(size); MagazineGroup[] magazineGroups; - if (!FastThreadLocalThread.willCleanupFastThreadLocals(currentThread) || + if (!FastThreadLocalThread.willCleanupFastThreadLocals(Thread.currentThread()) || + IS_LOW_MEM || (magazineGroups = threadLocalGroup.get()) == null) { magazineGroups = sizeClassedMagazineGroups; } if (index < magazineGroups.length) { allocated = magazineGroups[index].allocate(size, maxCapacity, currentThread, buf); - } else { + } else if (!IS_LOW_MEM) { allocated = largeBufferMagazineGroup.allocate(size, maxCapacity, currentThread, buf); } } @@ -292,8 +299,7 @@ static int[] getSizeClasses() { return SIZE_CLASSES.clone(); } - private AdaptiveByteBuf allocateFallback(int size, int maxCapacity, Thread currentThread, - AdaptiveByteBuf buf) { + private AdaptiveByteBuf allocateFallback(int size, int maxCapacity, Thread currentThread, AdaptiveByteBuf buf) { // If we don't already have a buffer, obtain one from the most conveniently available magazine. Magazine magazine; if (buf != null) { @@ -307,10 +313,11 @@ private AdaptiveByteBuf allocateFallback(int size, int maxCapacity, Thread curre } // Create a one-off chunk for this allocation. AbstractByteBuf innerChunk = chunkAllocator.allocate(size, maxCapacity); - Chunk chunk = new Chunk(innerChunk, magazine, false, CHUNK_RELEASE_ALWAYS); + Chunk chunk = new Chunk(innerChunk, magazine); chunkRegistry.add(chunk); try { - chunk.readInitInto(buf, size, size, maxCapacity); + boolean success = chunk.readInitInto(buf, size, size, maxCapacity); + assert success: "Failed to initialize ByteBuf with dedicated chunk"; } finally { // As the chunk is an one-off we need to always call release explicitly as readInitInto(...) // will take care of retain once when successful. Once The AdaptiveByteBuf is released it will @@ -355,38 +362,37 @@ private void free() { largeBufferMagazineGroup.free(); } - static int sizeToBucket(int size) { - return HistogramChunkController.sizeToBucket(size); - } - @SuppressJava6Requirement(reason = "Guarded by version check") private static final class MagazineGroup { private final AdaptivePoolingAllocator allocator; private final ChunkAllocator chunkAllocator; - private final ChunkControllerFactory chunkControllerFactory; - private final Queue chunkReuseQueue; + private final ChunkManagementStrategy chunkManagementStrategy; + private final ChunkCache chunkCache; private final StampedLock magazineExpandLock; private final Magazine threadLocalMagazine; + private Thread ownerThread; private volatile Magazine[] magazines; private volatile boolean freed; MagazineGroup(AdaptivePoolingAllocator allocator, ChunkAllocator chunkAllocator, - ChunkControllerFactory chunkControllerFactory, + ChunkManagementStrategy chunkManagementStrategy, boolean isThreadLocal) { this.allocator = allocator; this.chunkAllocator = chunkAllocator; - this.chunkControllerFactory = chunkControllerFactory; - chunkReuseQueue = createSharedChunkQueue(); + this.chunkManagementStrategy = chunkManagementStrategy; + chunkCache = chunkManagementStrategy.createChunkCache(isThreadLocal); if (isThreadLocal) { + ownerThread = Thread.currentThread(); magazineExpandLock = null; - threadLocalMagazine = new Magazine(this, false, chunkReuseQueue, chunkControllerFactory.create(this)); + threadLocalMagazine = new Magazine(this, false, chunkManagementStrategy.createController(this)); } else { + ownerThread = null; magazineExpandLock = new StampedLock(); threadLocalMagazine = null; Magazine[] mags = new Magazine[INITIAL_MAGAZINES]; for (int i = 0; i < mags.length; i++) { - mags[i] = new Magazine(this, true, chunkReuseQueue, chunkControllerFactory.create(this)); + mags[i] = new Magazine(this, true, chunkManagementStrategy.createController(this)); } magazines = mags; } @@ -446,12 +452,9 @@ private boolean tryExpandMagazines(int currentLength) { if (mags.length >= MAX_STRIPES || mags.length > currentLength || freed) { return true; } - Magazine firstMagazine = mags[0]; Magazine[] expanded = new Magazine[mags.length * 2]; for (int i = 0, l = expanded.length; i < l; i++) { - Magazine m = new Magazine(this, true, chunkReuseQueue, chunkControllerFactory.create(this)); - firstMagazine.initializeSharedStateIn(m); - expanded[i] = m; + expanded[i] = new Magazine(this, true, chunkManagementStrategy.createController(this)); } magazines = expanded; } finally { @@ -464,22 +467,32 @@ private boolean tryExpandMagazines(int currentLength) { return true; } - boolean offerToQueue(Chunk buffer) { + Chunk pollChunk(int size) { + return chunkCache.pollChunk(size); + } + + boolean offerChunk(Chunk chunk) { if (freed) { return false; } - boolean isAdded = chunkReuseQueue.offer(buffer); + if (chunk.hasUnprocessedFreelistEntries()) { + chunk.processFreelistEntries(); + } + boolean isAdded = chunkCache.offerChunk(chunk); + if (freed && isAdded) { // Help to free the reuse queue. - freeChunkReuseQueue(); + freeChunkReuseQueue(ownerThread); } return isAdded; } private void free() { freed = true; + Thread ownerThread = this.ownerThread; if (threadLocalMagazine != null) { + this.ownerThread = null; threadLocalMagazine.free(); } else { long stamp = magazineExpandLock.writeLock(); @@ -492,22 +505,153 @@ private void free() { magazineExpandLock.unlockWrite(stamp); } } - freeChunkReuseQueue(); + freeChunkReuseQueue(ownerThread); } - private void freeChunkReuseQueue() { - for (;;) { - Chunk chunk = chunkReuseQueue.poll(); + private void freeChunkReuseQueue(Thread ownerThread) { + Chunk chunk; + while ((chunk = chunkCache.pollChunk(0)) != null) { + if (ownerThread != null && chunk instanceof SizeClassedChunk) { + SizeClassedChunk threadLocalChunk = (SizeClassedChunk) chunk; + assert ownerThread == threadLocalChunk.ownerThread; + // no release segment can ever happen from the owner Thread since it's not running anymore + // This is required to let the ownerThread to be GC'ed despite there are AdaptiveByteBuf + // that reference some thread local chunk + threadLocalChunk.ownerThread = null; + } + chunk.markToDeallocate(); + } + } + } + + private interface ChunkCache { + Chunk pollChunk(int size); + boolean offerChunk(Chunk chunk); + } + + private static final class ConcurrentQueueChunkCache implements ChunkCache { + private final Queue queue; + + private ConcurrentQueueChunkCache() { + queue = createSharedChunkQueue(); + } + + @Override + public SizeClassedChunk pollChunk(int size) { + // we really don't care about size here since the sized class chunk q + // just care about segments of fixed size! + Queue queue = this.queue; + for (int i = 0; i < CHUNK_REUSE_QUEUE; i++) { + SizeClassedChunk chunk = queue.poll(); if (chunk == null) { + return null; + } + if (chunk.hasRemainingCapacity()) { + return chunk; + } + queue.offer(chunk); + } + return null; + } + + @Override + public boolean offerChunk(Chunk chunk) { + return queue.offer((SizeClassedChunk) chunk); + } + } + + private static final class ConcurrentSkipListChunkCache implements ChunkCache { + private final ConcurrentSkipListIntObjMultimap chunks; + + private ConcurrentSkipListChunkCache() { + chunks = new ConcurrentSkipListIntObjMultimap(-1); + } + + @Override + public Chunk pollChunk(int size) { + if (chunks.isEmpty()) { + return null; + } + IntEntry entry = chunks.pollCeilingEntry(size); + if (entry != null) { + Chunk chunk = entry.getValue(); + if (chunk.hasUnprocessedFreelistEntries()) { + chunk.processFreelistEntries(); + } + return chunk; + } + + Chunk bestChunk = null; + int bestRemainingCapacity = 0; + Iterator> itr = chunks.iterator(); + while (itr.hasNext()) { + entry = itr.next(); + final Chunk chunk; + if (entry != null && (chunk = entry.getValue()).hasUnprocessedFreelistEntries()) { + if (!chunks.remove(entry.getKey(), entry.getValue())) { + continue; + } + chunk.processFreelistEntries(); + int remainingCapacity = chunk.remainingCapacity(); + if (remainingCapacity >= size && + (bestChunk == null || remainingCapacity > bestRemainingCapacity)) { + if (bestChunk != null) { + chunks.put(bestRemainingCapacity, bestChunk); + } + bestChunk = chunk; + bestRemainingCapacity = remainingCapacity; + } else { + chunks.put(remainingCapacity, chunk); + } + } + } + + return bestChunk; + } + + @Override + public boolean offerChunk(Chunk chunk) { + chunks.put(chunk.remainingCapacity(), chunk); + + int size = chunks.size(); + while (size > CHUNK_REUSE_QUEUE) { + // Deallocate the chunk with the fewest incoming references. + int key = -1; + Chunk toDeallocate = null; + for (IntEntry entry : chunks) { + Chunk candidate = entry.getValue(); + if (candidate != null) { + if (toDeallocate == null) { + toDeallocate = candidate; + key = entry.getKey(); + } else { + int candidateRefCnt = candidate.refCnt(); + int toDeallocateRefCnt = toDeallocate.refCnt(); + if (candidateRefCnt < toDeallocateRefCnt || + candidateRefCnt == toDeallocateRefCnt && + candidate.capacity() < toDeallocate.capacity()) { + toDeallocate = candidate; + key = entry.getKey(); + } + } + } + } + if (toDeallocate == null) { break; } - chunk.release(); + if (chunks.remove(key, toDeallocate)) { + toDeallocate.markToDeallocate(); + } + size = chunks.size(); } + return true; } } - private interface ChunkControllerFactory { - ChunkController create(MagazineGroup group); + private interface ChunkManagementStrategy { + ChunkController createController(MagazineGroup group); + + ChunkCache createChunkCache(boolean isThreadLocal); } private interface ChunkController { @@ -516,66 +660,75 @@ private interface ChunkController { */ int computeBufferCapacity(int requestedSize, int maxCapacity, boolean isReallocation); - /** - * Initialize the given chunk factory with shared statistics state (if any) from this factory. - */ - void initializeSharedStateIn(ChunkController chunkController); - /** * Allocate a new {@link Chunk} for the given {@link Magazine}. */ Chunk newChunkAllocation(int promptingSize, Magazine magazine); } - private interface ChunkReleasePredicate { - boolean shouldReleaseChunk(int chunkSize); - } - - private static final class SizeClassChunkControllerFactory implements ChunkControllerFactory { + private static final class SizeClassChunkManagementStrategy implements ChunkManagementStrategy { // To amortize activation/deactivation of chunks, we should have a minimum number of segments per chunk. // We choose 32 because it seems neither too small nor too big. // For segments of 16 KiB, the chunks will be half a megabyte. private static final int MIN_SEGMENTS_PER_CHUNK = 32; private final int segmentSize; private final int chunkSize; - private final int[] segmentOffsets; - private SizeClassChunkControllerFactory(int segmentSize) { + private SizeClassChunkManagementStrategy(int segmentSize) { this.segmentSize = ObjectUtil.checkPositive(segmentSize, "segmentSize"); chunkSize = Math.max(MIN_CHUNK_SIZE, segmentSize * MIN_SEGMENTS_PER_CHUNK); - int segmentsCount = chunkSize / segmentSize; - segmentOffsets = new int[segmentsCount]; - for (int i = 0; i < segmentsCount; i++) { - segmentOffsets[i] = i * segmentSize; - } } @Override - public ChunkController create(MagazineGroup group) { - return new SizeClassChunkController(group, segmentSize, chunkSize, segmentOffsets); + public ChunkController createController(MagazineGroup group) { + return new SizeClassChunkController(group, segmentSize, chunkSize); + } + + @Override + public ChunkCache createChunkCache(boolean isThreadLocal) { + return new ConcurrentQueueChunkCache(); } } private static final class SizeClassChunkController implements ChunkController { - private static final ChunkReleasePredicate FALSE_PREDICATE = new ChunkReleasePredicate() { - @Override - public boolean shouldReleaseChunk(int chunkSize) { - return false; - } - }; private final ChunkAllocator chunkAllocator; private final int segmentSize; private final int chunkSize; private final ChunkRegistry chunkRegistry; - private final int[] segmentOffsets; - private SizeClassChunkController(MagazineGroup group, int segmentSize, int chunkSize, int[] segmentOffsets) { + private SizeClassChunkController(MagazineGroup group, int segmentSize, int chunkSize) { chunkAllocator = group.chunkAllocator; this.segmentSize = segmentSize; this.chunkSize = chunkSize; chunkRegistry = group.allocator.chunkRegistry; - this.segmentOffsets = segmentOffsets; + } + + private MpscAtomicIntegerArrayQueue createEmptyFreeList() { + return new MpscAtomicIntegerArrayQueue(chunkSize / segmentSize, SizeClassedChunk.FREE_LIST_EMPTY); + } + + private MpscAtomicIntegerArrayQueue createFreeList() { + final int segmentsCount = chunkSize / segmentSize; + final MpscAtomicIntegerArrayQueue freeList = new MpscAtomicIntegerArrayQueue( + segmentsCount, SizeClassedChunk.FREE_LIST_EMPTY); + int segmentOffset = 0; + for (int i = 0; i < segmentsCount; i++) { + freeList.offer(segmentOffset); + segmentOffset += segmentSize; + } + return freeList; + } + + private IntStack createLocalFreeList() { + final int segmentsCount = chunkSize / segmentSize; + int segmentOffset = chunkSize; + int[] offsets = new int[segmentsCount]; + for (int i = 0; i < segmentsCount; i++) { + segmentOffset -= segmentSize; + offsets[i] = segmentOffset; + } + return new IntStack(offsets); } @Override @@ -584,235 +737,59 @@ public int computeBufferCapacity( return Math.min(segmentSize, maxCapacity); } - @Override - public void initializeSharedStateIn(ChunkController chunkController) { - // NOOP - } - @Override public Chunk newChunkAllocation(int promptingSize, Magazine magazine) { AbstractByteBuf chunkBuffer = chunkAllocator.allocate(chunkSize, chunkSize); assert chunkBuffer.capacity() == chunkSize; - SizeClassedChunk chunk = new SizeClassedChunk(chunkBuffer, magazine, true, - segmentSize, segmentOffsets, FALSE_PREDICATE); + SizeClassedChunk chunk = new SizeClassedChunk(chunkBuffer, magazine, this); chunkRegistry.add(chunk); return chunk; } } - private static final class HistogramChunkControllerFactory implements ChunkControllerFactory { - private final boolean shareable; + private static final class BuddyChunkManagementStrategy implements ChunkManagementStrategy { + private final AtomicInteger maxChunkSize = new AtomicInteger(); - private HistogramChunkControllerFactory(boolean shareable) { - this.shareable = shareable; + @Override + public ChunkController createController(MagazineGroup group) { + return new BuddyChunkController(group, maxChunkSize); } @Override - public ChunkController create(MagazineGroup group) { - return new HistogramChunkController(group, shareable); + public ChunkCache createChunkCache(boolean isThreadLocal) { + return new ConcurrentSkipListChunkCache(); } } - private static final class HistogramChunkController implements ChunkController, ChunkReleasePredicate { - private static final int MIN_DATUM_TARGET = 1024; - private static final int MAX_DATUM_TARGET = 65534; - private static final int INIT_DATUM_TARGET = 9; - private static final int HISTO_BUCKET_COUNT = 16; - private static final int[] HISTO_BUCKETS = { - 16 * 1024, - 24 * 1024, - 32 * 1024, - 48 * 1024, - 64 * 1024, - 96 * 1024, - 128 * 1024, - 192 * 1024, - 256 * 1024, - 384 * 1024, - 512 * 1024, - 768 * 1024, - 1024 * 1024, - 1792 * 1024, - 2048 * 1024, - 3072 * 1024 - }; - - private final MagazineGroup group; - private final boolean shareable; - private final short[][] histos = { - new short[HISTO_BUCKET_COUNT], new short[HISTO_BUCKET_COUNT], - new short[HISTO_BUCKET_COUNT], new short[HISTO_BUCKET_COUNT], - }; + private static final class BuddyChunkController implements ChunkController { + private final ChunkAllocator chunkAllocator; private final ChunkRegistry chunkRegistry; - private short[] histo = histos[0]; - private final int[] sums = new int[HISTO_BUCKET_COUNT]; - - private int histoIndex; - private int datumCount; - private int datumTarget = INIT_DATUM_TARGET; - private boolean hasHadRotation; - private volatile int sharedPrefChunkSize = MIN_CHUNK_SIZE; - private volatile int localPrefChunkSize = MIN_CHUNK_SIZE; - private volatile int localUpperBufSize; - - private HistogramChunkController(MagazineGroup group, boolean shareable) { - this.group = group; - this.shareable = shareable; - chunkRegistry = group.allocator.chunkRegistry; - } - - @Override - public int computeBufferCapacity( - int requestedSize, int maxCapacity, boolean isReallocation) { - if (!isReallocation) { - // Only record allocation size if it's not caused by a reallocation that was triggered by capacity - // change of the buffer. - recordAllocationSize(requestedSize); - } + private final AtomicInteger maxChunkSize; - // Predict starting capacity from localUpperBufSize, but place limits on the max starting capacity - // based on the requested size, because localUpperBufSize can potentially be quite large. - int startCapLimits; - if (requestedSize <= 32768) { // Less than or equal to 32 KiB. - startCapLimits = 65536; // Use at most 64 KiB, which is also the AdaptiveRecvByteBufAllocator max. - } else { - startCapLimits = requestedSize * 2; // Otherwise use at most twice the requested memory. - } - int startingCapacity = Math.min(startCapLimits, localUpperBufSize); - startingCapacity = Math.max(requestedSize, Math.min(maxCapacity, startingCapacity)); - return startingCapacity; - } - - private void recordAllocationSize(int bufferSizeToRecord) { - // Use the preserved size from the reused AdaptiveByteBuf, if available. - // Otherwise, use the requested buffer size. - // This way, we better take into account - if (bufferSizeToRecord == 0) { - return; - } - int bucket = sizeToBucket(bufferSizeToRecord); - histo[bucket]++; - if (datumCount++ == datumTarget) { - rotateHistograms(); - } - } - - static int sizeToBucket(int size) { - int index = binarySearchInsertionPoint(Arrays.binarySearch(HISTO_BUCKETS, size)); - return index >= HISTO_BUCKETS.length ? HISTO_BUCKETS.length - 1 : index; - } - - private static int binarySearchInsertionPoint(int index) { - if (index < 0) { - index = -(index + 1); - } - return index; - } - - static int bucketToSize(int sizeBucket) { - return HISTO_BUCKETS[sizeBucket]; - } - - private void rotateHistograms() { - short[][] hs = histos; - for (int i = 0; i < HISTO_BUCKET_COUNT; i++) { - sums[i] = (hs[0][i] & 0xFFFF) + (hs[1][i] & 0xFFFF) + (hs[2][i] & 0xFFFF) + (hs[3][i] & 0xFFFF); - } - int sum = 0; - for (int count : sums) { - sum += count; - } - int targetPercentile = (int) (sum * 0.99); - int sizeBucket = 0; - for (; sizeBucket < sums.length; sizeBucket++) { - if (sums[sizeBucket] > targetPercentile) { - break; - } - targetPercentile -= sums[sizeBucket]; - } - hasHadRotation = true; - int percentileSize = bucketToSize(sizeBucket); - int prefChunkSize = Math.max(percentileSize * BUFS_PER_CHUNK, MIN_CHUNK_SIZE); - localUpperBufSize = percentileSize; - localPrefChunkSize = prefChunkSize; - if (shareable) { - for (Magazine mag : group.magazines) { - HistogramChunkController statistics = (HistogramChunkController) mag.chunkController; - prefChunkSize = Math.max(prefChunkSize, statistics.localPrefChunkSize); - } - } - if (sharedPrefChunkSize != prefChunkSize) { - // Preferred chunk size changed. Increase check frequency. - datumTarget = Math.max(datumTarget >> 1, MIN_DATUM_TARGET); - sharedPrefChunkSize = prefChunkSize; - } else { - // Preferred chunk size did not change. Check less often. - datumTarget = Math.min(datumTarget << 1, MAX_DATUM_TARGET); - } - - histoIndex = histoIndex + 1 & 3; - histo = histos[histoIndex]; - datumCount = 0; - Arrays.fill(histo, (short) 0); - } - - /** - * Get the preferred chunk size, based on statistics from the {@linkplain #recordAllocationSize(int) recorded} - * allocation sizes. - *

- * This method must be thread-safe. - * - * @return The currently preferred chunk allocation size. - */ - int preferredChunkSize() { - return sharedPrefChunkSize; + BuddyChunkController(MagazineGroup group, AtomicInteger maxChunkSize) { + chunkAllocator = group.chunkAllocator; + chunkRegistry = group.allocator.chunkRegistry; + this.maxChunkSize = maxChunkSize; } @Override - public void initializeSharedStateIn(ChunkController chunkController) { - HistogramChunkController statistics = (HistogramChunkController) chunkController; - int sharedPrefChunkSize = this.sharedPrefChunkSize; - statistics.localPrefChunkSize = sharedPrefChunkSize; - statistics.sharedPrefChunkSize = sharedPrefChunkSize; + public int computeBufferCapacity(int requestedSize, int maxCapacity, boolean isReallocation) { + return MathUtil.safeFindNextPositivePowerOfTwo(requestedSize); } @Override public Chunk newChunkAllocation(int promptingSize, Magazine magazine) { - int size = Math.max(promptingSize * BUFS_PER_CHUNK, preferredChunkSize()); - int minChunks = size / MIN_CHUNK_SIZE; - if (MIN_CHUNK_SIZE * minChunks < size) { - // Round up to nearest whole MIN_CHUNK_SIZE unit. The MIN_CHUNK_SIZE is an even multiple of many - // popular small page sizes, like 4k, 16k, and 64k, which makes it easier for the system allocator - // to manage the memory in terms of whole pages. This reduces memory fragmentation, - // but without the potentially high overhead that power-of-2 chunk sizes would bring. - size = MIN_CHUNK_SIZE * (1 + minChunks); - } - - // Limit chunks to the max size, even if the histogram suggests to go above it. - size = Math.min(size, MAX_CHUNK_SIZE); - - // If we haven't rotated the histogram yet, optimisticly record this chunk size as our preferred. - if (!hasHadRotation && sharedPrefChunkSize == MIN_CHUNK_SIZE) { - sharedPrefChunkSize = size; - } - - ChunkAllocator chunkAllocator = group.chunkAllocator; - Chunk chunk = new Chunk(chunkAllocator.allocate(size, size), magazine, true, this); + int maxChunkSize = this.maxChunkSize.get(); + int proposedChunkSize = MathUtil.safeFindNextPositivePowerOfTwo(BUFS_PER_CHUNK * promptingSize); + int chunkSize = Math.min(MAX_CHUNK_SIZE, Math.max(maxChunkSize, proposedChunkSize)); + if (chunkSize > maxChunkSize) { + // Update our stored max chunk size. It's fine that this is racy. + this.maxChunkSize.set(chunkSize); + } + BuddyChunk chunk = new BuddyChunk(chunkAllocator.allocate(chunkSize, chunkSize), magazine); chunkRegistry.add(chunk); return chunk; } - - @Override - public boolean shouldReleaseChunk(int chunkSize) { - int preferredSize = preferredChunkSize(); - int givenChunks = chunkSize / MIN_CHUNK_SIZE; - int preferredChunks = preferredSize / MIN_CHUNK_SIZE; - int deviation = Math.abs(givenChunks - preferredChunks); - - // Retire chunks with a 5% probability per unit of MIN_CHUNK_SIZE deviation from preference. - return deviation != 0 && - ThreadLocalRandom.current().nextDouble() * 20.0 < deviation; - } } @SuppressJava6Requirement(reason = "Guarded by version check") @@ -823,13 +800,31 @@ private static final class Magazine { } private static final Chunk MAGAZINE_FREED = new Chunk(); - private static final ObjectPool EVENT_LOOP_LOCAL_BUFFER_POOL = ObjectPool.newPool( - new ObjectPool.ObjectCreator() { - @Override - public AdaptiveByteBuf newObject(ObjectPool.Handle handle) { - return new AdaptiveByteBuf(handle); - } - }); + private static final class AdaptiveRecycler extends Recycler { + + private AdaptiveRecycler() { + } + + private AdaptiveRecycler(int maxCapacity) { + // doesn't use fast thread local, shared + super(maxCapacity); + } + + @Override + protected AdaptiveByteBuf newObject(final Handle handle) { + return new AdaptiveByteBuf((EnhancedHandle) handle); + } + + public static AdaptiveRecycler threadLocal() { + return new AdaptiveRecycler(); + } + + public static AdaptiveRecycler sharedWith(int maxCapacity) { + return new AdaptiveRecycler(maxCapacity); + } + } + + private static final AdaptiveRecycler EVENT_LOOP_LOCAL_BUFFER_POOL = AdaptiveRecycler.threadLocal(); private Chunk current; @SuppressWarnings("unused") // updated via NEXT_IN_LINE @@ -837,31 +832,20 @@ public AdaptiveByteBuf newObject(ObjectPool.Handle handle) { private final MagazineGroup group; private final ChunkController chunkController; private final StampedLock allocationLock; - private final Queue bufferQueue; - private final ObjectPool.Handle handle; - private final Queue sharedChunkQueue; + private final AdaptiveRecycler recycler; - Magazine(MagazineGroup group, boolean shareable, Queue sharedChunkQueue, - ChunkController chunkController) { + Magazine(MagazineGroup group, boolean shareable, ChunkController chunkController) { this.group = group; this.chunkController = chunkController; if (shareable) { // We only need the StampedLock if this Magazine will be shared across threads. allocationLock = new StampedLock(); - bufferQueue = PlatformDependent.newFixedMpmcQueue(MAGAZINE_BUFFER_QUEUE_CAPACITY); - handle = new ObjectPool.Handle() { - @Override - public void recycle(AdaptiveByteBuf self) { - bufferQueue.offer(self); - } - }; + recycler = AdaptiveRecycler.sharedWith(MAGAZINE_BUFFER_QUEUE_CAPACITY); } else { allocationLock = null; - bufferQueue = null; - handle = null; + recycler = null; } - this.sharedChunkQueue = sharedChunkQueue; } public boolean tryAllocate(int size, int maxCapacity, AdaptiveByteBuf buf, boolean reallocate) { @@ -890,7 +874,7 @@ private boolean allocateWithoutLock(int size, int maxCapacity, AdaptiveByteBuf b return false; } if (curr == null) { - curr = sharedChunkQueue.poll(); + curr = group.pollChunk(size); if (curr == null) { return false; } @@ -900,9 +884,10 @@ private boolean allocateWithoutLock(int size, int maxCapacity, AdaptiveByteBuf b int remainingCapacity = curr.remainingCapacity(); int startingCapacity = chunkController.computeBufferCapacity( size, maxCapacity, true /* never update stats as we don't hold the magazine lock */); - if (remainingCapacity >= size) { - curr.readInitInto(buf, size, Math.min(remainingCapacity, startingCapacity), maxCapacity); + if (remainingCapacity >= size && + curr.readInitInto(buf, size, Math.min(remainingCapacity, startingCapacity), maxCapacity)) { allocated = true; + remainingCapacity = curr.remainingCapacity(); } try { if (remainingCapacity >= RETIRE_CAPACITY) { @@ -921,33 +906,17 @@ private boolean allocate(int size, int maxCapacity, AdaptiveByteBuf buf, boolean int startingCapacity = chunkController.computeBufferCapacity(size, maxCapacity, reallocate); Chunk curr = current; if (curr != null) { - // We have a Chunk that has some space left. + boolean success = curr.readInitInto(buf, size, startingCapacity, maxCapacity); int remainingCapacity = curr.remainingCapacity(); - if (remainingCapacity > startingCapacity) { - curr.readInitInto(buf, size, startingCapacity, maxCapacity); - // We still have some bytes left that we can use for the next allocation, just early return. - return true; - } - - // At this point we know that this will be the last time current will be used, so directly set it to - // null and release it once we are done. - current = null; - if (remainingCapacity >= size) { - try { - curr.readInitInto(buf, size, remainingCapacity, maxCapacity); - return true; - } finally { - curr.releaseFromMagazine(); - } - } - - // Check if we either retain the chunk in the nextInLine cache or releasing it. - if (remainingCapacity < RETIRE_CAPACITY) { - curr.releaseFromMagazine(); - } else { - // See if it makes sense to transfer the Chunk to the nextInLine cache for later usage. - // This method will release curr if this is not the case + if (!success && remainingCapacity > 0) { + current = null; transferToNextInLineOrRelease(curr); + } else if (remainingCapacity == 0) { + current = null; + curr.releaseFromMagazine(); + } + if (success) { + return true; } } @@ -969,32 +938,28 @@ private boolean allocate(int size, int maxCapacity, AdaptiveByteBuf buf, boolean } int remainingCapacity = curr.remainingCapacity(); - if (remainingCapacity > startingCapacity) { + if (remainingCapacity > startingCapacity && + curr.readInitInto(buf, size, startingCapacity, maxCapacity)) { // We have a Chunk that has some space left. - curr.readInitInto(buf, size, startingCapacity, maxCapacity); current = curr; return true; } - if (remainingCapacity >= size) { - // At this point we know that this will be the last time curr will be used, so directly set it to - // null and release it once we are done. - try { - curr.readInitInto(buf, size, remainingCapacity, maxCapacity); - return true; - } finally { - // Release in a finally block so even if readInitInto(...) would throw we would still correctly - // release the current chunk before null it out. - curr.releaseFromMagazine(); + try { + if (remainingCapacity >= size) { + // At this point we know that this will be the last time curr will be used, so directly set it + // to null and release it once we are done. + return curr.readInitInto(buf, size, remainingCapacity, maxCapacity); } - } else { - // Release it as it's too small. + } finally { + // Release in a finally block so even if readInitInto(...) would throw we would still correctly + // release the current chunk before null it out. curr.releaseFromMagazine(); } } // Now try to poll from the central queue first - curr = sharedChunkQueue.poll(); + curr = group.pollChunk(size); if (curr == null) { curr = chunkController.newChunkAllocation(size, this); } else { @@ -1015,14 +980,15 @@ private boolean allocate(int size, int maxCapacity, AdaptiveByteBuf buf, boolean } current = curr; + boolean success; try { int remainingCapacity = curr.remainingCapacity(); assert remainingCapacity >= size; if (remainingCapacity > startingCapacity) { - curr.readInitInto(buf, size, startingCapacity, maxCapacity); + success = curr.readInitInto(buf, size, startingCapacity, maxCapacity); curr = null; } else { - curr.readInitInto(buf, size, remainingCapacity, maxCapacity); + success = curr.readInitInto(buf, size, remainingCapacity, maxCapacity); } } finally { if (curr != null) { @@ -1032,7 +998,7 @@ private boolean allocate(int size, int maxCapacity, AdaptiveByteBuf buf, boolean current = null; } } - return true; + return success; } private void restoreMagazineFreed() { @@ -1063,10 +1029,6 @@ private void transferToNextInLineOrRelease(Chunk chunk) { chunk.releaseFromMagazine(); } - boolean trySetNextInLine(Chunk chunk) { - return NEXT_IN_LINE.compareAndSet(this, null, chunk); - } - void free() { // Release the current Chunk and the next that was stored for later usage. restoreMagazineFreed(); @@ -1084,26 +1046,15 @@ void free() { } public AdaptiveByteBuf newBuffer() { - AdaptiveByteBuf buf; - if (handle == null) { - buf = EVENT_LOOP_LOCAL_BUFFER_POOL.get(); - } else { - buf = bufferQueue.poll(); - if (buf == null) { - buf = new AdaptiveByteBuf(handle); - } - } + AdaptiveRecycler recycler = this.recycler; + AdaptiveByteBuf buf = recycler == null? EVENT_LOOP_LOCAL_BUFFER_POOL.get() : recycler.get(); buf.resetRefCnt(); buf.discardMarks(); return buf; } boolean offerToQueue(Chunk chunk) { - return group.offerToQueue(chunk); - } - - public void initializeSharedStateIn(Magazine other) { - chunkController.initializeSharedStateIn(other.chunkController); + return group.offerChunk(chunk); } } @@ -1133,9 +1084,7 @@ private static class Chunk implements ReferenceCounted { protected final AbstractByteBuf delegate; protected Magazine magazine; private final AdaptivePoolingAllocator allocator; - private final ChunkReleasePredicate chunkReleasePredicate; private final int capacity; - private final boolean pooled; protected int allocatedBytes; private static final ReferenceCountUpdater updater = @@ -1161,23 +1110,17 @@ protected long unsafeOffset() { delegate = null; magazine = null; allocator = null; - chunkReleasePredicate = null; capacity = 0; - pooled = false; } - Chunk(AbstractByteBuf delegate, Magazine magazine, boolean pooled, - ChunkReleasePredicate chunkReleasePredicate) { + Chunk(AbstractByteBuf delegate, Magazine magazine) { this.delegate = delegate; - this.pooled = pooled; capacity = delegate.capacity(); updater.setInitialValue(this); attachToMagazine(magazine); // We need the top-level allocator so ByteBuf.capacity(int) can call reallocate() allocator = magazine.group.allocator; - - this.chunkReleasePredicate = chunkReleasePredicate; } Magazine currentMagazine() { @@ -1241,46 +1184,33 @@ public boolean release(int decrement) { /** * Called when a magazine is done using this chunk, probably because it was emptied. */ - boolean releaseFromMagazine() { - return release(); + void releaseFromMagazine() { + // Chunks can be reused before they become empty. + // We can therefor put them in the shared queue as soon as the magazine is done with this chunk. + Magazine mag = magazine; + detachFromMagazine(); + if (!mag.offerToQueue(this)) { + markToDeallocate(); + } } /** * Called when a ByteBuf is done using its allocation in this chunk. */ - boolean releaseSegment(int ignoredSegmentId) { - return release(); + void releaseSegment(int ignoredSegmentId, int size) { + release(); } - private void deallocate() { - Magazine mag = magazine; - int chunkSize = delegate.capacity(); - if (!pooled || chunkReleasePredicate.shouldReleaseChunk(chunkSize) || mag == null) { - // Drop the chunk if the parent allocator is closed, - // or if the chunk deviates too much from the preferred chunk size. - detachFromMagazine(); - allocator.chunkRegistry.remove(this); - delegate.release(); - } else { - updater.resetRefCnt(this); - delegate.setIndex(0, 0); - allocatedBytes = 0; - if (!mag.trySetNextInLine(this)) { - // As this Chunk does not belong to the mag anymore we need to decrease the used memory . - detachFromMagazine(); - if (!mag.offerToQueue(this)) { - // The central queue is full. Ensure we release again as we previously did use resetRefCnt() - // which did increase the reference count by 1. - boolean released = updater.release(this); - allocator.chunkRegistry.remove(this); - delegate.release(); - assert released; - } - } - } + void markToDeallocate() { + release(); } - public void readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, int maxCapacity) { + protected void deallocate() { + allocator.chunkRegistry.remove(this); + delegate.release(); + } + + public boolean readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, int maxCapacity) { int startIndex = allocatedBytes; allocatedBytes = startIndex + startingCapacity; Chunk chunk = this; @@ -1297,101 +1227,423 @@ public void readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, in chunk.release(); } } + return true; } public int remainingCapacity() { return capacity - allocatedBytes; } + public boolean hasUnprocessedFreelistEntries() { + return false; + } + + public void processFreelistEntries() { + } + public int capacity() { return capacity; } } + private static final class IntStack { + + private final int[] stack; + private int top; + + IntStack(int[] initialValues) { + stack = initialValues; + top = initialValues.length - 1; + } + + public boolean isEmpty() { + return top == -1; + } + + public int pop() { + final int last = stack[top]; + top--; + return last; + } + + public void push(int value) { + stack[top + 1] = value; + top++; + } + + public int size() { + return top + 1; + } + } + + /** + * Removes per-allocation retain()/release() atomic ops from the hot path by replacing ref counting + * with a segment-count state machine. Atomics are only needed on the cold deallocation path + * ({@link #markToDeallocate()}), which is rare for long-lived chunks that cycle segments many times. + * The tradeoff is a {@link MpscIntQueue#size()} call (volatile reads, no RMW) per remaining segment + * return after mark — acceptable since it avoids atomic RMWs entirely. + *

+ * State transitions: + *

    + *
  • {@link #AVAILABLE} (-1): chunk is in use, no deallocation tracking needed
  • + *
  • 0..N: local free list size at the time {@link #markToDeallocate()} was called; + * used to track when all segments have been returned
  • + *
  • {@link #DEALLOCATED} (Integer.MIN_VALUE): all segments returned, chunk deallocated
  • + *
+ *

+ * Ordering: external {@link #releaseSegment} pushes to the MPSC queue (which has an implicit + * StoreLoad barrier via its {@code offer()}), then reads {@code state} — this guarantees + * visibility of any preceding {@link #markToDeallocate()} write. + */ private static final class SizeClassedChunk extends Chunk { private static final int FREE_LIST_EMPTY = -1; + private static final int AVAILABLE = -1; + // Integer.MIN_VALUE so that `DEALLOCATED + externalFreeList.size()` can never equal `segments`, + // making late-arriving releaseSegment calls on external threads arithmetically harmless. + private static final int DEALLOCATED = Integer.MIN_VALUE; + private static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(SizeClassedChunk.class, "state"); + private volatile int state; + private final int segments; private final int segmentSize; - private final MpscIntQueue freeList; - - SizeClassedChunk(AbstractByteBuf delegate, Magazine magazine, boolean pooled, int segmentSize, - final int[] segmentOffsets, ChunkReleasePredicate shouldReleaseChunk) { - super(delegate, magazine, pooled, shouldReleaseChunk); - this.segmentSize = segmentSize; - int segmentCount = segmentOffsets.length; - assert delegate.capacity() / segmentSize == segmentCount; - assert segmentCount > 0: "Chunk must have a positive number of segments"; - freeList = new MpscAtomicIntegerArrayQueue(segmentCount, FREE_LIST_EMPTY); - freeList.fill(segmentCount, new IntSupplier() { - int counter; - @Override - public int get() { - return segmentOffsets[counter++]; - } - }); + private final MpscIntQueue externalFreeList; + private final IntStack localFreeList; + private Thread ownerThread; + + SizeClassedChunk(AbstractByteBuf delegate, Magazine magazine, + SizeClassChunkController controller) { + super(delegate, magazine); + segmentSize = controller.segmentSize; + segments = controller.chunkSize / segmentSize; + STATE.lazySet(this, AVAILABLE); + ownerThread = magazine.group.ownerThread; + if (ownerThread == null) { + externalFreeList = controller.createFreeList(); + localFreeList = null; + } else { + externalFreeList = controller.createEmptyFreeList(); + localFreeList = controller.createLocalFreeList(); + } } @Override - public void readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, int maxCapacity) { - int startIndex = freeList.poll(); + public boolean readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, int maxCapacity) { + assert state == AVAILABLE; + final int startIndex = nextAvailableSegmentOffset(); if (startIndex == FREE_LIST_EMPTY) { - throw new IllegalStateException("Free list is empty"); + return false; } allocatedBytes += segmentSize; + try { + buf.init(delegate, this, 0, 0, startIndex, size, startingCapacity, maxCapacity); + } catch (Throwable t) { + allocatedBytes -= segmentSize; + releaseSegmentOffsetIntoFreeList(startIndex); + PlatformDependent.throwException(t); + } + return true; + } + + private int nextAvailableSegmentOffset() { + final int startIndex; + IntStack localFreeList = this.localFreeList; + if (localFreeList != null) { + assert Thread.currentThread() == ownerThread; + if (localFreeList.isEmpty()) { + startIndex = externalFreeList.poll(); + } else { + startIndex = localFreeList.pop(); + } + } else { + startIndex = externalFreeList.poll(); + } + return startIndex; + } + + // this can be used by the ConcurrentQueueChunkCache to find the first buffer to use: + // it doesn't update the remaining capacity and it's not consider a single segmentSize + // case as not suitable to be reused + public boolean hasRemainingCapacity() { + int remaining = super.remainingCapacity(); + if (remaining > 0) { + return true; + } + if (localFreeList != null) { + return !localFreeList.isEmpty(); + } + return !externalFreeList.isEmpty(); + } + + @Override + public int remainingCapacity() { + int remaining = super.remainingCapacity(); + return remaining > segmentSize ? remaining : updateRemainingCapacity(remaining); + } + + private int updateRemainingCapacity(int snapshotted) { + int freeSegments = externalFreeList.size(); + IntStack localFreeList = this.localFreeList; + if (localFreeList != null) { + freeSegments += localFreeList.size(); + } + int updated = freeSegments * segmentSize; + if (updated != snapshotted) { + allocatedBytes = capacity() - updated; + } + return updated; + } + + private void releaseSegmentOffsetIntoFreeList(int startIndex) { + IntStack localFreeList = this.localFreeList; + if (localFreeList != null && Thread.currentThread() == ownerThread) { + localFreeList.push(startIndex); + } else { + boolean segmentReturned = externalFreeList.offer(startIndex); + assert segmentReturned : "Unable to return segment " + startIndex + " to free list"; + } + } + + @Override + void releaseSegment(int startIndex, int size) { + IntStack localFreeList = this.localFreeList; + if (localFreeList != null && Thread.currentThread() == ownerThread) { + localFreeList.push(startIndex); + int state = this.state; + if (state != AVAILABLE) { + updateStateOnLocalReleaseSegment(state, localFreeList); + } + } else { + boolean segmentReturned = externalFreeList.offer(startIndex); + assert segmentReturned; + // implicit StoreLoad barrier from MPSC offer() + int state = this.state; + if (state != AVAILABLE) { + deallocateIfNeeded(state); + } + } + } + + private void updateStateOnLocalReleaseSegment(int previousLocalSize, IntStack localFreeList) { + int newLocalSize = localFreeList.size(); + boolean alwaysTrue = STATE.compareAndSet(this, previousLocalSize, newLocalSize); + assert alwaysTrue : "this shouldn't happen unless double release in the local free list"; + deallocateIfNeeded(newLocalSize); + } + + private void deallocateIfNeeded(int localSize) { + // Check if all segments have been returned. + int totalFreeSegments = localSize + externalFreeList.size(); + if (totalFreeSegments == segments && STATE.compareAndSet(this, localSize, DEALLOCATED)) { + deallocate(); + } + } + + @Override + void markToDeallocate() { + IntStack localFreeList = this.localFreeList; + int localSize = localFreeList != null ? localFreeList.size() : 0; + STATE.set(this, localSize); + deallocateIfNeeded(localSize); + } + } + + private static final class BuddyChunk extends Chunk implements IntConsumer { + private static final int MIN_BUDDY_SIZE = 32768; + private static final byte IS_CLAIMED = (byte) (1 << 7); + private static final byte HAS_CLAIMED_CHILDREN = 1 << 6; + private static final byte SHIFT_MASK = ~(IS_CLAIMED | HAS_CLAIMED_CHILDREN); + private static final int PACK_OFFSET_MASK = 0xFFFF; + private static final int PACK_SIZE_SHIFT = Integer.SIZE - Integer.numberOfLeadingZeros(PACK_OFFSET_MASK); + + private final MpscAtomicIntegerArrayQueue freeList; + // The bits of each buddy: [1: is claimed][1: has claimed children][30: MIN_BUDDY_SIZE shift to get size] + private final byte[] buddies; + private final int freeListCapacity; + + BuddyChunk(AbstractByteBuf delegate, Magazine magazine) { + super(delegate, magazine); + freeListCapacity = delegate.capacity() / MIN_BUDDY_SIZE; + int maxShift = Integer.numberOfTrailingZeros(freeListCapacity); + assert maxShift <= 30; // The top 2 bits are used for marking. + // At most half of tree (all leaf nodes) can be freed. + freeList = new MpscAtomicIntegerArrayQueue(freeListCapacity, -1); + buddies = new byte[freeListCapacity << 1]; + + // Generate the buddies entries. + int index = 1; + int runLength = 1; + int currentRun = 0; + while (maxShift > 0) { + buddies[index++] = (byte) maxShift; + if (++currentRun == runLength) { + currentRun = 0; + runLength <<= 1; + maxShift--; + } + } + } + + @Override + public boolean readInitInto(AdaptiveByteBuf buf, int size, int startingCapacity, int maxCapacity) { + if (!freeList.isEmpty()) { + freeList.drain(freeListCapacity, this); + } + int startIndex = chooseFirstFreeBuddy(1, startingCapacity, 0); + if (startIndex == -1) { + return false; + } Chunk chunk = this; chunk.retain(); try { - buf.init(delegate, chunk, 0, 0, startIndex, size, startingCapacity, maxCapacity); + buf.init(delegate, this, 0, 0, startIndex, size, startingCapacity, maxCapacity); + allocatedBytes += startingCapacity; chunk = null; } finally { if (chunk != null) { + unreserveMatchingBuddy(1, startingCapacity, startIndex, 0); // If chunk is not null we know that buf.init(...) failed and so we need to manually release - // the chunk again as we retained it before calling buf.init(...). Beside this we also need to - // restore the old allocatedBytes value. - allocatedBytes -= segmentSize; - chunk.releaseSegment(startIndex); + // the chunk again as we retained it before calling buf.init(...). + chunk.release(); } } + return true; + } + + @Override + public void accept(int packed) { + // Called by allocating thread when draining freeList. + int size = unpackSize(packed); + int offset = unpackOffset(packed); + unreserveMatchingBuddy(1, size, offset, 0); + allocatedBytes -= size; + } + + private static int unpackSize(int packed) { + return MIN_BUDDY_SIZE << (packed >> PACK_SIZE_SHIFT); + } + + private static int unpackOffset(int packed) { + return (packed & PACK_OFFSET_MASK) * MIN_BUDDY_SIZE; + } + + @Override + void releaseSegment(int startingIndex, int size) { + int packedOffset = startingIndex / MIN_BUDDY_SIZE; + int packedSize = Integer.numberOfTrailingZeros(size / MIN_BUDDY_SIZE) << PACK_SIZE_SHIFT; + int packed = packedOffset | packedSize; + freeList.offer(packed); + release(); } @Override public int remainingCapacity() { - int remainingCapacity = super.remainingCapacity(); - if (remainingCapacity > segmentSize) { - return remainingCapacity; + int capacityInFreeList = 0; + if (!freeList.isEmpty()) { + capacityInFreeList = freeList.weakPeekReduce(freeListCapacity, 0, + new MpscAtomicIntegerArrayQueue.IntBinaryOperator() { + @Override + public int applyAsInt(int sum, int entry) { + return sum + unpackSize(entry); + } + }); } - int updatedRemainingCapacity = freeList.size() * segmentSize; - if (updatedRemainingCapacity == remainingCapacity) { - return remainingCapacity; - } - // update allocatedBytes based on what's available in the free list - allocatedBytes = capacity() - updatedRemainingCapacity; - return updatedRemainingCapacity; + return super.remainingCapacity() + capacityInFreeList; } @Override - boolean releaseFromMagazine() { - // Size-classed chunks can be reused before they become empty. - // We can therefor put them in the shared queue as soon as the magazine is done with this chunk. - Magazine mag = magazine; - detachFromMagazine(); - if (!mag.offerToQueue(this)) { - return super.releaseFromMagazine(); + public boolean hasUnprocessedFreelistEntries() { + return !freeList.isEmpty(); + } + + @Override + public void processFreelistEntries() { + freeList.drain(freeListCapacity, this); + } + + /** + * Claim a suitable buddy and return its start offset into the delegate chunk, or return -1 if nothing claimed. + */ + private int chooseFirstFreeBuddy(int index, int size, int currOffset) { + byte[] buddies = this.buddies; + while (index < buddies.length) { + byte buddy = buddies[index]; + int currValue = MIN_BUDDY_SIZE << (buddy & SHIFT_MASK); + if (currValue < size || (buddy & IS_CLAIMED) == IS_CLAIMED) { + return -1; + } + if (currValue == size && (buddy & HAS_CLAIMED_CHILDREN) == 0) { + buddies[index] |= IS_CLAIMED; + return currOffset; + } + int found = chooseFirstFreeBuddy(index << 1, size, currOffset); + if (found != -1) { + buddies[index] |= HAS_CLAIMED_CHILDREN; + return found; + } + index = (index << 1) + 1; + currOffset += currValue >> 1; // Bump offset to skip first half of this layer. } - return false; + return -1; + } + + /** + * Un-reserve the matching buddy and return whether there are any other child or sibling reservations. + */ + private boolean unreserveMatchingBuddy(int index, int size, int offset, int currOffset) { + byte[] buddies = this.buddies; + if (buddies.length <= index) { + return false; + } + byte buddy = buddies[index]; + int currSize = MIN_BUDDY_SIZE << (buddy & SHIFT_MASK); + + if (currSize == size) { + // We're at the right size level. + if (currOffset == offset) { + buddies[index] &= SHIFT_MASK; + return false; + } + throw new IllegalStateException("The intended segment was not found at index " + + index + ", for size " + size + " and offset " + offset); + } + + // We're at a parent size level. Use the target offset to guide our drill-down path. + boolean claims; + int siblingIndex; + if (offset < currOffset + (currSize >> 1)) { + // Must be down the left path. + claims = unreserveMatchingBuddy(index << 1, size, offset, currOffset); + siblingIndex = (index << 1) + 1; + } else { + // Must be down the rigth path. + claims = unreserveMatchingBuddy((index << 1) + 1, size, offset, currOffset + (currSize >> 1)); + siblingIndex = index << 1; + } + if (!claims) { + // No other claims down the path we took. Check if the sibling has claims. + byte sibling = buddies[siblingIndex]; + if ((sibling & SHIFT_MASK) == sibling) { + // No claims in the sibling. We can clear this level as well. + buddies[index] &= SHIFT_MASK; + return false; + } + } + return true; } @Override - boolean releaseSegment(int startIndex) { - boolean released = release(); - boolean segmentReturned = freeList.offer(startIndex); - assert segmentReturned: "Unable to return segment " + startIndex + " to free list"; - return released; + public String toString() { + int capacity = delegate.capacity(); + int remaining = capacity - allocatedBytes; + return "BuddyChunk[capacity: " + capacity + + ", remaining: " + remaining + + ", free list: " + freeList.size() + ']'; } } static final class AdaptiveByteBuf extends AbstractReferenceCountedByteBuf { - private final ObjectPool.Handle handle; + private final EnhancedHandle handle; // this both act as adjustment and the start index for a free list segment allocation private int startIndex; @@ -1403,7 +1655,7 @@ static final class AdaptiveByteBuf extends AbstractReferenceCountedByteBuf { private boolean hasArray; private boolean hasMemoryAddress; - AdaptiveByteBuf(ObjectPool.Handle recyclerHandle) { + AdaptiveByteBuf(EnhancedHandle recyclerHandle) { super(0); handle = ObjectUtil.checkNotNull(recyclerHandle, "recyclerHandle"); } @@ -1442,12 +1694,11 @@ public int maxFastWritableBytes() { @Override public ByteBuf capacity(int newCapacity) { + checkNewCapacity(newCapacity); if (length <= newCapacity && newCapacity <= maxFastCapacity) { - ensureAccessible(); length = newCapacity; return this; } - checkNewCapacity(newCapacity); if (newCapacity < capacity()) { length = newCapacity; trimIndicesToCapacity(newCapacity); @@ -1460,11 +1711,14 @@ public ByteBuf capacity(int newCapacity) { int readerIndex = this.readerIndex; int writerIndex = this.writerIndex; int baseOldRootIndex = startIndex; - int oldCapacity = length; + int oldLength = length; + int oldCapacity = maxFastCapacity; AbstractByteBuf oldRoot = rootParent(); allocator.reallocate(newCapacity, maxCapacity(), this); - oldRoot.getBytes(baseOldRootIndex, this, 0, oldCapacity); - chunk.releaseSegment(baseOldRootIndex); + oldRoot.getBytes(baseOldRootIndex, this, 0, oldLength); + chunk.releaseSegment(baseOldRootIndex, oldCapacity); + assert oldCapacity < maxFastCapacity && newCapacity <= maxFastCapacity: + "Capacity increase failed"; this.readerIndex = readerIndex; this.writerIndex = writerIndex; return this; @@ -1475,6 +1729,7 @@ public ByteBufAllocator alloc() { return rootParent().alloc(); } + @SuppressWarnings("deprecation") @Override public ByteOrder order() { return rootParent().order(); @@ -1841,17 +2096,12 @@ private int idx(int index) { @Override protected void deallocate() { if (chunk != null) { - chunk.releaseSegment(startIndex); + chunk.releaseSegment(startIndex, maxFastCapacity); } tmpNioBuf = null; chunk = null; rootParent = null; - if (handle instanceof EnhancedHandle) { - EnhancedHandle enhancedHandle = (EnhancedHandle) handle; - enhancedHandle.unguardedRecycle(this); - } else { - handle.recycle(this); - } + handle.unguardedRecycle(this); } } diff --git a/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java b/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java index 4ad86136888..4786724dc0b 100644 --- a/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java @@ -2360,4 +2360,17 @@ private void shiftComps(int i, int count) { } componentCount = newSize; } + + /** + * Decreases the reference count by the specified {@code decrement} and deallocates this object if the reference + * count reaches at {@code 0}. At this point it will also decrement the reference count of each internal + * component by {@code 1}. + * + * @param decrement the number by which the reference count should be decreased + * @return {@code true} if and only if the reference count became {@code 0} and this object has been deallocated + */ + @Override + public boolean release(final int decrement) { + return super.release(decrement); + } } diff --git a/buffer/src/main/java/io/netty/buffer/SizeClasses.java b/buffer/src/main/java/io/netty/buffer/SizeClasses.java index b42d455d5e6..d1fa1389855 100644 --- a/buffer/src/main/java/io/netty/buffer/SizeClasses.java +++ b/buffer/src/main/java/io/netty/buffer/SizeClasses.java @@ -107,7 +107,7 @@ final class SizeClasses implements SizeClassesMetric { private final int[] pageIdx2sizeTab; - // lookup table for sizeIdx <= smallMaxSizeIdx + // lookup table for sizeIdx < nSizes private final int[] sizeIdx2sizeTab; // lookup table used for size <= lookupMaxClass diff --git a/buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java b/buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java index c32183fa707..a5f3675ba66 100644 --- a/buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java +++ b/buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java @@ -17,6 +17,7 @@ import io.netty.util.internal.PlatformDependent; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; import java.lang.management.ManagementFactory; import java.lang.management.ThreadMXBean; @@ -26,6 +27,7 @@ import static org.assertj.core.api.Assumptions.assumeThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.abort; import static org.junit.jupiter.api.Assumptions.assumeTrue; @@ -196,6 +198,27 @@ public void shouldReuseChunks() throws Exception { .isLessThan(8 * 1024 * 1024); } + @Test + public void testCapacityNotGreaterThanMaxCapacity() { + testCapacityNotGreaterThanMaxCapacity(true); + testCapacityNotGreaterThanMaxCapacity(false); + } + + private void testCapacityNotGreaterThanMaxCapacity(boolean preferDirect) { + final int maxSize = 100000; + final ByteBuf buf = newAllocator(preferDirect).newDirectBuffer(maxSize, maxSize); + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + buf.capacity(maxSize + 1); + } + }); + } finally { + buf.release(); + } + } + protected long expectedUsedMemory(T allocator, int capacity) { return capacity; } diff --git a/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java b/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java index 58a4ae82e75..a6656116a70 100644 --- a/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java +++ b/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java @@ -57,6 +57,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -74,7 +75,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotSame; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -2290,7 +2290,7 @@ public void testToString() { } @Test - @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + @Timeout(30) public void testToStringMultipleThreads() throws Throwable { buffer.clear(); buffer.writeBytes("Hello, World!".getBytes(CharsetUtil.ISO_8859_1)); @@ -2300,7 +2300,7 @@ public void testToStringMultipleThreads() throws Throwable { static void testToStringMultipleThreads0(final ByteBuf buffer) throws Throwable { final String expected = buffer.toString(CharsetUtil.ISO_8859_1); - final AtomicInteger counter = new AtomicInteger(30000); + final CyclicBarrier startBarrier = new CyclicBarrier(10); final AtomicReference errorRef = new AtomicReference(); List threads = new ArrayList(); for (int i = 0; i < 10; i++) { @@ -2308,11 +2308,15 @@ static void testToStringMultipleThreads0(final ByteBuf buffer) throws Throwable @Override public void run() { try { - while (errorRef.get() == null && counter.decrementAndGet() > 0) { + startBarrier.await(10, TimeUnit.SECONDS); + int counter = 3000; + while (errorRef.get() == null && counter-- > 0) { assertEquals(expected, buffer.toString(CharsetUtil.ISO_8859_1)); } } catch (Throwable cause) { - errorRef.compareAndSet(null, cause); + if (!errorRef.compareAndSet(null, cause)) { + ThrowableUtil.addSuppressed(errorRef.get(), cause); + } } } }); @@ -2322,13 +2326,27 @@ public void run() { thread.start(); } - for (Thread thread : threads) { - thread.join(); - } + joinAllAndReportErrors(threads, errorRef); + } + + private static void joinAllAndReportErrors(List threads, AtomicReference errorRef) + throws Throwable { + try { + for (Thread thread : threads) { + thread.join(); + } - Throwable error = errorRef.get(); - if (error != null) { - throw error; + Throwable error = errorRef.get(); + if (error != null) { + throw error; + } + } catch (Throwable e) { + for (Thread thread : threads) { + if (thread.isAlive()) { + ThrowableUtil.interruptAndAttachAsyncStackTrace(thread, e); + } + } + throw e; } } @@ -2345,7 +2363,7 @@ public void testCopyMultipleThreads0() throws Throwable { static void testCopyMultipleThreads0(final ByteBuf buffer) throws Throwable { final ByteBuf expected = buffer.copy(); try { - final AtomicInteger counter = new AtomicInteger(30000); + final CyclicBarrier startBarrier = new CyclicBarrier(10); final AtomicReference errorRef = new AtomicReference(); List threads = new ArrayList(); for (int i = 0; i < 10; i++) { @@ -2353,7 +2371,9 @@ static void testCopyMultipleThreads0(final ByteBuf buffer) throws Throwable { @Override public void run() { try { - while (errorRef.get() == null && counter.decrementAndGet() > 0) { + startBarrier.await(10, TimeUnit.SECONDS); + int counter = 3000; + while (errorRef.get() == null && counter-- > 0) { ByteBuf copy = buffer.copy(); try { assertEquals(expected, copy); @@ -2372,14 +2392,7 @@ public void run() { thread.start(); } - for (Thread thread : threads) { - thread.join(); - } - - Throwable error = errorRef.get(); - if (error != null) { - throw error; - } + joinAllAndReportErrors(threads, errorRef); } finally { expected.release(); } @@ -2713,6 +2726,7 @@ private void testInternalNioBuffer(int a) { } @Test + @Timeout(30) public void testDuplicateReadGatheringByteChannelMultipleThreads() throws Exception { final byte[] bytes = new byte[8]; random.nextBytes(bytes); @@ -2727,6 +2741,7 @@ public void testDuplicateReadGatheringByteChannelMultipleThreads() throws Except } @Test + @Timeout(30) public void testSliceReadGatheringByteChannelMultipleThreads() throws Exception { final byte[] bytes = new byte[8]; random.nextBytes(bytes); @@ -2744,44 +2759,59 @@ static void testReadGatheringByteChannelMultipleThreads( final ByteBuf buffer, final byte[] expectedBytes, final boolean slice) throws Exception { assertEquals(buffer.readableBytes(), expectedBytes.length); final CountDownLatch latch = new CountDownLatch(60000); + final AtomicReference innerThrowable = new AtomicReference(); final CyclicBarrier barrier = new CyclicBarrier(11); for (int i = 0; i < 10; i++) { new Thread(new Runnable() { @Override public void run() { - while (latch.getCount() > 0) { - ByteBuf buf; - if (slice) { - buf = buffer.slice(); - } else { - buf = buffer.duplicate(); - } - TestGatheringByteChannel channel = new TestGatheringByteChannel(); - - while (buf.isReadable()) { - try { - buf.readBytes(channel, buf.readableBytes()); - } catch (IOException e) { - // Never happens - return; + try { + while (latch.getCount() > 0) { + ByteBuf buf; + if (slice) { + buf = buffer.slice(); + } else { + buf = buffer.duplicate(); + } + TestGatheringByteChannel channel = new TestGatheringByteChannel(); + + while (buf.isReadable()) { + try { + buf.readBytes(channel, buf.readableBytes()); + } catch (IOException e) { + // Never happens + return; + } } + assertArrayEquals(expectedBytes, channel.writtenBytes()); + latch.countDown(); + } + } catch (Throwable e) { + innerThrowable.compareAndSet(null, e); + } finally { + try { + barrier.await(); + } catch (Exception e) { + // ignore } - assertArrayEquals(expectedBytes, channel.writtenBytes()); - latch.countDown(); - } - try { - barrier.await(); - } catch (Exception e) { - // ignore } } }).start(); } - latch.await(10, TimeUnit.SECONDS); - barrier.await(5, TimeUnit.SECONDS); + try { + latch.await(); + barrier.await(5, TimeUnit.SECONDS); + } catch (Exception e) { + Throwable inner = innerThrowable.get(); + if (inner != null) { + e.addSuppressed(inner); + } + throw e; + } } @Test + @Timeout(30) public void testDuplicateReadOutputStreamMultipleThreads() throws Exception { final byte[] bytes = new byte[8]; random.nextBytes(bytes); @@ -2796,6 +2826,7 @@ public void testDuplicateReadOutputStreamMultipleThreads() throws Exception { } @Test + @Timeout(30) public void testSliceReadOutputStreamMultipleThreads() throws Exception { final byte[] bytes = new byte[8]; random.nextBytes(bytes); @@ -2812,41 +2843,55 @@ public void testSliceReadOutputStreamMultipleThreads() throws Exception { static void testReadOutputStreamMultipleThreads( final ByteBuf buffer, final byte[] expectedBytes, final boolean slice) throws Exception { final CountDownLatch latch = new CountDownLatch(60000); + final AtomicReference innerThrowable = new AtomicReference(); final CyclicBarrier barrier = new CyclicBarrier(11); for (int i = 0; i < 10; i++) { new Thread(new Runnable() { @Override public void run() { - while (latch.getCount() > 0) { - ByteBuf buf; - if (slice) { - buf = buffer.slice(); - } else { - buf = buffer.duplicate(); - } - ByteArrayOutputStream out = new ByteArrayOutputStream(); - - while (buf.isReadable()) { - try { - buf.readBytes(out, buf.readableBytes()); - } catch (IOException e) { - // Never happens - return; + try { + while (latch.getCount() > 0) { + ByteBuf buf; + if (slice) { + buf = buffer.slice(); + } else { + buf = buffer.duplicate(); } + ByteArrayOutputStream out = new ByteArrayOutputStream(); + + while (buf.isReadable()) { + try { + buf.readBytes(out, buf.readableBytes()); + } catch (IOException e) { + // Never happens + return; + } + } + assertArrayEquals(expectedBytes, out.toByteArray()); + latch.countDown(); + } + } catch (Throwable e) { + innerThrowable.compareAndSet(null, e); + } finally { + try { + barrier.await(); + } catch (Exception e) { + // ignore } - assertArrayEquals(expectedBytes, out.toByteArray()); - latch.countDown(); - } - try { - barrier.await(); - } catch (Exception e) { - // ignore } } }).start(); } - latch.await(10, TimeUnit.SECONDS); - barrier.await(5, TimeUnit.SECONDS); + try { + latch.await(); + barrier.await(5, TimeUnit.SECONDS); + } catch (Exception e) { + Throwable inner = innerThrowable.get(); + if (inner != null) { + e.addSuppressed(inner); + } + throw e; + } } @Test @@ -2879,14 +2924,14 @@ public void testSliceBytesInArrayMultipleThreads() throws Exception { static void testBytesInArrayMultipleThreads( final ByteBuf buffer, final byte[] expectedBytes, final boolean slice) throws Exception { - final AtomicReference cause = new AtomicReference(); - final CountDownLatch latch = new CountDownLatch(60000); - final CyclicBarrier barrier = new CyclicBarrier(11); - for (int i = 0; i < 10; i++) { - new Thread(new Runnable() { - @Override - public void run() { - while (cause.get() == null && latch.getCount() > 0) { + final CyclicBarrier startBarrier = new CyclicBarrier(10); + final CyclicBarrier endBarrier = new CyclicBarrier(11); + Callable callable = new Callable() { + @Override + public Void call() throws Exception { + startBarrier.await(); + try { + for (int i = 0; i < 6000; i++) { ByteBuf buf; if (slice) { buf = buffer.slice(); @@ -2902,20 +2947,34 @@ public void run() { Arrays.fill(array, (byte) 0); buf.getBytes(0, array); assertArrayEquals(expectedBytes, array); - - latch.countDown(); - } - try { - barrier.await(); - } catch (Exception e) { - // ignore } + } finally { + endBarrier.await(); } - }).start(); + return null; + } + }; + List> tasks = new ArrayList>(); + for (int i = 0; i < 10; i++) { + FutureTask task = new FutureTask(callable); + new Thread(task).start(); + tasks.add(task); + } + try { + endBarrier.await(30, TimeUnit.SECONDS); + } catch (Exception e) { + for (FutureTask task : tasks) { + try { + task.get(100, TimeUnit.MILLISECONDS); + } catch (Exception ex) { + e.addSuppressed(ex); + } + } + throw e; + } + for (FutureTask task : tasks) { + task.get(1, TimeUnit.SECONDS); } - latch.await(10, TimeUnit.SECONDS); - barrier.await(5, TimeUnit.SECONDS); - assertNull(cause.get()); } public static Object[][] setCharSequenceCombinations() { @@ -5954,6 +6013,7 @@ private void testRefCnt0(final boolean parameter) throws Exception { final ByteBuf buffer = newBuffer(4); assertEquals(1, buffer.refCnt()); + final AtomicReference innerThrowable = new AtomicReference(); final AtomicInteger cnt = new AtomicInteger(Integer.MAX_VALUE); Thread t1 = new Thread(new Runnable() { @Override @@ -5964,7 +6024,11 @@ public void run() { } else { released = buffer.release(); } - assertTrue(released); + if (!released) { + innerThrowable.set(new AssertionError("buffer was not released: " + buffer)); + latch.countDown(); + return; + } Thread t2 = new Thread(new Runnable() { @Override public void run() { @@ -5984,6 +6048,10 @@ public void run() { t1.start(); latch.await(); + Throwable inner = innerThrowable.get(); + if (inner != null) { + fail(inner); + } assertEquals(0, cnt.get()); innerLatch.countDown(); } diff --git a/buffer/src/test/java/io/netty/buffer/AdaptiveByteBufAllocatorTest.java b/buffer/src/test/java/io/netty/buffer/AdaptiveByteBufAllocatorTest.java index 448930a3189..4c212410d88 100644 --- a/buffer/src/test/java/io/netty/buffer/AdaptiveByteBufAllocatorTest.java +++ b/buffer/src/test/java/io/netty/buffer/AdaptiveByteBufAllocatorTest.java @@ -16,10 +16,17 @@ package io.netty.buffer; import io.netty.util.NettyRuntime; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.RepetitionInfo; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; + +import java.lang.reflect.Array; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.SplittableRandom; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicReference; @@ -111,24 +118,29 @@ public void testUsedHeapMemory() { @Test void adaptiveChunkMustDeallocateOrReuseWthBufferRelease() throws Exception { AdaptiveByteBufAllocator allocator = newAllocator(false); - ByteBuf a = allocator.heapBuffer(28 * 1024); - assertEquals(262144, allocator.usedHeapMemory()); - ByteBuf b = allocator.heapBuffer(100 * 1024); - assertEquals(262144, allocator.usedHeapMemory()); - b.release(); - a.release(); - assertEquals(262144, allocator.usedHeapMemory()); - a = allocator.heapBuffer(28 * 1024); - assertEquals(262144, allocator.usedHeapMemory()); - b = allocator.heapBuffer(100 * 1024); - assertEquals(262144, allocator.usedHeapMemory()); - a.release(); - ByteBuf c = allocator.heapBuffer(28 * 1024); - assertEquals(2 * 262144, allocator.usedHeapMemory()); - c.release(); - assertEquals(2 * 262144, allocator.usedHeapMemory()); - b.release(); - assertEquals(2 * 262144, allocator.usedHeapMemory()); + Deque bufs = new ArrayDeque(); + assertEquals(0, allocator.usedHeapMemory()); + assertEquals(0, allocator.usedHeapMemory()); + bufs.add(allocator.heapBuffer(256)); + long usedHeapMemory = allocator.usedHeapMemory(); + int buffersPerChunk = Math.toIntExact(usedHeapMemory / 256); + for (int i = 0; i < buffersPerChunk; i++) { + bufs.add(allocator.heapBuffer(256)); + } + assertEquals(2 * usedHeapMemory, allocator.usedHeapMemory()); + bufs.pop().release(); + assertEquals(2 * usedHeapMemory, allocator.usedHeapMemory()); + while (!bufs.isEmpty()) { + bufs.pop().release(); + } + assertEquals(2 * usedHeapMemory, allocator.usedHeapMemory()); + for (int i = 0; i < 2 * buffersPerChunk; i++) { + bufs.add(allocator.heapBuffer(256)); + } + assertEquals(2 * usedHeapMemory, allocator.usedHeapMemory()); + while (!bufs.isEmpty()) { + bufs.pop().release(); + } } @ParameterizedTest @@ -198,4 +210,71 @@ public void run() { fail("Expected no exception, but got", throwable); } } + + @RepeatedTest(100) + void buddyAllocationConsistency(RepetitionInfo info) { + SplittableRandom rng = new SplittableRandom(info.getCurrentRepetition()); + AdaptiveByteBufAllocator allocator = newAllocator(true); + int small = 32768; + int large = 2 * small; + int xlarge = 2 * large; + + int[] allocationSizes = { + small, small, small, small, small, small, small, small, + large, large, large, large, + xlarge, xlarge, + }; + + shuffle(rng, allocationSizes); + + ByteBuf[] bufs = new ByteBuf[allocationSizes.length]; + for (int i = 0; i < bufs.length; i++) { + bufs[i] = allocator.buffer(allocationSizes[i], allocationSizes[i]); + } + + shuffle(rng, bufs); + + int[] reallocations = new int[bufs.length / 2]; + for (int i = 0; i < reallocations.length; i++) { + reallocations[i] = bufs[i].capacity(); + bufs[i].release(); + bufs[i] = null; + } + for (int i = 0; i < reallocations.length; i++) { + assertNull(bufs[i]); + bufs[i] = allocator.buffer(reallocations[i], reallocations[i]); + } + + for (int i = 0; i < bufs.length; i++) { + while (bufs[i].isWritable()) { + bufs[i].writeByte(i + 1); + } + } + try { + for (int i = 0; i < bufs.length; i++) { + while (bufs[i].isReadable()) { + int b = Byte.toUnsignedInt(bufs[i].readByte()); + if (b != i + 1) { + fail("Expected byte " + (i + 1) + + " at index " + (bufs[i].readerIndex() - 1) + + " but got " + b); + } + } + } + } finally { + for (ByteBuf buf : bufs) { + buf.release(); + } + } + } + + private static void shuffle(SplittableRandom rng, Object array) { + int len = Array.getLength(array); + for (int i = 0; i < len; i++) { + int n = rng.nextInt(i, len); + Object value = Array.get(array, i); + Array.set(array, i, Array.get(array, n)); + Array.set(array, n, value); + } + } } diff --git a/buffer/src/test/java/io/netty/buffer/AdaptivePoolingAllocatorTest.java b/buffer/src/test/java/io/netty/buffer/AdaptivePoolingAllocatorTest.java index ab47050c641..4a4c28deebf 100644 --- a/buffer/src/test/java/io/netty/buffer/AdaptivePoolingAllocatorTest.java +++ b/buffer/src/test/java/io/netty/buffer/AdaptivePoolingAllocatorTest.java @@ -15,52 +15,11 @@ */ package io.netty.buffer; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import java.util.function.Supplier; - import static org.junit.jupiter.api.Assertions.assertEquals; -class AdaptivePoolingAllocatorTest implements Supplier { - private int i; - - @BeforeEach - void setUp() { - i = 0; - } - - @Override - public String get() { - return "i = " + i; - } - - @Test - void sizeBucketComputations() throws Exception { - assertSizeBucket(0, 16 * 1024); - assertSizeBucket(1, 24 * 1024); - assertSizeBucket(2, 32 * 1024); - assertSizeBucket(3, 48 * 1024); - assertSizeBucket(4, 64 * 1024); - assertSizeBucket(5, 96 * 1024); - assertSizeBucket(6, 128 * 1024); - assertSizeBucket(7, 192 * 1024); - assertSizeBucket(8, 256 * 1024); - assertSizeBucket(9, 384 * 1024); - assertSizeBucket(10, 512 * 1024); - assertSizeBucket(11, 768 * 1024); - assertSizeBucket(12, 1024 * 1024); - assertSizeBucket(13, 1792 * 1024); - assertSizeBucket(14, 2048 * 1024); - assertSizeBucket(15, 3072 * 1024); - // The sizeBucket function will be used for sizes up to 8 MiB - assertSizeBucket(15, 4 * 1024 * 1024); - assertSizeBucket(15, 5 * 1024 * 1024); - assertSizeBucket(15, 6 * 1024 * 1024); - assertSizeBucket(15, 7 * 1024 * 1024); - assertSizeBucket(15, 8 * 1024 * 1024); - } - +class AdaptivePoolingAllocatorTest { @Test void sizeClassComputations() throws Exception { final int[] sizeClasses = AdaptivePoolingAllocator.getSizeClasses(); @@ -75,20 +34,7 @@ void sizeClassComputations() throws Exception { private static void assertSizeClassOf(int expectedSizeClass, int previousSizeIncluded, int maxSizeIncluded) { for (int size = previousSizeIncluded; size <= maxSizeIncluded; size++) { - final int sizeToTest = size; - Supplier messageSupplier = new Supplier() { - @Override - public String get() { - return "size = " + sizeToTest; - } - }; - assertEquals(expectedSizeClass, AdaptivePoolingAllocator.sizeClassIndexOf(size), messageSupplier); - } - } - - private void assertSizeBucket(int expectedSizeBucket, int maxSizeIncluded) { - for (; i <= maxSizeIncluded; i++) { - assertEquals(expectedSizeBucket, AdaptivePoolingAllocator.sizeToBucket(i), this); + assertEquals(expectedSizeClass, AdaptivePoolingAllocator.sizeClassIndexOf(size), "size = " + size); } } } diff --git a/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java b/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java index ecc01065210..64638f8e1cb 100644 --- a/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java +++ b/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java @@ -20,6 +20,7 @@ import io.netty.util.concurrent.FastThreadLocalThread; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.ThrowableUtil; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -30,6 +31,7 @@ import java.util.Random; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.FutureTask; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -349,13 +351,13 @@ public void testAllocateSmallOffset() { } @Test - @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) + @Timeout(value = 20, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) public void testThreadCacheDestroyedByThreadCleaner() throws InterruptedException { testThreadCacheDestroyed(false); } @Test - @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) + @Timeout(value = 20, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) public void testThreadCacheDestroyedAfterExitRun() throws InterruptedException { testThreadCacheDestroyed(true); } @@ -408,7 +410,6 @@ public void run() { while (allocator.metric().numThreadLocalCaches() > 0) { // Signal we want to have a GC run to ensure we can process our ThreadCleanerReference System.gc(); - System.runFinalization(); LockSupport.parkNanos(MILLISECONDS.toNanos(100)); } @@ -416,8 +417,8 @@ public void run() { } @Test - @Timeout(value = 3000, unit = MILLISECONDS) - public void testNumThreadCachesWithNoDirectArenas() throws InterruptedException { + @Timeout(10) + public void testNumThreadCachesWithNoDirectArenas() throws Exception { int numHeapArenas = 1; final PooledByteBufAllocator allocator = new PooledByteBufAllocator(numHeapArenas, 0, 8192, 1); @@ -436,11 +437,11 @@ public void testNumThreadCachesWithNoDirectArenas() throws InterruptedException } @Test - @Timeout(value = 3000, unit = MILLISECONDS) - public void testNumThreadCachesAccountForDirectAndHeapArenas() throws InterruptedException { - int numHeapArenas = 1; + @Timeout(10) + public void testNumThreadCachesAccountForDirectAndHeapArenas() throws Exception { + int numArenas = 1; final PooledByteBufAllocator allocator = - new PooledByteBufAllocator(numHeapArenas, 0, 8192, 1); + new PooledByteBufAllocator(numArenas, numArenas, 8192, 1); ThreadCache tcache0 = createNewThreadCache(allocator, false); assertEquals(1, allocator.metric().numThreadLocalCaches()); @@ -456,8 +457,8 @@ public void testNumThreadCachesAccountForDirectAndHeapArenas() throws Interrupte } @Test - @Timeout(value = 3000, unit = MILLISECONDS) - public void testThreadCacheToArenaMappings() throws InterruptedException { + @Timeout(10) + public void testThreadCacheToArenaMappings() throws Exception { int numArenas = 2; final PooledByteBufAllocator allocator = new PooledByteBufAllocator(numArenas, numArenas, 8192, 1); @@ -500,8 +501,7 @@ private static ThreadCache createNewThreadCache(final PooledByteBufAllocator all throws InterruptedException { final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch cacheLatch = new CountDownLatch(1); - final Thread t = new FastThreadLocalThread(new Runnable() { - + final FutureTask task = new FutureTask(new Runnable() { @Override public void run() { final ByteBuf buf; @@ -527,23 +527,35 @@ public void run() { FastThreadLocal.removeAll(); } - }); + }, null); + final Thread t = new FastThreadLocalThread(task); t.start(); // Wait until we allocated a buffer and so be sure the thread was started and the cache exists. - cacheLatch.await(); + try { + cacheLatch.await(); + } catch (InterruptedException e) { + ThrowableUtil.interruptAndAttachAsyncStackTrace(t, e); + throw e; + } return new ThreadCache() { @Override - public void destroy() throws InterruptedException { + public void destroy() throws Exception { latch.countDown(); - t.join(); + try { + task.get(); + t.join(); + } catch (InterruptedException e) { + ThrowableUtil.interruptAndAttachAsyncStackTrace(t, e); + throw e; + } } }; } private interface ThreadCache { - void destroy() throws InterruptedException; + void destroy() throws Exception; } @Test diff --git a/buffer/src/test/java/io/netty/buffer/ReadOnlyDirectByteBufferBufTest.java b/buffer/src/test/java/io/netty/buffer/ReadOnlyDirectByteBufferBufTest.java index 7e9fd0b019d..bf395843614 100644 --- a/buffer/src/test/java/io/netty/buffer/ReadOnlyDirectByteBufferBufTest.java +++ b/buffer/src/test/java/io/netty/buffer/ReadOnlyDirectByteBufferBufTest.java @@ -490,6 +490,7 @@ void testIsWritable(DerivedParam param) { } @Test + @Timeout(30) public void testDuplicateReadGatheringByteChannelMultipleThreads() throws Exception { final byte[] bytes = new byte[8]; PlatformDependent.threadLocalRandom().nextBytes(bytes); @@ -505,6 +506,7 @@ public void testDuplicateReadGatheringByteChannelMultipleThreads() throws Except } @Test + @Timeout(30) public void testSliceReadGatheringByteChannelMultipleThreads() throws Exception { final byte[] bytes = new byte[8]; PlatformDependent.threadLocalRandom().nextBytes(bytes); @@ -520,6 +522,7 @@ public void testSliceReadGatheringByteChannelMultipleThreads() throws Exception } @Test + @Timeout(30) public void testDuplicateReadOutputStreamMultipleThreads() throws Exception { final byte[] bytes = new byte[8]; PlatformDependent.threadLocalRandom().nextBytes(bytes); @@ -535,6 +538,7 @@ public void testDuplicateReadOutputStreamMultipleThreads() throws Exception { } @Test + @Timeout(30) public void testSliceReadOutputStreamMultipleThreads() throws Exception { final byte[] bytes = new byte[8]; PlatformDependent.threadLocalRandom().nextBytes(bytes); @@ -580,7 +584,7 @@ public void testSliceBytesInArrayMultipleThreads() throws Exception { } @Test - @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + @Timeout(30) public void testToStringMultipleThreads1() throws Throwable { String expected = "Hello, World!"; byte[] bytes = expected.getBytes(CharsetUtil.ISO_8859_1); diff --git a/buffer/src/test/java/io/netty/buffer/UnpooledTest.java b/buffer/src/test/java/io/netty/buffer/UnpooledTest.java index efc1dafd1ed..fd705f1bed0 100644 --- a/buffer/src/test/java/io/netty/buffer/UnpooledTest.java +++ b/buffer/src/test/java/io/netty/buffer/UnpooledTest.java @@ -476,7 +476,7 @@ public void testUnmodifiableBuffer() throws Exception { } catch (UnsupportedOperationException e) { // Expected } - Mockito.verifyZeroInteractions(inputStream); + Mockito.verifyNoInteractions(inputStream); ScatteringByteChannel scatteringByteChannel = Mockito.mock(ScatteringByteChannel.class); try { @@ -485,7 +485,7 @@ public void testUnmodifiableBuffer() throws Exception { } catch (UnsupportedOperationException e) { // Expected } - Mockito.verifyZeroInteractions(scatteringByteChannel); + Mockito.verifyNoInteractions(scatteringByteChannel); buf.release(); } diff --git a/codec-dns/pom.xml b/codec-dns/pom.xml index 39846136332..807b5b7860b 100644 --- a/codec-dns/pom.xml +++ b/codec-dns/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec-dns diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java index 2aea39159fe..80cf862ab6a 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java @@ -16,6 +16,8 @@ package io.netty.handler.codec.dns; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.CorruptedFrameException; /** * The default {@link DnsRecordDecoder} implementation. @@ -99,6 +101,30 @@ protected DnsRecord decodeRecord( DnsCodecUtil.decompressDomainName( in.duplicate().setIndex(offset, offset + length))); } + if (type == DnsRecordType.MX) { + // MX RDATA: 16-bit preference + exchange (domain name, possibly compressed) + if (length < 3) { + throw new CorruptedFrameException("MX record RDATA is too short: " + length); + } + final int pref = in.getUnsignedShort(offset); + ByteBuf exchange = null; + try { + exchange = DnsCodecUtil.decompressDomainName( + in.duplicate().setIndex(offset + 2, offset + length)); + + // Build decompressed RDATA = [preference][expanded exchange name] + final ByteBuf out = in.alloc().buffer(2 + exchange.readableBytes()); + out.writeShort(pref); + out.writeBytes(exchange); + + return new DefaultDnsRawRecord(name, type, dnsClass, timeToLive, out); + } finally { + if (exchange != null) { + exchange.release(); + } + } + } + return new DefaultDnsRawRecord( name, type, dnsClass, timeToLive, in.retainedDuplicate().setIndex(offset, offset + length)); } diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsCodecUtil.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsCodecUtil.java index a702771df86..3e1d6b1a868 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsCodecUtil.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsCodecUtil.java @@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.handler.codec.CorruptedFrameException; +import io.netty.handler.codec.TooLongFrameException; import io.netty.util.CharsetUtil; import static io.netty.handler.codec.dns.DefaultDnsRecordDecoder.*; @@ -35,14 +36,33 @@ static void encodeDomainName(String name, ByteBuf buf) { return; } + int totalLength = 0; final String[] labels = name.split("\\."); - for (String label : labels) { + for (int i = 0; i < labels.length; i++) { + String label = labels[i]; final int labelLen = label.length(); if (labelLen == 0) { - // zero-length label means the end of the name. - break; + if (i == labels.length - 1) { + // zero-length label at the end means the end of the name. + break; + } else { + throw new IllegalArgumentException("DNS name contains empty label: " + name); + } + } + if (labelLen > 63) { + throw new IllegalArgumentException( + "DNS label length " + labelLen + " exceeds maximum of 63: " + name); + } + int idx = label.indexOf('\0'); + if (idx != -1) { + throw new IllegalArgumentException( + "DNS label contains null byte at index " + idx); + } + totalLength += 1 + labelLen; + if (totalLength > 255) { + throw new IllegalArgumentException( + "DNS name exceeds maximum length of 255: " + name); } - buf.writeByte(labelLen); ByteBufUtil.writeAscii(buf, label); } @@ -95,8 +115,16 @@ static String decodeDomainName(ByteBuf in) { if (!in.isReadable(len)) { throw new CorruptedFrameException("truncated label in a name"); } + // See https://datatracker.ietf.org/doc/html/rfc1035#section-2.3.4 + if (len > 63) { + throw new TooLongFrameException("label must be <= 63 but was " + len); + } name.append(in.toString(in.readerIndex(), len, CharsetUtil.UTF_8)).append('.'); in.skipBytes(len); + // See https://datatracker.ietf.org/doc/html/rfc1035#section-2.3.4 + if (name.length() > 255) { + throw new TooLongFrameException("domain name must be <= 255 but was " + name.length()); + } } else { // len == 0 break; } diff --git a/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java b/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java index a8379f6d8d7..d66b994b604 100644 --- a/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java +++ b/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java @@ -166,6 +166,51 @@ public void testDecodeCompressionRDataPointer() throws Exception { } } + @Test + public void testDecodeCompressionRDataPointerMX() throws Exception { + DefaultDnsRecordDecoder decoder = new DefaultDnsRecordDecoder(); + byte[] compressionPointer = { + 5, 'n', 'e', 't', 't', 'y', 2, 'i', 'o', 0, + 0, 10, // preference = 10 + (byte) 0xC0, 0 // record is a pointer to netty.io + }; + + byte[] expected = { + 0, 10, // pref = 10 + 5, 'n', 'e', 't', 't', 'y', 2, 'i', 'o', 0 + }; + ByteBuf buffer = Unpooled.wrappedBuffer(compressionPointer); + DefaultDnsRawRecord mxRecord = null; + ByteBuf expectedBuf = null; + try { + mxRecord = (DefaultDnsRawRecord) decoder.decodeRecord( + "mail.example.com", + DnsRecordType.MX, + DnsRecord.CLASS_IN, + 60, + buffer, + 10, + 4); + + expectedBuf = Unpooled.wrappedBuffer(expected); + + assertEquals(0, ByteBufUtil.compare(expectedBuf, mxRecord.content()), + "The rdata of MX-type record should be decompressed in advance"); + assertEquals(10, mxRecord.content().getUnsignedShort(0)); + + ByteBuf exchangerName = mxRecord.content().duplicate().setIndex(2, mxRecord.content().writerIndex()); + assertEquals("netty.io.", DnsCodecUtil.decodeDomainName(exchangerName)); + } finally { + buffer.release(); + if (expectedBuf != null) { + expectedBuf.release(); + } + if (mxRecord != null) { + mxRecord.release(); + } + } + } + @Test public void testDecodeMessageCompression() throws Exception { // See https://www.ietf.org/rfc/rfc1035 [4.1.4. Message compression] diff --git a/codec-dns/src/test/java/io/netty/handler/codec/dns/DnsCodecUtilTest.java b/codec-dns/src/test/java/io/netty/handler/codec/dns/DnsCodecUtilTest.java new file mode 100644 index 00000000000..5d7ddcc107b --- /dev/null +++ b/codec-dns/src/test/java/io/netty/handler/codec/dns/DnsCodecUtilTest.java @@ -0,0 +1,130 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.dns; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.TooLongFrameException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DnsCodecUtilTest { + + @Test + void rejectTooLongLabelWhileDecoding() { + final ByteBuf buf = Unpooled.buffer(256); + // 63 is the maximum label length + writeLabel(buf, 64); + writeLabel(buf, 3); + buf.writeByte(0); + + assertThrows(TooLongFrameException.class, new Executable() { + @Override + public void execute() throws Throwable { + DnsCodecUtil.decodeDomainName(buf); + } + }); + buf.release(); + } + + @Test + void rejectTooLongDomainNameWhileDecoding() { + // 255 is the maximum domain name + final ByteBuf buf = Unpooled.buffer(512); + writeLabel(buf, 50); + writeLabel(buf, 50); + writeLabel(buf, 50); + writeLabel(buf, 50); + writeLabel(buf, 56); + buf.writeByte(0); + + assertThrows(TooLongFrameException.class, new Executable() { + @Override + public void execute() throws Throwable { + DnsCodecUtil.decodeDomainName(buf); + } + }); + buf.release(); + } + + @Test + void rejectTooLongLabelWhileEncoding() { + final ByteBuf buf = Unpooled.buffer(256); + // 63 is the maximum label length + final StringBuilder sb = new StringBuilder(); + appendLabel(sb, 64); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + DnsCodecUtil.encodeDomainName(sb.toString(), buf); + } + }); + buf.release(); + } + + @Test + void rejectEmptyLabelWhileEncoding() { + final ByteBuf buf = Unpooled.buffer(256); + // 63 is the maximum label length + final StringBuilder sb = new StringBuilder(); + appendLabel(sb, 5); + appendLabel(sb, 0); + appendLabel(sb, 5); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + DnsCodecUtil.encodeDomainName(sb.toString(), buf); + } + }); + buf.release(); + } + + @Test + void rejectTooLongDomainNameWhileEncoding() { + final ByteBuf buf = Unpooled.buffer(256); + // 255 is the maximum domain name + final StringBuilder sb = new StringBuilder(); + appendLabel(sb, 50); + appendLabel(sb, 50); + appendLabel(sb, 50); + appendLabel(sb, 50); + appendLabel(sb, 56); + + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + DnsCodecUtil.encodeDomainName(sb.toString(), buf); + } + }); + buf.release(); + } + + private static void writeLabel(ByteBuf buf, int length) { + buf.writeByte(length); + for (int i = 1; i <= length; i++) { + buf.writeByte(i); + } + } + + private static void appendLabel(StringBuilder sb, int length) { + for (int i = 0; i < length; i++) { + sb.append('a'); + } + sb.append('.'); + } +} diff --git a/codec-haproxy/pom.xml b/codec-haproxy/pom.xml index b7a0e7202a1..0ebb62f2d79 100644 --- a/codec-haproxy/pom.xml +++ b/codec-haproxy/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec-haproxy diff --git a/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessage.java b/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessage.java index 2bb319ff201..b6c47663fbb 100644 --- a/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessage.java +++ b/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessage.java @@ -18,13 +18,13 @@ import io.netty.buffer.ByteBuf; import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.AddressFamily; import io.netty.util.AbstractReferenceCounted; -import io.netty.util.ByteProcessor; import io.netty.util.CharsetUtil; import io.netty.util.NetUtil; import io.netty.util.ResourceLeakDetector; import io.netty.util.ResourceLeakDetectorFactory; import io.netty.util.ResourceLeakTracker; import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; import java.util.ArrayList; @@ -251,15 +251,50 @@ private static List readTlvs(final ByteBuf header) { // In most cases there are less than 4 TLVs available List haProxyTLVs = new ArrayList(4); - do { - haProxyTLVs.add(haProxyTLV); - if (haProxyTLV instanceof HAProxySSLTLV) { - haProxyTLVs.addAll(((HAProxySSLTLV) haProxyTLV).encapsulatedTLVs()); - } - } while ((haProxyTLV = readNextTLV(header, 0)) != null); + try { + do { + haProxyTLVs.add(haProxyTLV); + if (haProxyTLV instanceof HAProxySSLTLV) { + haProxyTLVs.addAll(((HAProxySSLTLV) haProxyTLV).encapsulatedTLVs()); + } + } while ((haProxyTLV = readNextTLV(header, 0)) != null); + } catch (Throwable t) { + // Release all previously read TLVs before rethrowing as otherwise we would leak. + releaseTlvs(haProxyTLVs); + PlatformDependent.throwException(t); + } return haProxyTLVs; } + private static void releaseDeep(List children) { + for (HAProxyTLV child : children) { + child.release(); + if (child instanceof HAProxySSLTLV) { + releaseDeep(((HAProxySSLTLV) child).encapsulatedTLVs()); + } + } + } + + private static void releaseTlvs(List tlvs) { + int skip = 0; + for (HAProxyTLV tlv : tlvs) { + if (skip > 0) { + skip--; + // This TLV is a flattened depth-1 child. If it encapsulates anything (depth-2+), + // those deeper children were NOT flattened, so we must release them recursively. + if (tlv instanceof HAProxySSLTLV) { + releaseDeep(((HAProxySSLTLV) tlv).encapsulatedTLVs()); + } + } else if (tlv instanceof HAProxySSLTLV) { + // This is a top-level (depth-0) SSL TLV. + // Its immediate children (depth-1) were flattened into this list, + // so we must skip them in the outer loop to avoid treating them as top-level TLVs. + skip = ((HAProxySSLTLV) tlv).encapsulatedTLVs().size(); + } + tlv.release(); + } + } + private static HAProxyTLV readNextTLV(final ByteBuf header, int nestingLevel) { if (nestingLevel > MAX_NESTING_LEVEL) { throw new HAProxyProtocolException( @@ -276,7 +311,16 @@ private static HAProxyTLV readNextTLV(final ByteBuf header, int nestingLevel) { final int length = header.readUnsignedShort(); switch (type) { case PP2_TYPE_SSL: - final ByteBuf rawContent = header.retainedSlice(header.readerIndex(), length); + if (length < 5) { + throw new HAProxyProtocolException("TLV length must be at least 5 but was: " + length); + } + if (length > header.readableBytes()) { + throw new HAProxyProtocolException("TLV length must be smaller or equal the readable bytes (" + + header.readableBytes() + ") but was: " + length); + } + // Slice the rawContent but only retain it if we didn't see an error as otherwise we might + // leak. + final ByteBuf rawContent = header.slice(header.readerIndex(), length); final ByteBuf byteBuf = header.readSlice(length); final byte client = byteBuf.readByte(); final int verify = byteBuf.readInt(); @@ -284,17 +328,22 @@ private static HAProxyTLV readNextTLV(final ByteBuf header, int nestingLevel) { if (byteBuf.readableBytes() >= 4) { final List encapsulatedTlvs = new ArrayList(4); - do { - final HAProxyTLV haProxyTLV = readNextTLV(byteBuf, nestingLevel + 1); - if (haProxyTLV == null) { - break; - } - encapsulatedTlvs.add(haProxyTLV); - } while (byteBuf.readableBytes() >= 4); - - return new HAProxySSLTLV(verify, client, encapsulatedTlvs, rawContent); + try { + do { + final HAProxyTLV haProxyTLV = readNextTLV(byteBuf, nestingLevel + 1); + if (haProxyTLV == null) { + break; + } + encapsulatedTlvs.add(haProxyTLV); + } while (byteBuf.readableBytes() >= 4); + } catch (Throwable t) { + releaseTlvs(encapsulatedTlvs); + PlatformDependent.throwException(t); + } + + return new HAProxySSLTLV(verify, client, encapsulatedTlvs, rawContent.retain()); } - return new HAProxySSLTLV(verify, client, Collections.emptyList(), rawContent); + return new HAProxySSLTLV(verify, client, Collections.emptyList(), rawContent.retain()); // If we're not dealing with an SSL Type, we can use the same mechanism case PP2_TYPE_ALPN: case PP2_TYPE_AUTHORITY: @@ -599,9 +648,7 @@ private void tryRecord() { @Override protected void deallocate() { try { - for (HAProxyTLV tlv : tlvs) { - tlv.release(); - } + releaseTlvs(tlvs); } finally { final ResourceLeakTracker leak = this.leak; if (leak != null) { diff --git a/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java b/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java index b00a53e2588..fd47d416b53 100644 --- a/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java +++ b/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java @@ -27,6 +27,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; @@ -764,6 +766,99 @@ public void testV2WithSslTLVs() { assertFalse(ch.finish()); } + @Test + public void testV2WithNestedSslTLVs() { + ch = new EmbeddedChannel(new HAProxyMessageDecoder()); + + // Outer SSL TLV (type=0x20, content=28): + // client(1)=0x05 verify(4)=0 + // Inner SSL TLV (type=0x20, content=13): <-- depth-1 nested SSL + // client(1)=0x01 verify(4)=0 + // PP2_TYPE_SSL_VERSION (type=0x21, len=5): "TLSv1" <-- depth-2 leaf + // PP2_TYPE_SSL_CN (type=0x22, len=4): "LEAF" <-- depth-1 leaf + final byte[] bytes = { + 13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, // v2 signature + 33, 17, // v2|PROXY, TCP4 + 0, 43, // remaining: 12 + 31 + 127, 0, 0, 1, 127, 0, 0, 1, -55, -90, 7, 89, // addresses + ports + 32, 0, 28, // outer SSL: type=0x20, len=28 + 5, 0, 0, 0, 0, // outer: client=0x05, verify=0 + 32, 0, 13, // inner SSL: type=0x20, len=13 + 1, 0, 0, 0, 0, // inner: client=0x01, verify=0 + 33, 0, 5, 84, 76, 83, 118, 49, // SSL_VERSION: "TLSv1" + 34, 0, 4, 76, 69, 65, 70 // SSL_CN: "LEAF" + }; + + int startChannels = ch.pipeline().names().size(); + assertTrue(ch.writeInbound(copiedBuffer(bytes))); + Object msgObj = ch.readInbound(); + assertEquals(startChannels - 1, ch.pipeline().names().size()); + HAProxyMessage msg = (HAProxyMessage) msgObj; + + assertEquals(HAProxyProtocolVersion.V2, msg.protocolVersion()); + assertEquals(HAProxyCommand.PROXY, msg.command()); + assertEquals(HAProxyProxiedProtocol.TCP4, msg.proxiedProtocol()); + assertEquals("127.0.0.1", msg.sourceAddress()); + assertEquals("127.0.0.1", msg.destinationAddress()); + assertEquals(51622, msg.sourcePort()); + assertEquals(1881, msg.destinationPort()); + final List tlvs = msg.tlvs(); + + // Flattened list: [outerSSL, innerSSL, SSL_CN] + // SSL_CN is a direct child of outer, so it is flattened. + // innerSSL is also a direct child of outer, so it is flattened. + // But "TLSv1" (SSL_VERSION) is a child of innerSSL (depth 2) — NOT flattened. + assertEquals(3, tlvs.size()); + final HAProxyTLV firstTlv = tlvs.get(0); + assertEquals(HAProxyTLV.Type.PP2_TYPE_SSL, firstTlv.type()); + final HAProxySSLTLV sslTlv = (HAProxySSLTLV) firstTlv; + assertEquals(0, sslTlv.verify()); + assertTrue(sslTlv.isPP2ClientSSL()); + assertTrue(sslTlv.isPP2ClientCertSess()); + assertFalse(sslTlv.isPP2ClientCertConn()); + + final HAProxyTLV secondTlv = tlvs.get(1); + + assertEquals(HAProxyTLV.Type.PP2_TYPE_SSL, secondTlv.type()); + final HAProxySSLTLV innerSslTlv = (HAProxySSLTLV) secondTlv; + + // The depth-2 leaf: SSL_VERSION "TLSv1" lives inside innerSslTlv + assertEquals(1, innerSslTlv.encapsulatedTLVs().size()); + final HAProxyTLV depth2Leaf = innerSslTlv.encapsulatedTLVs().get(0); + assertEquals(HAProxyTLV.Type.PP2_TYPE_SSL_VERSION, depth2Leaf.type()); + ByteBuf versionBuf = depth2Leaf.content(); + byte[] versionContent = new byte[versionBuf.readableBytes()]; + versionBuf.readBytes(versionContent); + assertArrayEquals("TLSv1".getBytes(CharsetUtil.US_ASCII), versionContent); + + final HAProxyTLV thirdTLV = tlvs.get(2); + assertEquals(HAProxyTLV.Type.PP2_TYPE_SSL_CN, thirdTLV.type()); + ByteBuf thirdContentBuf = thirdTLV.content(); + byte[] thirdContent = new byte[thirdContentBuf.readableBytes()]; + thirdContentBuf.readBytes(thirdContent); + assertArrayEquals("LEAF".getBytes(CharsetUtil.US_ASCII), thirdContent); + + assertTrue(sslTlv.encapsulatedTLVs().contains(secondTlv)); + assertTrue(sslTlv.encapsulatedTLVs().contains(thirdTLV)); + + assertTrue(0 < firstTlv.refCnt()); + assertTrue(0 < secondTlv.refCnt()); + assertTrue(0 < thirdTLV.refCnt()); + assertTrue(0 < depth2Leaf.refCnt()); + assertTrue(msg.release()); + + // The depth-2 leaf TLV must be fully released after message.release(). + // It is a child of the inner SSL TLV (depth 1), but readTlvs() only flattens + // one level of encapsulated TLVs. + assertEquals(0, depth2Leaf.refCnt(), "Depth-2 leaf TLV leaked"); + assertEquals(0, firstTlv.refCnt()); + assertEquals(0, secondTlv.refCnt()); + assertEquals(0, thirdTLV.refCnt()); + + assertNull(ch.readInbound()); + assertFalse(ch.finish()); + } + @Test public void testReleaseHAProxyMessage() { ch = new EmbeddedChannel(new HAProxyMessageDecoder()); @@ -1256,4 +1351,123 @@ public void execute() { } }); } + + @ParameterizedTest + @ValueSource(shorts = { + 4, // Use a length which is < 5. + Short.MAX_VALUE // Use a length which is > readable bytes. + }) + public void testInvalidTLVLengthCorrectlyHandled(short length) throws Exception { + ByteArrayOutputStream headerWriter = new ByteArrayOutputStream(); + //src_ip = "AAAA", dst_ip = "BBBB", src_port = "CC", dst_port = "DD" + headerWriter.write(new byte[] {'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'C', 'C', 'D', 'D'}); + //write TLV + ByteBuffer tlvLengthBuf = ByteBuffer.allocate(2); + tlvLengthBuf.order(ByteOrder.BIG_ENDIAN); + //write PP2_TYPE_SSL TLV + headerWriter.write(0x20); //PP2_TYPE_SSL + //notice that the TLV length cannot be bigger than 0xffff + tlvLengthBuf.clear(); + tlvLengthBuf.putShort(length); + //add to the header + headerWriter.write(tlvLengthBuf.array()); + //write client field + headerWriter.write(1); + //write verify field + headerWriter.write(new byte[] {'V', 'V', 'V', 'V'}); + //subtract the client and verify fields + + byte[] header = headerWriter.toByteArray(); + ByteBuffer numsWrite = ByteBuffer.allocate(2); + numsWrite.order(ByteOrder.BIG_ENDIAN); + numsWrite.putShort((short) header.length); + + final ByteBuf data = Unpooled.buffer(); + data.writeBytes(new byte[] { + (byte) 0x0D, + (byte) 0x0A, + (byte) 0x0D, + (byte) 0x0A, + (byte) 0x00, + (byte) 0x0D, + (byte) 0x0A, + (byte) 0x51, + (byte) 0x55, + (byte) 0x49, + (byte) 0x54, + (byte) 0x0A + }); + //verCmd = 32 + byte versionCmd = 0x20 | 1; //V2 | ProxyCmd + data.writeByte(versionCmd); + data.writeByte(17); //TPAF_TCP4_BYTE + data.writeBytes(numsWrite.array()); + data.writeBytes(header); + + assertThrows(HAProxyProtocolException.class, new Executable() { + @Override + public void execute() { + ch.writeInbound(data); + } + }); + } + + @Test + public void testReadTlvsLeaksRetainedBufferWhenSecondSSLTLVIsMalformed() { + final ByteBuf data = Unpooled.buffer(); + data.writeBytes(new byte[] { + 13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, // v2 signature + 33, 17, // V2|PROXY, TCP4 + 0, 26, // remaining = 26 (12 addr + 8 TLV#1 + 6 TLV#2) + 65, 65, 65, 65, 66, 66, 66, 66, 67, 67, 68, 68, // addr + ports + 32, 0, 5, 1, 0, 0, 0, 0, // TLV #1: PP2_TYPE_SSL len=5, client=1, verify=0 + 32, 0, 3, 65, 66, 67 // TLV #2: PP2_TYPE_SSL len=3 (MALFORMED: len < 5) + }); + + assertEquals(1, data.refCnt()); + assertThrows(HAProxyProtocolException.class, new Executable() { + @Override + public void execute() throws Throwable { + HAProxyMessage.decodeHeader(data); + } + }); + + try { + assertEquals(1, data.refCnt(), + "TLV #1 rawContent leaked in readTlvs() - expected refCnt=1, got " + data.refCnt()); + } finally { + data.release(); + } + } + + @Test + public void testEncapsulatedTLVsLeakWhenInnerSSLTLVIsMalformed() { + final ByteBuf data = Unpooled.buffer(); + data.writeBytes(new byte[] { + 13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, // v2 signature + 33, 17, // V2|PROXY, TCP4 + 0, 34, // remaining = 34 (12 addr + 22 outer SSL TLV) + 65, 65, 65, 65, 66, 66, 66, 66, 67, 67, 68, 68, // addr + ports + 32, 0, 19, // outer SSL: type=0x20, len=19 + 5, 0, 0, 0, 0, // outer: client=0x05, verify=0 + 33, 0, 5, 84, 76, 83, 118, 49, // inner SSL_VERSION: "TLSv1" (readRetainedSlice) + 32, 0, 3, 65, 66, 67 // inner SSL: len=3 (MALFORMED) → throws + }); + + assertEquals(1, data.refCnt()); + assertThrows(HAProxyProtocolException.class, new Executable() { + @Override + public void execute() throws Throwable { + HAProxyMessage.decodeHeader(data); + } + }); + + try { + assertEquals(1, data.refCnt(), + "Inner PP2_TYPE_SSL_VERSION buffer leaked in encapsulated TLV loop - " + + "expected refCnt=1, got " + data.refCnt()); + } finally { + data.release(); + } + } } diff --git a/codec-http/pom.xml b/codec-http/pom.xml index 9d37046e74e..14fd514d8ee 100644 --- a/codec-http/pom.xml +++ b/codec-http/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec-http diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/ContentLengthNotAllowedException.java b/codec-http/src/main/java/io/netty/handler/codec/http/ContentLengthNotAllowedException.java new file mode 100644 index 00000000000..a63952ee53a --- /dev/null +++ b/codec-http/src/main/java/io/netty/handler/codec/http/ContentLengthNotAllowedException.java @@ -0,0 +1,34 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.DecoderException; + +/** + * Thrown by {@link HttpObjectDecoder#handleTransferEncodingChunkedWithContentLength(HttpMessage)} by default. + *

+ * The HTTP/1.1 specification, RFC 9112, disallow senders from including both {@code Tranfer-Encoding} and + * {@code Content-Length headers in the same message, and permits servers to reject such requests. + */ +public final class ContentLengthNotAllowedException extends DecoderException { + /** + * Create a new instance with the given message. + * @param message The exception message. + */ + public ContentLengthNotAllowedException(String message) { + super(message); + } +} diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java index 3cd8d0c6985..a4762516846 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java @@ -92,7 +92,15 @@ public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String */ public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, ByteBuf content, HttpHeaders headers, HttpHeaders trailingHeader) { - super(httpVersion, method, uri, headers); + this(httpVersion, method, uri, content, headers, trailingHeader, true); + } + + /** + * Create a full HTTP response with the given HTTP version, method, URI, contents, and header and trailer objects. + */ + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, + ByteBuf content, HttpHeaders headers, HttpHeaders trailingHeader, boolean validateRequestLine) { + super(httpVersion, method, uri, headers, validateRequestLine); this.content = checkNotNull(content, "content"); this.trailingHeader = checkNotNull(trailingHeader, "trailingHeader"); } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java index 271b6069a02..32c241f2810 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java @@ -25,6 +25,7 @@ public class DefaultHttpRequest extends DefaultHttpMessage implements HttpReques private static final int HASH_CODE_PRIME = 31; private HttpMethod method; private String uri; + private final boolean validateRequestLine; /** * Creates a new instance. @@ -75,9 +76,26 @@ public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri * @param headers the Headers for this Request */ public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, HttpHeaders headers) { + this(httpVersion, method, uri, headers, true); + } + + /** + * Creates a new instance. + * + * @param httpVersion the HTTP version of the request + * @param method the HTTP method of the request + * @param uri the URI or path of the request + * @param headers the Headers for this Request + */ + public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, HttpHeaders headers, + boolean validateRequestLine) { super(httpVersion, headers); this.method = checkNotNull(method, "method"); this.uri = checkNotNull(uri, "uri"); + this.validateRequestLine = validateRequestLine; + if (validateRequestLine) { + HttpUtil.validateRequestLineTokens(method, uri); + } } @Override @@ -104,13 +122,21 @@ public String uri() { @Override public HttpRequest setMethod(HttpMethod method) { - this.method = checkNotNull(method, "method"); + checkNotNull(method, "method"); + if (validateRequestLine) { + HttpUtil.validateRequestLineTokens(method, uri); + } + this.method = method; return this; } @Override public HttpRequest setUri(String uri) { - this.uri = checkNotNull(uri, "uri"); + checkNotNull(uri, "uri"); + if (validateRequestLine) { + HttpUtil.validateRequestLineTokens(method, uri); + } + this.uri = uri; return this; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpChunkLineValidatingByteProcessor.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpChunkLineValidatingByteProcessor.java new file mode 100644 index 00000000000..ddd5b71ea19 --- /dev/null +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpChunkLineValidatingByteProcessor.java @@ -0,0 +1,178 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.ByteProcessor; + +import java.util.BitSet; + +/** + * Validates the chunk start line. That is, the chunk size and chunk extensions, until the CR LF pair. + * See RFC 9112 section 7.1. + * + *

{@code
+ *   chunked-body   = *chunk
+ *                    last-chunk
+ *                    trailer-section
+ *                    CRLF
+ *
+ *   chunk          = chunk-size [ chunk-ext ] CRLF
+ *                    chunk-data CRLF
+ *   chunk-size     = 1*HEXDIG
+ *   last-chunk     = 1*("0") [ chunk-ext ] CRLF
+ *
+ *   chunk-data     = 1*OCTET ; a sequence of chunk-size octets
+ *   chunk-ext      = *( BWS ";" BWS chunk-ext-name
+ *                       [ BWS "=" BWS chunk-ext-val ] )
+ *
+ *   chunk-ext-name = token
+ *   chunk-ext-val  = token / quoted-string
+ *   quoted-string  = DQUOTE *( qdtext / quoted-pair ) DQUOTE
+ *   qdtext         = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
+ *   quoted-pair    = "\" ( HTAB / SP / VCHAR / obs-text )
+ *   obs-text       = %x80-FF
+ *   OWS            = *( SP / HTAB )
+ *                  ; optional whitespace
+ *   BWS            = OWS
+ *                  ; "bad" whitespace
+ *   VCHAR          =  %x21-7E
+ *                  ; visible (printing) characters
+ * }
+ */ +final class HttpChunkLineValidatingByteProcessor implements ByteProcessor { + private static final int SIZE = 0; + private static final int CHUNK_EXT_NAME = 1; + private static final int CHUNK_EXT_VAL_START = 2; + private static final int CHUNK_EXT_VAL_QUOTED = 3; + private static final int CHUNK_EXT_VAL_QUOTED_ESCAPE = 4; + private static final int CHUNK_EXT_VAL_QUOTED_END = 5; + private static final int CHUNK_EXT_VAL_TOKEN = 6; + + static final class Match extends BitSet { + private static final long serialVersionUID = 49522994383099834L; + private final int then; + + Match(int then) { + super(256); + this.then = then; + } + + Match chars(String chars) { + return chars(chars, true); + } + + Match chars(String chars, boolean value) { + for (int i = 0, len = chars.length(); i < len; i++) { + set(chars.charAt(i), value); + } + return this; + } + + Match range(int from, int to) { + return range(from, to, true); + } + + Match range(int from, int to, boolean value) { + for (int i = from; i <= to; i++) { + set(i, value); + } + return this; + } + } + + private enum State { + Size( + new Match(SIZE).chars("0123456789abcdefABCDEF \t"), + new Match(CHUNK_EXT_NAME).chars(";")), + ChunkExtName( + new Match(CHUNK_EXT_NAME) + .range(0x21, 0x7E) + .chars(" \t") + .chars("(),/:<=>?@[\\]{}", false), + new Match(CHUNK_EXT_VAL_START).chars("=")), + ChunkExtValStart( + new Match(CHUNK_EXT_VAL_START).chars(" \t"), + new Match(CHUNK_EXT_VAL_QUOTED).chars("\""), + new Match(CHUNK_EXT_VAL_TOKEN) + .range(0x21, 0x7E) + .chars("(),/:<=>?@[\\]{}\"", false)), + ChunkExtValQuoted( + new Match(CHUNK_EXT_VAL_QUOTED_ESCAPE).chars("\\"), + new Match(CHUNK_EXT_VAL_QUOTED_END).chars("\""), + new Match(CHUNK_EXT_VAL_QUOTED) + .chars("\t !") + .range(0x23, 0x5B) + .range(0x5D, 0x7E) + .range(0x80, 0xFF)), + ChunkExtValQuotedEscape( + new Match(CHUNK_EXT_VAL_QUOTED) + .chars("\t ") + .range(0x21, 0x7E) + .range(0x80, 0xFF)), + ChunkExtValQuotedEnd( + new Match(CHUNK_EXT_VAL_QUOTED_END).chars("\t "), + new Match(CHUNK_EXT_NAME).chars(";")), + ChunkExtValToken( + new Match(CHUNK_EXT_VAL_TOKEN) + .range(0x21, 0x7E, true) + .chars("(),/:<=>?@[\\]{};", false), + new Match(CHUNK_EXT_NAME).chars(";")), + ; + + private final Match[] matches; + + State(Match... matches) { + this.matches = matches; + } + + State match(byte value) { + for (Match match : matches) { + if (match.get(value)) { + return STATES_BY_ORDINAL[match.then]; + } + } + if (this == Size) { + throw new NumberFormatException("Invalid chunk size"); + } else { + throw new InvalidChunkExtensionException("Invalid chunk extension"); + } + } + } + + private static final State[] STATES_BY_ORDINAL = State.values(); + + private State state = State.Size; + + @Override + public boolean process(byte value) { + state = state.match(value); + return true; + } + + public void finish() { + switch (state) { + case ChunkExtValQuoted: + case ChunkExtValQuotedEscape: + case ChunkExtValStart: + throw new InvalidChunkExtensionException("Invalid chunk extension"); + } + // Exhaustiveness check + assert state == State.Size || + state == State.ChunkExtName || + state == State.ChunkExtValQuotedEnd || + state == State.ChunkExtValToken; + } +} diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpClientCodec.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpClientCodec.java index b34122779b0..89db71f8aef 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpClientCodec.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpClientCodec.java @@ -340,13 +340,6 @@ private void decrement(Object msg) { @Override protected boolean isContentAlwaysEmpty(HttpMessage msg) { - // Get the method of the HTTP request that corresponds to the - // current response. - // - // Even if we do not use the method to compare we still need to poll it to ensure we keep - // request / response pairs in sync. - HttpMethod method = queue.poll(); - final HttpResponseStatus status = ((HttpResponse) msg).status(); final HttpStatusClass statusClass = status.codeClass(); final int statusCode = status.code(); @@ -356,6 +349,10 @@ protected boolean isContentAlwaysEmpty(HttpMessage msg) { return super.isContentAlwaysEmpty(msg); } + // Get the method of the HTTP request that corresponds to the + // current response. + HttpMethod method = queue.poll(); + // If the remote peer did for example send multiple responses for one request (which is not allowed per // spec but may still be possible) method will be null so guard against it. if (method != null) { diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecompressor.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecompressor.java index 44e6195332d..f6fde488627 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecompressor.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecompressor.java @@ -104,7 +104,7 @@ protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Excep } if (Brotli.isAvailable() && BR.contentEqualsIgnoreCase(contentEncoding)) { return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), - ctx.channel().config(), new BrotliDecoder()); + ctx.channel().config(), new BrotliDecoder(maxAllocation)); } if (SNAPPY.contentEqualsIgnoreCase(contentEncoding)) { @@ -114,7 +114,7 @@ protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Excep if (Zstd.isAvailable() && ZSTD.contentEqualsIgnoreCase(contentEncoding)) { return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), - ctx.channel().config(), new ZstdDecoder()); + ctx.channel().config(), new ZstdDecoder(maxAllocation)); } // 'identity' or unsupported diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpDecoderConfig.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpDecoderConfig.java index 4d80801482e..26574352681 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpDecoderConfig.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpDecoderConfig.java @@ -35,6 +35,7 @@ public final class HttpDecoderConfig implements Cloneable { private int maxHeaderSize = HttpObjectDecoder.DEFAULT_MAX_HEADER_SIZE; private int initialBufferSize = HttpObjectDecoder.DEFAULT_INITIAL_BUFFER_SIZE; private boolean strictLineParsing = HttpObjectDecoder.DEFAULT_STRICT_LINE_PARSING; + private boolean useRfc9112TransferEncoding = HttpObjectDecoder.RFC9112_TRANSFER_ENCODING; public int getInitialBufferSize() { return initialBufferSize; @@ -231,13 +232,16 @@ public boolean isStrictLineParsing() { * security vulnerabilities, when multiple systems disagree on the meaning of leniently parsed messages. *

* When strict line parsing is enabled ({@code true}), then Netty will enforce that start- and header - * field-lines MUST be separated by a CR LF octet pair, and will produce messagas with failed + * field-lines MUST be separated by a CR LF octet pair, and will produce messages with failed * {@link io.netty.handler.codec.DecoderResult}s. + * Additionally, Netty will enforce that only CR LF characters precede the initial line, if any. *

* When strict line parsing is disabled ({@code false}), then Netty will accept lone LF octets as line - * seperators for the start- and header field-lines. + * separators for the start- and header field-lines. + * Additionally, Netty will ignore any ISO control and line separator characters prior to the initial line. *

- * See RFC 9112 Section 2.1. + * See RFC 9112 Section 2.1 and + * RFC 9112 Section 2.2. * @param strictLineParsing Whether strict line parsing should be enabled ({@code true}), * or not ({@code false}). * @return This decoder config. @@ -247,6 +251,28 @@ public HttpDecoderConfig setStrictLineParsing(boolean strictLineParsing) { return this; } + public boolean isUseRfc9112TransferEncoding() { + return useRfc9112TransferEncoding; + } + + /** + * The RFC 9112 specification is more strict than RFC 7230 with regards to having {@code Transfer-Encoding} and + * {@code Content-Length} headers in the same HTTP message. Senders are now forbidden from including both headers + * in the same message, while servers may reject such requests. When this setting is set to {@code true}, which + * is the default, then such messages will be rejected. + *

+ * When this setting is set to {@code false}, it restores the RFC 7230 behavior of instead removing any + * {@code Content-Length} headers when {@code Transfer-Encoding} headers are present. + * @param useRfc9112TransferEncoding Whether to reject messages with both {@code Transfer-Encoding} and + * {@code Content-Length} headers. + * @return This decoder config. + * @see HttpObjectDecoder#handleTransferEncodingChunkedWithContentLength(HttpMessage) + */ + public HttpDecoderConfig setUseRfc9112TransferEncoding(boolean useRfc9112TransferEncoding) { + this.useRfc9112TransferEncoding = useRfc9112TransferEncoding; + return this; + } + @Override public HttpDecoderConfig clone() { try { diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java index 6b09c1614b3..70a52848c2d 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java @@ -27,14 +27,34 @@ public final class HttpHeaderValues { */ public static final AsciiString APPLICATION_JSON = AsciiString.cached("application/json"); /** - * {@code "application/x-www-form-urlencoded"} + * {@code "application/manifest+json"} */ - public static final AsciiString APPLICATION_X_WWW_FORM_URLENCODED = - AsciiString.cached("application/x-www-form-urlencoded"); + public static final AsciiString APPLICATION_MANIFEST_JSON = AsciiString.cached("application/manifest+json"); /** * {@code "application/octet-stream"} */ public static final AsciiString APPLICATION_OCTET_STREAM = AsciiString.cached("application/octet-stream"); + /** + * {@code "application/ogg"} + */ + public static final AsciiString APPLICATION_OGG = AsciiString.cached("application/ogg"); + /** + * {@code "application/pdf"} + */ + public static final AsciiString APPLICATION_PDF = AsciiString.cached("application/pdf"); + /** + * {@code "application/rtf"} + */ + public static final AsciiString APPLICATION_RTF = AsciiString.cached("application/rtf"); + /** + * {@code "application/wasm"} + */ + public static final AsciiString APPLICATION_WASM = AsciiString.cached("application/wasm"); + /** + * {@code "application/x-www-form-urlencoded"} + */ + public static final AsciiString APPLICATION_X_WWW_FORM_URLENCODED = + AsciiString.cached("application/x-www-form-urlencoded"); /** * {@code "application/xhtml+xml"} */ @@ -52,6 +72,34 @@ public final class HttpHeaderValues { * See {@link HttpHeaderNames#CONTENT_DISPOSITION} */ public static final AsciiString ATTACHMENT = AsciiString.cached("attachment"); + /** + * {@code "audio/aac"} + */ + public static final AsciiString AUDIO_AAC = AsciiString.cached("audio/aac"); + /** + * {@code "audio/midi"} + */ + public static final AsciiString AUDIO_MIDI = AsciiString.cached("audio/midi"); + /** + * {@code "audio/x-midi"} + */ + public static final AsciiString AUDIO_X_MIDI = AsciiString.cached("audio/x-midi"); + /** + * {@code "audio/mpeg"} + */ + public static final AsciiString AUDIO_MPEG = AsciiString.cached("audio/mpeg"); + /** + * {@code "audio/ogg"} + */ + public static final AsciiString AUDIO_OGG = AsciiString.cached("audio/ogg"); + /** + * {@code "audio/wav"} + */ + public static final AsciiString AUDIO_WAV = AsciiString.cached("audio/wav"); + /** + * {@code "audio/webm"} + */ + public static final AsciiString AUDIO_WEBM = AsciiString.cached("audio/webm"); /** * {@code "base64"} */ @@ -106,6 +154,22 @@ public final class HttpHeaderValues { * See {@link HttpHeaderNames#CONTENT_DISPOSITION} */ public static final AsciiString FILENAME = AsciiString.cached("filename"); + /** + * {@code "font/otf"} + */ + public static final AsciiString FONT_OTF = AsciiString.cached("font/otf"); + /** + * {@code "font/ttf"} + */ + public static final AsciiString FONT_TTF = AsciiString.cached("font/ttf"); + /** + * {@code "font/woff"} + */ + public static final AsciiString FONT_WOFF = AsciiString.cached("font/woff"); + /** + * {@code "font/woff2"} + */ + public static final AsciiString FONT_WOFF2 = AsciiString.cached("font/woff2"); /** * {@code "form-data"} * See {@link HttpHeaderNames#CONTENT_DISPOSITION} @@ -141,6 +205,34 @@ public final class HttpHeaderValues { * {@code "identity"} */ public static final AsciiString IDENTITY = AsciiString.cached("identity"); + /** + * {@code "image/avif"} + */ + public static final AsciiString IMAGE_AVIF = AsciiString.cached("image/avif"); + /** + * {@code "image/bmp"} + */ + public static final AsciiString IMAGE_BMP = AsciiString.cached("image/bmp"); + /** + * {@code "image/jpeg"} + */ + public static final AsciiString IMAGE_JPEG = AsciiString.cached("image/jpeg"); + /** + * {@code "image/png"} + */ + public static final AsciiString IMAGE_PNG = AsciiString.cached("image/png"); + /** + * {@code "image/svg+xml"} + */ + public static final AsciiString IMAGE_SVG_XML = AsciiString.cached("image/svg+xml"); + /** + * {@code "image/tiff"} + */ + public static final AsciiString IMAGE_TIFF = AsciiString.cached("image/tiff"); + /** + * {@code "image/webp"} + */ + public static final AsciiString IMAGE_WEBP = AsciiString.cached("image/webp"); /** * {@code "keep-alive"} */ @@ -222,10 +314,22 @@ public final class HttpHeaderValues { * {@code "text/css"} */ public static final AsciiString TEXT_CSS = AsciiString.cached("text/css"); + /** + * {@code "text/csv"} + */ + public static final AsciiString TEXT_CSV = AsciiString.cached("text/csv"); /** * {@code "text/html"} */ public static final AsciiString TEXT_HTML = AsciiString.cached("text/html"); + /** + * {@code "text/javascript"} + */ + public static final AsciiString TEXT_JAVASCRIPT = AsciiString.cached("text/javascript"); + /** + * {@code "text/markdown"} + */ + public static final AsciiString TEXT_MARKDOWN = AsciiString.cached("text/markdown"); /** * {@code "text/event-stream"} */ @@ -242,6 +346,22 @@ public final class HttpHeaderValues { * {@code "upgrade"} */ public static final AsciiString UPGRADE = AsciiString.cached("upgrade"); + /** + * {@code "video/mp4"} + */ + public static final AsciiString VIDEO_MP4 = AsciiString.cached("video/mp4"); + /** + * {@code "video/mpeg"} + */ + public static final AsciiString VIDEO_MPEG = AsciiString.cached("video/mpeg"); + /** + * {@code "video/ogg"} + */ + public static final AsciiString VIDEO_OGG = AsciiString.cached("video/ogg"); + /** + * {@code "video/webm"} + */ + public static final AsciiString VIDEO_WEBM = AsciiString.cached("video/webm"); /** * {@code "websocket"} */ diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java index 1efd2c58b77..8aa4e5cea31 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java @@ -158,14 +158,14 @@ protected boolean isContentLengthInvalid(HttpMessage start, int maxContentLength } } - private static Object continueResponse(HttpMessage start, int maxContentLength, ChannelPipeline pipeline) { + private Object continueResponse(HttpMessage start, int maxContentLength, ChannelPipeline pipeline) { if (HttpUtil.isUnsupportedExpectation(start)) { // if the request contains an unsupported expectation, we return 417 pipeline.fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE); return EXPECTATION_FAILED.retainedDuplicate(); } else if (HttpUtil.is100ContinueExpected(start)) { // if the request contains 100-continue but the content-length is too large, we return 413 - if (getContentLength(start, -1L) <= maxContentLength) { + if (!isContentLengthInvalid(start, maxContentLength)) { return CONTINUE.retainedDuplicate(); } pipeline.fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE); @@ -247,7 +247,8 @@ protected void handleOversizedMessage(final ChannelHandlerContext ctx, HttpMessa // If the client started to send data already, close because it's impossible to recover. // If keep-alive is off and 'Expect: 100-continue' is missing, no need to leave the connection open. - if (oversized instanceof FullHttpMessage || + // If auto read is false the channel must be closed or it will be stuck without a call to read() + if (oversized instanceof FullHttpMessage || !ctx.channel().config().isAutoRead() || !HttpUtil.is100ContinueExpected(oversized) && !HttpUtil.isKeepAlive(oversized)) { ChannelFuture future = ctx.writeAndFlush(TOO_LARGE_CLOSE.retainedDuplicate()); future.addListener(new ChannelFutureListener() { diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java index 2f0d6c4fd72..64018fa71f4 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java @@ -27,7 +27,9 @@ import io.netty.util.ByteProcessor; import io.netty.util.internal.StringUtil; import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.ThrowableUtil; +import java.util.Iterator; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; @@ -154,6 +156,9 @@ public abstract class HttpObjectDecoder extends ByteToMessageDecoder { public static final boolean DEFAULT_ALLOW_DUPLICATE_CONTENT_LENGTHS = false; public static final boolean DEFAULT_STRICT_LINE_PARSING = SystemPropertyUtil.getBoolean("io.netty.handler.codec.http.defaultStrictLineParsing", true); + public static final String PROP_RFC9112_TRANSFER_ENCODING = "io.netty.handler.codec.http.rfc9112TransferEncoding"; + public static final boolean RFC9112_TRANSFER_ENCODING = + SystemPropertyUtil.getBoolean(PROP_RFC9112_TRANSFER_ENCODING, true); private static final Runnable THROW_INVALID_CHUNK_EXTENSION = new Runnable() { @Override @@ -168,6 +173,12 @@ public void run() { throw new InvalidLineSeparatorException(); } }; + private static final TransferEncodingNotAllowedException TRANSFER_ENCODING_NOT_ALLOWED = + ThrowableUtil.unknownStackTrace( + new TransferEncodingNotAllowedException( + "The Transfer-Encoding header is only allowed in HTTP/1.1 or newer"), + HttpObjectDecoder.class, + "readHeaders(ByteBuf)"); private final int maxChunkSize; private final boolean chunkedSupported; @@ -180,6 +191,7 @@ public void run() { protected final HttpHeadersFactory headersFactory; protected final HttpHeadersFactory trailersFactory; private final boolean allowDuplicateContentLengths; + private final boolean useRfc9112TransferEncoding; private final ByteBuf parserScratchBuffer; private final Runnable defaultStrictCRLFCheck; private final HeaderParser headerParser; @@ -212,6 +224,7 @@ protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { * Internal use only. */ private enum State { + SKIP_INITIAL_LINE_CHARS, SKIP_CONTROL_CHARS, READ_INITIAL, READ_HEADER, @@ -225,7 +238,7 @@ private enum State { UPGRADED } - private State currentState = State.SKIP_CONTROL_CHARS; + private State currentState = State.SKIP_INITIAL_LINE_CHARS; /** * Creates a new instance with the default @@ -344,6 +357,7 @@ protected HttpObjectDecoder(HttpDecoderConfig config) { validateHeaders = isValidating(headersFactory); allowDuplicateContentLengths = config.isAllowDuplicateContentLengths(); allowPartialChunks = config.isAllowPartialChunks(); + useRfc9112TransferEncoding = config.isUseRfc9112TransferEncoding(); } protected boolean isValidating(HttpHeadersFactory headersFactory) { @@ -361,7 +375,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List ou } switch (currentState) { - case SKIP_CONTROL_CHARS: + case SKIP_INITIAL_LINE_CHARS: // Fall-through case READ_INITIAL: try { ByteBuf line = lineParser.parse(buffer, defaultStrictCRLFCheck); @@ -477,6 +491,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List ou if (line == null) { return; } + checkChunkExtensions(line); int chunkSize = getChunkSize(line.array(), line.arrayOffset() + line.readerIndex(), line.readableBytes()); this.chunkSize = chunkSize; if (chunkSize == 0) { @@ -604,6 +619,7 @@ protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List ou resetNow(); return; case SKIP_CONTROL_CHARS: // fall-trough + case SKIP_INITIAL_LINE_CHARS: // fall-trough case READ_INITIAL:// fall-trough case BAD_MESSAGE: // fall-trough case UPGRADED: // fall-trough @@ -691,7 +707,7 @@ private void resetNow() { message = null; name = null; value = null; - contentLength = Long.MIN_VALUE; + clearContentLength(); chunked = false; lineParser.reset(); headerParser.reset(); @@ -703,7 +719,7 @@ private void resetNow() { } resetRequested.lazySet(false); - currentState = State.SKIP_CONTROL_CHARS; + currentState = State.SKIP_INITIAL_LINE_CHARS; } private HttpMessage invalidMessage(HttpMessage current, ByteBuf in, Exception cause) { @@ -723,6 +739,16 @@ private HttpMessage invalidMessage(HttpMessage current, ByteBuf in, Exception ca return current; } + private static void checkChunkExtensions(ByteBuf line) { + int extensionsStart = line.bytesBefore((byte) ';'); + if (extensionsStart == -1) { + return; + } + HttpChunkLineValidatingByteProcessor processor = new HttpChunkLineValidatingByteProcessor(); + line.forEachByte(processor); + processor.finish(); + } + private HttpContent invalidChunk(ByteBuf in, Exception cause) { currentState = State.BAD_MESSAGE; message = null; @@ -814,10 +840,36 @@ private State readHeaders(ByteBuf buffer) { HttpUtil.setTransferEncodingChunked(message, false); return State.SKIP_CONTROL_CHARS; } + if (message.headers().contains(HttpHeaderNames.TRANSFER_ENCODING) && + message.protocolVersion() != HttpVersion.HTTP_1_1 && + useRfc9112TransferEncoding) { + // The Transfer-Encoding header is not permitted at all with HTTP protocols older than 1.1, + // and such requests must be rejected. + throw TRANSFER_ENCODING_NOT_ALLOWED; + } if (HttpUtil.isTransferEncodingChunked(message)) { this.chunked = true; - if (!contentLengthFields.isEmpty() && message.protocolVersion() == HttpVersion.HTTP_1_1) { - handleTransferEncodingChunkedWithContentLength(message); + if (message.protocolVersion() == HttpVersion.HTTP_1_1) { + Iterator encodingIt = + message.headers().valueCharSequenceIterator(HttpHeaderNames.TRANSFER_ENCODING); + // Validate that chunked is the last encoding. + // See https://datatracker.ietf.org/doc/html/rfc9112#name-message-body-length + CharSequence v = null; + while (encodingIt.hasNext()) { + v = encodingIt.next(); + } + final int vLen = v.length(); + final int chunkedValueLength = HttpHeaderValues.CHUNKED.length(); + // We only need to validate if we have more then the chunked value length contained as otherwise + // we know it is only chunked. + if (vLen > chunkedValueLength && !AsciiString.regionMatches(v, true, vLen - chunkedValueLength, + HttpHeaderValues.CHUNKED, 0, chunkedValueLength)) { + throw new IllegalArgumentException( + "chunked must be the last encoding present in the Transfer-Encoding header"); + } + if (!contentLengthFields.isEmpty()) { + handleTransferEncodingChunkedWithContentLength(message); + } } return State.READ_CHUNK_SIZE; } @@ -829,27 +881,61 @@ private State readHeaders(ByteBuf buffer) { /** * Invoked when a message with both a "Transfer-Encoding: chunked" and a "Content-Length" header field is detected. - * The default behavior is to remove the Content-Length field, but this method could be overridden - * to change the behavior (to, e.g., throw an exception and produce an invalid message). + * The default behavior is to throw a {@link ContentLengthNotAllowedException} exception, but this method could + * be overridden to change the behavior (to, e.g., remove the {@code Content-Length} header value. *

- * See: https://tools.ietf.org/html/rfc7230#section-3.3.3 + * See: RFC 9112, Section 6.1-15. *

-     *     If a message is received with both a Transfer-Encoding and a
-     *     Content-Length header field, the Transfer-Encoding overrides the
-     *     Content-Length.  Such a message might indicate an attempt to
-     *     perform request smuggling (Section 9.5) or response splitting
-     *     (Section 9.4) and ought to be handled as an error.  A sender MUST
-     *     remove the received Content-Length field prior to forwarding such
-     *     a message downstream.
+     *     A server MAY reject a request that contains both Content-Length and Transfer-Encoding
+     *     or process such a request in accordance with the Transfer-Encoding alone.
+     *     Regardless, the server MUST close the connection after responding to such a request
+     *     to avoid the potential attacks.
      * 
- * Also see: - * https://github.com/apache/tomcat/blob/b693d7c1981fa7f51e58bc8c8e72e3fe80b7b773/ - * java/org/apache/coyote/http11/Http11Processor.java#L747-L755 - * https://github.com/nginx/nginx/blob/0ad4393e30c119d250415cb769e3d8bc8dce5186/ - * src/http/ngx_http_request.c#L1946-L1953 + * Since Netty itself cannot track the request/response pairing, it cannot guarantee that the connection is closed + * immediately after the response is sent. As such, it is safer to immediately reject the request. + *

+ * Note: RFC 7230 (the previous HTTP/1.1 RFC) allowed the {@code Content-Length} header to simply + * be ignored, in the presence of a {@code Transfer-Encoding} header, but this practice is now obsolete + * and considered unsafe. + * The RFC 7230 behavior can be restored in the following ways: + *

    + *
  • + * Process-wide, by setting the {@value PROP_RFC9112_TRANSFER_ENCODING} system property to {@code false}. + *
  • + *
  • + * Configured for a specific decoder, by setting + * {@link HttpDecoderConfig#setUseRfc9112TransferEncoding(boolean)} to {@code false}. + *
  • + *
  • + * Hard-coded for a specific decoder, by overriding this method with an implementation like the following: + *
    {@code
    +     * @Override
    +     * protected void handleTransferEncodingChunkedWithContentLength(HttpMessage message) {
    +     *     clearContentLength();
    +     *     message.headers().remove(HttpHeaderNames.CONTENT_LENGTH);
    +     * }
    +     *         }
    + *
  • + *
+ *

+ * Note: This method is only called for {@code HTTP/1.1} requests. Earlier HTTP protocol versions + * do not support the {@code Transfer-Encoding} header, and will reject requests that include it. */ + @SuppressWarnings("unused") protected void handleTransferEncodingChunkedWithContentLength(HttpMessage message) { - message.headers().remove(HttpHeaderNames.CONTENT_LENGTH); + clearContentLength(); + if (useRfc9112TransferEncoding) { + throw new ContentLengthNotAllowedException( + "Content-Length are not allowed in HTTP/1.1 messages that contains a Transfer-Encoding header."); + } else { + message.headers().remove(HttpHeaderNames.CONTENT_LENGTH); + if (isDecodingRequest()) { + HttpUtil.setKeepAlive(message, false); + } + } + } + + protected final void clearContentLength() { contentLength = Long.MIN_VALUE; } @@ -867,7 +953,6 @@ private LastHttpContent readTrailingHeaders(ByteBuf buffer) { return LastHttpContent.EMPTY_LAST_CONTENT; } - CharSequence lastHeader = null; if (trailer == null) { trailer = this.trailer = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, trailersFactory); } @@ -875,29 +960,19 @@ private LastHttpContent readTrailingHeaders(ByteBuf buffer) { final byte[] lineContent = line.array(); final int startLine = line.arrayOffset() + line.readerIndex(); final byte firstChar = lineContent[startLine]; - if (lastHeader != null && (firstChar == ' ' || firstChar == '\t')) { - List current = trailer.trailingHeaders().getAll(lastHeader); - if (!current.isEmpty()) { - int lastPos = current.size() - 1; - //please do not make one line from below code - //as it breaks +XX:OptimizeStringConcat optimization - String lineTrimmed = langAsciiString(lineContent, startLine, line.readableBytes()).trim(); - String currentLastPos = current.get(lastPos); - current.set(lastPos, currentLastPos + lineTrimmed); - } + if (name != null && (firstChar == ' ' || firstChar == '\t')) { + //please do not make one line from below code + //as it breaks +XX:OptimizeStringConcat optimization + String trimmedLine = langAsciiString(lineContent, startLine, lineLength).trim(); + String valueStr = value; + value = valueStr + ' ' + trimmedLine; } else { - splitHeader(lineContent, startLine, lineLength); - AsciiString headerName = name; - if (!HttpHeaderNames.CONTENT_LENGTH.contentEqualsIgnoreCase(headerName) && - !HttpHeaderNames.TRANSFER_ENCODING.contentEqualsIgnoreCase(headerName) && - !HttpHeaderNames.TRAILER.contentEqualsIgnoreCase(headerName)) { - trailer.trailingHeaders().add(headerName, value); + if (name != null && isPermittedTrailingHeader(name)) { + trailer.trailingHeaders().add(name, value); } - lastHeader = name; - // reset name and value fields - name = null; - value = null; + splitHeader(lineContent, startLine, lineLength); } + line = headerParser.parse(buffer, defaultStrictCRLFCheck); if (line == null) { return null; @@ -905,10 +980,28 @@ private LastHttpContent readTrailingHeaders(ByteBuf buffer) { lineLength = line.readableBytes(); } + // Add the last trailer + if (name != null && isPermittedTrailingHeader(name)) { + trailer.trailingHeaders().add(name, value); + } + + // reset name and value fields + name = null; + value = null; + this.trailer = null; return trailer; } + /** + * Checks whether the given trailer field name is permitted per RFC 9110 section 6.5 + */ + private static boolean isPermittedTrailingHeader(final AsciiString name) { + return !HttpHeaderNames.CONTENT_LENGTH.contentEqualsIgnoreCase(name) && + !HttpHeaderNames.TRANSFER_ENCODING.contentEqualsIgnoreCase(name) && + !HttpHeaderNames.TRAILER.contentEqualsIgnoreCase(name); + } + protected abstract boolean isDecodingRequest(); protected abstract HttpMessage createMessage(String[] initialLine) throws Exception; protected abstract HttpMessage createInvalidMessage(); @@ -926,7 +1019,7 @@ private static int skipWhiteSpaces(byte[] hex, int start, int length) { } private static int getChunkSize(byte[] hex, int start, int length) { - // trim the leading bytes if white spaces, if any + // trim the leading bytes of white spaces, if any final int skipped = skipWhiteSpaces(hex, start, length); if (skipped == length) { // empty case @@ -934,7 +1027,7 @@ private static int getChunkSize(byte[] hex, int start, int length) { } start += skipped; length -= skipped; - int result = 0; + long result = 0; for (int i = 0; i < length; i++) { final int digit = StringUtil.decodeHexNibble(hex[start + i]); if (digit == -1) { @@ -945,18 +1038,18 @@ private static int getChunkSize(byte[] hex, int start, int length) { // empty case throw new NumberFormatException("Empty chunk size"); } - return result; + return (int) result; } // non-hex char fail-fast path throw new NumberFormatException("Invalid character in chunk size"); } result *= 16; result += digit; - if (result < 0) { + if (result > Integer.MAX_VALUE) { throw new NumberFormatException("Chunk size overflow: " + result); } } - return result; + return (int) result; } private String[] splitInitialLine(ByteBuf asciiBuffer) { @@ -1239,26 +1332,33 @@ public ByteBuf parse(ByteBuf buffer, Runnable strictCRLFCheck) { if (readableBytes == 0) { return null; } - if (currentState == State.SKIP_CONTROL_CHARS && - skipControlChars(buffer, readableBytes, buffer.readerIndex())) { + if (currentState == State.SKIP_INITIAL_LINE_CHARS && + skipLineChars(buffer, readableBytes, buffer.readerIndex(), strictCRLFCheck)) { return null; } return super.parse(buffer, strictCRLFCheck); } - private boolean skipControlChars(ByteBuf buffer, int readableBytes, int readerIndex) { - assert currentState == State.SKIP_CONTROL_CHARS; + private boolean skipLineChars(ByteBuf buffer, int readableBytes, int readerIndex, Runnable strictCRLFCheck) { + assert currentState == State.SKIP_INITIAL_LINE_CHARS; final int maxToSkip = Math.min(maxLength, readableBytes); - final int firstNonControlIndex = buffer.forEachByte(readerIndex, maxToSkip, SKIP_CONTROL_CHARS_BYTES); - if (firstNonControlIndex == -1) { + final int firstNonLineIndex = buffer.forEachByte(readerIndex, maxToSkip, + strictCRLFCheck == null ? SKIP_CONTROL_CHARS_BYTES : ByteProcessor.FIND_NON_CRLF); + if (firstNonLineIndex == -1) { buffer.skipBytes(maxToSkip); if (readableBytes > maxLength) { throw newException(maxLength); } return true; } + if (strictCRLFCheck != null) { + final int b = buffer.getByte(firstNonLineIndex) & 0xFF; + if (Character.isISOControl(b)) { + strictCRLFCheck.run(); + } + } // from now on we don't care about control chars - buffer.readerIndex(firstNonControlIndex); + buffer.readerIndex(firstNonLineIndex); currentState = State.READ_INITIAL; return false; } @@ -1279,7 +1379,6 @@ protected TooLongFrameException newException(int maxLength) { } private static final ByteProcessor SKIP_CONTROL_CHARS_BYTES = new ByteProcessor() { - @Override public boolean process(byte value) { return ISO_CONTROL_OR_WHITESPACE[128 + value]; diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java index 55a4e3dae23..9e1526d0a96 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java @@ -16,7 +16,9 @@ package io.netty.handler.codec.http; import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; import io.netty.channel.CombinedChannelDuplexHandler; import java.util.ArrayDeque; @@ -52,6 +54,11 @@ public final class HttpServerCodec extends CombinedChannelDuplexHandler queue = new ArrayDeque(); + /** + * When set, the connection will be closed after the next response is written. + */ + private boolean mustCloseAfterResponse; + /** * Creates a new instance with the default decoder options * ({@code maxInitialLineLength (4096)}, {@code maxHeaderSize (8192)}, and @@ -173,12 +180,27 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List ou } } } + + @Override + protected void handleTransferEncodingChunkedWithContentLength(HttpMessage message) { + super.handleTransferEncodingChunkedWithContentLength(message); + mustCloseAfterResponse = true; + } } private final class HttpServerResponseEncoder extends HttpResponseEncoder { private HttpMethod method; + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (mustCloseAfterResponse && msg instanceof LastHttpContent) { + mustCloseAfterResponse = false; + promise = promise.unvoid().addListener(ChannelFutureListener.CLOSE); + } + super.write(ctx, msg, promise); + } + @Override protected void sanitizeHeadersBeforeEncode(HttpResponse msg, boolean isAlwaysEmpty) { if (!isAlwaysEmpty && HttpMethod.CONNECT.equals(method) diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java index 409718628b4..643f79a9757 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java @@ -45,7 +45,7 @@ private HttpUtil() { } /** * Determine if a uri is in origin-form according to - * rfc7230, 5.3. + * RFC 9112, 3.2.1. */ public static boolean isOriginForm(URI uri) { return isOriginForm(uri.toString()); @@ -53,7 +53,7 @@ public static boolean isOriginForm(URI uri) { /** * Determine if a string uri is in origin-form according to - * rfc7230, 5.3. + * RFC 9112, 3.2.1. */ public static boolean isOriginForm(String uri) { return uri.startsWith("/"); @@ -61,7 +61,7 @@ public static boolean isOriginForm(String uri) { /** * Determine if a uri is in asterisk-form according to - * rfc7230, 5.3. + * RFC 9112, 3.2.4. */ public static boolean isAsteriskForm(URI uri) { return isAsteriskForm(uri.toString()); @@ -69,16 +69,59 @@ public static boolean isAsteriskForm(URI uri) { /** * Determine if a string uri is in asterisk-form according to - * rfc7230, 5.3. + * RFC 9112, 3.2.4. */ public static boolean isAsteriskForm(String uri) { return "*".equals(uri); } + static void validateRequestLineTokens(HttpMethod method, String uri) { + // The HttpVersion class does its own validation, and it's not possible for subclasses to circumvent it. + // The HttpMethod class does its own validation, but subclasses might circumvent it. + if (method.getClass() != HttpMethod.class) { + if (!isEncodingSafeStartLineToken(method.asciiName())) { + throw new IllegalArgumentException( + "The HTTP method name contain illegal characters: " + method.asciiName()); + } + } + + if (!isEncodingSafeStartLineToken(uri)) { + throw new IllegalArgumentException("The URI contain illegal characters: " + uri); + } + } + /** - * Returns {@code true} if and only if the connection can remain open and - * thus 'kept alive'. This methods respects the value of the. + * Validate that the given request line token is safe for verbatim encoding to the network. + * This does not fully check that the token – HTTP method, version, or URI – is valid and formatted correctly. + * Only that the token does not contain characters that would break or + * desynchronize HTTP message parsing of the start line wherein the token would be included. + *

+ * See RFC 9112, 3. * + * @param token The token to check. + * @return {@code true} if the token is safe to encode verbatim into the HTTP message output stream, + * otherwise {@code false}. + */ + public static boolean isEncodingSafeStartLineToken(CharSequence token) { + int lenBytes = token.length(); + for (int i = 0; i < lenBytes; i++) { + char ch = token.charAt(i); + // this is to help AOT compiled code which cannot profile the switch + if (ch <= ' ') { + switch (ch) { + case '\n': + case '\r': + case ' ': + return false; + } + } + } + return true; + } + + /** + * Returns {@code true} if and only if the connection can remain open and + * thus 'kept alive'. This method respects the value of the * {@code "Connection"} header first and then the return value of * {@link HttpVersion#isKeepAliveDefault()}. */ @@ -676,8 +719,10 @@ private static int validateAsciiStringToken(AsciiString token) { */ private static int validateCharSequenceToken(CharSequence token) { for (int i = 0, len = token.length(); i < len; i++) { - byte value = (byte) token.charAt(i); - if (!isValidTokenChar(value)) { + int value = token.charAt(i); + // 1. Check for truncation (anything above 255) + // 2. Check against the BitSet (isValidTokenChar handles 128-255 via bit < 0) + if (value > 0xFF || !isValidTokenChar((byte) value)) { return i; } } @@ -761,18 +806,17 @@ private static int validateCharSequenceToken(CharSequence token) { // .bits('-', '.', '_', '~') // Unreserved characters. // .bits('!', '#', '$', '%', '&', '\'', '*', '+', '^', '`', '|'); // Token special characters. - //this constants calculated by the above code + // This constants calculated by the above code private static final long TOKEN_CHARS_HIGH = 0x57ffffffc7fffffeL; private static final long TOKEN_CHARS_LOW = 0x3ff6cfa00000000L; - private static boolean isValidTokenChar(byte bit) { - if (bit < 0) { + static boolean isValidTokenChar(byte octet) { + if (octet < 0) { return false; } - if (bit < 64) { - return 0 != (TOKEN_CHARS_LOW & 1L << bit); + if (octet < 64) { + return 0 != (TOKEN_CHARS_LOW & 1L << octet); } - return 0 != (TOKEN_CHARS_HIGH & 1L << bit - 64); + return 0 != (TOKEN_CHARS_HIGH & 1L << octet - 64); } - } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java index f2af9b2916a..aa41143b566 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java @@ -24,6 +24,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.Locale; /** * The version of HTTP or its derived protocols, such as @@ -124,7 +125,10 @@ public HttpVersion(String text, boolean keepAliveDefault) { } HttpVersion(String text, boolean strict, boolean keepAliveDefault) { - text = checkNonEmptyAfterTrim(text, "text").toUpperCase(); + // toUpperCase() without an explicit Locale uses the JVM default. In Turkish locale + // (tr_TR) 'i' uppercases to 'İ' (U+0130), which would corrupt protocol strings such + // as "icap/1.0" or any custom HTTP-derived scheme that contains a lowercase 'i'. + text = checkNonEmptyAfterTrim(text, "text").toUpperCase(Locale.US); if (strict) { // Only single digit major / minor version is allowed. @@ -181,7 +185,9 @@ public HttpVersion( private HttpVersion( String protocolName, int majorVersion, int minorVersion, boolean keepAliveDefault, boolean bytes) { - protocolName = checkNonEmptyAfterTrim(protocolName, "protocolName").toUpperCase(); + // See the comment in the (text, strict, keepAliveDefault) constructor for why this needs + // an explicit Locale.US: avoids the Turkish-locale 'i' -> 'İ' corruption. + protocolName = checkNonEmptyAfterTrim(protocolName, "protocolName").toUpperCase(Locale.US); for (int i = 0; i < protocolName.length(); i ++) { if (Character.isISOControl(protocolName.charAt(i)) || diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/TransferEncodingNotAllowedException.java b/codec-http/src/main/java/io/netty/handler/codec/http/TransferEncodingNotAllowedException.java new file mode 100644 index 00000000000..8ece189f7d5 --- /dev/null +++ b/codec-http/src/main/java/io/netty/handler/codec/http/TransferEncodingNotAllowedException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.DecoderException; + +/** + * Thrown by {@link HttpObjectDecoder} when an HTTP message uses a protocol version older than {@code HTTP/1.1} + * and includes an {@code Transfer-Encoding} header. + */ +public final class TransferEncodingNotAllowedException extends DecoderException { + /** + * Create a new instance with the given message. + * @param message The exception message. + */ + public TransferEncodingNotAllowedException(String message) { + super(message); + } +} diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java index 75e958c5ad7..45d5e0ffb13 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java @@ -22,6 +22,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpHeaders; @@ -29,6 +30,7 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpUtil; +import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -58,6 +60,7 @@ public class CorsHandler extends ChannelDuplexHandler { private HttpRequest request; private final List configList; private final boolean isShortCircuit; + private boolean consumeContent; /** * Creates a new instance with a single {@link CorsConfig}. @@ -87,13 +90,28 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) throw config = getForOrigin(origin); if (isPreflightRequest(request)) { handlePreflight(ctx, request); + // Enable consumeContent so that all following HttpContent + // for this request will be released and not propagated downstream. + consumeContent = true; return; } if (isShortCircuit && !(origin == null || config != null)) { forbidden(ctx, request); + consumeContent = true; return; } + + // This request is forwarded, stop discarding + consumeContent = false; + ctx.fireChannelRead(msg); + return; + } + + if (consumeContent && (msg instanceof HttpContent)) { + ReferenceCountUtil.release(msg); + return; } + ctx.fireChannelRead(msg); } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java index e90d4068220..14962c7cd96 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java @@ -41,6 +41,7 @@ import java.nio.charset.UnsupportedCharsetException; import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.TreeMap; @@ -919,7 +920,11 @@ protected InterfaceHttpData getFileUpload(String delimiter) { if (encoding != null) { String code; try { - code = encoding.getValue().toLowerCase(); + // RFC 2045 Content-Transfer-Encoding values are case-insensitive ASCII tokens. + // toLowerCase() without a Locale would corrupt them under Turkish (tr_TR) locale, + // where 'I' lowercases to 'ı' (U+0131) and "BINARY" becomes "bınary" - never + // matching the lowercase ASCII constants compared against below. + code = encoding.getValue().toLowerCase(Locale.US); } catch (IOException e) { throw new ErrorDataDecoderException(e); } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoder.java index 8fcd4cdf225..28925b6d551 100755 --- a/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoder.java @@ -48,6 +48,8 @@ import java.util.ListIterator; import java.util.Map; import java.util.regex.Pattern; +import java.util.Locale; +import java.util.concurrent.ThreadLocalRandom; import static io.netty.buffer.Unpooled.wrappedBuffer; import static io.netty.util.internal.ObjectUtil.checkNotNull; @@ -760,7 +762,12 @@ public HttpRequest finalizeRequest() throws ErrorDataEncoderException { headers.remove(HttpHeaderNames.CONTENT_TYPE); for (String contentType : contentTypes) { // "multipart/form-data; boundary=--89421926422648" - String lowercased = contentType.toLowerCase(); + // toLowerCase() without a Locale would corrupt the comparison under Turkish + // (tr_TR) locale, where 'I' -> 'ı' (U+0131): a request that sets + // Content-Type: MULTIPART/form-data would lowercase to "multıpart/form-data" and + // miss the prefix check, leaving the original header in place alongside the + // multipart one this encoder is about to add. + String lowercased = contentType.toLowerCase(Locale.US); if (lowercased.startsWith(HttpHeaderValues.MULTIPART_FORM_DATA.toString()) || lowercased.startsWith(HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED.toString())) { // ignore diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java index 01f1c0036c1..a898439fe22 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java @@ -21,7 +21,7 @@ import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -53,7 +53,7 @@ public static List extractExtensions(String extensionHea String name = extensionParameters[0].trim(); Map parameters; if (extensionParameters.length > 1) { - parameters = new HashMap(extensionParameters.length - 1); + parameters = new LinkedHashMap(extensionParameters.length - 1); for (int i = 1; i < extensionParameters.length; i++) { String parameter = extensionParameters[i].trim(); Matcher parameterMatcher = PARAMETER.matcher(parameter); @@ -93,7 +93,7 @@ static String computeMergeExtensionsHeaderValue(String userDefinedHeaderValue, extraExtensions.add(userDefined); } else { // merge with higher precedence to user defined parameters - Map mergedParameters = new HashMap(matchingExtra.parameters()); + Map mergedParameters = new LinkedHashMap(matchingExtra.parameters()); mergedParameters.putAll(userDefined.parameters()); extraExtensions.set(i, new WebSocketExtensionData(matchingExtra.name(), mergedParameters)); } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java index 944f36e50b4..972b4c9e7ef 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java @@ -232,10 +232,16 @@ public WebSocketClientExtension handshakeExtension(WebSocketExtensionData extens if (CLIENT_MAX_WINDOW.equalsIgnoreCase(parameter.getKey())) { // allowed client_window_size_bits if (allowClientWindowSize) { - clientWindowSize = Integer.parseInt(parameter.getValue()); - if (clientWindowSize > MAX_WINDOW_SIZE || clientWindowSize < MIN_WINDOW_SIZE) { - succeed = false; + // RFC 7692: client_max_window_bits may have a value or no value + String value = parameter.getValue(); + if (value != null) { + // Let NumberFormatException bubble up if value is invalid + clientWindowSize = Integer.parseInt(value); + if (clientWindowSize > MAX_WINDOW_SIZE || clientWindowSize < MIN_WINDOW_SIZE) { + succeed = false; + } } + // If value is null, keep MAX_WINDOW_SIZE (default) } else { succeed = false; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java index ce19476f403..9aeb219b142 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java @@ -220,8 +220,18 @@ public WebSocketServerExtension handshakeExtension(WebSocketExtensionData extens Entry parameter = parametersIterator.next(); if (CLIENT_MAX_WINDOW.equalsIgnoreCase(parameter.getKey())) { - // use preferred clientWindowSize because client is compatible with customization - clientWindowSize = preferredClientWindowSize; + // RFC 7692: client_max_window_bits may have a value or no value + String value = parameter.getValue(); + if (value != null) { + // Let NumberFormatException bubble up if value is invalid + clientWindowSize = Integer.parseInt(value); + if (clientWindowSize > MAX_WINDOW_SIZE || clientWindowSize < MIN_WINDOW_SIZE) { + deflateEnabled = false; + } + } else { + // No value specified, use preferred client window size + clientWindowSize = preferredClientWindowSize; + } } else if (SERVER_MAX_WINDOW.equalsIgnoreCase(parameter.getKey())) { // use provided windowSize if it is allowed if (allowServerWindowSize) { diff --git a/codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspMethods.java b/codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspMethods.java index a0dca7618f8..a546b36225f 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspMethods.java +++ b/codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspMethods.java @@ -20,6 +20,7 @@ import io.netty.handler.codec.http.HttpMethod; import java.util.HashMap; +import java.util.Locale; import java.util.Map; /** @@ -119,7 +120,10 @@ public final class RtspMethods { * will be returned. Otherwise, a new instance will be returned. */ public static HttpMethod valueOf(String name) { - name = checkNonEmptyAfterTrim(name, "name").toUpperCase(); + // RFC 2326 RTSP method names are ASCII tokens. toUpperCase() without an explicit Locale + // uses the JVM default, which in Turkish (tr_TR) maps 'i' to 'İ' (U+0130) and breaks the + // lookup of methods such as "describe" or "redirect" against the cached uppercase keys. + name = checkNonEmptyAfterTrim(name, "name").toUpperCase(Locale.US); HttpMethod result = methodMap.get(name); if (result != null) { return result; diff --git a/codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspVersions.java b/codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspVersions.java index 92831fb6b1f..9789ac310a2 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspVersions.java +++ b/codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspVersions.java @@ -18,6 +18,8 @@ import io.netty.handler.codec.http.HttpVersion; import io.netty.util.internal.ObjectUtil; +import java.util.Locale; + /** * The version of RTSP. */ @@ -37,7 +39,9 @@ public final class RtspVersions { public static HttpVersion valueOf(String text) { ObjectUtil.checkNotNull(text, "text"); - text = text.trim().toUpperCase(); + // toUpperCase() must specify Locale.US so the comparison against "RTSP/1.0" is not + // affected by the JVM default locale (e.g. Turkish, where 'i' uppercases to 'İ'). + text = text.trim().toUpperCase(Locale.US); if ("RTSP/1.0".equals(text)) { return RTSP_1_0; } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java index 9ddb597ae9c..e47ce808644 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java @@ -17,12 +17,267 @@ import io.netty.util.AsciiString; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.Arrays; +import java.util.List; +import java.util.SplittableRandom; +import java.util.function.LongFunction; +import java.util.stream.Stream; import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class DefaultHttpRequestTest { + public static List invalidUris() { + return Arrays.asList( + "http://localhost/\r\n", + "/r\r\n?q=1", + "http://localhost/\r\n?q=1", + "/r\r\n/?q=1", + "http://localhost/\r\n/?q=1", + "/r\r\n", + "http://localhost/ HTTP/1.1\r\n\r\nPOST /p HTTP/1.1\r\n\r\n", + "/r HTTP/1.1\r\n\r\nPOST /p HTTP/1.1\r\n\r\n", + "/ path", + "/path ", + " /path", + "http://localhost/ ", + " http://localhost/", + "http://local host/" + ); + } + + public static List invalidMethods() { + return Arrays.asList( + "GET ", + " GET", + "G ET", + " GET ", + "GET\r", + "GET\n", + "GET\r\n", + "GE\rT", + "GE\nT", + "GE\r\nT", + "\rGET", + "\nGET", + "\r\nGET", + " \r\nGET", + "\r \nGET", + "\r\n GET", + "\r\nGET ", + "\nGET ", + "\rGET ", + "\r GET", + " \rGET", + "\nGET ", + "\n GET", + " \nGET", + "GET \n", + "GET \r", + " GET\r", + " GET\r", + "GET \n", + " GET\n", + " GET\n", + "GE\nT ", + "GE\rT ", + " GE\rT", + " GE\rT", + "GE\nT ", + " GE\nT", + " GE\nT" + ); + } + + public static Stream validUris() { + final String pdigit = "123456789"; + final String digit = '0' + pdigit; + final String digitcolon = digit + ':'; + final String alpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + final String alphanum = alpha + digit; + final String alphanumdot = alphanum + '.'; + final String unreserved = alphanumdot + "-_~"; + final String subdelims = "$&%=!+,;'()"; + final String userinfochars = unreserved + subdelims + ':'; + final String pathchars = unreserved + '/'; + final String querychars = pathchars + subdelims + '?'; + return new SplittableRandom().longs(1000) + .mapToObj(new LongFunction() { + @Override + public String apply(long seed) { + SplittableRandom rng = new SplittableRandom(seed); + String start; + String path; + String query; + String fragment; + if (rng.nextBoolean()) { + String scheme = rng.nextBoolean() ? "http://" : "HTTP://"; + String userinfo = rng.nextBoolean() ? "" : pick(rng, userinfochars, 1, 8) + '@'; + String host; + String port; + switch (rng.nextInt(3)) { + case 0: + host = pick(rng, alphanum, 1, 1) + pick(rng, alphanumdot, 1, 5); + break; + case 1: + host = pick(rng, pdigit, 1, 1) + pick(rng, digit, 0, 2) + '.' + + pick(rng, pdigit, 1, 1) + pick(rng, digit, 0, 2) + '.' + + pick(rng, pdigit, 1, 1) + pick(rng, digit, 0, 2) + '.' + + pick(rng, pdigit, 1, 1) + pick(rng, digit, 0, 2); + break; + default: + host = '[' + pick(rng, digitcolon, 1, 8) + ']'; + break; + } + if (rng.nextBoolean()) { + port = ':' + pick(rng, pdigit, 1, 1) + pick(rng, digit, 0, 4); + } else { + port = ""; + } + start = scheme + userinfo + host + port; + } else { + start = ""; + } + path = '/' + pick(rng, pathchars, 0, 8); + if (rng.nextBoolean()) { + query = '?' + pick(rng, querychars, 0, 8); + } else { + query = ""; + } + if (rng.nextBoolean()) { + fragment = '#' + pick(rng, querychars, 0, 8); + } else { + fragment = ""; + } + return start + path + query + fragment; + } + }); + } + + public static List validMethods() { + return Arrays.asList("GET", + "POST", + "PUT", + "HEAD", + "DELETE", + "OPTIONS", + "CONNECT", + "TRACE", + "PATCH", + "QUERY"); + } + + private static String pick(SplittableRandom rng, String cs, int lowerBound, int upperBound) { + int length = rng.nextInt(lowerBound, upperBound + 1); + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append(cs.charAt(rng.nextInt(cs.length()))); + } + return sb.toString(); + } + + @ParameterizedTest + @MethodSource("invalidUris") + void constructorMustRejectIllegalUrisByDefault(final String uri) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + } + }); + } + + @ParameterizedTest + @MethodSource("invalidUris") + void setUriMustRejectIllegalUrisByDefault(final String uri) { + final DefaultHttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + request.setUri(uri); + } + }); + } + + @ParameterizedTest + @MethodSource("validUris") + void constructorMustAcceptValidUris(String uri) { + new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + } + + @ParameterizedTest + @MethodSource("validUris") + void setUriMustAcceptValidUris(String uri) { + new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/").setUri(uri); + } + + @ParameterizedTest + @MethodSource("invalidMethods") + void constructorMustRejectIllegalHttpMethodByDefault(final String method) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + new DefaultHttpRequest(HttpVersion.HTTP_1_0, + new HttpMethod("GET") { + @Override + public AsciiString asciiName() { + return new AsciiString(method); + } + }, "/"); + } + }); + } + + @ParameterizedTest + @MethodSource("invalidMethods") + void setMethodMustRejectIllegalHttpMethodByDefault(final String method) { + final DefaultHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + request.setMethod(new HttpMethod("GET") { + @Override + public AsciiString asciiName() { + return new AsciiString(method); + } + }); + } + }); + } + + @ParameterizedTest + @MethodSource("validMethods") + void constructorMustAcceptAllHttpMethods(final String method) { + new DefaultHttpRequest(HttpVersion.HTTP_1_0, new HttpMethod("GET") { + @Override + public AsciiString asciiName() { + return new AsciiString(method); + } + }, "/"); + + new DefaultHttpRequest(HttpVersion.HTTP_1_0, new HttpMethod(method), "/"); + } + + @ParameterizedTest + @MethodSource("validMethods") + void setMethodMustAcceptAllHttpMethods(final String method) { + DefaultHttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + + request.setMethod(new HttpMethod("GET") { + @Override + public AsciiString asciiName() { + return new AsciiString(method); + } + }); + + request.setMethod(new HttpMethod(method)); + } @Test public void testHeaderRemoval() { diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpClientCodecTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpClientCodecTest.java index fd0037111fd..da674ceef20 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpClientCodecTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpClientCodecTest.java @@ -360,19 +360,22 @@ public void testWebDavResponse() { @Test public void testInformationalResponseKeepsPairsInSync() { - byte[] data = ("HTTP/1.1 102 Processing\r\n" + + String data = "HTTP/1.1 102 Processing\r\n" + "Status-URI: Status-URI:http://status.com; 404\r\n" + - "\r\n").getBytes(); - byte[] data2 = ("HTTP/1.1 200 OK\r\n" + + "\r\n" + + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 5\r\n" + + "\r\n"; // No contents; we're responding to a HEAD request. + String data2 = "HTTP/1.1 200 OK\r\n" + "Content-Length: 8\r\n" + "\r\n" + - "12345678").getBytes(); + "12345678"; EmbeddedChannel ch = new EmbeddedChannel(new HttpClientCodec()); assertTrue(ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.HEAD, "/"))); ByteBuf buffer = ch.readOutbound(); buffer.release(); assertNull(ch.readOutbound()); - assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data))); + assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data.getBytes(CharsetUtil.ISO_8859_1)))); HttpResponse res = ch.readInbound(); assertSame(HttpVersion.HTTP_1_1, res.protocolVersion()); assertEquals(HttpResponseStatus.PROCESSING, res.status()); @@ -381,12 +384,21 @@ public void testInformationalResponseKeepsPairsInSync() { assertEquals(0, content.content().readableBytes()); assertInstanceOf(LastHttpContent.class, content); content.release(); + res = ch.readInbound(); + assertSame(HttpVersion.HTTP_1_1, res.protocolVersion()); + assertEquals(HttpResponseStatus.OK, res.status()); + // If it had not been a HEAD request, server *would* have sent 5 bytes of contents... + assertEquals(5, res.headers().getInt(HttpHeaderNames.CONTENT_LENGTH)); + content = ch.readInbound(); + // ... but it is a HEAD request, so we get zero bytes. + assertEquals(0, content.content().readableBytes()); + content.release(); assertTrue(ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"))); buffer = ch.readOutbound(); buffer.release(); assertNull(ch.readOutbound()); - assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data2))); + assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data2.getBytes(CharsetUtil.ISO_8859_1)))); res = ch.readInbound(); assertSame(HttpVersion.HTTP_1_1, res.protocolVersion()); @@ -400,6 +412,51 @@ public void testInformationalResponseKeepsPairsInSync() { assertFalse(ch.finish()); } + @Test + public void testInformationalFollowedByResponse() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpClientCodec()); + + assertTrue(channel.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/1"))); + ByteBuf request = channel.readOutbound(); + request.release(); + assertNull(channel.readOutbound()); + + assertTrue(channel.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.HEAD, "/2"))); + request = channel.readOutbound(); + request.release(); + assertNull(channel.readOutbound()); + + String responseStr = + "HTTP/1.1 103 Early Hints\r\n\r\n" + // Early response to first GET request + "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello" + // Actual response to first GET request + "HTTP/1.1 200 OK\r\n\r\n"; // Body-less response to second HEAD request + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + + // Response 1: Early hints to first GET request. + HttpResponse response = channel.readInbound(); + assertEquals(HttpResponseStatus.EARLY_HINTS, response.status()); + LastHttpContent last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + + // Response 2: Actual response, with contents, to first GET request. + response = channel.readInbound(); + assertEquals(HttpResponseStatus.OK, response.status()); + assertEquals(5, response.headers().getInt(HttpHeaderNames.CONTENT_LENGTH)); + last = channel.readInbound(); + assertEquals(5, last.content().readableBytes()); + last.release(); + + // Response 3: Actual response, with no contents, to second HEAD request. + response = channel.readInbound(); + assertEquals(HttpResponseStatus.OK, response.status()); + last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + + assertFalse(channel.finish()); + } + @Test public void testMultipleResponses() { String response = "HTTP/1.1 200 OK\r\n" + diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpInvalidMessageTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpInvalidMessageTest.java index d5e96eefa21..bc83c231b1a 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpInvalidMessageTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpInvalidMessageTest.java @@ -91,7 +91,7 @@ public void testResponseWithBadHeader() throws Exception { @Test public void testBadChunk() throws Exception { EmbeddedChannel ch = new EmbeddedChannel(new HttpRequestDecoder()); - ch.writeInbound(Unpooled.copiedBuffer("GET / HTTP/1.0\r\n", CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer("GET / HTTP/1.1\r\n", CharsetUtil.UTF_8)); ch.writeInbound(Unpooled.copiedBuffer("Transfer-Encoding: chunked\r\n\r\n", CharsetUtil.UTF_8)); ch.writeInbound(Unpooled.copiedBuffer("BAD_LENGTH\r\n", CharsetUtil.UTF_8)); diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java index 0bb21521387..4fb28861ad0 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java @@ -26,7 +26,6 @@ import io.netty.util.AsciiString; import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; - import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; import org.mockito.Mockito; @@ -758,4 +757,34 @@ public void execute() { } }); } + + @Test + public void invalidContinueLength() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpServerCodec(), new HttpObjectAggregator(1024)); + + channel.writeInbound(Unpooled.copiedBuffer("POST / HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length:\r\n" + + "\r\n\r\n", CharsetUtil.US_ASCII)); + assertTrue(channel.finishAndReleaseAll()); + } + + @Test + public void testOversizedRequestWithAutoReadFalse() { + EmbeddedChannel embedder = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(4)); + embedder.config().setAutoRead(false); + assertFalse(embedder.writeInbound(Unpooled.copiedBuffer( + "PUT /upload HTTP/1.1\r\n" + + "Content-Length: 5\r\n\r\n", CharsetUtil.US_ASCII))); + + assertNull(embedder.readInbound()); + + FullHttpResponse response = embedder.readOutbound(); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + assertEquals("0", response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + ReferenceCountUtil.release(response); + + assertFalse(embedder.isOpen()); + assertFalse(embedder.finish()); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java index 0ce7d196fad..8d411015890 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Map; -import static io.netty.handler.codec.http.HttpHeaderNames.*; +import static io.netty.handler.codec.http.HttpHeaderNames.HOST; import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -269,6 +269,92 @@ public void testEmptyHeaderValue() { assertEquals("", req.headers().get(of("EmptyHeader"))); } + @Test + public void testSingleTrailingHeader() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + String request = "POST / HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "5\r\n" + + "hello\r\n" + + "0\r\n" + + "X-Checksum: abc123\r\n" + + "\r\n"; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(request, CharsetUtil.US_ASCII))); + HttpRequest req = channel.readInbound(); + assertFalse(req.decoderResult().isFailure()); + HttpContent body = channel.readInbound(); + body.release(); + LastHttpContent last = channel.readInbound(); + assertFalse(last.decoderResult().isFailure()); + assertEquals("abc123", last.trailingHeaders().get(of("X-Checksum"))); + last.release(); + assertFalse(channel.finish()); + } + + @Test + public void testMultiLineTrailingHeader() { + // Regression: folded trailer values previously threw UnsupportedOperationException + // because trailingHeaders().getAll() returns an AbstractList that does not implement set(). + // Note: obs-fold in trailers is permitted as trailers are field-lines per + // https://www.rfc-editor.org/rfc/rfc9112#section-5.2 + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + String request = "POST / HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "5\r\n" + + "hello\r\n" + + "0\r\n" + + "X-Long: part1\r\n" + + " part2\r\n" + + "\t\t\t part3\r\n" + + "X-Short: value\r\n" + + "\r\n"; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(request, CharsetUtil.US_ASCII))); + HttpRequest req = channel.readInbound(); + assertFalse(req.decoderResult().isFailure()); + HttpContent body = channel.readInbound(); + body.release(); + LastHttpContent last = channel.readInbound(); + assertFalse(last.decoderResult().isFailure()); + assertEquals("part1 part2 part3", last.trailingHeaders().get(of("X-Long"))); + assertEquals("value", last.trailingHeaders().get(of("X-Short"))); + last.release(); + assertFalse(channel.finish()); + } + + @Test + public void testForbiddenTrailingHeadersAreDropped() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + String request = "POST / HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "5\r\n" + + "hello\r\n" + + "0\r\n" + + HttpHeaderNames.CONTENT_LENGTH + ": 5\r\n" + + HttpHeaderNames.TRANSFER_ENCODING + ": chunked\r\n" + + "X-Custom: keep\r\n" + + HttpHeaderNames.TRAILER + ": X-Checksum\r\n" + // covering post-loop flush path + "\r\n"; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(request, CharsetUtil.US_ASCII))); + HttpRequest req = channel.readInbound(); + assertFalse(req.decoderResult().isFailure()); + HttpContent body = channel.readInbound(); + body.release(); + LastHttpContent last = channel.readInbound(); + assertFalse(last.decoderResult().isFailure()); + assertNull(last.trailingHeaders().get(HttpHeaderNames.CONTENT_LENGTH)); + assertNull(last.trailingHeaders().get(HttpHeaderNames.TRANSFER_ENCODING)); + assertNull(last.trailingHeaders().get(HttpHeaderNames.TRAILER)); + assertEquals("keep", last.trailingHeaders().get(of("X-Custom"))); + last.release(); + assertFalse(channel.finish()); + } + @Test public void test100Continue() { HttpRequestDecoder decoder = new HttpRequestDecoder(); @@ -403,6 +489,31 @@ public void testInitialLineWithLeadingControlChars() { assertTrue(channel.finishAndReleaseAll()); } + @Test + public void testNonCrlfControlBytesPrecedingRequestLineAreRejected() { + // RFC 9112 §2.2: servers SHOULD ignore "at least one empty line (CRLF)" before the + // request-line. Non-CRLF control bytes are not part of this robustness allowance + // and must not be silently swallowed. + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + + ByteBuf buf = Unpooled.buffer(); + buf.writeByte(0x00); // NUL — not an empty CRLF line + buf.writeByte(0x01); // SOH — not an empty CRLF line + buf.writeCharSequence( + "GET / HTTP/1.1\r\nHost: example.com\r\n\r\n", + CharsetUtil.US_ASCII); + + channel.writeInbound(buf); + HttpRequest req = channel.readInbound(); + + DecoderResult decoderResult = req.decoderResult(); + assertTrue(decoderResult.isFailure(), + "Non-CRLF control bytes before the request-line must not be silently skipped"); + assertThat(decoderResult.cause()).isInstanceOf(InvalidLineSeparatorException.class); + + assertFalse(channel.finish()); + } + @Test public void testTooLargeHeaders() { EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder(1024, 10, 1024)); @@ -591,6 +702,16 @@ public void testMultipleContentLengthHeadersWithFolding() { testInvalidHeaders0(requestStr); } + @Test + public void testChunkedNotLastInTransferEncoding() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Transfer-Encoding: chunked, identity\r\n" + + "Content-Length: 1\r\n" + + "Host: netty.io\r\n\r\n" + + "a"; + testInvalidHeaders0(requestStr); + } + @Test public void testContentLengthAndTransferEncodingHeadersWithVerticalTab() { testContentLengthAndTransferEncodingHeadersWithInvalidSeparator((char) 0x0b, false); @@ -616,7 +737,7 @@ private static void testContentLengthAndTransferEncodingHeadersWithInvalidSepara } @Test - public void testContentLengthHeaderAndChunked() { + public void testContentLengthHeaderAndChunkedHttp11() { String requestStr = "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + "Connection: close\r\n" + @@ -626,15 +747,48 @@ public void testContentLengthHeaderAndChunked() { EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); HttpRequest request = channel.readInbound(); + assertTrue(request.decoderResult().isFailure()); + assertThat(request.decoderResult().cause()).isInstanceOf(ContentLengthNotAllowedException.class); + assertFalse(channel.finish()); + } + + @Test + public void testContentLengthHeaderAndChunkedHttp11RFC7230() { + String requestStr = "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Content-Length: 5\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "0\r\n\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder( + new HttpDecoderConfig().setUseRfc9112TransferEncoding(false))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); assertFalse(request.decoderResult().isFailure()); assertTrue(request.headers().names().contains("Transfer-Encoding")); assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); assertFalse(request.headers().contains("Content-Length")); + assertEquals("close", request.headers().get("Connection")); LastHttpContent c = channel.readInbound(); c.release(); assertFalse(channel.finish()); } + @Test + public void testContentLengthHeaderAndChunkedHttp10() { + String requestStr = "POST / HTTP/1.0\r\n" + + "Host: example.com\r\n" + + "Connection: close\r\n" + + "Content-Length: 5\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "0\r\n\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertTrue(request.decoderResult().isFailure()); + assertThat(request.decoderResult().cause()).isInstanceOf(TransferEncodingNotAllowedException.class); + assertFalse(channel.finish()); + } + @Test void mustRejectImproperlyTerminatedChunkExtensions() throws Exception { // See full explanation: https://w4ke.info/2025/06/18/funky-chunks.html @@ -695,6 +849,243 @@ void mustRejectImproperlyTerminatedChunkBodies() throws Exception { assertFalse(channel.finish()); } + @Test + void mustParsedChunkExtensionsWithQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String requestStr = "GET /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "1;a=\" ;\t\"\r\n" + // chunk extension quote end + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + content.release(); + LastHttpContent last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkExtensionsWithLineBreaksInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String requestStr = "GET /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "1;a=\"\r\n" + // chunk extension quote start + "X\r\n" + + "0\r\n\r\n" + + "GET /two HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "\"\r\n" + // chunk extension quote end + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain line breaks. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustParseChunkExtensionsWithQuotedStringsAndEscapes() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String requestStr = "GET /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "1;a=\" \\\";\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + content.release(); + LastHttpContent last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustParseMultipleChunkExtensionsWithTokenValues() throws Exception { + // Regression: the old Match-based state machine had ';' (0x3B) missing from the + // exclusion set in ChunkExtValToken, so ';' was treated as a token character + // instead of starting a new extension. This caused valid multi-extension lines + // like ";name1=val1;name2=val2" to be rejected with InvalidChunkExtensionException. + String requestStr = "GET / HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "1;name1=val1;name2=val2\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); + HttpContent content = channel.readInbound(); + if (content.decoderResult().isFailure()) { + content.decoderResult().cause().printStackTrace(); + } + assertFalse(content.decoderResult().isFailure()); // Must accept valid multi-extension token values. + content.release(); + LastHttpContent last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + assertFalse(channel.finish()); + } + + @Test + void mustRejectChunkExtensionsWithEscapedLineBreakInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String requestStr = "GET /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "1;a=\" \\\n;\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain line breaks. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkExtensionsWithEscapedCarriageReturnInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String requestStr = "GET /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "1;a=\" \\\r;\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain carraige return. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void lineLengthRestrictionMustNotApplyToChunkContents() throws Exception { + char[] chars = new char[10000]; + Arrays.fill(chars, 'a'); + String requestContent = new String(chars); + String requestStr = "POST /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + Integer.toHexString(chars.length) + "\r\n" + + requestContent + "\r\n" + + "0\r\n\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + int contentLength = 0; + HttpContent content; + do { + content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + if (decoderResult.cause() != null) { + throw new Exception(decoderResult.cause()); + } + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + contentLength += content.content().readableBytes(); + content.release(); + } while (!(content instanceof LastHttpContent)); + assertEquals(chars.length, contentLength); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkSizeWithNonHexadecimalCharacters() throws Exception { + String requestStr = "POST /one HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "test\r\n\r\n" + // chunk extension quote start + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); // We parse the headers + HttpContent content = channel.readInbound(); + assertTrue(content.decoderResult().isFailure()); + assertThat(content.decoderResult().cause()).isInstanceOf(NumberFormatException.class); + assertFalse(channel.finish()); + } + + @Test + public void mustRejectChunkSizeThatWouldCauseOverflow() { + String requestStr = "POST / HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "100000004\r\n" + + "test\r\n" + + "0\r\n" + + "\r\n" + + "GET /smuggled HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Content-Length: 0\r\n" + + "\r\n"; + + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + + // Request 1 + HttpRequest request = channel.readInbound(); + assertTrue(request.decoderResult().isSuccess()); + HttpContent content = channel.readInbound(); + assertFalse(content.decoderResult().isSuccess()); + assertThat(content.decoderResult().cause()).hasMessageContaining("Chunk size overflow"); + content.release(); + assertFalse(channel.finish()); + } + @Test public void testOrderOfHeadersWithContentLength() { String requestStr = "GET /some/path HTTP/1.1\r\n" + diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java index 2c0ffd7d942..cd822063dab 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java @@ -37,8 +37,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -/** - */ public class HttpRequestEncoderTest { @SuppressWarnings("deprecation") diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java index d38e6169d0c..2b9254ca7c9 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java @@ -26,11 +26,12 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import java.util.Arrays; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Random; + import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertArrayEquals; @@ -206,21 +207,6 @@ public void testResponseChunkedWithValidUncommonPatterns() { assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); - // leading whitespace, trailing control char - - assertFalse(ch.writeInbound(Unpooled.copiedBuffer(" " + Integer.toHexString(data.length) + "\0\r\n", - CharsetUtil.US_ASCII))); - assertTrue(ch.writeInbound(Unpooled.copiedBuffer(data))); - content = ch.readInbound(); - assertEquals(data.length, content.content().readableBytes()); - - decodedData = new byte[data.length]; - content.content().readBytes(decodedData); - assertArrayEquals(data, decodedData); - content.release(); - - assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); - // leading whitespace, trailing semicolon assertFalse(ch.writeInbound(Unpooled.copiedBuffer(" " + Integer.toHexString(data.length) + ";\r\n", @@ -665,6 +651,64 @@ private static void testLastResponseWithTrailingHeaderFragmented(byte[] content, assertNull(ch.readInbound()); } + @Test + public void testMultiLineTrailingHeader() { + // Regression: folded trailer values previously threw UnsupportedOperationException + // because trailingHeaders().getAll() returns an AbstractList that does not implement set(). + // Note: obs-fold in trailers is permitted as trailers are field-lines per + // https://www.rfc-editor.org/rfc/rfc9112#section-5.2 + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + String response = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0\r\n" + + "X-Long: part1\r\n" + + " part2\r\n" + + "\t\t\t part3\r\n" + + "X-Short: value\r\n" + + "\r\n"; + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(response, CharsetUtil.US_ASCII))); + HttpResponse res = ch.readInbound(); + assertFalse(res.decoderResult().isFailure()); + assertSame(HttpVersion.HTTP_1_1, res.protocolVersion()); + assertEquals(HttpResponseStatus.OK, res.status()); + + LastHttpContent last = ch.readInbound(); + assertFalse(last.decoderResult().isFailure()); + assertEquals("part1 part2 part3", last.trailingHeaders().get(of("X-Long"))); + assertEquals("value", last.trailingHeaders().get(of("X-Short"))); + last.release(); + assertFalse(ch.finish()); + } + + @Test + public void testForbiddenTrailingHeadersAreDropped() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + String response = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0\r\n" + + HttpHeaderNames.CONTENT_LENGTH + ": 5\r\n" + + HttpHeaderNames.TRANSFER_ENCODING + ": chunked\r\n" + + "X-Custom: keep\r\n" + + HttpHeaderNames.TRAILER + ": X-Checksum\r\n" + // covering post-loop flush path + "\r\n"; + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(response, CharsetUtil.US_ASCII))); + HttpResponse res = ch.readInbound(); + assertFalse(res.decoderResult().isFailure()); + assertSame(HttpVersion.HTTP_1_1, res.protocolVersion()); + assertEquals(HttpResponseStatus.OK, res.status()); + + LastHttpContent last = ch.readInbound(); + assertFalse(last.decoderResult().isFailure()); + assertNull(last.trailingHeaders().get(HttpHeaderNames.CONTENT_LENGTH)); + assertNull(last.trailingHeaders().get(HttpHeaderNames.TRANSFER_ENCODING)); + assertNull(last.trailingHeaders().get(HttpHeaderNames.TRAILER)); + assertEquals("keep", last.trailingHeaders().get(of("X-Custom"))); + last.release(); + assertFalse(ch.finish()); + } + @Test public void testResponseWithContentLength() { EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); @@ -698,6 +742,54 @@ public void testResponseWithContentLength() { assertNull(ch.readInbound()); } + @Test + public void testContentLengthHeaderAndChunkedHttp11() { + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Connection: close\r\n" + + "Content-Length: 5\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "0\r\n\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertTrue(response.decoderResult().isFailure()); + assertThat(response.decoderResult().cause()).isInstanceOf(ContentLengthNotAllowedException.class); + assertFalse(channel.finish()); + } + + @Test + public void testContentLengthHeaderAndChunkedHttp11RFC7230() { + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Content-Length: 5\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "0\r\n\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder( + new HttpDecoderConfig().setUseRfc9112TransferEncoding(false))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + assertFalse(response.headers().contains("Content-Length")); + LastHttpContent c = channel.readInbound(); + c.release(); + assertFalse(channel.finish()); + } + + @Test + public void testContentLengthHeaderAndChunkedHttp10() { + String responseStr = "HTTP/1.0 200 OK\r\n" + + "Content-Length: 5\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "0\r\n\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertTrue(response.decoderResult().isFailure()); + assertThat(response.decoderResult().cause()).isInstanceOf(TransferEncodingNotAllowedException.class); + assertFalse(channel.finish()); + } + @Test public void testResponseWithContentLengthFragmented() { byte[] data = ("HTTP/1.1 200 OK\r\n" + @@ -1000,7 +1092,7 @@ public void testGarbageChunkAfterWhiteSpaces() { @Test void mustRejectImproperlyTerminatedChunkExtensions() throws Exception { // See full explanation: https://w4ke.info/2025/06/18/funky-chunks.html - String requestStr = "HTTP/1.1 200 OK\r\n" + + String responseStr = "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "2;\n" + // Chunk size followed by illegal single newline (not preceded by carraige return) @@ -1011,7 +1103,7 @@ void mustRejectImproperlyTerminatedChunkExtensions() throws Exception { "Transfer-Encoding: chunked\r\n\r\n" + "0\r\n\r\n"; EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); - assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); HttpResponse response = channel.readInbound(); assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. assertTrue(response.headers().names().contains("Transfer-Encoding")); @@ -1027,7 +1119,7 @@ void mustRejectImproperlyTerminatedChunkExtensions() throws Exception { @Test void mustRejectImproperlyTerminatedChunkBodies() throws Exception { // See full explanation: https://w4ke.info/2025/06/18/funky-chunks.html - String requestStr = "HTTP/1.1 200 OK\r\n" + + String responseStr = "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + "5\r\n" + "AAAAXX" + // Chunk body contains extra (XX) bytes, and no CRLF terminator. @@ -1037,7 +1129,7 @@ void mustRejectImproperlyTerminatedChunkBodies() throws Exception { "Transfer-Encoding: chunked\r\n\r\n" + "0\r\n\r\n"; EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); - assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); HttpResponse response = channel.readInbound(); assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. assertTrue(response.headers().names().contains("Transfer-Encoding")); @@ -1054,6 +1146,237 @@ void mustRejectImproperlyTerminatedChunkBodies() throws Exception { assertFalse(channel.finish()); } + @Test + void mustParsedChunkExtensionsWithQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "1;a=\" ;\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + content.release(); + LastHttpContent last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkExtensionsWithLineBreaksInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "1;a=\"\r\n" + // chunk extension quote start + "X\r\n" + + "0\r\n\r\n" + + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "\"\r\n" + // chunk extension quote end + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain line breaks. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustParsedChunkExtensionsWithQuotedStringsAndEscapes() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "1;a=\" \\\";\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + content.release(); + LastHttpContent last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustParseMultipleChunkExtensionsWithTokenValues() throws Exception { + // Regression: the old Match-based state machine had ';' (0x3B) missing from the + // exclusion set in ChunkExtValToken, so ';' was treated as a token character + // instead of starting a new extension. This caused valid multi-extension lines + // like ";name1=val1;name2=val2" to be rejected with InvalidChunkExtensionException. + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "1;name1=val1;name2=val2\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); + HttpContent content = channel.readInbound(); + assertFalse(content.decoderResult().isFailure()); // Must accept valid multi-extension token values. + content.release(); + LastHttpContent last = channel.readInbound(); + assertEquals(0, last.content().readableBytes()); + last.release(); + assertFalse(channel.finish()); + } + + @Test + void mustRejectChunkExtensionsWithEscapedLineBreakInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "1;a=\" \\\n;\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain line breaks. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkExtensionsWithEscapedCarraigeReturnInQuotedStrings() throws Exception { + // See full explanation: https://w4ke.info/2025/10/29/funky-chunks-2.html + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "1;a=\" \\\r;\t\"\r\n" + + "Y\r\n" + + "0\r\n" + + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + HttpContent content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + assertTrue(decoderResult.isFailure()); // Chunk extension is not allowed to contain carriage returns. + assertThat(decoderResult.cause()).isInstanceOf(InvalidChunkExtensionException.class); + content.release(); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void lineLengthRestrictionMustNotApplyToChunkContents() throws Exception { + char[] chars = new char[10000]; + Arrays.fill(chars, 'a'); + String requestContent = new String(chars); + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Host: localhost\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + Integer.toHexString(chars.length) + "\r\n" + + requestContent + "\r\n" + + "0\r\n\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers just fine. + assertTrue(response.headers().names().contains("Transfer-Encoding")); + assertTrue(response.headers().contains("Transfer-Encoding", "chunked", false)); + int contentLength = 0; + HttpContent content; + do { + content = channel.readInbound(); + DecoderResult decoderResult = content.decoderResult(); + if (decoderResult.cause() != null) { + throw new Exception(decoderResult.cause()); + } + assertFalse(decoderResult.isFailure()); // And we parse the chunk. + contentLength += content.content().readableBytes(); + content.release(); + } while (!(content instanceof LastHttpContent)); + assertEquals(chars.length, contentLength); + assertFalse(channel.finish()); // And there are no other chunks parsed. + } + + @Test + void mustRejectChunkSizeWithNonHexadecimalCharacters() throws Exception { + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "test\r\n\r\n" + // chunk extension quote start + "\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); // We parse the headers + HttpContent content = channel.readInbound(); + assertTrue(content.decoderResult().isFailure()); + assertThat(content.decoderResult().cause()).isInstanceOf(NumberFormatException.class); + assertFalse(channel.finish()); + } + + @Test + public void mustRejectChunkSizeThatWouldCauseOverflow() { + String requestStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "100000004\r\n" + + "test\r\n" + + "0\r\n" + + "\r\n" + + "GET /smuggled HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Content-Length: 0\r\n" + + "\r\n"; + + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + + // Request 1 + HttpResponse response = channel.readInbound(); + assertTrue(response.decoderResult().isSuccess()); + HttpContent content = channel.readInbound(); + assertFalse(content.decoderResult().isSuccess()); + assertThat(content.decoderResult().cause()).hasMessageContaining("Chunk size overflow"); + content.release(); + assertFalse(channel.finish()); + } + @Test public void testConnectionClosedBeforeHeadersReceived() { EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpServerCodecTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpServerCodecTest.java index eb60569eb0f..c58dddc81b3 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpServerCodecTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpServerCodecTest.java @@ -19,8 +19,11 @@ import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -174,6 +177,56 @@ public void testChunkedHeadFullHttpResponse() { assertFalse(ch.finishAndReleaseAll()); } + @Test + public void testConnectionClosedAfterResponseWhenBothTransferEncodingAndContentLengthRfc9112() { + // We reject these requests by default. + EmbeddedChannel ch = new EmbeddedChannel(new HttpServerCodec()); + + String requestStr = "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Content-Length: 5\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "0\r\n\r\n"; + + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + + HttpRequest request = ch.readInbound(); + assertTrue(request.decoderResult().isFailure()); + assertThat(request.decoderResult().cause()).isInstanceOf(ContentLengthNotAllowedException.class); + assertFalse(ch.finishAndReleaseAll()); + } + + @Test + public void testConnectionClosedAfterResponseWhenBothTransferEncodingAndContentLengthRfc7230() { + // Leniency, or "RFC 7230" mode, can be configured but the connection is then closed after. + EmbeddedChannel ch = new EmbeddedChannel(new HttpServerCodec( + new HttpDecoderConfig().setUseRfc9112TransferEncoding(false))); + + String requestStr = "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Content-Length: 5\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "0\r\n\r\n"; + + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + + HttpRequest request = ch.readInbound(); + assertFalse(request.decoderResult().isFailure()); + assertFalse(HttpUtil.isKeepAlive(request)); + LastHttpContent content = ch.readInbound(); + ReferenceCountUtil.release(content); + + FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpUtil.setContentLength(response, 0); + + assertTrue(ch.writeOutbound(response)); + // Channel should be closed after the response is written + assertFalse(ch.isOpen()); + + ReferenceCountUtil.release(ch.readOutbound()); + assertFalse(ch.finishAndReleaseAll()); + } + private static ByteBuf prepareDataChunk(int size) { StringBuilder sb = new StringBuilder(); for (int i = 0; i < size; ++i) { diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpServerKeepAliveHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpServerKeepAliveHandlerTest.java index 0ef78e0d5bf..1617921466a 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpServerKeepAliveHandlerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpServerKeepAliveHandlerTest.java @@ -15,10 +15,13 @@ */ package io.netty.handler.codec.http; +import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; @@ -212,6 +215,37 @@ public void testPipelineKeepAlive(boolean isKeepAliveResponseExpected, HttpVersi assertFalse(channel.finishAndReleaseAll()); } + @Test + public void testConnectionClosedWhenBothTransferEncodingAndContentLengthRfc7230() { + EmbeddedChannel ch = new EmbeddedChannel( + new HttpRequestDecoder(new HttpDecoderConfig().setUseRfc9112TransferEncoding(false)), + new HttpServerKeepAliveHandler()); + + String requestStr = "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Content-Length: 5\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "0\r\n\r\n"; + + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + + HttpRequest request = ch.readInbound(); + assertFalse(HttpUtil.isKeepAlive(request)); + LastHttpContent content = ch.readInbound(); + ReferenceCountUtil.release(content); + + FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + setContentLength(response, 0); + + ch.writeAndFlush(response); + HttpResponse writtenResponse = ch.readOutbound(); + + assertFalse(isKeepAlive(writtenResponse)); + assertFalse(ch.isOpen()); + ReferenceCountUtil.release(writtenResponse); + assertFalse(ch.finishAndReleaseAll()); + } + private static void setupMessageLength(HttpResponse response, int setSelfDefinedMessageLength) { switch (setSelfDefinedMessageLength) { case NOT_SELF_DEFINED_MSG_LENGTH: diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java index 05dd678564d..f77d4e8c297 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java @@ -56,7 +56,8 @@ public void testRecognizesOriginForm() { assertFalse(HttpUtil.isOriginForm(URI.create("*"))); } - @Test public void testRecognizesAsteriskForm() { + @Test + public void testRecognizesAsteriskForm() { // Asterisk form: https://tools.ietf.org/html/rfc7230#section-5.3.4 assertTrue(HttpUtil.isAsteriskForm(URI.create("*"))); // Origin form: https://tools.ietf.org/html/rfc7230#section-5.3.1 @@ -67,6 +68,26 @@ public void testRecognizesOriginForm() { assertFalse(HttpUtil.isAsteriskForm(URI.create("www.example.com:80"))); } + @ParameterizedTest + @ValueSource(strings = { + "http://localhost/\r\n", + "/r\r\n?q=1", + "http://localhost/\r\n?q=1", + "/r\r\n/?q=1", + "http://localhost/\r\n/?q=1", + "/r\r\n", + "http://localhost/ HTTP/1.1\r\n\r\nPOST /p HTTP/1.1\r\n\r\n", + "/r HTTP/1.1\r\n\r\nPOST /p HTTP/1.1\r\n\r\n", + "GET ", + " GET", + "HTTP/ 1.1", + "HTTP/\r0.9", + "HTTP/\n1.1", + }) + public void requestLineTokenValidationMustRejectInvalidTokens(String token) throws Exception { + assertFalse(HttpUtil.isEncodingSafeStartLineToken(token)); + } + @Test public void testRemoveTransferEncodingIgnoreCase() { HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); @@ -514,4 +535,22 @@ public void testInvalidTokenChars(char invalidChar) { assertEquals(2, validateToken(asciiStringToken)); assertEquals(2, validateToken(token)); } + + @ParameterizedTest + @ValueSource(chars = { + // High-bit Truncation Candidates (verifying > 0xFF check) + // These characters are chosen because their lower 8 bits + // alias to valid US-ASCII 'tchar' values. + '\u0161', // 0x0161 truncates to 0x61 ('a') + '\u0121', // 0x0121 truncates to 0x21 ('!') + '\u0231', // 0x0231 truncates to 0x31 ('1') + '\u0361' // 0x0361 truncates to 0x61 ('a') + }) + public void testInvalidTokenCharsOutsideAsciiRange(char invalidChar) { + // We use a String here because AsciiString would truncate + // the char to a byte during construction. + String token = "GE" + invalidChar + 'T'; + assertEquals(2, validateToken(token), + String.format("Character U+%04X should be invalid", (int) invalidChar)); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpVersionLocaleTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpVersionLocaleTest.java new file mode 100644 index 00000000000..86c79868535 --- /dev/null +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpVersionLocaleTest.java @@ -0,0 +1,69 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Isolated; + +import java.util.Locale; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests that exercise {@link HttpVersion} parsing with the JVM-default Locale flipped to a value + * that exposes the Turkish dotted-I problem (U+0130). {@link Locale#setDefault} is process-global + * mutable state, so this class is marked {@link Isolated} to keep it from leaking into the rest + * of the codec-http suite, which runs in concurrent mode + * ({@code junit.jupiter.execution.parallel.mode.default = concurrent}). + */ +@Isolated("Mutates Locale.getDefault() which is JVM-global state.") +class HttpVersionLocaleTest { + + @Test + void testLowercaseIotaProtocolNameUnderTurkishLocale() { + // Turkish locale maps 'i' -> 'İ' (U+0130) under the JVM-default toUpperCase(). + // The constructor must use Locale.US so an HTTP-derived protocol name like "icap" + // round-trips to the ASCII "ICAP" instead of being corrupted to "İCAP". + Locale original = Locale.getDefault(); + try { + Locale.setDefault(new Locale("tr", "TR")); + HttpVersion version = HttpVersion.valueOf("icap/1.0"); + assertEquals("ICAP", version.protocolName()); + for (int i = 0; i < version.protocolName().length(); i++) { + assertTrue(version.protocolName().charAt(i) < 0x80, + "protocolName must remain ASCII regardless of JVM default locale"); + } + assertEquals("ICAP/1.0", version.text()); + } finally { + Locale.setDefault(original); + } + } + + @Test + void testProtocolNameConstructorUnderTurkishLocale() { + // Same Locale.US guarantee for the (protocolName, major, minor, ...) constructor. + Locale original = Locale.getDefault(); + try { + Locale.setDefault(new Locale("tr", "TR")); + HttpVersion version = new HttpVersion("icap", 1, 0, true); + assertEquals("ICAP", version.protocolName()); + assertEquals("ICAP/1.0", version.text()); + } finally { + Locale.setDefault(original); + } + } +} diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpVersionParsingTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpVersionParsingTest.java new file mode 100644 index 00000000000..d2971ac4726 --- /dev/null +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpVersionParsingTest.java @@ -0,0 +1,172 @@ +/* + * Copyright 2025 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpVersionParsingTest { + + @Test + void testStandardVersions() { + HttpVersion v10 = HttpVersion.valueOf("HTTP/1.0"); + HttpVersion v11 = HttpVersion.valueOf("HTTP/1.1"); + + assertSame(HttpVersion.HTTP_1_0, v10); + assertSame(HttpVersion.HTTP_1_1, v11); + + assertEquals("HTTP", v10.protocolName()); + assertEquals(1, v10.majorVersion()); + assertEquals(0, v10.minorVersion()); + + assertEquals("HTTP", v11.protocolName()); + assertEquals(1, v11.majorVersion()); + assertEquals(1, v11.minorVersion()); + } + + @Test + void testLowerCaseProtocolNameNonStrict() { + HttpVersion version = HttpVersion.valueOf("http/1.1"); + assertEquals("HTTP", version.protocolName()); + assertEquals(1, version.majorVersion()); + assertEquals(1, version.minorVersion()); + assertEquals("HTTP/1.1", version.text()); + } + + @Test + void testMixedCaseProtocolNameNonStrict() { + HttpVersion version = HttpVersion.valueOf("hTtP/1.0"); + assertEquals("HTTP", version.protocolName()); + assertEquals(1, version.majorVersion()); + assertEquals(0, version.minorVersion()); + assertEquals("HTTP/1.0", version.text()); + } + + @Test + void testCustomLowerCaseProtocolNonStrict() { + HttpVersion version = HttpVersion.valueOf("mqtt/5.0"); + assertEquals("MQTT", version.protocolName()); + assertEquals(5, version.majorVersion()); + assertEquals(0, version.minorVersion()); + assertEquals("MQTT/5.0", version.text()); + } + + @Test + void testCustomVersionNonStrict() { + HttpVersion version = HttpVersion.valueOf("MyProto/2.3"); + assertEquals("MYPROTO", version.protocolName()); // uppercased + assertEquals(2, version.majorVersion()); + assertEquals(3, version.minorVersion()); + assertEquals("MYPROTO/2.3", version.text()); + } + + @Test + void testCustomVersionStrict() { + HttpVersion version = new HttpVersion("HTTP/1.1", true, true); + assertEquals("HTTP", version.protocolName()); + assertEquals(1, version.majorVersion()); + assertEquals(1, version.minorVersion()); + } + + @Test + void testCustomVersionStrictFailsOnLongVersion() { + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + new HttpVersion("HTTP/10.1", true, true); + } + }); + assertTrue(ex.getMessage().contains("invalid version format")); + } + + @Test + void testInvalidFormatMissingSlash() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + HttpVersion.valueOf("HTTP1.1"); + } + }); + } + + @Test + void testInvalidFormatWhitespaceInProtocol() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + HttpVersion.valueOf("HT TP/1.1"); + } + }); + } + + @ParameterizedTest + @ValueSource(strings = { + "HTTP ", + " HTTP", + "H TTP", + " HTTP ", + "HTTP\r", + "HTTP\n", + "HTTP\r\n", + "HTT\rP", + "HTT\nP", + "HTT\r\nP", + "\rHTTP", + "\nHTTP", + "\r\nHTTP", + " \r\nHTTP", + "\r \nHTTP", + "\r\n HTTP", + "\r\nHTTP ", + "\nHTTP ", + "\rHTTP ", + "\r HTTP", + " \rHTTP", + "\nHTTP ", + "\n HTTP", + " \nHTTP", + "HTTP \n", + "HTTP \r", + " HTTP\r", + " HTTP\r", + "HTTP \n", + " HTTP\n", + " HTTP\n", + "HTT\nTP", + "HTT\rTP", + " HTT\rP", + " HTT\rP", + "HTT\nTP", + " HTT\nP", + " HTT\nP", + }) + void httpVersionMustRejectIllegalTokens(String protocol) { + try { + HttpVersion httpVersion = new HttpVersion(protocol, 1, 0, true); + // If no exception is thrown, then the version must have been sanitized and made safe. + assertTrue(HttpUtil.isEncodingSafeStartLineToken(httpVersion.text())); + } catch (IllegalArgumentException ignore) { + // Throwing is good. + } + } +} diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java index 8b3065fbb89..d76f8d3f04f 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java @@ -21,12 +21,17 @@ import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultLastHttpContent; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.DefaultHttpHeadersFactory; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpUtil; import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import org.junit.jupiter.api.Test; @@ -512,6 +517,154 @@ public void simpleRequestDoNotAllowPrivateNetwork() { assertTrue(ReferenceCountUtil.release(response)); } + @Test + public void preflightEmptyLastDiscarded() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + FullHttpRequest preflight = new DefaultFullHttpRequest(HTTP_1_1, OPTIONS, "/test"); + preflight.headers().set(ORIGIN, "http://allowed"); + preflight.headers().set(ACCESS_CONTROL_REQUEST_METHOD, "GET"); + + assertFalse(ch.writeInbound(preflight)); + + Object outbound = ch.readOutbound(); + assertNotNull(outbound); // preflight response + + LastHttpContent lastHttpContent = LastHttpContent.EMPTY_LAST_CONTENT; + assertFalse(ch.writeInbound(lastHttpContent)); + + // Nothing should have been forwarded + assertNull(ch.readInbound()); + + assertFalse(ch.finish()); + } + + @Test + public void preflightSecondEmptyLastForwardedAfterFirstDiscard() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + FullHttpRequest preflight = new DefaultFullHttpRequest(HTTP_1_1, OPTIONS, "/test"); + preflight.headers().set(ORIGIN, "http://allowed"); + preflight.headers().set(ACCESS_CONTROL_REQUEST_METHOD, "GET"); + + assertFalse(ch.writeInbound(preflight)); + ReferenceCountUtil.release(ch.readOutbound()); + + LastHttpContent first = LastHttpContent.EMPTY_LAST_CONTENT; + LastHttpContent second = LastHttpContent.EMPTY_LAST_CONTENT; + + assertFalse(ch.writeInbound(first)); + + assertFalse(ch.writeInbound(second)); + + assertNull(ch.readInbound()); + assertFalse(ch.finish()); + } + + @Test + public void preflightSecondNonEmptyLastDiscarded() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + FullHttpRequest preflight = new DefaultFullHttpRequest(HTTP_1_1, OPTIONS, "/test"); + preflight.headers().set(ORIGIN, "http://allowed"); + preflight.headers().set(ACCESS_CONTROL_REQUEST_METHOD, "GET"); + + assertFalse(ch.writeInbound(preflight)); + ReferenceCountUtil.release(ch.readOutbound()); + + LastHttpContent first = LastHttpContent.EMPTY_LAST_CONTENT; + LastHttpContent second = new DefaultLastHttpContent( + Unpooled.copiedBuffer("test message", CharsetUtil.UTF_8)); + + assertFalse(ch.writeInbound(first)); + assertFalse(ch.writeInbound(second)); + assertNull(ch.readInbound()); + assertFalse(ch.finish()); + } + + @Test + public void preflightNonEmptyLastForwarded() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + FullHttpRequest preflight = new DefaultFullHttpRequest(HTTP_1_1, OPTIONS, "/x"); + preflight.headers().set(ORIGIN, "http://allowed"); + preflight.headers().set(ACCESS_CONTROL_REQUEST_METHOD, "GET"); + + assertFalse(ch.writeInbound(preflight)); + Object outbound = ch.releaseOutbound(); + assertNotNull(outbound); + + LastHttpContent nonEmpty = new DefaultLastHttpContent(Unpooled.copiedBuffer("x", CharsetUtil.UTF_8)); + assertFalse(ch.writeInbound(nonEmpty)); + + Object inbound = ch.readInbound(); + assertNull(inbound); + + assertFalse(ch.finish()); + } + + @Test + public void testNormalRequestForwarded() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, GET, "/test"); + req.headers().set(ORIGIN, "http://allowed"); + + assertTrue(ch.writeInbound(req)); + + LastHttpContent last = LastHttpContent.EMPTY_LAST_CONTENT; + assertTrue(ch.writeInbound(last)); + + Object firstInbound = ch.readInbound(); + Object secondInbound = ch.readInbound(); + + assertNotNull(firstInbound); + assertNotNull(secondInbound); + + assertNull(ch.readInbound()); + assertFalse(ch.finish()); + } + + @Test + public void preflightEmptyLastDiscardedThenNewRequestForwarded() { + CorsConfig config = forOrigin("http://allowed").build(); + EmbeddedChannel ch = new EmbeddedChannel(new CorsHandler(config)); + + // Preflight request + FullHttpRequest preflight = new DefaultFullHttpRequest(HTTP_1_1, OPTIONS, "/pre"); + preflight.headers().set(ORIGIN, "http://allowed"); + preflight.headers().set(ACCESS_CONTROL_REQUEST_METHOD, "GET"); + assertFalse(ch.writeInbound(preflight)); + Object preflightResp = ch.readOutbound(); + assertNotNull(preflightResp); + ReferenceCountUtil.release(preflightResp); + + // Empty last content should be discarded + assertFalse(ch.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT)); + assertNull(ch.readInbound()); + + // New request should be forwarded + FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, GET, "/next"); + req.headers().set(ORIGIN, "http://allowed"); + assertTrue(ch.writeInbound(req)); + + Object firstInbound = ch.readInbound(); + assertNotNull(firstInbound); + + HttpContent content = new DefaultHttpContent(Unpooled.copiedBuffer("test message", CharsetUtil.UTF_8)); + assertTrue(ch.writeInbound(content)); + Object secondInbound = ch.readInbound(); + assertNotNull(secondInbound); + + assertNull(ch.readInbound()); + assertFalse(ch.finish()); + } + private static HttpResponse simpleRequest(final CorsConfig config, final String origin) { return simpleRequest(config, origin, null); } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostMultipartLocaleDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostMultipartLocaleDecoderTest.java new file mode 100644 index 00000000000..e1504c6d2e5 --- /dev/null +++ b/codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostMultipartLocaleDecoderTest.java @@ -0,0 +1,72 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Isolated; + +import java.util.Locale; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/** + * Decoder-side test that flips {@link Locale#setDefault(Locale)} to Turkish to exercise the + * {@code Content-Transfer-Encoding} normalization path. Marked {@link Isolated} because the JVM + * default Locale is process-global state and codec-http runs tests with + * {@code junit.jupiter.execution.parallel.mode.default = concurrent}. + */ +@Isolated("Mutates Locale.getDefault() which is JVM-global state.") +class HttpPostMultipartLocaleDecoderTest { + + @Test + void testUppercaseBinaryTransferEncodingUnderTurkishLocale() { + // Repro: a part declares an uppercase Content-Transfer-Encoding (RFC 2045 mechanism + // tokens are case-insensitive). On a Turkish-locale JVM the decoder used to call + // toLowerCase() without a Locale and produced "bınary" (U+0131) which then failed the + // compare against the lowercase ASCII constants and threw + // "TransferEncoding Unknown: bınary". + Locale original = Locale.getDefault(); + try { + Locale.setDefault(new Locale("tr", "TR")); + String content = "\n--861fbeab-cd20-470c-9609-d40a0f704466\r\n" + + "content-disposition: form-data; " + + "name=\"file\"; filename=\"myfile.ogg\"\r\n" + + "content-type: audio/ogg; codecs=opus; charset=UTF8\r\n" + + "Content-Transfer-Encoding: BINARY\r\n" + + "\r\n\r\n--861fbeab-cd20-470c-9609-d40a0f704466--\r\n"; + + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/upload", + Unpooled.copiedBuffer(content, CharsetUtil.US_ASCII)); + req.headers().set("content-type", "multipart/form-data; boundary=861fbeab-cd20-470c-9609-d40a0f704466"); + req.headers().set("content-length", content.length()); + + HttpPostMultipartRequestDecoder decoder = new HttpPostMultipartRequestDecoder(req); + FileUpload httpData = (FileUpload) decoder.getBodyHttpDatas("file").get(0); + assertNotNull(httpData); + assertEquals("audio/ogg", httpData.getContentType()); + decoder.destroy(); + } finally { + Locale.setDefault(original); + } + } +} diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostMultipartLocaleEncoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostMultipartLocaleEncoderTest.java new file mode 100644 index 00000000000..1862a826d1d --- /dev/null +++ b/codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostMultipartLocaleEncoderTest.java @@ -0,0 +1,66 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Isolated; + +import java.util.List; +import java.util.Locale; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE; +import static io.netty.handler.codec.http.HttpMethod.POST; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Encoder-side test that flips {@link Locale#setDefault(Locale)} to Turkish to exercise the + * {@code Content-Type} de-duplication path inside + * {@link HttpPostRequestEncoder#finalizeRequest()}. Marked {@link Isolated} because the JVM + * default Locale is process-global state and codec-http runs tests with + * {@code junit.jupiter.execution.parallel.mode.default = concurrent}. + */ +@Isolated("Mutates Locale.getDefault() which is JVM-global state.") +class HttpPostMultipartLocaleEncoderTest { + + @Test + void testFinalizeRemovesPreexistingMultipartContentTypeUnderTurkishLocale() throws Exception { + // Repro: a caller pre-sets `Content-Type: MULTIPART/form-data` (uppercase, RFC-allowed, + // case-insensitive). When the JVM default Locale is Turkish, the encoder's + // contentType.toLowerCase() used to map 'I' -> 'ı' (U+0131), the resulting + // "multıpart/form-data" missed the prefix check, and the original mixed-case header was + // left alongside the new multipart Content-Type the encoder is about to set - producing + // two Content-Type headers on the request. + Locale original = Locale.getDefault(); + try { + Locale.setDefault(new Locale("tr", "TR")); + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, "http://localhost"); + request.headers().add(CONTENT_TYPE, "MULTIPART/form-data; boundary=preexisting"); + HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(request, true); + encoder.finalizeRequest(); + List contentTypes = request.headers().getAll(CONTENT_TYPE); + assertEquals(1, contentTypes.size(), + "encoder must drop the pre-existing multipart Content-Type regardless of JVM locale"); + assertTrue(contentTypes.get(0).startsWith("multipart/form-data"), + "the surviving Content-Type must be the encoder's freshly-built multipart header"); + request.release(); + } finally { + Locale.setDefault(original); + } + } +} diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java index 786891d604d..5e46bfe7e71 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java @@ -23,6 +23,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; @@ -36,6 +37,7 @@ import java.util.Map; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; public class PerMessageDeflateClientExtensionHandshakerTest { @@ -243,4 +245,44 @@ public void testDecoderNoClientContext() { assertFalse(decoderChannel.finish()); } + + @Test + public void testClientMaxWindowWithNoValue() { + // Test that client handles client_max_window_bits with no value (null) + // RFC 7692: client_max_window_bits may have no value + PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(6, true, 15, true, false, 0); + + Map parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, null); // No value specified + + // Should not throw NumberFormatException + WebSocketClientExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // Handshake should succeed, using MAX_WINDOW_SIZE (15) as default + assertNotNull(extension); + assertEquals(RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); + } + + @Test + public void testClientMaxWindowWithInvalidValue() { + // Test that client throws NumberFormatException for invalid client_max_window_bits value + final PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(6, true, 15, true, false, 0); + + final Map parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, "invalid"); + + // Should throw NumberFormatException + assertThrows(NumberFormatException.class, new Executable() { + @Override + public void execute() throws Throwable { + handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + } + }); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java index e661e05a1a0..efaa7f88679 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java @@ -173,4 +173,40 @@ public void testCustomHandshake() { assertEquals(PERMESSAGE_DEFLATE_EXTENSION, data.name()); assertTrue(data.parameters().isEmpty()); } + + @Test + public void testClientMaxWindowWithValue() { + PerMessageDeflateServerExtensionHandshaker handshaker = + new PerMessageDeflateServerExtensionHandshaker(6, true, 10, true, true, 0); + + Map parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, "12"); + + WebSocketServerExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + assertNotNull(extension); + assertEquals(WebSocketServerExtension.RSV1, extension.rsv()); + + WebSocketExtensionData data = extension.newReponseData(); + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, data.name()); + // Server should use the client's requested value (12) not the preferred (10) + assertTrue(data.parameters().containsKey(CLIENT_MAX_WINDOW)); + assertEquals("12", data.parameters().get(CLIENT_MAX_WINDOW)); + } + + @Test + public void testClientMaxWindowWithInvalidValue() { + PerMessageDeflateServerExtensionHandshaker handshaker = + new PerMessageDeflateServerExtensionHandshaker(6, true, 10, true, true, 0); + + Map parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, "7"); // Below MIN_WINDOW_SIZE (8) + + WebSocketServerExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // Handshake should fail when client_max_window_bits is out of range + assertNull(extension); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/rtsp/RtspMethodsTest.java b/codec-http/src/test/java/io/netty/handler/codec/rtsp/RtspMethodsTest.java new file mode 100644 index 00000000000..9c428652664 --- /dev/null +++ b/codec-http/src/test/java/io/netty/handler/codec/rtsp/RtspMethodsTest.java @@ -0,0 +1,67 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +import io.netty.handler.codec.http.HttpMethod; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Isolated; + +import java.util.Locale; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; + +@Isolated("valueOfNormalizesLowercaseInputUnderTurkishLocale flips the JVM-default Locale, " + + "which is process-global state, so the class must not run alongside the rest of the " + + "codec-http suite (junit.jupiter.execution.parallel.mode.default = concurrent).") +class RtspMethodsTest { + + @Test + void valueOfReturnsCachedInstanceForUppercaseName() { + assertSame(RtspMethods.DESCRIBE, RtspMethods.valueOf("DESCRIBE")); + assertSame(RtspMethods.SETUP, RtspMethods.valueOf("SETUP")); + assertSame(RtspMethods.GET_PARAMETER, RtspMethods.valueOf("GET_PARAMETER")); + } + + @Test + void valueOfNormalizesLowercaseInputUnderUsLocale() { + assertSame(RtspMethods.DESCRIBE, RtspMethods.valueOf("describe")); + assertSame(RtspMethods.PLAY, RtspMethods.valueOf("play")); + assertSame(RtspMethods.REDIRECT, RtspMethods.valueOf("redirect")); + } + + @Test + void valueOfNormalizesLowercaseInputUnderTurkishLocale() { + // In Turkish locale (tr_TR), 'i' uppercases to 'İ' (U+0130) under the JVM default. + // RtspMethods.valueOf must pin Locale.US so RTSP method names that contain 'i' such as + // "describe" or "redirect" continue to resolve to the cached uppercase entries. + Locale original = Locale.getDefault(); + try { + Locale.setDefault(new Locale("tr", "TR")); + HttpMethod describe = RtspMethods.valueOf("describe"); + HttpMethod redirect = RtspMethods.valueOf("redirect"); + assertSame(RtspMethods.DESCRIBE, describe); + assertSame(RtspMethods.REDIRECT, redirect); + // Sanity-check: with the locale-default toUpperCase the names would have contained + // U+0130 instead of plain 'I'. Asserting the resolved names round-trip to ASCII + // pins down that the fix is taking effect. + assertEquals("DESCRIBE", describe.name()); + assertEquals("REDIRECT", redirect.name()); + } finally { + Locale.setDefault(original); + } + } +} diff --git a/codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java index a0c2cd132a4..e496e5ef136 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java @@ -31,7 +31,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.verifyNoInteractions; public class SpdyFrameDecoderTest { @@ -841,7 +841,7 @@ public void testDiscardUnknownFrame() throws Exception { buf.writeLong(RANDOM.nextLong()); decoder.decode(buf); - verifyZeroInteractions(delegate); + verifyNoInteractions(delegate); assertFalse(buf.isReadable()); buf.release(); } @@ -856,7 +856,7 @@ public void testDiscardUnknownEmptyFrame() throws Exception { encodeControlFrameHeader(buf, type, flags, length); decoder.decode(buf); - verifyZeroInteractions(delegate); + verifyNoInteractions(delegate); assertFalse(buf.isReadable()); buf.release(); } @@ -878,7 +878,7 @@ public void testProgressivelyDiscardUnknownEmptyFrame() throws Exception { decoder.decode(header); decoder.decode(segment1); decoder.decode(segment2); - verifyZeroInteractions(delegate); + verifyNoInteractions(delegate); assertFalse(header.isReadable()); assertFalse(segment1.isReadable()); assertFalse(segment2.isReadable()); diff --git a/codec-http2/pom.xml b/codec-http2/pom.xml index f363a73fbdb..a47a3551106 100644 --- a/codec-http2/pom.xml +++ b/codec-http2/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec-http2 @@ -163,6 +163,29 @@ junit-jupiter-params test + + com.code-intelligence + jazzer-junit + test + + + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-params + + + org.junit.platform + junit-platform-commons + + + org.junit.platform + junit-platform-launcher + + + org.assertj assertj-core diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java index 7747e4fa458..66c4f92d35c 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java @@ -13,7 +13,6 @@ * License for the specific language governing permissions and limitations * under the License. */ - package io.netty.handler.codec.http2; import io.netty.channel.Channel; @@ -113,6 +112,8 @@ public abstract class AbstractHttp2ConnectionHandlerBuilder 0) { decoder = new Http2EmptyDataFrameConnectionDecoder(decoder, maxConsecutiveEmptyDataFrames); @@ -655,6 +688,13 @@ private T buildFromCodec(Http2ConnectionDecoder decoder, Http2ConnectionEncoder return handler; } + private static void enforceMaxActiveStreams(Http2Connection connection, Http2Settings initialSettings) { + Long maxConcurrentStreams = initialSettings.maxConcurrentStreams(); + if (maxConcurrentStreams != null) { + connection.remote().maxActiveStreams((int) Math.min(maxConcurrentStreams, Integer.MAX_VALUE)); + } + } + /** * Implement this method to create a new {@link Http2ConnectionHandler} or its subtype instance. *

diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamChannel.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamChannel.java index 564584a545f..f1d90ad951c 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamChannel.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamChannel.java @@ -185,7 +185,7 @@ private enum ReadStatus { private final Http2StreamChannelConfig config = new Http2StreamChannelConfig(this); private final Http2ChannelUnsafe unsafe = new Http2ChannelUnsafe(); - private final ChannelId channelId; + private final Http2StreamChannelId channelId; private final ChannelPipeline pipeline; private final DefaultHttp2FrameStream stream; private final ChannelPromise closePromise; @@ -584,7 +584,7 @@ public int compareTo(Channel o) { @Override public String toString() { - return parent().toString() + "(H2 - " + stream + ')'; + return parent().toString() + '/' + channelId.getSequenceId() + " (H2 - " + stream + ')'; } /** diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java index 90daeabc0f2..d82224d9fd8 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java @@ -100,7 +100,9 @@ public void remoteSettings(Http2Settings settings) throws Http2Exception { } Long maxHeaderListSize = settings.maxHeaderListSize(); - if (maxHeaderListSize != null) { + if (maxHeaderListSize != null && !connection.isServer()) { + // Servers ignore the MAX_HEADER_LIST_SIZE setting from clients. + // It's advisory in spec (RFC 9113 §6.5.2) and best praxis is to ignore it. outboundHeaderConfig.maxHeaderListSize(maxHeaderListSize); } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java index b67eec80316..31e95afda4c 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java @@ -18,12 +18,14 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http2.Http2FrameReader.Configuration; +import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.PlatformDependent; import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.INT_FIELD_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND; import static io.netty.handler.codec.http2.Http2CodecUtil.PING_FRAME_PAYLOAD_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_ENTRY_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_INITIAL_WINDOW_SIZE; @@ -31,6 +33,7 @@ import static io.netty.handler.codec.http2.Http2CodecUtil.headerListSizeExceeded; import static io.netty.handler.codec.http2.Http2CodecUtil.isMaxFrameSizeValid; import static io.netty.handler.codec.http2.Http2CodecUtil.readUnsignedInt; +import static io.netty.handler.codec.http2.Http2Error.ENHANCE_YOUR_CALM; import static io.netty.handler.codec.http2.Http2Error.FLOW_CONTROL_ERROR; import static io.netty.handler.codec.http2.Http2Error.FRAME_SIZE_ERROR; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; @@ -51,6 +54,7 @@ * A {@link Http2FrameReader} that supports all frame types defined by the HTTP/2 specification. */ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSizePolicy, Configuration { + private static final int FRAGMENT_THRESHOLD = MAX_FRAME_SIZE_LOWER_BOUND / 2; private final Http2HeadersDecoder headersDecoder; /** @@ -67,7 +71,8 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize private Http2Flags flags; private int payloadLength; private HeadersContinuation headersContinuation; - private int maxFrameSize; + private int maxFrameSize = DEFAULT_MAX_FRAME_SIZE; + private final int maxSmallContinuationFrames; /** * Create a new instance. @@ -88,8 +93,13 @@ public DefaultHttp2FrameReader(boolean validateHeaders) { } public DefaultHttp2FrameReader(Http2HeadersDecoder headersDecoder) { - this.headersDecoder = headersDecoder; - maxFrameSize = DEFAULT_MAX_FRAME_SIZE; + this(headersDecoder, Http2CodecUtil.DEFAULT_MAX_SMALL_CONTINUATION_FRAME); + } + + public DefaultHttp2FrameReader(Http2HeadersDecoder headersDecoder, int maxSmallContinuationFrames) { + this.headersDecoder = ObjectUtil.checkNotNull(headersDecoder, "headersDecoder"); + this.maxSmallContinuationFrames = ObjectUtil.checkPositiveOrZero( + maxSmallContinuationFrames, "maxSmallContinuationFrames"); } @Override @@ -390,6 +400,12 @@ private void verifyContinuationFrame() throws Http2Exception { throw connectionError(PROTOCOL_ERROR, "Continuation stream ID does not match pending headers. " + "Expected %d, but received %d.", headersContinuation.getStreamId(), streamId); } + + if (headersContinuation.numSmallFragments() >= maxSmallContinuationFrames) { + throw connectionError(ENHANCE_YOUR_CALM, + "Number of small consecutive continuations frames %d exceeds maximum: %d", + headersContinuation.numSmallFragments(), maxSmallContinuationFrames); + } } private void verifyUnknownFrame() throws Http2Exception { @@ -399,7 +415,6 @@ private void verifyUnknownFrame() throws Http2Exception { private void readDataFrame(ChannelHandlerContext ctx, ByteBuf payload, Http2FrameListener listener) throws Http2Exception { int padding = readPadding(payload); - verifyPadding(padding); // Determine how much data there is to read by removing the trailing // padding. @@ -414,7 +429,6 @@ private void readHeadersFrame(final ChannelHandlerContext ctx, ByteBuf payload, final int headersStreamId = streamId; final Http2Flags headersFlags = flags; final int padding = readPadding(payload); - verifyPadding(padding); // The callback that is invoked is different depending on whether priority information // is present in the headers frame. @@ -536,7 +550,6 @@ private void readPushPromiseFrame(final ChannelHandlerContext ctx, ByteBuf paylo Http2FrameListener listener) throws Http2Exception { final int pushPromiseStreamId = streamId; final int padding = readPadding(payload); - verifyPadding(padding); final int promisedStreamId = readUnsignedInt(payload); // Create a handler that invokes the listener when the header block is complete. @@ -620,21 +633,19 @@ private int readPadding(ByteBuf payload) { return payload.readUnsignedByte() + 1; } - private void verifyPadding(int padding) throws Http2Exception { - int len = lengthWithoutTrailingPadding(payloadLength, padding); - if (len < 0) { - throw connectionError(PROTOCOL_ERROR, "Frame payload too small for padding."); - } - } - /** * The padding parameter consists of the 1 byte pad length field and the trailing padding bytes. This method * returns the number of readable bytes without the trailing padding. */ - private static int lengthWithoutTrailingPadding(int readableBytes, int padding) { - return padding == 0 - ? readableBytes - : readableBytes - (padding - 1); + private static int lengthWithoutTrailingPadding(int readableBytes, int padding) throws Http2Exception { + if (padding == 0) { + return readableBytes; + } + int n = readableBytes - (padding - 1); + if (n < 0) { + throw connectionError(PROTOCOL_ERROR, "Frame payload too small for padding."); + } + return n; } /** @@ -650,6 +661,15 @@ private abstract class HeadersContinuation { */ abstract int getStreamId(); + /** + * Return the number of fragments that were used so far. + * + * @return the number of fragments + */ + final int numSmallFragments() { + return builder.numSmallFragments(); + } + /** * Processes the next fragment for the current header block. * @@ -678,6 +698,7 @@ final void close() { */ protected class HeadersBlockBuilder { private ByteBuf headerBlock; + private int numSmallFragments; /** * The local header size maximum has been exceeded while accumulating bytes. @@ -688,6 +709,15 @@ private void headerSizeExceeded() throws Http2Exception { headerListSizeExceeded(headersDecoder.configuration().maxHeaderListSizeGoAway()); } + /** + * Return the number of fragments that was used so far. + * + * @return number of fragments. + */ + int numSmallFragments() { + return numSmallFragments; + } + /** * Adds a fragment to the block. * @@ -699,6 +729,11 @@ private void headerSizeExceeded() throws Http2Exception { */ final void addFragment(ByteBuf fragment, int len, ByteBufAllocator alloc, boolean endOfHeaders) throws Http2Exception { + if (maxSmallContinuationFrames > 0 && !endOfHeaders && len < FRAGMENT_THRESHOLD) { + // Only count of the fragment is not the end of header and if its < 8kb. + numSmallFragments++; + } + if (headerBlock == null) { if (len > headersDecoder.configuration().maxHeaderListSizeGoAway()) { headerSizeExceeded(); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java index 73e497ccb8c..c14502b94f9 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java @@ -76,6 +76,8 @@ public DelegatingDecompressorFrameListener(Http2Connection connection, Http2Fram * @param listener the delegate listener used by {@link Http2FrameListenerDecorator} * @param maxAllocation maximum size of the decompression buffer. Must be >= 0. * If zero, maximum size is not limited by decoder. + * Some compression codecs will output buffers up to 64 KiB in size, + * even if {@code maxAllocation} is configured lower. */ public DelegatingDecompressorFrameListener(Http2Connection connection, Http2FrameListener listener, int maxAllocation) { @@ -108,6 +110,8 @@ public DelegatingDecompressorFrameListener(Http2Connection connection, Http2Fram * otherwise the decoder can fallback to {@link ZlibWrapper#NONE} * @param maxAllocation maximum size of the decompression buffer. Must be >= 0. * If zero, maximum size is not limited by decoder. + * Some compression codecs will output buffers up to 64 KiB in size, + * even if {@code maxAllocation} is configured lower. */ public DelegatingDecompressorFrameListener(Http2Connection connection, Http2FrameListener listener, boolean strict, int maxAllocation) { @@ -177,7 +181,7 @@ protected EmbeddedChannel newContentDecompressor(final ChannelHandlerContext ctx } if (Brotli.isAvailable() && BR.contentEqualsIgnoreCase(contentEncoding)) { return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), - ctx.channel().config(), new BrotliDecoder()); + ctx.channel().config(), new BrotliDecoder(maxAllocation)); } if (SNAPPY.contentEqualsIgnoreCase(contentEncoding)) { return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), @@ -185,7 +189,7 @@ protected EmbeddedChannel newContentDecompressor(final ChannelHandlerContext ctx } if (Zstd.isAvailable() && ZSTD.contentEqualsIgnoreCase(contentEncoding)) { return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), - ctx.channel().config(), new ZstdDecoder()); + ctx.channel().config(), new ZstdDecoder(maxAllocation)); } // 'identity' or unsupported return null; @@ -361,16 +365,22 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception buf.release(); return; } - incrementDecompressedBytes(buf.readableBytes()); - // Immediately return the bytes back to the flow controller. ConsumedBytesConverter will convert - // from the decompressed amount which the user knows about to the compressed amount which flow - // control knows about. - connection.local().flowController().consumeBytes(stream, - listener.onDataRead(targetCtx, stream.id(), buf, padding, false)); - padding = 0; // Padding is only communicated once on the first iteration. - buf.release(); - - dataDecompressed = true; + try { + // Also take padding into account. + incrementDecompressedBytes(padding); + + incrementDecompressedBytes(buf.readableBytes()); + // Immediately return the bytes back to the flow controller. ConsumedBytesConverter will convert + // from the decompressed amount which the user knows about to the compressed amount which flow + // control knows about. + connection.local().flowController().consumeBytes(stream, + listener.onDataRead(targetCtx, stream.id(), buf, padding, false)); + padding = 0; // Padding is only communicated once on the first iteration. + + dataDecompressed = true; + } finally { + buf.release(); + } } @Override diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java index f68ad765d84..0c8c400a861 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java @@ -111,11 +111,19 @@ public final class Http2CodecUtil { public static final int DEFAULT_MAX_FRAME_SIZE = MAX_FRAME_SIZE_LOWER_BOUND; /** * The assumed minimum value for {@code SETTINGS_MAX_CONCURRENT_STREAMS} as - * recommended by the HTTP/2 spec. + * recommended by the HTTP/2 spec. */ public static final int SMALLEST_MAX_CONCURRENT_STREAMS = 100; static final int DEFAULT_MAX_RESERVED_STREAMS = SMALLEST_MAX_CONCURRENT_STREAMS; static final int DEFAULT_MIN_ALLOCATION_CHUNK = 1024; + static final int DEFAULT_MAX_SMALL_CONTINUATION_FRAME = 16; + + /** + * While the RFC only specified a minimum we should still pick a default which is good enough that most people + * no need to adjust it but still be somewhat protected. Let's use the minimum + * defined by the HTTP/2 spec. + */ + static final int DEFAULT_MAX_CONCURRENT_STREAMS = SMALLEST_MAX_CONCURRENT_STREAMS; /** * Calculate the threshold in bytes which should trigger a {@code GO_AWAY} if a set of headers exceeds this amount. diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java index 61e9cd1213b..ca494b8de24 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java @@ -221,6 +221,16 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { public boolean prefaceSent() { return true; } + + /** + * Send the preface if needed. + * + * @param ctx the {@link ChannelHandlerContext} to use. + * @throws Exception thrown on error. + */ + public void sendPrefaceIfNeeded(ChannelHandlerContext ctx) throws Exception { + // Noop by default. + } } private final class PrefaceDecoder extends BaseDecoder { @@ -231,7 +241,7 @@ private final class PrefaceDecoder extends BaseDecoder { clientPrefaceString = clientPrefaceString(encoder.connection()); // This handler was just added to the context. In case it was handled after // the connection became active, send the connection preface now. - sendPreface(ctx); + sendPrefaceIfNeeded(ctx); } @Override @@ -248,6 +258,10 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) thro byteDecoder.decode(ctx, in, out); } } catch (Throwable e) { + if (byteDecoder != null) { + // Skip all bytes before we report the exception as + in.skipBytes(in.readableBytes()); + } onError(ctx, false, e); } } @@ -255,14 +269,7 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) thro @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { // The channel just became active - send the connection preface to the remote endpoint. - sendPreface(ctx); - - if (flushPreface) { - // As we don't know if any channelReadComplete() events will be triggered at all we need to ensure we - // also flush. Otherwise the remote peer might never see the preface / settings frame. - // See https://github.com/netty/netty/issues/12089 - ctx.flush(); - } + sendPrefaceIfNeeded(ctx); } @Override @@ -346,19 +353,25 @@ private boolean verifyFirstFrameIsSettings(ByteBuf in) throws Http2Exception { } short frameType = in.getUnsignedByte(in.readerIndex() + 3); - short flags = in.getUnsignedByte(in.readerIndex() + 4); - if (frameType != SETTINGS || (flags & Http2Flags.ACK) != 0) { + if (frameType != SETTINGS) { throw connectionError(PROTOCOL_ERROR, "First received frame was not SETTINGS. " + "Hex dump for first 5 bytes: %s", hexDump(in, in.readerIndex(), 5)); } + short flags = in.getUnsignedByte(in.readerIndex() + 4); + if ((flags & Http2Flags.ACK) != 0) { + throw connectionError(PROTOCOL_ERROR, "First received frame was SETTINGS frame but had ACK flag set. " + + "Hex dump for first 5 bytes: %s", + hexDump(in, in.readerIndex(), 5)); + } return true; } /** * Sends the HTTP/2 connection preface upon establishment of the connection, if not already sent. */ - private void sendPreface(ChannelHandlerContext ctx) throws Exception { + @Override + public void sendPrefaceIfNeeded(ChannelHandlerContext ctx) throws Exception { if (prefaceSent || !ctx.channel().isActive()) { return; } @@ -375,11 +388,20 @@ private void sendPreface(ChannelHandlerContext ctx) throws Exception { encoder.writeSettings(ctx, initialSettings, ctx.newPromise()).addListener( ChannelFutureListener.CLOSE_ON_FAILURE); - if (isClient) { - // If this handler is extended by the user and we directly fire the userEvent from this context then - // the user will not see the event. We should fire the event starting with this handler so this class - // (and extending classes) have a chance to process the event. - userEventTriggered(ctx, Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE); + try { + if (isClient) { + // If this handler is extended by the user and we directly fire the userEvent from this context then + // the user will not see the event. We should fire the event starting with this handler so this + // class (and extending classes) have a chance to process the event. + userEventTriggered(ctx, Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE); + } + } finally { + if (flushPreface) { + // As we don't know if any channelReadComplete() events will be triggered at all we need to ensure + // we also flush. Otherwise the remote peer might never see the preface / settings frame. + // See https://github.com/netty/netty/issues/12089 + ctx.flush(); + } } } } @@ -453,13 +475,19 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t @Override public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { - ctx.bind(localAddress, promise); + // Ensure we send the preface before we notify the bind promise as the user might try to write + // directly in the listener attached to the promise and we need to ensure the preface is always the first + // thing that is written. + ctx.bind(localAddress, ctx.newPromise()).addListener(new PrefaceSendListener(ctx, promise)); } @Override public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) throws Exception { - ctx.connect(remoteAddress, localAddress, promise); + // Ensure we send the preface before we notify the connect promise as the user might try to write + // directly in the listener attached to the promise and we need to ensure the preface is always the first + // thing that is written. + ctx.connect(remoteAddress, localAddress, ctx.newPromise()).addListener(new PrefaceSendListener(ctx, promise)); } @Override @@ -1004,4 +1032,31 @@ private void doClose() { } } } + + private final class PrefaceSendListener implements ChannelFutureListener { + private final ChannelHandlerContext ctx; + private final ChannelPromise promise; + + PrefaceSendListener(ChannelHandlerContext ctx, ChannelPromise promise) { + this.ctx = ctx; + this.promise = promise; + } + + @Override + public void operationComplete(ChannelFuture f) { + if (f.isSuccess()) { + try { + if (byteDecoder != null) { + byteDecoder.sendPrefaceIfNeeded(ctx); + } + } catch (Throwable e) { + promise.setFailure(e); + return; + } + promise.setSuccess(); + } else { + promise.setFailure(f.cause()); + } + } + } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java index d4bd2fe5a3a..2a4a1320d0b 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java @@ -203,6 +203,17 @@ public Http2FrameCodecBuilder encoderEnforceMaxRstFramesPerWindow( return super.encoderEnforceMaxRstFramesPerWindow(maxRstFramesPerWindow, secondsPerWindow); } + @Override + public int decoderEnforceMaxSmallContinuationFrames() { + return super.decoderEnforceMaxSmallContinuationFrames(); + } + + @Override + public Http2FrameCodecBuilder decoderEnforceMaxSmallContinuationFrames( + int maxConsecutiveContinuationsFrames) { + return super.decoderEnforceMaxSmallContinuationFrames(maxConsecutiveContinuationsFrames); + } + /** * Build a {@link Http2FrameCodec} object. */ @@ -216,7 +227,8 @@ public Http2FrameCodec build() { Long maxHeaderListSize = initialSettings().maxHeaderListSize(); Http2FrameReader frameReader = new DefaultHttp2FrameReader(maxHeaderListSize == null ? new DefaultHttp2HeadersDecoder(isValidateHeaders()) : - new DefaultHttp2HeadersDecoder(isValidateHeaders(), maxHeaderListSize)); + new DefaultHttp2HeadersDecoder(isValidateHeaders(), maxHeaderListSize), + decoderEnforceMaxSmallContinuationFrames()); if (frameLogger() != null) { frameWriter = new Http2OutboundFrameLogger(frameWriter, frameLogger()); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java index 65a1f471555..945c232b7a1 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java @@ -221,6 +221,17 @@ public Http2MultiplexCodecBuilder encoderEnforceMaxRstFramesPerWindow( return super.encoderEnforceMaxRstFramesPerWindow(maxRstFramesPerWindow, secondsPerWindow); } + @Override + public int decoderEnforceMaxSmallContinuationFrames() { + return super.decoderEnforceMaxSmallContinuationFrames(); + } + + @Override + public Http2MultiplexCodecBuilder decoderEnforceMaxSmallContinuationFrames( + int maxConsecutiveContinuationsFrames) { + return super.decoderEnforceMaxSmallContinuationFrames(maxConsecutiveContinuationsFrames); + } + @Override public Http2MultiplexCodec build() { Http2FrameWriter frameWriter = this.frameWriter; @@ -231,7 +242,8 @@ public Http2MultiplexCodec build() { Long maxHeaderListSize = initialSettings().maxHeaderListSize(); Http2FrameReader frameReader = new DefaultHttp2FrameReader(maxHeaderListSize == null ? new DefaultHttp2HeadersDecoder(isValidateHeaders()) : - new DefaultHttp2HeadersDecoder(isValidateHeaders(), maxHeaderListSize)); + new DefaultHttp2HeadersDecoder(isValidateHeaders(), maxHeaderListSize), + decoderEnforceMaxSmallContinuationFrames()); if (frameLogger() != null) { frameWriter = new Http2OutboundFrameLogger(frameWriter, frameLogger()); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Settings.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Settings.java index 11be4139ede..ddc696ad8a7 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Settings.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Settings.java @@ -18,6 +18,7 @@ import io.netty.util.collection.CharObjectHashMap; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_CONCURRENT_STREAMS; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_CONCURRENT_STREAMS; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_FRAME_SIZE_UPPER_BOUND; @@ -303,6 +304,7 @@ protected String keyToString(char key) { } public static Http2Settings defaultSettings() { - return new Http2Settings().maxHeaderListSize(DEFAULT_HEADER_LIST_SIZE); + return new Http2Settings().maxHeaderListSize(DEFAULT_HEADER_LIST_SIZE) + .maxConcurrentStreams(DEFAULT_MAX_CONCURRENT_STREAMS); } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java index e50038a051b..651cce73723 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java @@ -69,6 +69,10 @@ public boolean equals(Object obj) { return id == otherId.id && parentId.equals(otherId.parentId); } + public int getSequenceId() { + return id; + } + @Override public String toString() { return asShortText(); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpConversionUtil.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpConversionUtil.java index 147fb4947d1..cbf66603d31 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpConversionUtil.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpConversionUtil.java @@ -35,6 +35,7 @@ import io.netty.handler.codec.http.HttpVersion; import io.netty.util.AsciiString; import io.netty.util.internal.InternalThreadLocalMap; +import io.netty.util.internal.StringUtil; import java.net.URI; import java.util.Iterator; @@ -61,13 +62,15 @@ import static io.netty.util.ByteProcessor.FIND_SEMI_COLON; import static io.netty.util.internal.ObjectUtil.checkNotNull; import static io.netty.util.internal.StringUtil.isNullOrEmpty; -import static io.netty.util.internal.StringUtil.length; import static io.netty.util.internal.StringUtil.unescapeCsvFields; /** * Provides utility methods and constants for the HTTP/2 to HTTP conversion */ public final class HttpConversionUtil { + // Parsing logic adapted from Vert.x HttpUtils.parsePath/parseQuery: + // https://github.com/eclipse-vertx/vert.x/blob/98a8ef6c8b408009ff86eb8277fd0bbb2b866857/ + // vertx-core/src/main/java/io/vertx/core/http/impl/HttpUtils.java#L279-L319 /** * The set of headers that should not be directly copied when converting headers from HTTP to HTTP/2. */ @@ -438,11 +441,21 @@ public static Http2Headers toHttp2Headers(HttpMessage in, boolean validateHeader out.path(new AsciiString(request.uri())); setHttp2Scheme(inHeaders, out); } else { - URI requestTargetUri = URI.create(request.uri()); - out.path(toHttp2Path(requestTargetUri)); - // Take from the request-line if HOST header was empty - host = isNullOrEmpty(host) ? requestTargetUri.getAuthority() : host; - setHttp2Scheme(inHeaders, requestTargetUri, out); + String requestTarget = request.uri(); + out.path(toHttp2Path(requestTarget)); + if (hasSchemeAndAuthority(requestTarget)) { + URI requestTargetUri = URI.create(http2PathlessRequestTarget(requestTarget)); + // Take from the request-line if HOST header was empty + host = isNullOrEmpty(host) ? requestTargetUri.getAuthority() : host; + setHttp2Scheme(inHeaders, requestTargetUri, out); + } else { + int schemeEnd = schemeEnd(requestTarget); + if (schemeEnd != -1) { + setHttp2Scheme(inHeaders, requestTarget.substring(0, schemeEnd), -1, out); + } else { + setHttp2Scheme(inHeaders, out); + } + } } setHttp2Authority(host, out); out.method(request.method().asciiName()); @@ -592,25 +605,143 @@ private static void splitValidCookieHeader(Http2Headers out, CharSequence valueC } /** - * Generate an HTTP/2 {code :path} from a URI in accordance with + * Generate an HTTP/2 {code :path} from a request-target in accordance with * rfc7230, 5.3. */ - private static AsciiString toHttp2Path(URI uri) { - StringBuilder pathBuilder = new StringBuilder(length(uri.getRawPath()) + - length(uri.getRawQuery()) + length(uri.getRawFragment()) + 2); - if (!isNullOrEmpty(uri.getRawPath())) { - pathBuilder.append(uri.getRawPath()); + private static AsciiString toHttp2Path(String uri) { + String path = dropEmptyFragment(parsePath(uri)); + String query = parseQuery(uri); + if (isNullOrEmpty(query)) { + return path.isEmpty() ? EMPTY_REQUEST_PATH : new AsciiString(path); + } + StringBuilder pathBuilder = new StringBuilder(path.length() + query.length() + 1); + pathBuilder.append(path); + appendQuery(pathBuilder, query); + return new AsciiString(pathBuilder.toString()); + } + + /** + * Extract the path out of the request-target. Based on Vert.x' HttpUtils.parsePath logic. + */ + private static String parsePath(String uri) { + if (uri.isEmpty()) { + return StringUtil.EMPTY_STRING; + } + int i; + if (uri.charAt(0) == '/') { + i = 0; + } else { + i = uri.indexOf("://"); + // Netty change: validate the scheme before treating :// as authority syntax. + if (!isValidScheme(uri, i)) { + i = 0; + } else { + int authorityStart = i + 3; + // Netty change: only accept '/' before query/fragment as path start. + int queryOrFragmentStart = queryOrFragmentStart(uri, authorityStart); + i = uri.indexOf('/', authorityStart); + if (i == -1 || (queryOrFragmentStart != -1 && queryOrFragmentStart < i)) { + // contains no / + return "/"; + } + } } - if (!isNullOrEmpty(uri.getRawQuery())) { - pathBuilder.append('?'); - pathBuilder.append(uri.getRawQuery()); + + int queryStart = uri.indexOf('?', i); + if (queryStart == -1) { + queryStart = uri.length(); + if (i == 0) { + return uri; + } } - if (!isNullOrEmpty(uri.getRawFragment())) { - pathBuilder.append('#'); - pathBuilder.append(uri.getRawFragment()); + return uri.substring(i, queryStart); + } + + /** + * Extract the query out of a request-target or returns {@code null} if no query was found. + */ + private static String parseQuery(String uri) { + int i = uri.indexOf('?'); + if (i == -1) { + return null; + } else { + return uri.substring(i + 1); } - String path = pathBuilder.toString(); - return path.isEmpty() ? EMPTY_REQUEST_PATH : new AsciiString(path); + } + + private static String dropEmptyFragment(String path) { + // Netty change: old URI-based conversion dropped an empty fragment delimiter. + return path.endsWith("#") ? path.substring(0, path.length() - 1) : path; + } + + private static void appendQuery(StringBuilder pathBuilder, String query) { + int fragmentStart = query.indexOf('#'); + if (fragmentStart == 0) { + // Netty change: old URI-based conversion skipped an empty query before a fragment. + pathBuilder.append(query); + } else if (fragmentStart == query.length() - 1) { + // Netty change: old URI-based conversion dropped an empty fragment delimiter after a query. + pathBuilder.append('?').append(query, 0, fragmentStart); + } else { + pathBuilder.append('?').append(query); + } + } + + static int queryOrFragmentStart(String uri, int searchStart) { + int queryStart = uri.indexOf('?', searchStart); + int fragmentStart = uri.indexOf('#', searchStart); + return queryStart == -1 ? fragmentStart : + fragmentStart == -1 ? queryStart : Math.min(queryStart, fragmentStart); + } + + // Netty addition: detect authority for HTTP/2 :scheme/:authority extraction. + static boolean hasSchemeAndAuthority(String requestTarget) { + int schemeEnd = requestTarget.indexOf("://"); + return isValidScheme(requestTarget, schemeEnd); + } + + private static int schemeEnd(String requestTarget) { + int schemeEnd = requestTarget.indexOf(':'); + return isValidScheme(requestTarget, schemeEnd) ? schemeEnd : -1; + } + + // Netty addition: prepare only scheme://authority for URI validation. + private static String http2PathlessRequestTarget(String requestTarget) { + int schemeEnd = requestTarget.indexOf("://"); + int authorityStart = schemeEnd + 3; + // Netty addition: strip before path/query/fragment; Vert.x parsePath does not validate authority. + int pathStart = requestTarget.indexOf('/', authorityStart); + int delimiter = queryOrFragmentStart(requestTarget, authorityStart); + if (pathStart != -1 && (delimiter == -1 || pathStart < delimiter)) { + delimiter = pathStart; + } + if (delimiter == -1) { + return requestTarget; + } + return delimiter == authorityStart ? requestTarget.substring(0, delimiter + 1) : + requestTarget.substring(0, delimiter); + } + + // Netty addition: validate the text before :// as a scheme. + static boolean isValidScheme(String uri, int schemeEnd) { + if (schemeEnd <= 0) { + return false; + } + char first = uri.charAt(0); + if (!isAlpha(first)) { + return false; + } + for (int i = 1; i < schemeEnd; ++i) { + char c = uri.charAt(i); + if (!isAlpha(c) && (c < '0' || c > '9') && c != '+' && c != '-' && c != '.') { + return false; + } + } + return true; + } + + private static boolean isAlpha(char c) { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); } // package-private for testing only @@ -635,9 +766,12 @@ private static void setHttp2Scheme(HttpHeaders in, Http2Headers out) { } private static void setHttp2Scheme(HttpHeaders in, URI uri, Http2Headers out) { - String value = uri.getScheme(); - if (!isNullOrEmpty(value)) { - out.scheme(new AsciiString(value)); + setHttp2Scheme(in, uri.getScheme(), uri.getPort(), out); + } + + private static void setHttp2Scheme(HttpHeaders in, String scheme, int port, Http2Headers out) { + if (!isNullOrEmpty(scheme)) { + out.scheme(new AsciiString(scheme)); return; } @@ -648,9 +782,9 @@ private static void setHttp2Scheme(HttpHeaders in, URI uri, Http2Headers out) { return; } - if (uri.getPort() == HTTPS.port()) { + if (port == HTTPS.port()) { out.scheme(HTTPS.name()); - } else if (uri.getPort() == HTTP.port()) { + } else if (port == HTTP.port()) { out.scheme(HTTP.name()); } else { throw new IllegalArgumentException(":scheme must be specified. " + diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java index 638b115949e..f6cf383318c 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java @@ -24,9 +24,10 @@ import io.netty.handler.codec.http.HttpStatusClass; import io.netty.handler.codec.http.HttpUtil; -import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.ENHANCE_YOUR_CALM; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Exception.streamError; import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.util.internal.ObjectUtil.checkNotNull; import static io.netty.util.internal.ObjectUtil.checkPositive; @@ -233,7 +234,7 @@ public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int ByteBuf content = msg.content(); final int dataReadableBytes = data.readableBytes(); if (content.readableBytes() > maxContentLength - dataReadableBytes) { - throw connectionError(INTERNAL_ERROR, + throw streamError(streamId, ENHANCE_YOUR_CALM, "Content length exceeded max of %d for stream id %d", maxContentLength, streamId); } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java index 9bece58a3ba..fa11aaead85 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java @@ -29,6 +29,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.compression.Brotli; +import io.netty.handler.codec.compression.DecompressionException; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable; @@ -36,10 +37,14 @@ import io.netty.util.CharsetUtil; import io.netty.util.NetUtil; import io.netty.util.concurrent.Future; +import io.netty.util.internal.PlatformDependent; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; @@ -48,15 +53,25 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; import java.util.Random; +import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.stream.Stream; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeFalse; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyInt; @@ -89,6 +104,9 @@ public class DataCompressionHttp2Test { private Http2Connection clientConnection; private Http2ConnectionHandler clientHandler; private ByteArrayOutputStream serverOut; + private int maxServerOutBufferSize; + private int maxAllocation; + private final AtomicReference serverException = new AtomicReference(); @BeforeAll public static void beforeAllTests() throws Throwable { @@ -97,6 +115,7 @@ public static void beforeAllTests() throws Throwable { @BeforeEach public void setup() throws InterruptedException, Http2Exception { + maxAllocation = 0; MockitoAnnotations.initMocks(this); doAnswer(new Answer() { @Override @@ -122,7 +141,9 @@ public Void answer(InvocationOnMock invocation) throws Throwable { @AfterEach public void cleanup() throws IOException { - serverOut.close(); + if (serverOut != null) { + serverOut.close(); + } } @AfterEach @@ -140,16 +161,21 @@ public void teardown() throws InterruptedException { serverConnectedChannel.close().sync(); this.serverConnectedChannel = null; } - Future serverGroup = sb.config().group().shutdownGracefully(0, 0, MILLISECONDS); - Future serverChildGroup = sb.config().childGroup().shutdownGracefully(0, 0, MILLISECONDS); - Future clientGroup = cb.config().group().shutdownGracefully(0, 0, MILLISECONDS); - serverGroup.sync(); - serverChildGroup.sync(); - clientGroup.sync(); + if (sb != null) { + Future serverGroup = sb.config().group().shutdownGracefully(0, 0, MILLISECONDS); + Future serverChildGroup = sb.config().childGroup().shutdownGracefully(0, 0, MILLISECONDS); + serverGroup.sync(); + serverChildGroup.sync(); + } + if (cb != null) { + Future clientGroup = cb.config().group().shutdownGracefully(0, 0, MILLISECONDS); + clientGroup.sync(); + } } - @Test - public void justHeadersNoData() throws Exception { + @ParameterizedTest + @ValueSource(ints = { 0, 10 }) + public void justHeadersNoData(final int padding) throws Exception { bootstrapEnv(0); final Http2Headers headers = new DefaultHttp2Headers().method(GET).path(PATH) .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP); @@ -157,65 +183,49 @@ public void justHeadersNoData() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); awaitServer(); verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0), - eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true)); + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(padding), eq(true)); } - @Test - public void gzipEncodingSingleEmptyMessage() throws Exception { - final String text = ""; - final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); - bootstrapEnv(data.readableBytes()); - try { - final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) - .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP); - - runInChannel(clientChannel, new Http2Runnable() { - @Override - public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); - clientHandler.flush(ctxClient()); - } - }); - awaitServer(); - assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); - } finally { - data.release(); + public static List paddingAndCompression() { + List arguments = new ArrayList(); + for (int padding : new int[]{0, 10}) { + for (AsciiString compression : new AsciiString[]{ + HttpHeaderValues.GZIP, HttpHeaderValues.BR, HttpHeaderValues.ZSTD, HttpHeaderValues.SNAPPY}) { + final Object[] args = {padding, compression}; + arguments.add(new Arguments() { + @Override + public Object[] get() { + return args; + } + }); + } } + return arguments; } - @Test - public void gzipEncodingSingleMessage() throws Exception { - final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; - final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); - bootstrapEnv(data.readableBytes()); - try { - final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) - .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP); + @ParameterizedTest + @MethodSource("paddingAndCompression") + public void encodingSingleEmptyMessage(final int padding, AsciiString compressionAlgorithm) throws Exception { + final String text = ""; + testEncodingMessage(padding, text, compressionAlgorithm); + } - runInChannel(clientChannel, new Http2Runnable() { - @Override - public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); - clientHandler.flush(ctxClient()); - } - }); - awaitServer(); - assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); - } finally { - data.release(); - } + @ParameterizedTest + @MethodSource("paddingAndCompression") + public void encodingSingleMessage(final int padding, AsciiString compressionAlgorithm) throws Exception { + final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; + testEncodingMessage(padding, text, compressionAlgorithm); } - @Test - public void gzipEncodingMultipleMessages() throws Exception { + @ParameterizedTest + @MethodSource("paddingAndCompression") + public void encodingMultipleMessages(final int padding, AsciiString compressionAlgorithm) throws Exception { final String text1 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; final String text2 = "dddddddddddddddddddeeeeeeeeeeeeeeeeeeeffffffffffffffffffff"; final ByteBuf data1 = Unpooled.copiedBuffer(text1.getBytes()); @@ -223,171 +233,103 @@ public void gzipEncodingMultipleMessages() throws Exception { bootstrapEnv(data1.readableBytes() + data2.readableBytes()); try { final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) - .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP); + .set(HttpHeaderNames.CONTENT_ENCODING, compressionAlgorithm); runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data1.retain(), 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data2.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data1.retain(), padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data2.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); awaitServer(); - assertEquals(text1 + text2, serverOut.toString(CharsetUtil.UTF_8.name())); + assertEquals(text1 + text2, serverOut.toString(CharsetUtil.ISO_8859_1.name())); } finally { data1.release(); data2.release(); } } - @Test - public void brotliEncodingSingleEmptyMessage() throws Exception { - final String text = ""; - final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); - bootstrapEnv(data.readableBytes()); - try { - final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) - .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.BR); - - runInChannel(clientChannel, new Http2Runnable() { - @Override - public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); - clientHandler.flush(ctxClient()); - } - }); - awaitServer(); - assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); - } finally { - data.release(); - } - } - - @Test - public void brotliEncodingSingleMessage() throws Exception { - final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; - final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.UTF_8.name())); - bootstrapEnv(data.readableBytes()); - try { - final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) - .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.BR); + @ParameterizedTest + @MethodSource("paddingAndCompression") + public void encodingTooBigMessage(final int padding, AsciiString compressionAlgorithm) throws Exception { + // Make the compressed message produce half a megabyte of text, then limit the output buffer size to 64 KiB. + byte[] text = PlatformDependent.allocateUninitializedArray(524288); + final int inputLength = text.length; + maxAllocation = inputLength / 8; - runInChannel(clientChannel, new Http2Runnable() { - @Override - public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); - clientHandler.flush(ctxClient()); - } - }); - awaitServer(); - assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); - } finally { - data.release(); - } - } - - @Test - public void zstdEncodingSingleEmptyMessage() throws Exception { - final String text = ""; - final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); - bootstrapEnv(data.readableBytes()); - try { - final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) - .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.ZSTD); - - runInChannel(clientChannel, new Http2Runnable() { - @Override - public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); - clientHandler.flush(ctxClient()); + testEncodingMessage(padding, text, compressionAlgorithm, new Callable() { + @Override + public Void call() throws Exception { + assertTrue(serverLatch.await(5, SECONDS)); + serverOut.flush(); + Throwable cause = serverException.get(); + if (cause == null) { + // Compression codec must have mitigations + assertThat(maxServerOutBufferSize) + .as("check that the original string of size %s, " + + "got compressed and decompressed into max %s sized buffers", + inputLength, maxAllocation) + .isLessThanOrEqualTo(maxAllocation); + } else { + // Compression codec must reject + assertThat(cause) + .isInstanceOf(Http2Exception.StreamException.class) + .rootCause() + .isInstanceOf(DecompressionException.class) + .hasMessageContaining("maximum size"); } - }); - awaitServer(); - assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); - } finally { - data.release(); - } + return null; + } + }); } - @Test - public void zstdEncodingSingleMessage() throws Exception { - final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; - final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.UTF_8.name())); - bootstrapEnv(data.readableBytes()); - try { - final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) - .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.ZSTD); - - runInChannel(clientChannel, new Http2Runnable() { - @Override - public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); - clientHandler.flush(ctxClient()); - } - }); - awaitServer(); - assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); - } finally { - data.release(); - } + private void testEncodingMessage(final int padding, final String text, AsciiString compressionAlgorithmName) + throws Exception { + testEncodingMessage(padding, text, compressionAlgorithmName, new Callable() { + @Override + public Void call() throws Exception { + awaitServer(); + assertEquals(text, serverOut.toString(CharsetUtil.ISO_8859_1.name())); + return null; + } + }); } - @Test - public void snappyEncodingSingleEmptyMessage() throws Exception { - final String text = ""; - final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.US_ASCII)); - bootstrapEnv(data.readableBytes()); - try { - final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) - .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.SNAPPY); - - runInChannel(clientChannel, new Http2Runnable() { - @Override - public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); - clientHandler.flush(ctxClient()); - } - }); - awaitServer(); - assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); - } finally { - data.release(); - } + private void testEncodingMessage(int padding, String text, AsciiString compressionAlgorithmName, + Callable assertions) throws Exception { + testEncodingMessage(padding, text.getBytes(CharsetUtil.ISO_8859_1), compressionAlgorithmName, assertions); } - @Test - public void snappyEncodingSingleMessage() throws Exception { - final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; - final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.US_ASCII)); + private void testEncodingMessage(final int padding, + final byte[] text, + final AsciiString compressionAlgorithmName, + final Callable assertions) throws Exception { + final ByteBuf data = Unpooled.copiedBuffer(text); bootstrapEnv(data.readableBytes()); try { final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) - .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.SNAPPY); + .set(HttpHeaderNames.CONTENT_ENCODING, compressionAlgorithmName); runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); - awaitServer(); - assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); + assertions.call(); } finally { data.release(); } } - @Test - public void deflateEncodingWriteLargeMessage() throws Exception { + @ParameterizedTest + @MethodSource("paddingAndCompression") + public void deflateEncodingWriteLargeMessage(final int padding) throws Exception { final int BUFFER_SIZE = 1 << 12; final byte[] bytes = new byte[BUFFER_SIZE]; new Random().nextBytes(bytes); @@ -400,14 +342,14 @@ public void deflateEncodingWriteLargeMessage() throws Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); - clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientEncoder.writeHeaders(ctxClient(), 3, headers, padding, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), padding, true, newPromiseClient()); clientHandler.flush(ctxClient()); } }); awaitServer(); - assertEquals(data.resetReaderIndex().toString(CharsetUtil.UTF_8), - serverOut.toString(CharsetUtil.UTF_8.name())); + assertEquals(data.resetReaderIndex().toString(CharsetUtil.ISO_8859_1), + serverOut.toString(CharsetUtil.ISO_8859_1.name())); } finally { data.release(); } @@ -417,6 +359,7 @@ private void bootstrapEnv(int serverOutSize) throws Exception { final CountDownLatch prefaceWrittenLatch = new CountDownLatch(1); serverOut = new ByteArrayOutputStream(serverOutSize); serverLatch = new CountDownLatch(1); + serverException.set(null); sb = new ServerBootstrap(); cb = new Bootstrap(); @@ -438,10 +381,14 @@ public Integer answer(InvocationOnMock in) throws Throwable { int padding = (Integer) in.getArguments()[3]; int processedBytes = buf.readableBytes() + padding; + maxServerOutBufferSize = Math.max(maxServerOutBufferSize, buf.readableBytes()); buf.readBytes(serverOut, buf.readableBytes()); if (in.getArgument(4)) { - serverConnection.stream((Integer) in.getArgument(1)).close(); + Http2Stream stream = serverConnection.stream((Integer) in.getArgument(1)); + if (stream != null) { + stream.close(); + } } return processedBytes; } @@ -466,7 +413,19 @@ protected void initChannel(Channel ch) throws Exception { Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(serverConnection, encoder, new DefaultHttp2FrameReader()); Http2ConnectionHandler connectionHandler = new Http2ConnectionHandlerBuilder() - .frameListener(new DelegatingDecompressorFrameListener(serverConnection, serverListener, 0)) + .frameListener(new DelegatingDecompressorFrameListener(serverConnection, serverListener, + maxAllocation) { + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, + int padding, boolean endOfStream) throws Http2Exception { + try { + return super.onDataRead(ctx, streamId, data, padding, endOfStream); + } catch (Http2Exception e) { + serverException.set(e); + throw e; + } + } + }) .codec(decoder, encoder).build(); p.addLast(connectionHandler); serverChannelLatch.countDown(); @@ -521,6 +480,10 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc private void awaitServer() throws Exception { assertTrue(serverLatch.await(5, SECONDS)); serverOut.flush(); + Throwable cause = serverException.get(); + if (cause != null) { + throw new AssertionError("Server-side decompression error", cause); + } } private ChannelHandlerContext ctxClient() { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java index 4b562221b12..b2039a0846b 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java @@ -486,7 +486,7 @@ public Void answer(InvocationOnMock in) throws Throwable { @Override public Integer answer(InvocationOnMock in) throws Throwable { localFlow.consumeBytes(stream, 4); - throw new RuntimeException("Fake Exception"); + throw Http2TestUtil.FAKE_EXCEPTION; } }).when(listener).onDataRead(eq(ctx), eq(STREAM_ID), any(ByteBuf.class), eq(10), eq(true)); try { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java index a3d3ea07896..60ba05c0693 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java @@ -352,7 +352,7 @@ public void emptyFrameShouldSplitPadding() throws Exception { @Test public void writeHeadersUsingVoidPromise() throws Exception { - final Throwable cause = new RuntimeException("fake exception"); + final Throwable cause = Http2TestUtil.FAKE_EXCEPTION; when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), any(Http2Headers.class), anyInt(), anyBoolean(), any(ChannelPromise.class))) .then(new Answer() { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java index 1574d16828d..2e3ab733e0d 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java @@ -18,6 +18,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http2.Http2Connection.Endpoint; import io.netty.handler.codec.http2.Http2Stream.State; import io.netty.util.concurrent.Future; @@ -34,11 +35,13 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import javax.annotation.Nonnull; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import static java.lang.Integer.MAX_VALUE; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -709,6 +712,183 @@ public void execute() throws Throwable { }); } + @Test + public void defaultSettingsShouldEnforceMaxConcurrentStreamsOnRemoteEndpoint() throws Exception { + // Build a server handler using default settings (no explicit maxConcurrentStreams override). + final Http2ConnectionHandler handler = new Http2ConnectionHandlerBuilder() + .frameListener(new Http2FrameAdapter()) + .build(); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + try { + // Feed the client connection preface. + assertFalse(channel.writeInbound(Http2CodecUtil.connectionPrefaceBuf())); + + ByteBuf clientSettings = clientSettingsWithoutMaxConcurrentStreams(); + assertFalse(channel.writeInbound(clientSettings)); + + Http2Connection connection = handler.connection(); + + // The server's default (SMALLEST_MAX_CONCURRENT_STREAMS = 100) must be + // enforced on the remote endpoint — the one that tracks client-initiated streams. + assertEquals(Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS, + connection.remote().maxActiveStreams()); + + // Create exactly the maximum allowed client-initiated (odd-numbered) streams. + for (int id = 1; id < Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS * 2; id += 2) { + connection.remote().createStream(id, true); + } + assertEquals(Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS, + connection.numActiveStreams()); + + // The next stream must be refused. + final int nextStreamId = Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS * 2 + 1; + Http2Exception e = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + handler.connection().remote().createStream(nextStreamId, true); + } + }); + assertEquals(Http2Error.REFUSED_STREAM, e.error()); + } finally { + channel.finishAndReleaseAll(); + } + } + + @Test + public void customMaxConcurrentStreamsShouldBeEnforcedOnRemoteEndpoint() throws Exception { + final int maxConcurrentStreams = 150; + + final Http2ConnectionHandler handler = new Http2ConnectionHandlerBuilder() + .frameListener(new Http2FrameAdapter()) + .initialSettings(new Http2Settings().maxConcurrentStreams(maxConcurrentStreams)) + .build(); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + try { + assertFalse(channel.writeInbound(Http2CodecUtil.connectionPrefaceBuf())); + + ByteBuf clientSettings = clientSettingsWithoutMaxConcurrentStreams(); + assertFalse(channel.writeInbound(clientSettings)); + + Http2Connection connection = handler.connection(); + + assertEquals(maxConcurrentStreams, connection.remote().maxActiveStreams()); + + // Create exactly the configured limit of client-initiated streams. + for (int id = 1; id < maxConcurrentStreams * 2; id += 2) { + connection.remote().createStream(id, true); + } + assertEquals(maxConcurrentStreams, connection.numActiveStreams()); + + // The next stream must be refused. + final int nextStreamId = maxConcurrentStreams * 2 + 1; + Http2Exception e = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + handler.connection().remote().createStream(nextStreamId, true); + } + }); + assertEquals(Http2Error.REFUSED_STREAM, e.error()); + } finally { + channel.finishAndReleaseAll(); + } + } + + @Test + public void defaultSettingsShouldEnforceMaxConcurrentStreamsOnRemoteEndpointWithCodec() + throws Exception { + final DefaultHttp2Connection connection = new DefaultHttp2Connection(true); + DefaultHttp2FrameWriter frameWriter = new DefaultHttp2FrameWriter(); + Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); + Http2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(connection, encoder, new DefaultHttp2FrameReader()); + + Http2ConnectionHandler handler = new Http2ConnectionHandlerBuilder() + .frameListener(new Http2FrameAdapter()) + .codec(decoder, encoder) + .build(); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + try { + assertFalse(channel.writeInbound(Http2CodecUtil.connectionPrefaceBuf())); + assertFalse(channel.writeInbound(clientSettingsWithoutMaxConcurrentStreams())); + + assertEquals(Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS, + connection.remote().maxActiveStreams()); + + for (int id = 1; id < Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS * 2; id += 2) { + connection.remote().createStream(id, true); + } + assertEquals(Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS, connection.numActiveStreams()); + + final int nextStreamId = Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS * 2 + 1; + Http2Exception e = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + connection.remote().createStream(nextStreamId, true); + } + }); + assertEquals(Http2Error.REFUSED_STREAM, e.error()); + } finally { + channel.finishAndReleaseAll(); + } + } + + @Test + public void customMaxConcurrentStreamsShouldBeEnforcedOnRemoteEndpointWithCodec() + throws Exception { + final int maxConcurrentStreams = 150; + + final DefaultHttp2Connection connection = new DefaultHttp2Connection(true); + DefaultHttp2FrameWriter frameWriter = new DefaultHttp2FrameWriter(); + Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); + Http2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(connection, encoder, new DefaultHttp2FrameReader()); + + Http2ConnectionHandler handler = new Http2ConnectionHandlerBuilder() + .frameListener(new Http2FrameAdapter()) + .initialSettings(new Http2Settings().maxConcurrentStreams(maxConcurrentStreams)) + .codec(decoder, encoder) + .build(); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + try { + assertFalse(channel.writeInbound(Http2CodecUtil.connectionPrefaceBuf())); + assertFalse(channel.writeInbound(clientSettingsWithoutMaxConcurrentStreams())); + + assertEquals(maxConcurrentStreams, connection.remote().maxActiveStreams()); + + for (int id = 1; id < maxConcurrentStreams * 2; id += 2) { + connection.remote().createStream(id, true); + } + assertEquals(maxConcurrentStreams, connection.numActiveStreams()); + + final int nextStreamId = maxConcurrentStreams * 2 + 1; + Http2Exception e = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + connection.remote().createStream(nextStreamId, true); + } + }); + assertEquals(Http2Error.REFUSED_STREAM, e.error()); + } finally { + channel.finishAndReleaseAll(); + } + } + + @Nonnull + private static ByteBuf clientSettingsWithoutMaxConcurrentStreams() { + ByteBuf clientSettings = Unpooled.buffer(); + clientSettings.writeMedium(6); // Payload length: one 6-byte setting + clientSettings.writeByte(0x4); // Frame type: SETTINGS + clientSettings.writeByte(0x0); // Flags + clientSettings.writeInt(0x0); // Stream 0 + clientSettings.writeShort(0x4); // SETTINGS_INITIAL_WINDOW_SIZE + clientSettings.writeInt(65535); + return clientSettings; + } + private static void incrementAndGetStreamShouldSucceed(Endpoint endpoint) throws Http2Exception { Http2Stream streamA = endpoint.createStream(endpoint.incrementAndGetNextStreamId(), true); Http2Stream streamB = endpoint.createStream(streamA.id() + 2, true); @@ -718,7 +898,6 @@ private static void incrementAndGetStreamShouldSucceed(Endpoint endpoint) thr } private static final class ListenerExceptionThrower implements Answer { - private static final RuntimeException FAKE_EXCEPTION = new RuntimeException("Fake Exception"); private final boolean[] array; private final int index; @@ -730,7 +909,7 @@ private static final class ListenerExceptionThrower implements Answer { @Override public Void answer(InvocationOnMock invocation) throws Throwable { array[index] = true; - throw FAKE_EXCEPTION; + throw Http2TestUtil.FAKE_EXCEPTION; } } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java index 35863d6c06e..97117dac95e 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java @@ -109,6 +109,59 @@ public void readHeaderFrameAndContinuationFrame() throws Http2Exception { } } + @Test + public void readHeaderFrameAndContinuationFrameExceedMax() throws Http2Exception { + frameReader = new DefaultHttp2FrameReader(new DefaultHttp2HeadersDecoder(true), 2); + final int streamId = 1; + + final ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + writeHeaderFrame(input, streamId, headers, + new Http2Flags().endOfHeaders(false).endOfStream(true)); + writeContinuationFrame(input, streamId, new DefaultHttp2Headers().add("foo", "bar"), + new Http2Flags().endOfHeaders(false)); + writeContinuationFrame(input, streamId, new DefaultHttp2Headers().add("foo2", "bar2"), + new Http2Flags().endOfHeaders(false)); + + Http2Exception ex = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + assertEquals(Http2Error.ENHANCE_YOUR_CALM, ex.error()); + } finally { + input.release(); + } + } + + @Test + public void readHeaderFrameAndContinuationFrameDontExceedMax() throws Http2Exception { + frameReader = new DefaultHttp2FrameReader(new DefaultHttp2HeadersDecoder(true), 2); + final int streamId = 1; + + final ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + writeHeaderFrame(input, streamId, headers, + new Http2Flags().endOfHeaders(false).endOfStream(true)); + writeContinuationFrame(input, streamId, new DefaultHttp2Headers().add("foo", "bar"), + new Http2Flags().endOfHeaders(false)); + frameReader.readFrame(ctx, input, listener); + } finally { + input.release(); + } + } + @Test public void readUnknownFrame() throws Http2Exception { ByteBuf input = Unpooled.buffer(); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java index 04acf60d55f..23bc51786ce 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java @@ -29,6 +29,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; import io.netty.util.ReferenceCountUtil; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -66,7 +67,7 @@ protected void initChannel(SocketChannel ch) { } }); - ChannelFuture channelFuture = serverBootstrap.bind(0).sync(); + ChannelFuture channelFuture = serverBootstrap.bind(NetUtil.LOCALHOST, 0).sync(); final Bootstrap bootstrap = new Bootstrap() .group(eventLoopGroup) diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java index 0ded0e1ef38..1b77956a82c 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java @@ -53,7 +53,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; /** @@ -140,7 +140,7 @@ public void windowUpdateShouldChangeConnectionWindow() throws Http2Exception { assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_B)); assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_C)); assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_D)); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -151,7 +151,7 @@ public void windowUpdateShouldChangeStreamWindow() throws Http2Exception { assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_B)); assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_C)); assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_D)); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -159,10 +159,10 @@ public void payloadSmallerThanWindowShouldBeWrittenImmediately() throws Http2Exc FakeFlowControlled data = new FakeFlowControlled(5); sendData(STREAM_A, data); data.assertNotWritten(); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); controller.writePendingBytes(); data.assertFullyWritten(); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -172,7 +172,7 @@ public void emptyPayloadShouldBeWrittenImmediately() throws Http2Exception { data.assertNotWritten(); controller.writePendingBytes(); data.assertFullyWritten(); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -238,7 +238,7 @@ public void stalledStreamShouldQueuePayloads() throws Http2Exception { sendData(STREAM_A, moreData); controller.writePendingBytes(); moreData.assertNotWritten(); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -260,7 +260,7 @@ public void queuedPayloadsReceiveErrorOnStreamClose() throws Http2Exception { connection.stream(STREAM_A).close(); data.assertError(Http2Error.STREAM_CLOSED); moreData.assertError(Http2Error.STREAM_CLOSED); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test @@ -724,11 +724,10 @@ public Void answer(InvocationOnMock invocationOnMock) { public void flowControlledWriteAndErrorThrowAnException() throws Exception { final Http2RemoteFlowController.FlowControlled flowControlled = mockedFlowControlledThatThrowsOnWrite(); final Http2Stream stream = stream(STREAM_A); - final RuntimeException fakeException = new RuntimeException("error failed"); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocationOnMock) { - throw fakeException; + throw Http2TestUtil.FAKE_EXCEPTION; } }).when(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class)); @@ -741,14 +740,14 @@ public void execute() throws Throwable { controller.writePendingBytes(); } }); - assertSame(fakeException, e.getCause()); + assertSame(Http2TestUtil.FAKE_EXCEPTION, e.getCause()); verify(flowControlled, atLeastOnce()).write(any(ChannelHandlerContext.class), anyInt()); verify(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class)); verify(flowControlled, never()).writeComplete(); assertEquals(90, windowBefore - window(STREAM_A)); - verifyZeroInteractions(listener); + verifyNoInteractions(listener); } @Test diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/H2PrefaceTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/H2PrefaceTest.java new file mode 100644 index 00000000000..b3165ea0fc9 --- /dev/null +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/H2PrefaceTest.java @@ -0,0 +1,196 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.NetUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class H2PrefaceTest { + + enum OpenMode { + Blocked, + Listener, + SubmitInListener + } + + @ParameterizedTest + @EnumSource(OpenMode.class) + void openStreamAfterBlockingConnect(OpenMode mode) throws Exception { + final StreamRequestResponseListener streamRequestResponseListener = new StreamRequestResponseListener(); + EventLoopGroup eventLoopGroup = new NioEventLoopGroup(); + + Channel backend = new ServerBootstrap() + .group(eventLoopGroup) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(final SocketChannel ch) { + ch.pipeline().addLast(Http2FrameCodecBuilder.forServer().build()); + ch.pipeline().addLast(new Http2MultiplexHandler(new ChannelInitializer() { + @Override + protected void initChannel(final Http2StreamChannel ch) { + ch.pipeline().addLast(new H2ServerHandler()); + } + })); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + streamRequestResponseListener.responseHeaders.completeExceptionally(cause); + } + }); + } + }) + .bind(NetUtil.LOCALHOST, 0) + .sync() + .channel(); + + ChannelFuture cf = new Bootstrap() + .group(eventLoopGroup) + .channel(NioSocketChannel.class) + .remoteAddress(backend.localAddress()) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(final SocketChannel ch) { + ch.pipeline().addLast(Http2FrameCodecBuilder.forClient() + .initialSettings(Http2Settings.defaultSettings()) + .build()); + ch.pipeline().addLast(new Http2MultiplexHandler(new ChannelInboundHandlerAdapter())); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + streamRequestResponseListener.responseHeaders.completeExceptionally(cause); + } + }); + } + }).connect(); + final Channel channel = cf.channel(); + try { + final Http2StreamChannelBootstrap streamChannelBootstrap = new Http2StreamChannelBootstrap(channel); + + switch (mode) { + case Blocked: + cf.syncUninterruptibly(); + streamChannelBootstrap.open().addListener(streamRequestResponseListener); + break; + case Listener: + cf.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + streamChannelBootstrap.open().addListener(streamRequestResponseListener); + } + }); + + break; + case SubmitInListener: + cf.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + channel.eventLoop().submit(new Runnable() { + @Override + public void run() { + streamChannelBootstrap.open().addListener(streamRequestResponseListener); + } + }); + } + }); + break; + default: + throw new AssertionError(); + } + + assertEquals("200", streamRequestResponseListener.responseHeaders.get( + 5, TimeUnit.SECONDS).headers().status().toString()); + } finally { + channel.close().sync(); + backend.close().sync(); + eventLoopGroup.shutdownGracefully().sync(); + } + } + + private static class H2ServerHandler extends ChannelInboundHandlerAdapter { + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) { + if (msg instanceof Http2HeadersFrame) { + final Http2Headers responseHeaders = new DefaultHttp2Headers().status("200"); + ctx.write(new DefaultHttp2HeadersFrame(responseHeaders, false)); + ctx.writeAndFlush( + new DefaultHttp2DataFrame(Unpooled.copiedBuffer("hello world", StandardCharsets.UTF_8), true)); + } + } + } + + /// Send a request and wait for a response once an Http2StreamChannel is established + private static final class StreamRequestResponseListener implements + GenericFutureListener> { + private final CompletableFuture responseHeaders = + new CompletableFuture(); + + @Override + public void operationComplete(final Future future) { + final Http2StreamChannel streamChannel = future.getNow(); + streamChannel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof Http2HeadersFrame) { + responseHeaders.complete((Http2HeadersFrame) msg); + } + ctx.fireChannelRead(msg); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + responseHeaders.completeExceptionally(cause); + } + }); + final Http2Headers headers = new DefaultHttp2Headers() + .method("GET") + .path("/test") + .scheme("http"); + final Http2HeadersFrame headersFrame = new DefaultHttp2HeadersFrame(headers, true); + streamChannel.writeAndFlush(headersFrame).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture f) throws Exception { + if (!f.isSuccess()) { + responseHeaders.completeExceptionally(f.cause()); + } + } + }); + } + } +} diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java index 4c48e2780dc..db57f276cfd 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java @@ -53,10 +53,13 @@ import java.util.concurrent.atomic.AtomicBoolean; import static io.netty.buffer.Unpooled.copiedBuffer; +import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf; +import static io.netty.handler.codec.http2.Http2CodecUtil.writeFrameHeaderInternal; import static io.netty.handler.codec.http2.Http2Error.CANCEL; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED; +import static io.netty.handler.codec.http2.Http2FrameTypes.SETTINGS; import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED; import static io.netty.handler.codec.http2.Http2Stream.State.IDLE; import static io.netty.handler.codec.http2.Http2TestUtil.newVoidPromise; @@ -78,7 +81,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; /** @@ -154,7 +157,7 @@ public void setup() throws Exception { DefaultChannelConfig config = new DefaultChannelConfig(channel); when(channel.config()).thenReturn(config); - Throwable fakeException = new RuntimeException("Fake exception"); + Throwable fakeException = Http2TestUtil.FAKE_EXCEPTION; when(encoder.connection()).thenReturn(connection); when(decoder.connection()).thenReturn(connection); when(encoder.frameWriter()).thenReturn(frameWriter); @@ -303,9 +306,11 @@ public void clientShouldSendClientPrefaceStringWhenActive() throws Exception { when(connection.isServer()).thenReturn(false); when(channel.isActive()).thenReturn(false); handler = newHandler(); + verify(ctx, never()).flush(); when(channel.isActive()).thenReturn(true); handler.channelActive(ctx); verify(ctx).write(eq(connectionPrefaceBuf())); + verify(ctx).flush(); } @Test @@ -313,9 +318,29 @@ public void serverShouldNotSendClientPrefaceStringWhenActive() throws Exception when(connection.isServer()).thenReturn(true); when(channel.isActive()).thenReturn(false); handler = newHandler(); + verify(ctx, never()).flush(); when(channel.isActive()).thenReturn(true); handler.channelActive(ctx); verify(ctx, never()).write(eq(connectionPrefaceBuf())); + verify(ctx).flush(); + } + + @Test + public void clientShouldSendClientPrefaceStringWhenAddedAfterActive() throws Exception { + when(connection.isServer()).thenReturn(false); + when(channel.isActive()).thenReturn(true); + handler = newHandler(); + verify(ctx).write(eq(connectionPrefaceBuf())); + verify(ctx).flush(); + } + + @Test + public void serverShouldNotSendClientPrefaceStringWhenAddedAfterActive() throws Exception { + when(connection.isServer()).thenReturn(true); + when(channel.isActive()).thenReturn(true); + handler = newHandler(); + verify(ctx, never()).write(eq(connectionPrefaceBuf())); + verify(ctx).flush(); } @Test @@ -329,6 +354,20 @@ public void serverReceivingInvalidClientPrefaceStringShouldHandleException() thr assertEquals(0, captor.getValue().refCnt()); } + @Test + public void serverReceivingInvalidClientSettingsAfterPrefaceShouldHandleException() throws Exception { + ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, 0, SETTINGS, new Http2Flags().ack(true), 0); + + when(connection.isServer()).thenReturn(true); + handler = newHandler(); + handler.channelRead(ctx, Unpooled.wrappedBuffer(connectionPrefaceBuf(), buf)); + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + verify(frameWriter).writeGoAway(eq(ctx), eq(Integer.MAX_VALUE), eq(PROTOCOL_ERROR.code()), + captor.capture(), eq(promise)); + assertEquals(0, captor.getValue().refCnt()); + } + @Test public void serverReceivingHttp1ClientPrefaceStringShouldIncludePreface() throws Exception { when(connection.isServer()).thenReturn(true); @@ -687,7 +726,6 @@ public void canSendGoAwayUsingVoidPromise() throws Exception { ByteBuf data = dummyData(); long errorCode = Http2Error.INTERNAL_ERROR.code(); handler = newHandler(); - final Throwable cause = new RuntimeException("fake exception"); doAnswer(new Answer() { @Override public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { @@ -698,12 +736,12 @@ public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { new SimpleChannelPromiseAggregator(promise, channel, ImmediateEventExecutor.INSTANCE); aggregatedPromise.newPromise(); aggregatedPromise.doneAllocatingPromises(); - return aggregatedPromise.setFailure(cause); + return aggregatedPromise.setFailure(Http2TestUtil.FAKE_EXCEPTION); } }).when(frameWriter).writeGoAway( any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class), any(ChannelPromise.class)); handler.goAway(ctx, STREAM_ID, errorCode, data, newVoidPromise(channel)); - verify(pipeline).fireExceptionCaught(cause); + verify(pipeline).fireExceptionCaught(Http2TestUtil.FAKE_EXCEPTION); } @Test @@ -716,7 +754,8 @@ public void canCloseStreamWithVoidPromise() throws Exception { @Test public void channelReadCompleteTriggersFlush() throws Exception { - handler = newHandler(); + // Create the handler in a way that it will flush the preface by itself + handler = newHandler(false); handler.channelReadComplete(ctx); verify(ctx, times(1)).flush(); } @@ -748,7 +787,7 @@ public void clientChannelClosedDoesNotSendGoAwayBeforePreface() throws Exception handler = newHandler(); when(channel.isActive()).thenReturn(true); handler.close(ctx, promise); - verifyZeroInteractions(frameWriter); + verifyNoInteractions(frameWriter); } @Test @@ -854,7 +893,6 @@ public void operationComplete(ChannelFuture future) { private void writeRstStreamUsingVoidPromise(int streamId) throws Exception { handler = newHandler(); - final Throwable cause = new RuntimeException("fake exception"); when(stream.id()).thenReturn(STREAM_ID); when(frameWriter.writeRstStream(eq(ctx), eq(streamId), anyLong(), any(ChannelPromise.class))) .then(new Answer() { @@ -862,12 +900,12 @@ private void writeRstStreamUsingVoidPromise(int streamId) throws Exception { public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable { ChannelPromise promise = invocationOnMock.getArgument(3); assertFalse(promise.isVoid()); - return promise.setFailure(cause); + return promise.setFailure(Http2TestUtil.FAKE_EXCEPTION); } }); handler.resetStream(ctx, streamId, STREAM_CLOSED.code(), newVoidPromise(channel)); verify(frameWriter).writeRstStream(eq(ctx), eq(streamId), anyLong(), any(ChannelPromise.class)); - verify(pipeline).fireExceptionCaught(cause); + verify(pipeline).fireExceptionCaught(Http2TestUtil.FAKE_EXCEPTION); } private static ByteBuf dummyData() { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java index ba35086c735..34d8766f293 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java @@ -269,7 +269,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { }).when(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(5), eq(headers), anyInt(), anyShort(), anyBoolean(), eq(0), eq(true)); - bootstrapEnv(1, 2, 2, 0, 0); + bootstrapEnv(1, 2, 2, 0, 0, -1); // Set the maxHeaderListSize to 100 so we may be able to write some headers, but not all. We want to verify // that we don't corrupt state if some can be written but not all. @@ -624,7 +624,7 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { @Test public void listenerExceptionShouldCloseConnection() throws Exception { final Http2Headers headers = dummyHeaders(); - doThrow(new RuntimeException("Fake Exception")).when(serverListener).onHeadersRead( + doThrow(Http2TestUtil.FAKE_EXCEPTION).when(serverListener).onHeadersRead( any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0), eq((short) 16), eq(false), eq(0), eq(false)); @@ -817,7 +817,7 @@ public void run() throws Http2Exception { clientChannel.pipeline().addFirst(new ChannelHandlerAdapter() { @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - throw new RuntimeException("Fake Exception"); + throw Http2TestUtil.FAKE_EXCEPTION; } }); @@ -831,7 +831,7 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { @Test public void noMoreStreamIdsShouldSendGoAway() throws Exception { - bootstrapEnv(1, 1, 4, 1, 1); + bootstrapEnv(1, 1, 4, 1, 1, -1); // Don't wait for the server to close streams setClientGracefulShutdownTime(0); @@ -874,7 +874,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { } }).when(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class)); - bootstrapEnv(1, 1, 2, 1, 1); + bootstrapEnv(1, 1, 2, 1, 1, -1); // We want both sides to do graceful shutdown during the test. setClientGracefulShutdownTime(10000); @@ -959,7 +959,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { } }).when(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class)); - bootstrapEnv(1, 1, 3, 1, 1); + bootstrapEnv(1, 1, 3, 1, 1, -1); // We want both sides to do graceful shutdown during the test. setClientGracefulShutdownTime(10000); @@ -1133,7 +1133,7 @@ public Integer answer(InvocationOnMock in) throws Throwable { }).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean()); try { - bootstrapEnv(numStreams * length, 1, numStreams * 4 + 1 , numStreams); + bootstrapEnv(numStreams * length, 1, numStreams * 4 + 1 , numStreams, -1, numStreams); runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { @@ -1177,13 +1177,82 @@ public void run() throws Http2Exception { } } + @Test + public void serverShouldNotEnforceClientAdvertisedMaxHeaderListSize() throws Exception { + // Verifies that SETTINGS_MAX_HEADER_LIST_SIZE sent by a client is treated as advisory + // (per RFC 9113 §6.5.2) and does not prevent the server from encoding response headers. + final CountDownLatch clientSettingsAckLatch = new CountDownLatch(2); + final CountDownLatch responseLatch = new CountDownLatch(1); + final AtomicReference serverWriteError = new AtomicReference(); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + final ChannelHandlerContext sCtx = serverCtx(); + final int streamId = (Integer) invocationOnMock.getArgument(1); + Http2Headers responseHeaders = new DefaultHttp2Headers().status("200"); + http2Server.encoder().writeHeaders(sCtx, streamId, responseHeaders, 0, true, sCtx.newPromise()) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + serverWriteError.set(future.cause()); + responseLatch.countDown(); + } + }); + http2Server.flush(sCtx); + return null; + } + }).when(serverListener).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), + anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean()); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + clientSettingsAckLatch.countDown(); + return null; + } + }).when(clientListener).onSettingsAckRead(any(ChannelHandlerContext.class)); + + bootstrapEnv(0, 1, 2, 0); + + // Client advertises a tiny MAX_HEADER_LIST_SIZE (2 bytes) to the server. + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeSettings(ctx(), + new Http2Settings().maxHeaderListSize(2), + newPromise()); + http2Client.flush(ctx()); + } + }); + + // Wait for the server to acknowledge both the initial settings and our custom settings. + assertTrue(clientSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + // Send a request; the server will attempt to respond with headers far exceeding 2 bytes. + final short weight = 16; + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, dummyHeaders(), 0, weight, false, 0, true, + newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(responseLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertNull(serverWriteError.get(), + "Server must succeed writing response headers regardless of client's SETTINGS_MAX_HEADER_LIST_SIZE"); + } + private void bootstrapEnv(int dataCountDown, int settingsAckCount, int requestCountDown, int trailersCountDown) throws Exception { - bootstrapEnv(dataCountDown, settingsAckCount, requestCountDown, trailersCountDown, -1); + bootstrapEnv(dataCountDown, settingsAckCount, requestCountDown, trailersCountDown, -1, -1); } private void bootstrapEnv(int dataCountDown, int settingsAckCount, - int requestCountDown, int trailersCountDown, int goAwayCountDown) throws Exception { + int requestCountDown, int trailersCountDown, int goAwayCountDown, final long maxConcurrentStreams) + throws Exception { final CountDownLatch prefaceWrittenLatch = new CountDownLatch(1); requestLatch = new CountDownLatch(requestCountDown); serverSettingsAckLatch = new CountDownLatch(settingsAckCount); @@ -1205,11 +1274,14 @@ protected void initChannel(Channel ch) throws Exception { serverFrameCountDown = new FrameCountDown(serverListener, serverSettingsAckLatch, requestLatch, dataLatch, trailersLatch, goAwayLatch); - serverHandlerRef.set(new Http2ConnectionHandlerBuilder() + Http2ConnectionHandlerBuilder builder = new Http2ConnectionHandlerBuilder() .server(true) .frameListener(serverFrameCountDown) - .validateHeaders(false) - .build()); + .validateHeaders(false); + if (maxConcurrentStreams != -1) { + builder.initialSettings(Http2Settings.defaultSettings().maxConcurrentStreams(maxConcurrentStreams)); + } + serverHandlerRef.set(builder.build()); p.addLast(serverHandlerRef.get()); serverInitLatch.countDown(); } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java index c16eba07673..485499dbb72 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java @@ -1016,4 +1016,23 @@ private void assertInboundStreamFrame(int expectedId, Http2StreamFrame streamFra private ChannelHandlerContext eqFrameCodecCtx() { return eq(frameCodec.ctx); } + + @Test + public void invalidPayloadLength() throws Exception { + frameInboundWriter.writeInboundSettings(new Http2Settings()); + channel.writeInbound(Unpooled.wrappedBuffer(new byte[]{ + 0, 0, 4, // length + 0, // type: DATA + 9, // flags: PADDED, END_STREAM + 1, 0, 0, 0, // stream id + 4, // pad length + 0, 0, 0 // not enough space for padding + })); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + inboundHandler.checkException(); + } + }); + } } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java index 2356ca33a8b..b77c1a2bc63 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java @@ -57,6 +57,19 @@ * Utilities for the integration tests. */ public final class Http2TestUtil { + /** + * A fake exception that can be used in tests to simulate errors. The stack trace is not filled in to avoid + * unnecessary overhead. + */ + static final RuntimeException FAKE_EXCEPTION = new RuntimeException("Fake exception") { + private static final long serialVersionUID = -8316972447187527869L; + + @Override + public Throwable fillInStackTrace() { + return this; + } + }; + /** * Interface that allows for running a operation that throws a {@link Http2Exception}. */ diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpConversionUtilFuzzTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpConversionUtilFuzzTest.java new file mode 100644 index 00000000000..9e6b59c4739 --- /dev/null +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpConversionUtilFuzzTest.java @@ -0,0 +1,177 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import com.code_intelligence.jazzer.api.FuzzedDataProvider; +import com.code_intelligence.jazzer.junit.FuzzTest; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpScheme; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.AsciiString; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.net.URI; + +import static io.netty.util.internal.StringUtil.isNullOrEmpty; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +// Netty 4.1 CI still uses old Linux images whose glibc is too old for Jazzer's native driver. +@EnabledIfEnvironmentVariable(named = "JAZZER_FUZZ", matches = "1") +public class HttpConversionUtilFuzzTest { + + @FuzzTest(maxDuration = "30s") + public void currentConversionMatchesOldUriBasedConversion(final FuzzedDataProvider data) { + String requestTarget = data.consumeString(128); + HttpRequest msg = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, + HttpMethod.GET, + requestTarget, + new DefaultHttpHeaders(), + false); + msg.headers().set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), HttpScheme.HTTP.name()); + + Http2Headers oldHeaders; + try { + oldHeaders = oldToHttp2Headers(msg); + } catch (IllegalArgumentException e) { + return; + } + + Http2Headers newHeaders = HttpConversionUtil.toHttp2Headers(msg, false); + if (!oldHeaders.path().equals(newHeaders.path())) { + assertTrue(isKnownPathCompatibilityException(requestTarget), requestTarget); + } + assertEquals(oldHeaders.scheme(), newHeaders.scheme()); + assertEquals(oldHeaders.authority(), newHeaders.authority()); + assertEquals(oldHeaders.method(), newHeaders.method()); + } + + private static boolean isKnownPathCompatibilityException(final String requestTarget) { + // The old URI-based oracle only diverges on a few legacy RFC 2396-style forms: + // 1) Opaque scheme-specific targets like "x:y" or "x:foo" where URI path becomes "/". + // 2) Absolute-path targets with a scheme but no authority like "x:/path", where the old + // URI oracle strips the scheme while the new path logic preserves the raw path. + // 3) Absolute-form targets with no path slash, where Vert.x parsePath returns "/" before + // parseQuery is appended. + // 4) Malformed fragment-before-query targets like "#?", where Vert.x-shaped parsePath keeps + // '#' in the path but URI treats the following '?' as fragment data. + // 5) Empty query/fragment delimiters like "?#", which URI drops while Vert.x-shaped parsing + // keeps the delimiters as raw path/query syntax. + return isOpaqueSchemeSpecificPart(requestTarget) || isSchemeOnlyAbsolutePath(requestTarget) + || isAbsoluteFormWithoutPathSlash(requestTarget) || hasFragmentBeforeQuery(requestTarget) + || hasEmptyQueryAndFragmentDelimiters(requestTarget); + } + + private static boolean isOpaqueSchemeSpecificPart(final String requestTarget) { + int schemeEnd = requestTarget.indexOf(':'); + return HttpConversionUtil.isValidScheme(requestTarget, schemeEnd) && schemeEnd + 1 < requestTarget.length() + && requestTarget.charAt(schemeEnd + 1) != '/'; + } + + private static boolean isSchemeOnlyAbsolutePath(final String requestTarget) { + int schemeEnd = requestTarget.indexOf(':'); + return HttpConversionUtil.isValidScheme(requestTarget, schemeEnd) && schemeEnd + 1 < requestTarget.length() + && requestTarget.charAt(schemeEnd + 1) == '/' + && !HttpConversionUtil.hasSchemeAndAuthority(requestTarget) + && (schemeEnd + 2 >= requestTarget.length() || requestTarget.charAt(schemeEnd + 2) != '/'); + } + + private static boolean isAbsoluteFormWithoutPathSlash(final String requestTarget) { + int schemeEnd = requestTarget.indexOf("://"); + if (!HttpConversionUtil.hasSchemeAndAuthority(requestTarget)) { + return false; + } + int authorityStart = schemeEnd + 3; + int pathStart = requestTarget.indexOf('/', authorityStart); + int delimiter = HttpConversionUtil.queryOrFragmentStart(requestTarget, authorityStart); + return pathStart == -1 || (delimiter != -1 && delimiter < pathStart); + } + + private static boolean hasFragmentBeforeQuery(final String requestTarget) { + int fragmentStart = requestTarget.indexOf('#'); + int queryStart = requestTarget.indexOf('?'); + return fragmentStart != -1 && queryStart != -1 && fragmentStart < queryStart; + } + + private static boolean hasEmptyQueryAndFragmentDelimiters(final String requestTarget) { + return requestTarget.endsWith("?#"); + } + + private static Http2Headers oldToHttp2Headers(final HttpRequest request) { + HttpHeaders inHeaders = request.headers(); + Http2Headers out = new DefaultHttp2Headers(false, inHeaders.size()); + String host = inHeaders.getAsString(HttpHeaderNames.HOST); + if (HttpUtil.isOriginForm(request.uri()) || HttpUtil.isAsteriskForm(request.uri())) { + out.path(new AsciiString(request.uri())); + oldSetHttp2Scheme(inHeaders, URI.create(""), out); + } else { + URI requestTargetUri = URI.create(request.uri()); + out.path(oldToHttp2Path(requestTargetUri)); + host = isNullOrEmpty(host) ? requestTargetUri.getAuthority() : host; + oldSetHttp2Scheme(inHeaders, requestTargetUri, out); + } + HttpConversionUtil.setHttp2Authority(host, out); + out.method(request.method().asciiName()); + return out; + } + + private static AsciiString oldToHttp2Path(final URI uri) { + StringBuilder pathBuilder = new StringBuilder(); + if (!isNullOrEmpty(uri.getRawPath())) { + pathBuilder.append(uri.getRawPath()); + } + if (!isNullOrEmpty(uri.getRawQuery())) { + pathBuilder.append('?'); + pathBuilder.append(uri.getRawQuery()); + } + if (!isNullOrEmpty(uri.getRawFragment())) { + pathBuilder.append('#'); + pathBuilder.append(uri.getRawFragment()); + } + String path = pathBuilder.toString(); + return path.isEmpty() ? new AsciiString("/") : new AsciiString(path); + } + + private static void oldSetHttp2Scheme(final HttpHeaders in, final URI uri, final Http2Headers out) { + String value = uri.getScheme(); + if (!isNullOrEmpty(value)) { + out.scheme(new AsciiString(value)); + return; + } + + CharSequence cValue = in.get(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text()); + if (cValue != null) { + out.scheme(AsciiString.of(cValue)); + return; + } + + if (uri.getPort() == HttpScheme.HTTPS.port()) { + out.scheme(HttpScheme.HTTPS.name()); + } else if (uri.getPort() == HttpScheme.HTTP.port()) { + out.scheme(HttpScheme.HTTP.name()); + } else { + throw new IllegalArgumentException( + ":scheme must be specified. see https://tools.ietf.org/html/rfc7540#section-8.1.2.3"); + } + } +} diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpConversionUtilTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpConversionUtilTest.java index cfd6abfe795..0619801945a 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpConversionUtilTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpConversionUtilTest.java @@ -247,6 +247,126 @@ public void handlesRequestWithDoubleSlashPath() throws Exception { assertEquals(HttpMethod.GET.asciiName(), out.method()); } + @Test + public void handlesAbsoluteRequestWhoseQueryHasUrlCharactersRejectedByJavaNetUri() { + HttpRequest msg = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "https://bh.contextweb.com/bh/rtset?pid=558355&ev=1&us_privacy=${us_privacy}", true); + + Http2Headers out = HttpConversionUtil.toHttp2Headers(msg, false); + + assertEquals(new AsciiString("/bh/rtset?pid=558355&ev=1&us_privacy=${us_privacy}"), out.path()); + assertEquals(new AsciiString("https"), out.scheme()); + assertEquals(new AsciiString("bh.contextweb.com"), out.authority()); + assertEquals(HttpMethod.GET.asciiName(), out.method()); + } + + @Test + public void handlesAbsoluteRequestWhosePathHasUrlCharactersRejectedByJavaNetUri() { + HttpRequest msg = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "http://example.com/orders/{id}/items|details?expand={details}#section", true); + + Http2Headers out = HttpConversionUtil.toHttp2Headers(msg, false); + + assertEquals(new AsciiString("/orders/{id}/items|details?expand={details}#section"), out.path()); + assertEquals(new AsciiString("http"), out.scheme()); + assertEquals(new AsciiString("example.com"), out.authority()); + } + + @Test + public void handlesAbsoluteRequestWithoutPathUsingVertxCompatiblePath() { + HttpRequest msg = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "http://example.com?x=1#frag", true); + + Http2Headers out = HttpConversionUtil.toHttp2Headers(msg, true); + + assertEquals(new AsciiString("/?x=1#frag"), out.path()); + assertEquals(new AsciiString("http"), out.scheme()); + assertEquals(new AsciiString("example.com"), out.authority()); + } + + @Test + public void handlesAbsoluteRequestWithAuthorityOnlyUsingVertxCompatiblePath() { + HttpRequest msg = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "http://example.com", true); + + Http2Headers out = HttpConversionUtil.toHttp2Headers(msg, true); + + assertEquals(new AsciiString("/"), out.path()); + assertEquals(new AsciiString("http"), out.scheme()); + assertEquals(new AsciiString("example.com"), out.authority()); + } + + @Test + public void handlesAbsoluteRequestWithoutPathWhoseQueryOrFragmentContainsSlash() { + HttpRequest querySlash = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "http://example.com?next=/home", true); + HttpRequest fragmentSlash = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "http://example.com#/home", true); + + assertEquals(new AsciiString("/?next=/home"), + HttpConversionUtil.toHttp2Headers(querySlash, true).path()); + assertEquals(new AsciiString("/"), HttpConversionUtil.toHttp2Headers(fragmentSlash, true).path()); + } + + @Test + public void handlesEmptyRequestTargetUsingLegacyEmptyPathFallback() { + HttpRequest msg = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "", true); + msg.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + + Http2Headers out = HttpConversionUtil.toHttp2Headers(msg, true); + + assertEquals(new AsciiString("/"), out.path()); + assertEquals(new AsciiString("http"), out.scheme()); + assertNull(out.authority()); + } + + @Test + public void handlesAbsoluteRequestWithMissingAuthorityUsingUriAuthority() { + HttpRequest msg = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "http://?x=1#frag", true); + + Http2Headers out = HttpConversionUtil.toHttp2Headers(msg, true); + + assertEquals(new AsciiString("/?x=1#frag"), out.path()); + assertEquals(new AsciiString("http"), out.scheme()); + assertNull(out.authority()); + } + + @Test + public void handlesAbsoluteRequestWithEmptyQueryOrFragmentUsingVertxCompatiblePath() { + HttpRequest emptyQuery = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "http://example.com/path?", true); + HttpRequest emptyFragment = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "http://example.com/path#", true); + HttpRequest emptyQueryWithFragment = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "http://example.com/path?#frag", true); + HttpRequest queryWithEmptyFragment = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "http://example.com/path?x#", true); + + assertEquals(new AsciiString("/path"), HttpConversionUtil.toHttp2Headers(emptyQuery, true).path()); + assertEquals(new AsciiString("/path"), HttpConversionUtil.toHttp2Headers(emptyFragment, true).path()); + assertEquals(new AsciiString("/path#frag"), + HttpConversionUtil.toHttp2Headers(emptyQueryWithFragment, true).path()); + assertEquals(new AsciiString("/path?x"), + HttpConversionUtil.toHttp2Headers(queryWithEmptyFragment, true).path()); + } + + @Test + public void rejectsAbsoluteRequestWithMalformedAuthority() { + final HttpRequest msg = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, + "http://[bad host]/p?q={x}", + new DefaultHttpHeaders(), + false); + + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + HttpConversionUtil.toHttp2Headers(msg, false); + } + }); + } + @Test public void addHttp2ToHttpHeadersCombinesCookies() throws Http2Exception { Http2Headers inHeaders = new DefaultHttp2Headers(); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterTest.java index 092f2c20e10..dc9626c3c7b 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterTest.java @@ -290,6 +290,74 @@ public void run() throws Http2Exception { } } + @Test + public void exceedMaxContentLengthShouldCauseStreamErrorNotConnectionError() throws Exception { + // Verify that exceeding maxContentLength causes a stream error (RST_STREAM) + // not a connection error (GOAWAY), so other streams can continue. + // This is the fix for https://github.com/netty/netty/issues/11994 + // + // clientLatch=1: RST_STREAM for stream 3 triggers onRstStreamRead on the client, + // which fires exceptionCaught and counts down clientLatch. + // With a connection error (GOAWAY), no RST_STREAM is sent, so clientLatch + // never counts down and awaitResponses() times out — failing the test. + // serverLatch=1: stream 5 request should be delivered normally. + boostrapEnv(1, 1, 1); + final byte[] oversizedData = new byte[maxContentLength + 1]; + final ByteBuf oversizedContent = Unpooled.wrappedBuffer(oversizedData); + final String normalText = "hello"; + final ByteBuf normalContent = Unpooled.copiedBuffer(normalText.getBytes()); + final FullHttpRequest expectedRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "/normal/path", normalContent.copy(), true); + try { + HttpHeaders httpHeaders = expectedRequest.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, normalText.length()); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + final Http2Headers oversizedHeaders = new DefaultHttp2Headers().method(new AsciiString("POST")).path( + new AsciiString("/oversized/path")); + final Http2Headers normalHeaders = new DefaultHttp2Headers().method(new AsciiString("GET")).path( + new AsciiString("/normal/path")); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + // Stream 3: send data exceeding maxContentLength - should cause stream error + clientHandler.encoder().writeHeaders(ctxClient(), 3, oversizedHeaders, 0, false, + newPromiseClient()); + clientHandler.encoder().writeData(ctxClient(), 3, oversizedContent, 0, true, + newPromiseClient()); + // Stream 5: send a normal request - should succeed if connection is still alive + clientHandler.encoder().writeHeaders(ctxClient(), 5, normalHeaders, 0, false, + newPromiseClient()); + clientHandler.encoder().writeData(ctxClient(), 5, normalContent, 0, true, + newPromiseClient()); + clientChannel.flush(); + } + }); + + // Verify stream 5 is delivered successfully on the server + awaitRequests(); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(serverListener).messageReceived(requestCaptor.capture()); + capturedRequests = requestCaptor.getAllValues(); + assertEquals(expectedRequest, capturedRequests.get(0)); + + // Verify the client received RST_STREAM (not GOAWAY) for the oversized stream. + // The server's onStreamError sends RST_STREAM, which the client's + // InboundHttp2ToHttpAdapter.onRstStreamRead translates into an exceptionCaught + // event carrying a StreamException — this counts down clientLatch. + // With a connection error, the server sends GOAWAY instead — no RST_STREAM is + // received, exceptionCaught never fires, and awaitResponses() times out. + awaitResponses(); + assertNotNull(clientException); + assertTrue(isStreamError(clientException)); + Http2Exception.StreamException streamEx = (Http2Exception.StreamException) clientException; + assertEquals(3, streamEx.streamId()); + assertEquals(Http2Error.ENHANCE_YOUR_CALM, streamEx.error()); + } finally { + expectedRequest.release(); + } + } + @Test public void clientRequestMultipleDataFrames() throws Exception { boostrapEnv(1, 1, 1); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java index 9dec606d712..0b0c34385f1 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java @@ -109,7 +109,9 @@ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exceptio @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - queue.add(msg); + synchronized (queue) { + queue.add(msg); + } } @Override @@ -119,7 +121,9 @@ public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - queue.add(new UserEvent(evt)); + synchronized (queue) { + queue.add(new UserEvent(evt)); + } } @Override @@ -142,11 +146,13 @@ public void checkException() throws Exception { @SuppressWarnings("unchecked") public T readInbound() { - for (int i = 0; i < queue.size(); i++) { - Object o = queue.get(i); - if (!(o instanceof UserEvent)) { - queue.remove(i); - return (T) o; + synchronized (queue) { + for (int i = 0; i < queue.size(); i++) { + Object o = queue.get(i); + if (!(o instanceof UserEvent)) { + queue.remove(i); + return (T) o; + } } } @@ -163,11 +169,13 @@ public T blockingReadInbound() { @SuppressWarnings("unchecked") public T readUserEvent() { - for (int i = 0; i < queue.size(); i++) { - Object o = queue.get(i); - if (o instanceof UserEvent) { - queue.remove(i); - return (T) ((UserEvent) o).evt; + synchronized (queue) { + for (int i = 0; i < queue.size(); i++) { + Object o = queue.get(i); + if (o instanceof UserEvent) { + queue.remove(i); + return (T) ((UserEvent) o).evt; + } } } @@ -179,14 +187,16 @@ public T readUserEvent() { */ @SuppressWarnings("unchecked") public T readInboundMessageOrUserEvent() { - if (queue.isEmpty()) { - return null; - } - Object o = queue.remove(0); - if (o instanceof UserEvent) { - return (T) ((UserEvent) o).evt; + synchronized (queue) { + if (queue.isEmpty()) { + return null; + } + Object o = queue.remove(0); + if (o instanceof UserEvent) { + return (T) ((UserEvent) o).evt; + } + return (T) o; } - return (T) o; } public void writeOutbound(Object... msgs) throws Exception { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorTest.java index f1e22948b18..27f78d5851c 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorTest.java @@ -128,8 +128,7 @@ public void connectionErrorForWriterException() throws Http2Exception { initState(STREAM_C, 3, true); initState(STREAM_D, 4, true); - Exception fakeException = new RuntimeException("Fake exception"); - doThrow(fakeException).when(writer).write(same(stream(STREAM_C)), eq(3)); + doThrow(Http2TestUtil.FAKE_EXCEPTION).when(writer).write(same(stream(STREAM_C)), eq(3)); Http2Exception e = assertThrows(Http2Exception.class, new Executable() { @Override @@ -139,7 +138,7 @@ public void execute() throws Throwable { }); assertFalse(Http2Exception.isStreamError(e)); assertEquals(Http2Error.INTERNAL_ERROR, e.error()); - assertSame(fakeException, e.getCause()); + assertSame(Http2TestUtil.FAKE_EXCEPTION, e.getCause()); verifyWrite(atMost(1), STREAM_A, 1); verifyWrite(atMost(1), STREAM_B, 2); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorTest.java index c082fba8983..7aea0071acf 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorTest.java @@ -150,8 +150,7 @@ public void connectionErrorForWriterException() throws Http2Exception { initState(STREAM_C, 3, true); initState(STREAM_D, 4, true); - Exception fakeException = new RuntimeException("Fake exception"); - doThrow(fakeException).when(writer).write(same(stream(STREAM_C)), eq(3)); + doThrow(Http2TestUtil.FAKE_EXCEPTION).when(writer).write(same(stream(STREAM_C)), eq(3)); Http2Exception e = assertThrows(Http2Exception.class, new Executable() { @Override @@ -161,7 +160,7 @@ public void execute() throws Throwable { }); assertFalse(Http2Exception.isStreamError(e)); assertEquals(Http2Error.INTERNAL_ERROR, e.error()); - assertSame(fakeException, e.getCause()); + assertSame(Http2TestUtil.FAKE_EXCEPTION, e.getCause()); verifyWrite(atMost(1), STREAM_A, 1); verifyWrite(atMost(1), STREAM_B, 2); diff --git a/codec-memcache/pom.xml b/codec-memcache/pom.xml index f5b785a1331..afde7417da1 100644 --- a/codec-memcache/pom.xml +++ b/codec-memcache/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec-memcache diff --git a/codec-mqtt/pom.xml b/codec-mqtt/pom.xml index 1a6aaabe7a1..dba529a4e19 100644 --- a/codec-mqtt/pom.xml +++ b/codec-mqtt/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec-mqtt diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java index 49d7c9b8b1f..86bd8b04535 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java @@ -24,6 +24,7 @@ import io.netty.handler.codec.mqtt.MqttDecoder.DecoderState; import io.netty.handler.codec.mqtt.MqttProperties.IntegerProperty; import io.netty.util.CharsetUtil; +import io.netty.util.Signal; import io.netty.util.internal.ObjectUtil; import java.util.ArrayList; @@ -96,8 +97,21 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List ou case READ_VARIABLE_HEADER: try { int bytesRemainingBeforeVariableHeader = bytesRemainingInVariablePart; - variableHeader = decodeVariableHeader(ctx, buffer, mqttFixedHeader); - if (bytesRemainingBeforeVariableHeader > maxBytesInMessage) { + boolean bailOut = false; + try { + variableHeader = decodeVariableHeader(ctx, buffer, mqttFixedHeader); + } catch (Signal signal) { + if (bytesRemainingBeforeVariableHeader > maxBytesInMessage) { + // We couldn't parse the complete message, and it's already too large. + // Swallow the Signal (we don't need more data) and instead bail out + // and throw the TooLongFrameException below. + bailOut = true; + } else { + // Ask for REPLAY if the current message is within maxBytesInMessage. + throw signal; + } + } + if (bailOut || bytesRemainingBeforeVariableHeader > maxBytesInMessage) { buffer.skipBytes(actualReadableBytes()); throw new TooLongFrameException("message length exceeds " + maxBytesInMessage + ": " + bytesRemainingBeforeVariableHeader); @@ -494,7 +508,11 @@ private Object decodePayload( return decodePublishPayload(buffer); default: - // unknown payload , no byte consumed + // No payload for this message type. If the fixed header's Remaining Length + // claimed bytes beyond what the variable header consumed (e.g. a PINGREQ + // with non-zero Remaining Length), the frame is malformed. + // See https://github.com/netty/netty/issues/16851 + validateNoBytesRemain(0); return null; } } @@ -731,6 +749,10 @@ private static Result decodeProperties(ByteBuf buffer) { final long propertiesLength = decodeVariableByteInteger(buffer); int totalPropertiesLength = unpackA(propertiesLength); int numberOfBytesConsumed = unpackB(propertiesLength); + if (buffer.readableBytes() < totalPropertiesLength) { + // Force an early REPLAY to avoid repeatedly parsing the properties. + buffer.readSlice(totalPropertiesLength); + } MqttProperties decodedProperties = new MqttProperties(); while (numberOfBytesConsumed < totalPropertiesLength) { diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java index 9a601ead1c2..729efcf3dc0 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java @@ -118,8 +118,9 @@ private static ByteBuf encodeConnectMessage( (byte) variableHeader.version()); setMqttVersion(ctx, mqttVersion); - // as MQTT 3.1 & 3.1.1 spec, If the User Name Flag is set to 0, the Password Flag MUST be set to 0 - if (!variableHeader.hasUserName() && variableHeader.hasPassword()) { + // MQTT 3.1 and 3.1.1 require the Password Flag to be 0 when the User Name Flag is 0. + if ((mqttVersion == MqttVersion.MQTT_3_1 || mqttVersion == MqttVersion.MQTT_3_1_1) && + !variableHeader.hasUserName() && variableHeader.hasPassword()) { throw new EncoderException("Without a username, the password MUST be not set"); } @@ -287,7 +288,7 @@ private static ByteBuf encodeSubscribeMessage( // Payload for (MqttTopicSubscription topic : payload.topicSubscriptions()) { - writeUnsafeUTF8String(buf, topic.topicName()); + writeEagerUTF8String(buf, topic.topicName()); if (mqttVersion == MqttVersion.MQTT_3_1_1 || mqttVersion == MqttVersion.MQTT_3_1) { buf.writeByte(topic.qualityOfService().value()); } else { @@ -347,7 +348,7 @@ private static ByteBuf encodeUnsubscribeMessage( // Payload for (String topicName : payload.topics()) { - writeUnsafeUTF8String(buf, topicName); + writeEagerUTF8String(buf, topicName); } return buf; @@ -720,15 +721,6 @@ private static void writeEagerUTF8String(ByteBuf buf, String s) { buf.setShort(writerIndex, utf8Length); } - private static void writeUnsafeUTF8String(ByteBuf buf, String s) { - final int writerIndex = buf.writerIndex(); - final int startUtf8String = writerIndex + 2; - // no need to reserve any capacity here, already done earlier: that's why is Unsafe - buf.writerIndex(startUtf8String); - final int utf8Length = s != null? reserveAndWriteUtf8(buf, s, 0) : 0; - buf.setShort(writerIndex, utf8Length); - } - private static int getVariableLengthInt(int num) { int count = 0; do { diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttProperties.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttProperties.java index 04a52525d15..e65b8b15599 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttProperties.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttProperties.java @@ -361,7 +361,9 @@ public Collection listAll() { public boolean isEmpty() { IntObjectHashMap props = this.props; - return props == null || props.isEmpty(); + return (props == null || props.isEmpty()) && + (userProperties == null || userProperties.isEmpty()) && + (subscriptionIds == null || subscriptionIds.isEmpty()); } /** diff --git a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java index 11fc238c595..1862db6202c 100644 --- a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java +++ b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java @@ -21,6 +21,7 @@ import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.DecoderException; import io.netty.handler.codec.EncoderException; import io.netty.handler.codec.TooLongFrameException; @@ -266,7 +267,24 @@ private void checkForSingleDecoderException(final List out) { } @Test - public void testConnectMessageNoPassword() throws Exception { + public void testConnectMessagePasswordOnlyForMqtt31() throws Exception { + final MqttConnectMessage message = createConnectMessage( + MqttVersion.MQTT_3_1, + null, + PASSWORD, + MqttProperties.NO_PROPERTIES, + MqttProperties.NO_PROPERTIES); + + assertThrows(EncoderException.class, new Executable() { + @Override + public void execute() { + MqttEncoder.doEncode(ctx, message); + } + }); + } + + @Test + public void testConnectMessagePasswordOnlyForMqtt311() throws Exception { final MqttConnectMessage message = createConnectMessage( MqttVersion.MQTT_3_1_1, null, @@ -282,6 +300,31 @@ public void execute() { }); } + @Test + public void testConnectMessagePasswordOnlyForMqtt5() throws Exception { + final MqttConnectMessage message = createConnectMessage( + MqttVersion.MQTT_5, + null, + PASSWORD, + MqttProperties.NO_PROPERTIES, + MqttProperties.NO_PROPERTIES); + + assertFalse(message.variableHeader().hasUserName()); + assertTrue(message.variableHeader().hasPassword()); + + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + + mqttDecoder.channelRead(ctx, byteBuf); + + assertEquals(1, out.size()); + + final MqttConnectMessage decodedMessage = (MqttConnectMessage) out.get(0); + + validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); + validateConnectVariableHeader(message.variableHeader(), decodedMessage.variableHeader()); + validateConnectPayload(message.payload(), decodedMessage.payload()); + } + @Test public void testConnAckMessage() throws Exception { final MqttConnAckMessage message = createConnAckMessage(); @@ -311,6 +354,78 @@ public void testPublishMessage() throws Exception { validatePublishPayload(message.payload(), decodedMessage.payload()); } + @Test + public void testPublishMessageIncompleteVariableHeaderDoesNotUseCumulationSizeForTooLongCheck() throws Exception { + // The leading PUBLISH is hand-crafted rather than going through MqttEncoder because the + // bug under test only triggers when variable-header decoding asks for REPLAY mid-message, + // which in turn requires a deliberately malformed packet (topic-name length prefix larger + // than the bytes we actually supply). MqttEncoder only produces well-formed messages. + final int maxBytesInMessage = 16; + // bytes after the fixed header; < 128 so it fits in a 1-byte Variable Byte Integer. + final int currentPacketRemainingLength = 10; + // > the bytes we write below, so the decoder must REPLAY mid-variable-header. + final int claimedTopicNameLength = 32; + final int followingPingReqPackets = 3; + EmbeddedChannel channel = new EmbeddedChannel(new MqttDecoder(maxBytesInMessage)); + ByteBuf byteBuf = ALLOCATOR.buffer(); + // Leading PUBLISH packet (incomplete - missing most of the topic name): + // Fixed header byte 1: PUBLISH (type 3), DUP=0, QoS=0, RETAIN=0. + byteBuf.writeByte(0x30); + // Fixed header remaining-length, encoded as a Variable Byte Integer (single byte for values < 128). + byteBuf.writeByte(currentPacketRemainingLength); + // Variable header: 2-byte topic-name length prefix. + byteBuf.writeShort(claimedTopicNameLength); + // ... + only 8 of the 32 topic-name bytes the prefix claims (so the decoder will ask for REPLAY). + byteBuf.writeZero(currentPacketRemainingLength - 2); + // Trailing PINGREQ packets - the cumulation bytes that the buggy size check used to look at. + // Each PINGREQ is just a 2-byte fixed header: 0xC0 (type 12, flags 0) and remaining-length 0. + for (int i = 0; i < followingPingReqPackets; i++) { + byteBuf.writeByte(0xC0); + byteBuf.writeByte(0); + } + + try { + assertFalse(channel.writeInbound(byteBuf)); + assertNull(channel.readInbound()); + } finally { + channel.finishAndReleaseAll(); + } + } + + @Test + public void testPublishMessageIncompleteVariableHeaderStillFailsWhenCurrentPacketTooLarge() throws Exception { + // Same hand-crafting rationale as the test above: a malformed (incomplete-topic) PUBLISH + // is needed so variable-header decoding asks for REPLAY, which is the code path under test. + final int maxBytesInMessage = 16; + // Declared packet size already exceeds the limit; the in-flight check must still flag it. + final int currentPacketRemainingLength = maxBytesInMessage + 1; + // > the bytes we write below, so the decoder still asks for REPLAY mid-variable-header. + final int claimedTopicNameLength = 32; + EmbeddedChannel channel = new EmbeddedChannel(new MqttDecoder(maxBytesInMessage)); + ByteBuf byteBuf = ALLOCATOR.buffer(); + // Fixed header byte 1: PUBLISH (type 3), all flags 0. + byteBuf.writeByte(0x30); + // Fixed header remaining-length Variable Byte Integer: 17 (still a single byte since < 128). + byteBuf.writeByte(currentPacketRemainingLength); + // Variable header: 2-byte topic-name length prefix claiming 32 bytes. + byteBuf.writeShort(claimedTopicNameLength); + // ... + 14 zero bytes - fewer than the 32 claimed, so the decoder will ask for REPLAY. + byteBuf.writeZero(maxBytesInMessage - 2); + + try { + assertTrue(channel.writeInbound(byteBuf)); + MqttMessage decodedMessage = channel.readInbound(); + try { + assertTrue(decodedMessage.decoderResult().isFailure()); + assertInstanceOf(TooLongFrameException.class, decodedMessage.decoderResult().cause()); + } finally { + ReferenceCountUtil.release(decodedMessage); + } + } finally { + channel.finishAndReleaseAll(); + } + } + @Test public void testPubAckMessage() throws Exception { testMessageWithOnlyFixedHeaderAndMessageIdVariableHeader(MqttMessageType.PUBACK); @@ -419,6 +534,41 @@ public void testDisconnectMessage() throws Exception { testMessageWithOnlyFixedHeader(MqttMessage.DISCONNECT); } + @Test + public void testPingReqWithNonZeroRemainingLengthIsRejected() throws Exception { + // Regression for https://github.com/netty/netty/issues/16851: PINGREQ is a 2-byte + // fixed-header-only packet (0xC0 0x00). The bytes below claim Remaining Length 2, + // which makes the trailing 0xD0 0x00 part of the same (malformed) PINGREQ frame + // rather than a separate PINGRESP. The decoder must reject this as one invalid + // message rather than silently accept two. + EmbeddedChannel channel = new EmbeddedChannel(new MqttDecoder()); + ByteBuf byteBuf = channel.alloc().buffer(); + // Fixed header byte 1: PINGREQ (type 12), all flags 0. + byteBuf.writeByte(0xC0); + // Remaining Length 2 - invalid per MQTT 3.1.1 / 5.0 spec (PINGREQ has no variable + // header or payload, so Remaining Length must be 0). + byteBuf.writeByte(0x02); + // Two leftover bytes still inside the malformed packet's frame: + byteBuf.writeByte(0xD0); + byteBuf.writeByte(0x00); + + try { + assertTrue(channel.writeInbound(byteBuf)); + MqttMessage first = channel.readInbound(); + try { + assertTrue(first.decoderResult().isFailure(), + "expected a failed message for the malformed PINGREQ"); + assertInstanceOf(DecoderException.class, first.decoderResult().cause()); + } finally { + ReferenceCountUtil.release(first); + } + // No second message: the trailing bytes belong to the malformed frame. + assertNull(channel.readInbound()); + } finally { + assertFalse(channel.finishAndReleaseAll()); + } + } + //All 0..F message type codes are valid in MQTT 5 @Test public void testUnknownMessageType() throws Exception { @@ -670,6 +820,25 @@ public void testPubAckMessageSkipCodeForMqtt5() throws Exception { (MqttPubReplyMessageVariableHeader) decodedMessage.variableHeader()); } + @Test + public void testPubAckMessageWithUserPropertyAndSuccessForMqtt5() throws Exception { + when(versionAttrMock.get()).thenReturn(MqttVersion.MQTT_5); + + MqttProperties props = new MqttProperties(); + props.add(new MqttProperties.UserProperty("traceId", "abc")); + final MqttMessage message = createPubAckMessage((byte) 0, props); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + + mqttDecoder.channelRead(ctx, byteBuf); + + assertEquals(1, out.size()); + + final MqttMessage decodedMessage = (MqttMessage) out.get(0); + validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); + validatePubReplyVariableHeader((MqttPubReplyMessageVariableHeader) message.variableHeader(), + (MqttPubReplyMessageVariableHeader) decodedMessage.variableHeader()); + } + @Test public void testSubAckMessageForMqtt5() throws Exception { MqttProperties props = new MqttProperties(); @@ -822,6 +991,28 @@ public void testDisconnectMessageSkipCodeForMqtt5() throws Exception { (MqttReasonCodeAndPropertiesVariableHeader) decodedMessage.variableHeader()); } + @Test + public void testDisconnectMessageWithUserPropertyAndSuccessForMqtt5() throws Exception { + when(versionAttrMock.get()).thenReturn(MqttVersion.MQTT_5); + + MqttProperties props = new MqttProperties(); + props.add(new MqttProperties.UserProperty("traceId", "abc")); + final MqttMessage message = MqttMessageBuilders.disconnect() + .reasonCode((byte) 0) + .properties(props) + .build(); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + + mqttDecoder.channelRead(ctx, byteBuf); + + assertEquals(1, out.size()); + final MqttMessage decodedMessage = (MqttMessage) out.get(0); + validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); + validateReasonCodeAndPropertiesVariableHeader( + (MqttReasonCodeAndPropertiesVariableHeader) message.variableHeader(), + (MqttReasonCodeAndPropertiesVariableHeader) decodedMessage.variableHeader()); + } + @Test public void testAuthMessageForMqtt5() throws Exception { when(versionAttrMock.get()).thenReturn(MqttVersion.MQTT_5); diff --git a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttPropertiesTest.java b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttPropertiesTest.java index 580056a1587..85a5e235e3c 100644 --- a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttPropertiesTest.java +++ b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttPropertiesTest.java @@ -28,6 +28,8 @@ import static io.netty.handler.codec.mqtt.MqttProperties.MqttPropertyType.SUBSCRIPTION_IDENTIFIER; import static io.netty.handler.codec.mqtt.MqttProperties.MqttPropertyType.USER_PROPERTY; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; public class MqttPropertiesTest { @@ -108,4 +110,26 @@ public void testListAll() { assertEquals(expectedProperties, props.listAll()); } + @Test + public void testIsEmptyWithOnlyUserProperties() { + MqttProperties props = new MqttProperties(); + + assertTrue(props.isEmpty()); + + props.add(new MqttProperties.UserProperty("tag", "firstTag")); + + assertFalse(props.isEmpty()); + } + + @Test + public void testIsEmptyWithOnlySubscriptionIdentifiers() { + MqttProperties props = new MqttProperties(); + + assertTrue(props.isEmpty()); + + props.add(new MqttProperties.IntegerProperty(SUBSCRIPTION_IDENTIFIER.value(), 10)); + + assertFalse(props.isEmpty()); + } + } diff --git a/codec-redis/pom.xml b/codec-redis/pom.xml index 38b73c9326b..c7b70b4fcbf 100644 --- a/codec-redis/pom.xml +++ b/codec-redis/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec-redis diff --git a/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisArrayAggregator.java b/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisArrayAggregator.java index ed7bd120e78..25df754d962 100644 --- a/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisArrayAggregator.java +++ b/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisArrayAggregator.java @@ -18,7 +18,9 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.CodecException; import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.handler.codec.PrematureChannelClosureException; import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.UnstableApi; import java.util.ArrayDeque; @@ -33,7 +35,40 @@ @UnstableApi public final class RedisArrayAggregator extends MessageToMessageDecoder { + private static final int DEFAULT_MAX_ARRAY_LENGTH = RedisConstants.REDIS_MAX_ARRAY_LENGTH; + private final int maxNestedArrayDepth; private final Deque depths = new ArrayDeque(4); + private final int maxElements; + + /** + * Create a new instance that will aggregate an {@link ArrayHeaderRedisMessage} + * and its subsequent elements into an {@link ArrayRedisMessage}. + *

+ * This constructor specifies a maximum number of elements of 1.000.000, + * but this default can be increased with the {@value RedisConstants#PROP_REDIS_MAX_ARRAY_LENGTH} system property. + * + * @deprecated Use {@link #RedisArrayAggregator(int, int)} instead to define a max size of the array to aggregate. + */ + @Deprecated + public RedisArrayAggregator() { + // Let's impose some limit at least by default. + this(DEFAULT_MAX_ARRAY_LENGTH, 1024); + } + + /** + * Create a new instance that will aggregate an {@link ArrayHeaderRedisMessage} + * and its subsequent elements into an {@link ArrayRedisMessage}. + *

+ * A {@link CodecException} will be thrown if the array header specify a length greater than + * the given number of max elements. + * @param maxElements The maximum number of elements to aggregate in a single message. + * @param maxNestedArrayDepth the maximum depth of the nested array before an exception will be thrown + */ + public RedisArrayAggregator(int maxElements, int maxNestedArrayDepth) { + super(RedisMessage.class); + this.maxElements = ObjectUtil.checkPositive(maxElements, "maxElements"); + this.maxNestedArrayDepth = ObjectUtil.checkPositive(maxNestedArrayDepth, "maxNestedArrayDepth"); + } @Override protected void decode(ChannelHandlerContext ctx, RedisMessage msg, List out) throws Exception { @@ -70,10 +105,14 @@ private RedisMessage decodeRedisArrayHeader(ArrayHeaderRedisMessage header) { return ArrayRedisMessage.EMPTY_INSTANCE; } else if (header.length() > 0L) { // Currently, this codec doesn't support `long` length for arrays because Java's List.size() is int. - if (header.length() > Integer.MAX_VALUE) { - throw new CodecException("this codec doesn't support longer length than " + Integer.MAX_VALUE); + if (header.length() > maxElements) { + throw new CodecException("this codec doesn't support longer length than " + maxElements); } + if (depths.size() >= maxNestedArrayDepth) { + releaseAndClearDepths(); + throw new CodecException("max nested array depth exceeded: " + maxNestedArrayDepth); + } // start aggregating array depths.push(new AggregateState((int) header.length())); return null; @@ -90,4 +129,30 @@ private static final class AggregateState { this.children = new ArrayList(length); } } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + super.handlerRemoved(ctx); + releaseAndClearDepths(); + } + + private void releaseAndClearDepths() { + for (AggregateState state : depths) { + for (RedisMessage message : state.children) { + ReferenceCountUtil.safeRelease(message); + } + } + depths.clear(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + super.channelInactive(ctx); + + if (!depths.isEmpty()) { + ctx.fireExceptionCaught(new PrematureChannelClosureException( + "channel gone inactive with " + depths.size() + + " messages still incomplete")); + } + } } diff --git a/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisConstants.java b/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisConstants.java index bed626c3cc4..361f9aec6fa 100644 --- a/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisConstants.java +++ b/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisConstants.java @@ -15,6 +15,8 @@ package io.netty.handler.codec.redis; +import io.netty.util.internal.SystemPropertyUtil; + /** * Constant values for Redis encoder/decoder. */ @@ -43,4 +45,7 @@ private RedisConstants() { static final short NULL_SHORT = RedisCodecUtil.makeShort('-', '1'); static final short EOL_SHORT = RedisCodecUtil.makeShort('\r', '\n'); + + static final String PROP_REDIS_MAX_ARRAY_LENGTH = "io.netty.handler.codec.redis.maxArrayLength"; + static final int REDIS_MAX_ARRAY_LENGTH = SystemPropertyUtil.getInt(PROP_REDIS_MAX_ARRAY_LENGTH, 1000000); } diff --git a/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisDecoder.java b/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisDecoder.java index 13c0f3eac1e..7d728cec727 100644 --- a/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisDecoder.java +++ b/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisDecoder.java @@ -129,9 +129,13 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } } } catch (RedisCodecException e) { + // Let's discard everything + in.skipBytes(in.readableBytes()); resetDecoder(); throw e; } catch (Exception e) { + // Let's discard everything + in.skipBytes(in.readableBytes()); resetDecoder(); throw new RedisCodecException(e); } @@ -169,6 +173,16 @@ private boolean decodeInline(ByteBuf in, List out) throws Exception { private boolean decodeLength(ByteBuf in, List out) throws Exception { ByteBuf lineByteBuf = readLine(in); if (lineByteBuf == null) { + int readableBytes = in.readableBytes(); + if (readableBytes <= RedisConstants.POSITIVE_LONG_MAX_LENGTH) { + // fast-path + return false; + } + boolean isNegative = in.getByte(in.readerIndex()) == '-'; + int capacity = RedisConstants.POSITIVE_LONG_MAX_LENGTH + (isNegative ? 1 : 0) + 1; + if (readableBytes > capacity) { + throw new RedisCodecException("too many characters to be a valid RESP Integer: " + readableBytes); + } return false; } final long length = parseRedisNumber(lineByteBuf); diff --git a/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisEncoder.java b/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisEncoder.java index 70422f39434..ae6b947952e 100644 --- a/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisEncoder.java +++ b/codec-redis/src/main/java/io/netty/handler/codec/redis/RedisEncoder.java @@ -101,6 +101,12 @@ private static void writeErrorMessage(ByteBufAllocator allocator, ErrorRedisMess private static void writeString(ByteBufAllocator allocator, RedisMessageType type, String content, List out) { + if (type.isInline()) { + // Inline, or "simple" messages do not permit CRLF bytes in their contents. + if (content.indexOf('\r') != -1 || content.indexOf('\n') != -1) { + throw new CodecException("Line breaks are not permitted in 'simple' messages"); + } + } ByteBuf buf = allocator.ioBuffer(type.length() + ByteBufUtil.utf8MaxBytes(content) + RedisConstants.EOL_LENGTH); type.writeTo(buf); diff --git a/codec-redis/src/test/java/io/netty/handler/codec/redis/RedisArrayAggregatorTest.java b/codec-redis/src/test/java/io/netty/handler/codec/redis/RedisArrayAggregatorTest.java new file mode 100644 index 00000000000..707af400454 --- /dev/null +++ b/codec-redis/src/test/java/io/netty/handler/codec/redis/RedisArrayAggregatorTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.redis; + +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.PrematureChannelClosureException; +import io.netty.handler.codec.CodecException; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class RedisArrayAggregatorTest { + + @Test + void testDoesNotLeakOnClose() { + final EmbeddedChannel ch = new EmbeddedChannel(new RedisArrayAggregator()); + assertFalse(ch.writeInbound(new ArrayHeaderRedisMessage(2))); + + FullBulkStringRedisMessage redisMessage = new FullBulkStringRedisMessage(Unpooled.buffer()); + assertEquals(1, redisMessage.refCnt()); + assertFalse(ch.writeInbound(redisMessage)); + assertEquals(1, redisMessage.refCnt()); + + assertThrows(PrematureChannelClosureException.class, new Executable() { + @Override + public void execute() throws Throwable { + ch.finish(); + } + }); + assertEquals(0, redisMessage.refCnt()); + } + + @Test + void testDoesNotLeakOnRemoval() { + EmbeddedChannel ch = new EmbeddedChannel(new RedisArrayAggregator()); + assertFalse(ch.writeInbound(new ArrayHeaderRedisMessage(2))); + + FullBulkStringRedisMessage redisMessage = new FullBulkStringRedisMessage(Unpooled.buffer()); + assertEquals(1, redisMessage.refCnt()); + assertFalse(ch.writeInbound(redisMessage)); + assertEquals(1, redisMessage.refCnt()); + ch.pipeline().remove(RedisArrayAggregator.class); + assertEquals(0, redisMessage.refCnt()); + assertFalse(ch.finish()); + } + + @Test + public void testLimitNested() { + final byte[] arrayHeader = "*1\r\n".getBytes(CharsetUtil.US_ASCII); + int maxNestedDepth = 100; + final EmbeddedChannel channel = new EmbeddedChannel(new RedisDecoder(), + new RedisArrayAggregator(RedisConstants.REDIS_MAX_ARRAY_LENGTH, maxNestedDepth)); + for (int i = 0; i < maxNestedDepth; i++) { + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(arrayHeader))); + } + + // Next write should trigger an exception. + assertThrows(CodecException.class, new Executable() { + @Override + public void execute() throws Throwable { + channel.writeInbound(Unpooled.wrappedBuffer(arrayHeader)); + } + }); + assertFalse(channel.finishAndReleaseAll()); + } +} diff --git a/codec-redis/src/test/java/io/netty/handler/codec/redis/RedisDecoderTest.java b/codec-redis/src/test/java/io/netty/handler/codec/redis/RedisDecoderTest.java index 962238d98f6..d0b03be5530 100644 --- a/codec-redis/src/test/java/io/netty/handler/codec/redis/RedisDecoderTest.java +++ b/codec-redis/src/test/java/io/netty/handler/codec/redis/RedisDecoderTest.java @@ -19,9 +19,11 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CodecException; import io.netty.handler.codec.DecoderException; import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; +import org.assertj.core.api.ThrowableAssert; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -30,6 +32,7 @@ import java.util.List; import static io.netty.handler.codec.redis.RedisCodecTestUtil.*; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -55,7 +58,7 @@ private static EmbeddedChannel newChannel(boolean decodeInlineCommands) { return new EmbeddedChannel( new RedisDecoder(decodeInlineCommands), new RedisBulkStringAggregator(), - new RedisArrayAggregator()); + new RedisArrayAggregator(100, 1024)); } @AfterEach @@ -275,6 +278,20 @@ public void shouldDecodeNestedArray() throws Exception { ReferenceCountUtil.release(msg); } + @Test + public void shouldErrorOnTooLargeArray() { + // We defined the max aggregate array size to be 100 + assertThatThrownBy(new ThrowableAssert.ThrowingCallable() { + @Override + public void call() throws Throwable { + channel.writeInbound(byteBufOf("*101\r\n")); + } + }).isInstanceOf(DecoderException.class) + .rootCause() + .isInstanceOf(CodecException.class) + .hasMessageContaining("100"); + } + @Test public void shouldErrorOnDoubleReleaseArrayReferenceCounted() { ByteBuf buf = Unpooled.buffer(); @@ -341,4 +358,67 @@ public void testPredefinedMessagesNotEqual() { assertNotEquals(FullBulkStringRedisMessage.EMPTY_INSTANCE, FullBulkStringRedisMessage.NULL_INSTANCE); assertNotEquals(FullBulkStringRedisMessage.NULL_INSTANCE, FullBulkStringRedisMessage.EMPTY_INSTANCE); } + + @Test + public void shouldLimitIntegerTo64IntSigned() { + ByteBuf buf = Unpooled.buffer(); + buf.writeByte('$'); + for (int i = 0; i <= RedisConstants.POSITIVE_LONG_MAX_LENGTH; i++) { + buf.writeByte('0'); + } + assertFalse(channel.writeInbound(buf)); + + assertThrows(DecoderException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(byteBufOf("1")); + } + }); + } + + @Test + public void testPositiveLongWithCr() { + EmbeddedChannel channel = new EmbeddedChannel(new RedisDecoder()); + ByteBuf buf = Unpooled.buffer(); + buf.writeByte('$'); + for (int i = 0; i < RedisConstants.POSITIVE_LONG_MAX_LENGTH; i++) { + buf.writeByte('0'); + } + buf.writeByte('\r'); + + // 19 digits + \r = 20 bytes. + // It's a valid incomplete RESP number waiting for \n. + assertFalse(channel.writeInbound(buf)); + + ByteBuf buf2 = Unpooled.buffer(); + buf2.writeByte('\n'); + assertFalse(channel.writeInbound(buf2)); + assertFalse(channel.finish()); + } + + @Test + public void testGiantPayloadEndingWithCrBypassesLengthCheck() { + final EmbeddedChannel channel = new EmbeddedChannel(new RedisDecoder()); + ByteBuf buf = Unpooled.buffer(); + buf.writeByte('$'); + for (int i = 0; i < RedisConstants.POSITIVE_LONG_MAX_LENGTH; i++) { + buf.writeByte('1'); + } + assertFalse(channel.writeInbound(buf)); + + // We expect that sending 1000 more bytes will exceed the maximum valid length capacity, + // regardless of whether the final byte is '\r'. It should throw a DecoderException. + assertThrows(DecoderException.class, new Executable() { + @Override + public void execute() { + for (int i = 0; i < 1000; i++) { + ByteBuf chunk = Unpooled.buffer(); + chunk.writeByte('a'); + chunk.writeByte('\r'); + channel.writeInbound(chunk); + } + } + }); + assertFalse(channel.finish()); + } } diff --git a/codec-smtp/pom.xml b/codec-smtp/pom.xml index b0896da7ebc..8bccb5bd57c 100644 --- a/codec-smtp/pom.xml +++ b/codec-smtp/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec-smtp diff --git a/codec-socks/pom.xml b/codec-socks/pom.xml index 2397bb2e3fa..e0801bbae44 100644 --- a/codec-socks/pom.xml +++ b/codec-socks/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec-socks diff --git a/codec-stomp/pom.xml b/codec-stomp/pom.xml index 285373a69b3..43dc637be77 100644 --- a/codec-stomp/pom.xml +++ b/codec-stomp/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec-stomp diff --git a/codec-stomp/src/main/java/io/netty/handler/codec/stomp/StompSubframeDecoder.java b/codec-stomp/src/main/java/io/netty/handler/codec/stomp/StompSubframeDecoder.java index 8c88e67fee5..5c3ec0d78cb 100644 --- a/codec-stomp/src/main/java/io/netty/handler/codec/stomp/StompSubframeDecoder.java +++ b/codec-stomp/src/main/java/io/netty/handler/codec/stomp/StompSubframeDecoder.java @@ -238,6 +238,9 @@ private static void skipNullCharacter(ByteBuf buffer) { private static void skipControlCharacters(ByteBuf buffer) { byte b; for (;;) { + if (!buffer.isReadable()) { + return; + } b = buffer.readByte(); if (b != StompConstants.CR && b != StompConstants.LF) { buffer.readerIndex(buffer.readerIndex() - 1); diff --git a/codec-stomp/src/test/java/io/netty/handler/codec/stomp/StompSubframeDecoderTest.java b/codec-stomp/src/test/java/io/netty/handler/codec/stomp/StompSubframeDecoderTest.java index 4e2329495ff..7c01faf3b10 100644 --- a/codec-stomp/src/test/java/io/netty/handler/codec/stomp/StompSubframeDecoderTest.java +++ b/codec-stomp/src/test/java/io/netty/handler/codec/stomp/StompSubframeDecoderTest.java @@ -317,6 +317,118 @@ void testNotUnescapeHeadersForConnectedCommand() { assertNull(obj); } + @Test + public void testHeartbeatOnlyDoesNotThrowException() { + // STOMP heartbeat is just a LF byte - should not cause IndexOutOfBoundsException + ByteBuf heartbeat = Unpooled.buffer(); + heartbeat.writeByte('\n'); + channel.writeInbound(heartbeat); + + // Heartbeat should be consumed silently, no output produced + Object result = channel.readInbound(); + assertNull(result); + } + + @Test + public void testMultipleHeartbeatsDoNotThrowException() { + // Multiple consecutive heartbeats + ByteBuf heartbeats = Unpooled.buffer(); + heartbeats.writeByte('\n'); + heartbeats.writeByte('\n'); + heartbeats.writeByte('\n'); + channel.writeInbound(heartbeats); + + Object result = channel.readInbound(); + assertNull(result); + } + + @Test + public void testCarriageReturnLineFeedHeartbeat() { + // CR+LF heartbeat + ByteBuf heartbeat = Unpooled.buffer(); + heartbeat.writeByte('\r'); + heartbeat.writeByte('\n'); + channel.writeInbound(heartbeat); + + Object result = channel.readInbound(); + assertNull(result); + } + + @Test + public void testHeartbeatFollowedByFrame() { + // Heartbeat bytes followed by a real STOMP frame should decode correctly + ByteBuf incoming = Unpooled.buffer(); + incoming.writeByte('\n'); + incoming.writeByte('\n'); + incoming.writeBytes(StompTestConstants.CONNECT_FRAME.getBytes()); + channel.writeInbound(incoming); + + StompHeadersSubframe frame = channel.readInbound(); + assertNotNull(frame); + assertEquals(StompCommand.CONNECT, frame.command()); + + StompContentSubframe content = channel.readInbound(); + assertSame(LastStompContentSubframe.EMPTY_LAST_CONTENT, content); + content.release(); + + assertNull(channel.readInbound()); + } + + @Test + public void testHeartbeatBetweenFrames() { + // Heartbeat bytes between two STOMP frames + ByteBuf incoming = Unpooled.buffer(); + incoming.writeBytes(StompTestConstants.CONNECT_FRAME.getBytes()); + incoming.writeByte('\n'); + incoming.writeByte('\n'); + incoming.writeBytes(StompTestConstants.CONNECTED_FRAME.getBytes()); + channel.writeInbound(incoming); + + StompHeadersSubframe frame1 = channel.readInbound(); + assertNotNull(frame1); + assertEquals(StompCommand.CONNECT, frame1.command()); + + StompContentSubframe content1 = channel.readInbound(); + assertSame(LastStompContentSubframe.EMPTY_LAST_CONTENT, content1); + content1.release(); + + StompHeadersSubframe frame2 = channel.readInbound(); + assertNotNull(frame2); + assertEquals(StompCommand.CONNECTED, frame2.command()); + + StompContentSubframe content2 = channel.readInbound(); + assertSame(LastStompContentSubframe.EMPTY_LAST_CONTENT, content2); + content2.release(); + + assertNull(channel.readInbound()); + } + + @Test + public void testHeartbeatSentSeparatelyThenFrame() { + // Simulate heartbeat arriving in a separate TCP segment, then a frame later + ByteBuf heartbeat = Unpooled.buffer(); + heartbeat.writeByte('\n'); + channel.writeInbound(heartbeat); + + // No output from heartbeat + assertNull(channel.readInbound()); + + // Now send a real frame + ByteBuf frame = Unpooled.buffer(); + frame.writeBytes(StompTestConstants.CONNECT_FRAME.getBytes()); + channel.writeInbound(frame); + + StompHeadersSubframe headersSubframe = channel.readInbound(); + assertNotNull(headersSubframe); + assertEquals(StompCommand.CONNECT, headersSubframe.command()); + + StompContentSubframe content = channel.readInbound(); + assertSame(LastStompContentSubframe.EMPTY_LAST_CONTENT, content); + content.release(); + + assertNull(channel.readInbound()); + } + @Test void testInvalidEscapeHeadersSequence() { channel = new EmbeddedChannel(new StompSubframeDecoder(true)); diff --git a/codec-xml/pom.xml b/codec-xml/pom.xml index 8b7b69babfb..2e7914c20b0 100644 --- a/codec-xml/pom.xml +++ b/codec-xml/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec-xml diff --git a/codec/pom.xml b/codec/pom.xml index d121be92666..a4dfabcd16d 100644 --- a/codec/pom.xml +++ b/codec/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-codec @@ -74,7 +74,7 @@ true - org.lz4 + at.yawk.lz4 lz4-java true diff --git a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java index 3c341543438..ba28386b4a3 100644 --- a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java @@ -27,15 +27,17 @@ import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.StringUtil; +import java.util.ArrayDeque; import java.util.List; +import java.util.Queue; +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; import static io.netty.util.internal.ObjectUtil.checkPositive; -import static java.lang.Integer.MAX_VALUE; /** - * {@link ChannelInboundHandlerAdapter} which decodes bytes in a stream-like fashion from one {@link ByteBuf} to an - * other Message type. - * + * {@link ChannelInboundHandlerAdapter} which decodes bytes in a stream-like fashion from one {@link ByteBuf} to + * another Message type. + *

* For example here is an implementation which reads all readable bytes from * the input {@link ByteBuf} and create a new {@link ByteBuf}. * @@ -66,7 +68,7 @@ * is not always the case. Use in.getInt(in.readerIndex()) instead. *

Pitfalls

*

- * Be aware that sub-classes of {@link ByteToMessageDecoder} MUST NOT + * Be aware that subclasses of {@link ByteToMessageDecoder} MUST NOT * annotated with {@link @Sharable}. *

* Some methods such as {@link ByteBuf#readBytes(int)} will cause a memory leak if the returned buffer @@ -162,6 +164,8 @@ public ByteBuf cumulate(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf in) private static final byte STATE_CALLING_CHILD_DECODE = 1; private static final byte STATE_HANDLER_REMOVED_PENDING = 2; + // Used to guard the inputs for reentrant channelRead calls + private Queue inputMessages; ByteBuf cumulation; private Cumulator cumulator = MERGE_CUMULATOR; private boolean singleDecode; @@ -279,49 +283,60 @@ public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception { protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - if (msg instanceof ByteBuf) { - selfFiredChannelRead = true; - CodecOutputList out = CodecOutputList.newInstance(); - try { - first = cumulation == null; - cumulation = cumulator.cumulate(ctx.alloc(), - first ? Unpooled.EMPTY_BUFFER : cumulation, (ByteBuf) msg); - callDecode(ctx, cumulation, out); - } catch (DecoderException e) { - throw e; - } catch (Exception e) { - throw new DecoderException(e); - } finally { - try { - if (cumulation != null && !cumulation.isReadable()) { - numReads = 0; + public void channelRead(ChannelHandlerContext ctx, Object input) throws Exception { + if (decodeState == STATE_INIT) { + do { + if (input instanceof ByteBuf) { + selfFiredChannelRead = true; + CodecOutputList out = CodecOutputList.newInstance(); + try { + first = cumulation == null; + cumulation = cumulator.cumulate(ctx.alloc(), + first ? EMPTY_BUFFER : cumulation, (ByteBuf) input); + callDecode(ctx, cumulation, out); + } catch (DecoderException e) { + throw e; + } catch (Exception e) { + throw new DecoderException(e); + } finally { try { - cumulation.release(); - } catch (IllegalReferenceCountException e) { - //noinspection ThrowFromFinallyBlock - throw new IllegalReferenceCountException( - getClass().getSimpleName() + "#decode() might have released its input buffer, " + - "or passed it down the pipeline without a retain() call, " + - "which is not allowed.", e); + if (cumulation != null && !cumulation.isReadable()) { + numReads = 0; + try { + cumulation.release(); + } catch (IllegalReferenceCountException e) { + //noinspection ThrowFromFinallyBlock + throw new IllegalReferenceCountException( + getClass().getSimpleName() + + "#decode() might have released its input buffer, " + + "or passed it down the pipeline without a retain() call, " + + "which is not allowed.", e); + } + cumulation = null; + } else if (++numReads >= discardAfterReads) { + // We did enough reads already try to discard some bytes, so we not risk to see a OOME. + // See https://github.com/netty/netty/issues/4275 + numReads = 0; + discardSomeReadBytes(); + } + + int size = out.size(); + firedChannelRead |= out.insertSinceRecycled(); + fireChannelRead(ctx, out, size); + } finally { + out.recycle(); } - cumulation = null; - } else if (++numReads >= discardAfterReads) { - // We did enough reads already try to discard some bytes, so we not risk to see a OOME. - // See https://github.com/netty/netty/issues/4275 - numReads = 0; - discardSomeReadBytes(); } - - int size = out.size(); - firedChannelRead |= out.insertSinceRecycled(); - fireChannelRead(ctx, out, size); - } finally { - out.recycle(); + } else { + ctx.fireChannelRead(input); } - } + } while (inputMessages != null && (input = inputMessages.poll()) != null); } else { - ctx.fireChannelRead(msg); + // Reentrant call. Bail out here and let original call process our message. + if (inputMessages == null) { + inputMessages = new ArrayDeque(2); + } + inputMessages.offer(input); } } @@ -529,12 +544,14 @@ final void decodeRemovalReentryProtection(ChannelHandlerContext ctx, ByteBuf in, try { decode(ctx, in, out); } finally { - boolean removePending = decodeState == STATE_HANDLER_REMOVED_PENDING; - decodeState = STATE_INIT; - if (removePending) { - fireChannelRead(ctx, out, out.size()); - out.clear(); - handlerRemoved(ctx); + if (inputMessages == null || inputMessages.isEmpty()) { + boolean removePending = decodeState == STATE_HANDLER_REMOVED_PENDING; + decodeState = STATE_INIT; + if (removePending) { + fireChannelRead(ctx, out, out.size()); + out.clear(); + handlerRemoved(ctx); + } } } } @@ -558,7 +575,7 @@ static ByteBuf expandCumulation(ByteBufAllocator alloc, ByteBuf oldCumulation, B int oldBytes = oldCumulation.readableBytes(); int newBytes = in.readableBytes(); int totalBytes = oldBytes + newBytes; - ByteBuf newCumulation = alloc.buffer(alloc.calculateNewCapacity(totalBytes, MAX_VALUE)); + ByteBuf newCumulation = alloc.buffer(alloc.calculateNewCapacity(totalBytes, Integer.MAX_VALUE)); ByteBuf toRelease = newCumulation; try { // This avoids redundant checks and stack depth compared to calling writeBytes(...) diff --git a/codec/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java index 4a38db51be3..b4df8233256 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java @@ -27,11 +27,14 @@ /** * Decompresses a {@link ByteBuf} encoded with the brotli format. - * + *

* See brotli. */ public final class BrotliDecoder extends ByteToMessageDecoder { + private static final int DEFAULT_MAX_FORWARD_BYTES = CompressionUtil.DEFAULT_MAX_FORWARD_BYTES; + private static final int DEFAULT_INPUT_BUFFER_SIZE = 8 * 1024; + private enum State { DONE, NEEDS_MORE_INPUT, ERROR } @@ -45,15 +48,17 @@ private enum State { } private final int inputBufferSize; + private final int outputBufferSize; private DecoderJNI.Wrapper decoder; private boolean destroyed; private boolean needsRead; + private ByteBuf accumBuffer; /** * Creates a new BrotliDecoder with a default 8kB input buffer */ public BrotliDecoder() { - this(8 * 1024); + this(DEFAULT_INPUT_BUFFER_SIZE); } /** @@ -61,16 +66,42 @@ public BrotliDecoder() { * @param inputBufferSize desired size of the input buffer in bytes */ public BrotliDecoder(int inputBufferSize) { + this(inputBufferSize == 0 ? DEFAULT_INPUT_BUFFER_SIZE : inputBufferSize, DEFAULT_MAX_FORWARD_BYTES); + } + + /** + * Creates a new BrotliDecoder + * @param inputBufferSize desired size of the input buffer in bytes + * @param outputBufferSize desired max size of the output buffer in bytes + * (produce multiple output buffers if exceeded) + */ + public BrotliDecoder(int inputBufferSize, int outputBufferSize) { this.inputBufferSize = ObjectUtil.checkPositive(inputBufferSize, "inputBufferSize"); + this.outputBufferSize = ObjectUtil.checkPositive(outputBufferSize, "outputBufferSize"); } private void forwardOutput(ChannelHandlerContext ctx) { - ByteBuffer nativeBuffer = decoder.pull(); + ByteBuffer nativeBuffer = decoder.pull(outputBufferSize); // nativeBuffer actually wraps brotli's internal buffer so we need to copy its content - ByteBuf copy = ctx.alloc().buffer(nativeBuffer.remaining()); - copy.writeBytes(nativeBuffer); + int remaining = nativeBuffer.remaining(); + if (accumBuffer == null) { + accumBuffer = ctx.alloc().buffer(remaining); + } + accumBuffer.writeBytes(nativeBuffer); needsRead = false; - ctx.fireChannelRead(copy); + if (accumBuffer.readableBytes() >= outputBufferSize) { + ctx.fireChannelRead(accumBuffer); + accumBuffer = null; + } + } + + private void flushAccumBuffer(ChannelHandlerContext ctx) { + if (accumBuffer != null && accumBuffer.isReadable()) { + ctx.fireChannelRead(accumBuffer); + } else if (accumBuffer != null) { + accumBuffer.release(); + } + accumBuffer = null; } private State decompress(ChannelHandlerContext ctx, ByteBuf input) { @@ -84,7 +115,7 @@ private State decompress(ChannelHandlerContext ctx, ByteBuf input) { break; case NEEDS_MORE_INPUT: - if (decoder.hasOutput()) { + while (decoder.hasOutput()) { forwardOutput(ctx); } @@ -145,6 +176,8 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } catch (Exception e) { destroy(); throw e; + } finally { + flushAccumBuffer(ctx); } } diff --git a/codec/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java b/codec/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java index d2a06f95287..833b2f8f7cc 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java @@ -16,11 +16,15 @@ package io.netty.handler.codec.compression; import io.netty.buffer.ByteBuf; +import io.netty.util.internal.SystemPropertyUtil; import java.nio.ByteBuffer; final class CompressionUtil { + static final int DEFAULT_MAX_FORWARD_BYTES = SystemPropertyUtil.getInt( + "io.netty.compression.defaultMaxForwardBytes", 64 * 1024); + private CompressionUtil() { } static void checkChecksum(ByteBufChecksum checksum, ByteBuf uncompressed, int currentChecksum) { diff --git a/codec/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java index 51bdd670aa8..81f259f0a0d 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java @@ -28,6 +28,8 @@ public class JZlibDecoder extends ZlibDecoder { private final Inflater z = new Inflater(); private byte[] dictionary; + private static final int DEFAULT_MAX_FORWARD_BYTES = CompressionUtil.DEFAULT_MAX_FORWARD_BYTES; + private final int maxForwardBytes; private boolean needsRead; private volatile boolean finished; @@ -78,6 +80,7 @@ public JZlibDecoder(ZlibWrapper wrapper) { */ public JZlibDecoder(ZlibWrapper wrapper, int maxAllocation) { super(maxAllocation); + this.maxForwardBytes = maxAllocation > 0 ? maxAllocation : DEFAULT_MAX_FORWARD_BYTES; ObjectUtil.checkNotNull(wrapper, "wrapper"); @@ -113,6 +116,7 @@ public JZlibDecoder(byte[] dictionary) { */ public JZlibDecoder(byte[] dictionary, int maxAllocation) { super(maxAllocation); + this.maxForwardBytes = maxAllocation > 0 ? maxAllocation : DEFAULT_MAX_FORWARD_BYTES; this.dictionary = ObjectUtil.checkNotNull(dictionary, "dictionary"); int resultCode; resultCode = z.inflateInit(JZlib.W_ZLIB); @@ -174,7 +178,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t int outputLength = z.next_out_index - oldNextOutIndex; if (outputLength > 0) { decompressed.writerIndex(decompressed.writerIndex() + outputLength); - if (maxAllocation == 0) { + if (maxAllocation == 0 && decompressed.readableBytes() >= maxForwardBytes) { // If we don't limit the maximum allocations we should just // forward the buffer directly. ByteBuf buffer = decompressed; diff --git a/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java index 0ef03a217b7..ac2b75c8077 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java @@ -59,6 +59,9 @@ private enum GzipState { private int xlen = -1; private boolean needsRead; + private static final int DEFAULT_MAX_FORWARD_BYTES = CompressionUtil.DEFAULT_MAX_FORWARD_BYTES; + private final int maxForwardBytes; + private volatile boolean finished; private boolean decideZlibOrNone; @@ -161,6 +164,7 @@ public JdkZlibDecoder(boolean decompressConcatenated, int maxAllocation) { private JdkZlibDecoder(ZlibWrapper wrapper, byte[] dictionary, boolean decompressConcatenated, int maxAllocation) { super(maxAllocation); + this.maxForwardBytes = maxAllocation > 0 ? maxAllocation : DEFAULT_MAX_FORWARD_BYTES; ObjectUtil.checkNotNull(wrapper, "wrapper"); @@ -265,9 +269,9 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t if (crc != null) { crc.update(outArray, outIndex, outputLength); } - if (maxAllocation == 0) { - // If we don't limit the maximum allocations we should just - // forward the buffer directly. + if (maxAllocation == 0 && decompressed.readableBytes() >= maxForwardBytes) { + // Forward the buffer once it exceeds the threshold to bound memory + // while avoiding excessive fireChannelRead calls. ByteBuf buffer = decompressed; decompressed = null; needsRead = false; diff --git a/codec/src/main/java/io/netty/handler/codec/compression/Lz4FrameDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/Lz4FrameDecoder.java index 338046f44f0..c1d39b3cc4f 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/Lz4FrameDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/Lz4FrameDecoder.java @@ -51,6 +51,7 @@ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ public class Lz4FrameDecoder extends ByteToMessageDecoder { + private final int maxDecompressedLength; /** * Current state of stream. */ @@ -117,6 +118,21 @@ public Lz4FrameDecoder(boolean validateChecksums) { this(LZ4Factory.fastestInstance(), validateChecksums); } + /** + * Creates a LZ4 decoder with fastest decoder instance available on your machine. + * + * @param validateChecksums if {@code true}, the checksum field will be validated against the actual + * uncompressed data, and if the checksums do not match, a suitable + * {@link DecompressionException} will be thrown + * @param maxDecompressedLength + * maximum length of the decompressed block. If {@code 0} is given it uses {@code 32MB} + * by default. + */ + public Lz4FrameDecoder(boolean validateChecksums, int maxDecompressedLength) { + this(LZ4Factory.fastestInstance(), validateChecksums ? new Lz4XXHash32(DEFAULT_SEED) : null, + maxDecompressedLength); + } + /** * Creates a new LZ4 decoder with customizable implementation. * @@ -143,8 +159,25 @@ public Lz4FrameDecoder(LZ4Factory factory, boolean validateChecksums) { * You may set {@code null} if you do not want to validate checksum of each block */ public Lz4FrameDecoder(LZ4Factory factory, Checksum checksum) { + this(factory, checksum, MAX_BLOCK_SIZE); + } + + /** + * Creates a new customizable LZ4 decoder. + * + * @param factory user customizable {@link LZ4Factory} instance + * which may be JNI bindings to the original C implementation, a pure Java implementation + * or a Java implementation that uses the {@link sun.misc.Unsafe} + * @param checksum the {@link Checksum} instance to use to check data for integrity. + * You may set {@code null} if you do not want to validate checksum of each block + * @param maxDecompressedLength + * maximum length of the decompressed block. If {@code 0} is given it uses {@code 32MB} by default. + */ + public Lz4FrameDecoder(LZ4Factory factory, Checksum checksum, int maxDecompressedLength) { decompressor = ObjectUtil.checkNotNull(factory, "factory").fastDecompressor(); this.checksum = checksum == null ? null : ByteBufChecksum.wrapChecksum(checksum); + this.maxDecompressedLength = maxDecompressedLength == 0 ? MAX_BLOCK_SIZE : + ObjectUtil.checkInRange(maxDecompressedLength, 0, MAX_BLOCK_SIZE, "maxDecompressedLength"); } @Override @@ -172,12 +205,18 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } int decompressedLength = Integer.reverseBytes(in.readInt()); - final int maxDecompressedLength = 1 << compressionLevel; - if (decompressedLength < 0 || decompressedLength > maxDecompressedLength) { + if (decompressedLength > maxDecompressedLength) { throw new DecompressionException(String.format( - "invalid decompressedLength: %d (expected: 0-%d)", + "decompressedLength too large: %d (expected: 0-%d)", decompressedLength, maxDecompressedLength)); } + + final int maxLocalDecompressedLength = 1 << compressionLevel; + if (decompressedLength < 0 || decompressedLength > maxLocalDecompressedLength) { + throw new DecompressionException(String.format( + "invalid decompressedLength: %d (expected: 0-%d)", + decompressedLength, maxLocalDecompressedLength)); + } if (decompressedLength == 0 && compressedLength != 0 || decompressedLength != 0 && compressedLength == 0 || blockType == BLOCK_TYPE_NON_COMPRESSED && decompressedLength != compressedLength) { diff --git a/codec/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java index 05a35b14b92..3be2dcf0939 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java @@ -202,7 +202,9 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t boolean success = false; try { - decoder.decodeChunk(inputArray, inPos, outputArray, outPos, outPos + originalLength); + decoder.decodeChunk( + inputArray, inPos, inPos + chunkLength, + outputArray, outPos, outPos + originalLength); if (uncompressed.hasArray()) { uncompressed.writerIndex(uncompressed.writerIndex() + originalLength); } else { diff --git a/codec/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java b/codec/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java index 38793a97e6f..1397e123080 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java @@ -71,7 +71,7 @@ public static BrotliOptions brotli(int quality, int window, BrotliMode mode) { /** * Default implementation of {@link ZstdOptions} with{compressionLevel(int)} set to * {@link ZstdConstants#DEFAULT_COMPRESSION_LEVEL},{@link ZstdConstants#DEFAULT_BLOCK_SIZE}, - * {@link ZstdConstants#MAX_BLOCK_SIZE} + * {@link ZstdConstants#DEFAULT_MAX_ENCODE_SIZE} */ public static ZstdOptions zstd() { return ZstdOptions.DEFAULT; diff --git a/codec/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java b/codec/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java index 111372c3ede..b9a5aca6514 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java @@ -35,9 +35,9 @@ final class ZstdConstants { static final int MAX_COMPRESSION_LEVEL = Zstd.maxCompressionLevel(); /** - * Max block size + * Max encode size */ - static final int MAX_BLOCK_SIZE = 1 << (DEFAULT_COMPRESSION_LEVEL + 7) + 0x0F; // 32 M + static final int DEFAULT_MAX_ENCODE_SIZE = Integer.MAX_VALUE; /** * Default block size */ diff --git a/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java index ef0bf1371d8..d1ca998b8b5 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java @@ -15,7 +15,6 @@ */ package io.netty.handler.codec.compression; -import com.github.luben.zstd.ZstdIOException; import com.github.luben.zstd.ZstdInputStreamNoFinalizer; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; @@ -41,7 +40,21 @@ public final class ZstdDecoder extends ByteToMessageDecoder { } } + private static final int DEFAULT_MAX_FORWARD_BYTES = CompressionUtil.DEFAULT_MAX_FORWARD_BYTES; + /** + * Default maximum size of a single output buffer, in bytes (4 MiB). + */ + public static final int DEFAULT_MAXIMUM_ALLOCATION_SIZE = 4 * 1024 * 1024; + /** + * Default upper bound on the {@code Window_Log} accepted by the decoder. + * {@code 27} corresponds to a 128 MiB decompression window. + */ + public static final int DEFAULT_MAX_WINDOW_LOG = 27; + private static final int MIN_WINDOW_LOG = 10; + private static final int MAX_WINDOW_LOG = 31; private final int maximumAllocationSize; + private final int maxForwardBytes; + private final int maxWindowLog; private final MutableByteBufInputStream inputStream = new MutableByteBufInputStream(); private ZstdInputStreamNoFinalizer zstdIs; @@ -56,12 +69,44 @@ private enum State { CORRUPTED } + /** + * Creates a new decoder with the {@link #DEFAULT_MAXIMUM_ALLOCATION_SIZE}, + * and the {@link #DEFAULT_MAX_WINDOW_LOG} window log size. + *

+ * The window log size bounds the memory usage of the sliding window for ZSTD frame decompression. + * Frames declaring a larger window will be rejected to bound the memory the decoder may allocate per stream. + * + */ public ZstdDecoder() { - this(4 * 1024 * 1024); + this(DEFAULT_MAXIMUM_ALLOCATION_SIZE, DEFAULT_MAX_WINDOW_LOG); } + /** + * Creates a new decoder with the given maximum allocation size, + * and the {@link #DEFAULT_MAX_WINDOW_LOG} window log size. + *

+ * The window log size bounds the memory usage of the sliding window for ZSTD frame decompression. + * Frames declaring a larger window will be rejected to bound the memory the decoder may allocate per stream. + * + * @param maximumAllocationSize maximum size of a single output buffer. + */ public ZstdDecoder(int maximumAllocationSize) { + this(maximumAllocationSize, DEFAULT_MAX_WINDOW_LOG); + } + + /** + * Creates a new decoder with an explicit upper bound on the accepted {@code Window_Log}. + * + * @param maximumAllocationSize maximum size of a single output buffer. + * @param maxWindowLog upper bound on the {@code Window_Log} field of incoming + * frames; must be in {@code [10, 31]}. Frames declaring a + * larger window will be rejected to bound the memory the + * decoder may allocate per stream. + */ + public ZstdDecoder(int maximumAllocationSize, int maxWindowLog) { this.maximumAllocationSize = ObjectUtil.checkPositiveOrZero(maximumAllocationSize, "maximumAllocationSize"); + this.maxForwardBytes = maximumAllocationSize > 0 ? maximumAllocationSize : DEFAULT_MAX_FORWARD_BYTES; + this.maxWindowLog = ObjectUtil.checkInRange(maxWindowLog, MIN_WINDOW_LOG, MAX_WINDOW_LOG, "maxWindowLog"); } @Override @@ -101,13 +146,18 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } do { w = outBuffer.writeBytes(zstdIs, outBuffer.writableBytes()); - } while (w != -1 && outBuffer.isWritable()); - if (outBuffer.isReadable()) { + } while (w > 0 && outBuffer.isWritable()); + if (!outBuffer.isWritable() || outBuffer.readableBytes() >= maxForwardBytes) { needsRead = false; ctx.fireChannelRead(outBuffer); outBuffer = null; } - } while (w != -1); + } while (w > 0); + if (outBuffer != null && outBuffer.isReadable()) { + needsRead = false; + ctx.fireChannelRead(outBuffer); + outBuffer = null; + } } finally { if (outBuffer != null) { outBuffer.release(); @@ -137,6 +187,9 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { super.handlerAdded(ctx); zstdIs = new ZstdInputStreamNoFinalizer(inputStream); zstdIs.setContinuous(true); + // Bound the decompression window to mitigate memory amplification from frames that + // declare an oversized Window_Size. + zstdIs.setLongMax(maxWindowLog); } @Override diff --git a/codec/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java b/codec/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java index 7ece3c2a643..36e8f364f75 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java @@ -28,7 +28,7 @@ import static io.netty.handler.codec.compression.ZstdConstants.MIN_COMPRESSION_LEVEL; import static io.netty.handler.codec.compression.ZstdConstants.MAX_COMPRESSION_LEVEL; import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_BLOCK_SIZE; -import static io.netty.handler.codec.compression.ZstdConstants.MAX_BLOCK_SIZE; +import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_MAX_ENCODE_SIZE; /** * Compresses a {@link ByteBuf} using the Zstandard algorithm. @@ -56,7 +56,7 @@ public final class ZstdEncoder extends MessageToByteEncoder { * please use {@link ZstdEncoder(int,int)} constructor */ public ZstdEncoder() { - this(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE); + this(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, DEFAULT_MAX_ENCODE_SIZE); } /** @@ -65,7 +65,7 @@ public ZstdEncoder() { * specifies the level of the compression */ public ZstdEncoder(int compressionLevel) { - this(compressionLevel, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE); + this(compressionLevel, DEFAULT_BLOCK_SIZE, DEFAULT_MAX_ENCODE_SIZE); } /** @@ -113,7 +113,9 @@ protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, ByteBuf msg, boolean while (remaining > 0) { int curSize = Math.min(blockSize, remaining); remaining -= curSize; - bufferSize += Zstd.compressBound(curSize); + // calculate the max compressed size with Zstd.compressBound since + // it returns the maximum size of the compressed data + bufferSize = Math.max(bufferSize, Zstd.compressBound(curSize)); } if (bufferSize > maxEncodeSize || 0 > bufferSize) { @@ -141,6 +143,11 @@ protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) { flushBufferedData(out); } } + // return the remaining data in the buffer + // when buffer size is smaller than the block size + if (buffer.isReadable()) { + flushBufferedData(out); + } } private void flushBufferedData(ByteBuf out) { diff --git a/codec/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java b/codec/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java index 8b6ce3c5550..583151aa040 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java @@ -21,7 +21,7 @@ import static io.netty.handler.codec.compression.ZstdConstants.MIN_COMPRESSION_LEVEL; import static io.netty.handler.codec.compression.ZstdConstants.MAX_COMPRESSION_LEVEL; import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_BLOCK_SIZE; -import static io.netty.handler.codec.compression.ZstdConstants.MAX_BLOCK_SIZE; +import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_MAX_ENCODE_SIZE; /** * {@link ZstdOptions} holds compressionLevel for @@ -36,9 +36,10 @@ public class ZstdOptions implements CompressionOptions { /** * Default implementation of {@link ZstdOptions} with{compressionLevel(int)} set to * {@link ZstdConstants#DEFAULT_COMPRESSION_LEVEL},{@link ZstdConstants#DEFAULT_BLOCK_SIZE}, - * {@link ZstdConstants#MAX_BLOCK_SIZE} + * {@link ZstdConstants#DEFAULT_MAX_ENCODE_SIZE} */ - static final ZstdOptions DEFAULT = new ZstdOptions(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE); + static final ZstdOptions DEFAULT = new ZstdOptions(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, + DEFAULT_MAX_ENCODE_SIZE); /** * Create a new {@link ZstdOptions} diff --git a/codec/src/main/java/io/netty/handler/codec/marshalling/DefaultUnmarshallerProvider.java b/codec/src/main/java/io/netty/handler/codec/marshalling/DefaultUnmarshallerProvider.java index 57fa2dbda24..3ed598f7b8e 100644 --- a/codec/src/main/java/io/netty/handler/codec/marshalling/DefaultUnmarshallerProvider.java +++ b/codec/src/main/java/io/netty/handler/codec/marshalling/DefaultUnmarshallerProvider.java @@ -25,6 +25,10 @@ * Default implementation of {@link UnmarshallerProvider} which will just create a new {@link Unmarshaller} * on every call to {@link #getUnmarshaller(ChannelHandlerContext)} * + * Security: serialization can be a security liability, + * and should not be used without defining a list of classes that are + * allowed to be deserialized. This explicitly needs to be done via {@link MarshallingConfiguration}, + * missing to do so is a security risk. */ public class DefaultUnmarshallerProvider implements UnmarshallerProvider { diff --git a/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoder.java b/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoder.java index cb87c6219f1..0109cd6228d 100644 --- a/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoder.java @@ -21,9 +21,12 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.CorruptedFrameException; +import io.netty.handler.codec.TooLongFrameException; import java.util.List; +import static io.netty.util.internal.ObjectUtil.checkPositive; + /** * A decoder that splits the received {@link ByteBuf}s dynamically by the * value of the Google Protocol Buffers @@ -42,12 +45,37 @@ */ public class ProtobufVarint32FrameDecoder extends ByteToMessageDecoder { - // TODO maxFrameLength + safe skip + fail-fast option - // (just like LengthFieldBasedFrameDecoder) + private final int maxFrameLength; + private long bytesToDiscard; + + /** + * Creates a new instance with no frame length limit. + */ + public ProtobufVarint32FrameDecoder() { + this(Integer.MAX_VALUE); + } + + /** + * Creates a new instance with the specified maximum frame length. + * + * @param maxFrameLength the maximum length of the frame. + * If the length exceeds this value, + * {@link TooLongFrameException} will be thrown. + */ + public ProtobufVarint32FrameDecoder(int maxFrameLength) { + this.maxFrameLength = checkPositive(maxFrameLength, "maxFrameLength"); + } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (bytesToDiscard > 0) { + int localBytesToDiscard = (int) Math.min(bytesToDiscard, in.readableBytes()); + in.skipBytes(localBytesToDiscard); + bytesToDiscard -= localBytesToDiscard; + return; + } + in.markReaderIndex(); int preIndex = in.readerIndex(); int length = readRawVarint32(in); @@ -58,6 +86,19 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throw new CorruptedFrameException("negative length: " + length); } + if (length > maxFrameLength) { + long discard = length - in.readableBytes(); + if (discard <= 0) { + in.skipBytes(length); + } else { + bytesToDiscard = discard; + in.skipBytes(in.readableBytes()); + } + throw new TooLongFrameException( + "Frame length exceeds " + maxFrameLength + + ": " + length); + } + if (in.readableBytes() < length) { in.resetReaderIndex(); } else { diff --git a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java index 84f8c755559..e7069a542e7 100644 --- a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java @@ -684,4 +684,80 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { assertEquals(0, buffer.refCnt(), "Buffer should be released"); assertFalse(channel.finish()); } + + @Test + void reentrantReadSafety() throws Exception { + final EmbeddedChannel channel = new EmbeddedChannel(); + ByteToMessageDecoder decoder = new ByteToMessageDecoder() { + int reentrancy; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + reentrancy++; + if (reentrancy == 1) { + ByteBuf buf2 = channel.alloc().buffer(); + buf2.writeLong(42); // Adding 8 bytes. + assertFalse(channel.writeInbound(buf2)); // Reentrant call back into ByteToMessageDecoder + ctx.read(); + } + int bytes = in.readableBytes(); + out.add(bytes); + in.skipBytes(bytes); + } + }; + channel.pipeline().addLast(decoder); + ByteBuf buf1 = channel.alloc().buffer(); + buf1.writeInt(42); // Adding 4 bytes. + assertTrue(channel.writeInbound(buf1)); + Integer first = channel.readInbound(); + Integer second = channel.readInbound(); + assertEquals(4, first); + assertEquals(8, second); + assertFalse(channel.finishAndReleaseAll()); + } + + @Test + void reentrantReadThenRemoveSafety() throws Exception { + final EmbeddedChannel channel = new EmbeddedChannel(); + ByteToMessageDecoder decoder = new ByteToMessageDecoder() { + boolean removed; + int reentrancy; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertFalse(removed); + reentrancy++; + if (reentrancy == 1) { + ByteBuf buf2 = channel.alloc().buffer(); + buf2.writeLong(42); // Adding 8 bytes. + assertFalse(channel.writeInbound(buf2)); // Reentrant call back into ByteToMessageDecoder + ByteBuf buf3 = channel.alloc().buffer(); + buf3.writeShort(42); // Adding 2 bytes. + assertFalse(channel.writeInbound(buf3)); // Reentrant call back into ByteToMessageDecoder + ctx.read(); + } else if (reentrancy == 2) { + ctx.pipeline().remove(this); + } + int bytes = in.readableBytes(); + out.add(bytes); + in.skipBytes(bytes); + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + removed = true; + } + }; + channel.pipeline().addLast(decoder); + ByteBuf buf1 = channel.alloc().buffer(); + buf1.writeInt(42); // Adding 4 bytes. + assertTrue(channel.writeInbound(buf1)); + Integer first = channel.readInbound(); + Integer second = channel.readInbound(); + Integer third = channel.readInbound(); + assertEquals(4, first); + assertEquals(8, second); + assertEquals(2, third); + assertFalse(channel.finishAndReleaseAll()); + } } diff --git a/codec/src/test/java/io/netty/handler/codec/compression/Lz4FrameDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/compression/Lz4FrameDecoderTest.java index 01338f6e008..ec6b4ed6f57 100644 --- a/codec/src/test/java/io/netty/handler/codec/compression/Lz4FrameDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/compression/Lz4FrameDecoderTest.java @@ -45,7 +45,8 @@ public Lz4FrameDecoderTest() throws Exception { @Override protected EmbeddedChannel createChannel() { - return new EmbeddedChannel(new Lz4FrameDecoder(true)); + // Use max limit of 31 MB as we want to test that we reject 32 MB in one of the tests + return new EmbeddedChannel(new Lz4FrameDecoder(true, 31 * 1024 * 1024)); } @Test @@ -90,6 +91,24 @@ public void execute() { }, "invalid decompressedLength"); } + @Test + public void testTooLargeDecompressedLength() { + final ByteBuf buf = Unpooled.buffer(22, 22); + buf.writeLong(MAGIC_NUMBER); + buf.writeByte(BLOCK_TYPE_COMPRESSED | 0x0F); + buf.writeIntLE(1); + buf.writeIntLE(1 << 25); + buf.writeIntLE(0); + buf.writeByte(0); + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(buf); + } + }); + } + @Test public void testDecompressedAndCompressedLengthMismatch() { final byte[] data = Arrays.copyOf(DATA, DATA.length); diff --git a/codec/src/test/java/io/netty/handler/codec/compression/ZstdDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/compression/ZstdDecoderTest.java index 0c28f3ebf82..5f8eb0b4a41 100644 --- a/codec/src/test/java/io/netty/handler/codec/compression/ZstdDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/compression/ZstdDecoderTest.java @@ -16,7 +16,18 @@ package io.netty.handler.codec.compression; import com.github.luben.zstd.Zstd; +import com.github.luben.zstd.ZstdCompressCtx; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; public class ZstdDecoderTest extends AbstractDecoderTest { @@ -33,4 +44,76 @@ public EmbeddedChannel createChannel() { protected byte[] compress(byte[] data) throws Exception { return Zstd.compress(data); } + + @Test + public void testFrameWithWindowLogAboveCapIsRejected() { + // Incompressible random data so libzstd actually has to use the declared window + // (highly compressible content lets libzstd shrink the effective window to the + // content size, making setLongMax ineffective for the test). + byte[] payload = new byte[256 * 1024]; + new Random(12345L).nextBytes(payload); + + // Compressed with windowLog = 21 (2 MiB window). + final byte[] compressed = compressWithWindowLog(payload, 21); + + // Decoder caps Window_Log at 15 (32 KiB) -> the frame must be rejected. + final EmbeddedChannel ch = new EmbeddedChannel(new ZstdDecoder(4 * 1024 * 1024, 15)); + try { + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + ch.writeInbound(Unpooled.wrappedBuffer(compressed)); + } + }); + } finally { + ch.finishAndReleaseAll(); + } + } + + @Test + public void testFrameWithWindowLogWithinCapIsAccepted() { + byte[] payload = new byte[256 * 1024]; + new Random(12345L).nextBytes(payload); + + byte[] compressed = compressWithWindowLog(payload, 18); // 256 KiB window + + EmbeddedChannel ch = new EmbeddedChannel(new ZstdDecoder(4 * 1024 * 1024, 20)); + try { + assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(compressed))); + + ByteBuf acc = Unpooled.buffer(); + try { + ByteBuf b; + while ((b = ch.readInbound()) != null) { + try { + acc.writeBytes(b); + } finally { + b.release(); + } + } + byte[] actual = new byte[acc.readableBytes()]; + acc.readBytes(actual); + assertArrayEquals(payload, actual); + } finally { + acc.release(); + } + } finally { + ch.finishAndReleaseAll(); + } + } + + private static byte[] compressWithWindowLog(byte[] data, int windowLog) { + ZstdCompressCtx ctx = new ZstdCompressCtx(); + try { + ctx.setLevel(Zstd.defaultCompressionLevel()); + ctx.setWindowLog(windowLog); + byte[] dst = new byte[(int) Zstd.compressBound(data.length)]; + int written = ctx.compressByteArray(dst, 0, dst.length, data, 0, data.length); + byte[] out = new byte[written]; + System.arraycopy(dst, 0, out, 0, written); + return out; + } finally { + ctx.close(); + } + } } diff --git a/codec/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java b/codec/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java index 296e3dac7ea..250ed34587a 100644 --- a/codec/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java @@ -22,7 +22,9 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; @@ -30,9 +32,10 @@ import java.io.InputStream; - import static org.mockito.Mockito.when; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; public class ZstdEncoderTest extends AbstractEncoderTest { @@ -46,6 +49,14 @@ public void setup() { when(ctx.alloc()).thenReturn(ByteBufAllocator.DEFAULT); } + public static ByteBuf[] hugeData() { + final byte[] bytesHuge = new byte[36 * 1024 * 1024]; + ByteBuf heap = Unpooled.wrappedBuffer(bytesHuge); + ByteBuf direct = Unpooled.directBuffer(bytesHuge.length); + direct.writeBytes(bytesHuge); + return new ByteBuf[] {heap, direct}; + } + @Override public EmbeddedChannel createChannel() { return new EmbeddedChannel(new ZstdEncoder()); @@ -54,6 +65,16 @@ public EmbeddedChannel createChannel() { @ParameterizedTest @MethodSource("largeData") public void testCompressionOfLargeBatchedFlow(final ByteBuf data) throws Exception { + testCompressionOfLargeDataBatchedFlow(data); + } + + @ParameterizedTest + @MethodSource("hugeData") + public void testCompressionOfHugeBatchedFlow(final ByteBuf data) throws Exception { + testCompressionOfLargeDataBatchedFlow(data); + } + + private void testCompressionOfLargeDataBatchedFlow(final ByteBuf data) throws Exception { final int dataLength = data.readableBytes(); int written = 0; @@ -78,6 +99,18 @@ public void testCompressionOfSmallBatchedFlow(final ByteBuf data) throws Excepti testCompressionOfBatchedFlow(data); } + @Test + public void testCompressionOfTinyData() throws Exception { + ByteBuf data = Unpooled.copiedBuffer("Hello, World", CharsetUtil.UTF_8); + assertTrue(channel.writeOutbound(data)); + assertTrue(channel.finish()); + + ByteBuf out = channel.readOutbound(); + assertThat(out.readableBytes()).isPositive(); + out.release(); + assertNull(channel.readOutbound()); + } + @Override protected ByteBuf decompress(ByteBuf compressed, int originalLength) throws Exception { InputStream is = new ByteBufInputStream(compressed, true); diff --git a/codec/src/test/java/io/netty/handler/codec/compression/ZstdIntegrationTest.java b/codec/src/test/java/io/netty/handler/codec/compression/ZstdIntegrationTest.java index 575dcddb993..620875ac9ab 100644 --- a/codec/src/test/java/io/netty/handler/codec/compression/ZstdIntegrationTest.java +++ b/codec/src/test/java/io/netty/handler/codec/compression/ZstdIntegrationTest.java @@ -17,7 +17,7 @@ import io.netty.channel.embedded.EmbeddedChannel; -import static io.netty.handler.codec.compression.ZstdConstants.MAX_BLOCK_SIZE; +import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_MAX_ENCODE_SIZE; public class ZstdIntegrationTest extends AbstractIntegrationTest { @@ -25,7 +25,7 @@ public class ZstdIntegrationTest extends AbstractIntegrationTest { @Override protected EmbeddedChannel createEncoder() { - return new EmbeddedChannel(new ZstdEncoder(BLOCK_SIZE, MAX_BLOCK_SIZE)); + return new EmbeddedChannel(new ZstdEncoder(BLOCK_SIZE, DEFAULT_MAX_ENCODE_SIZE)); } @Override diff --git a/codec/src/test/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoderTest.java index 432fd4452fc..e7aae2fdc59 100644 --- a/codec/src/test/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoderTest.java @@ -17,14 +17,17 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.TooLongFrameException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; import static io.netty.buffer.Unpooled.*; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class ProtobufVarint32FrameDecoderTest { @@ -79,4 +82,82 @@ public void testRegularDecode() { expected.release(); actual.release(); } + + @Test + public void testFrameWithinMaxFrameLength() { + EmbeddedChannel channel = new EmbeddedChannel(new ProtobufVarint32FrameDecoder(10)); + byte[] b = { 4, 1, 1, 1, 1 }; + assertTrue(channel.writeInbound(wrappedBuffer(b))); + + ByteBuf expected = wrappedBuffer(new byte[] { 1, 1, 1, 1 }); + ByteBuf actual = channel.readInbound(); + assertEquals(expected, actual); + assertFalse(channel.finish()); + + expected.release(); + actual.release(); + } + + @Test + public void testFrameExceedingMaxFrameLength() { + final EmbeddedChannel channel = new EmbeddedChannel(new ProtobufVarint32FrameDecoder(3)); + final byte[] b = { 4, 1, 1, 1, 1 }; + assertThrows(TooLongFrameException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(wrappedBuffer(b)); + } + }); + assertNull(channel.readInbound()); + assertFalse(channel.finish()); + } + + @Test + public void testOversizedFramePartialDiscard() { + final EmbeddedChannel channel = new EmbeddedChannel(new ProtobufVarint32FrameDecoder(3)); + + // Frame with length=10, only send length byte + 5 data bytes + final byte[] partial = { 10, 1, 2, 3, 4, 5 }; + assertThrows(TooLongFrameException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(wrappedBuffer(partial)); + } + }); + + // Send remaining 5 bytes — should be silently discarded + byte[] remaining = { 6, 7, 8, 9, 10 }; + assertFalse(channel.writeInbound(wrappedBuffer(remaining))); + assertNull(channel.readInbound()); + assertFalse(channel.finish()); + } + + @Test + public void testValidFrameAfterOversized() { + final EmbeddedChannel channel = new EmbeddedChannel(new ProtobufVarint32FrameDecoder(5)); + + // Oversized frame: length=10, all data present + final byte[] oversized = new byte[11]; + oversized[0] = 10; + for (int i = 1; i <= 10; i++) { + oversized[i] = (byte) i; + } + assertThrows(TooLongFrameException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(wrappedBuffer(oversized)); + } + }); + + // Valid frame after recovery + byte[] valid = { 3, 10, 20, 30 }; + assertTrue(channel.writeInbound(wrappedBuffer(valid))); + ByteBuf expected = wrappedBuffer(new byte[] { 10, 20, 30 }); + ByteBuf actual = channel.readInbound(); + assertEquals(expected, actual); + assertFalse(channel.finish()); + + expected.release(); + actual.release(); + } } diff --git a/common/pom.xml b/common/pom.xml index 81587e4e9f0..318fe48a8b3 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -21,7 +21,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-common diff --git a/common/src/main/java/io/netty/util/DefaultAttributeMap.java b/common/src/main/java/io/netty/util/DefaultAttributeMap.java index a39bb5b996f..39a2a28df09 100644 --- a/common/src/main/java/io/netty/util/DefaultAttributeMap.java +++ b/common/src/main/java/io/netty/util/DefaultAttributeMap.java @@ -68,11 +68,12 @@ private static void orderedCopyOnInsert(DefaultAttribute[] sortedSrc, int srcLen int i; for (i = srcLength - 1; i >= 0; i--) { DefaultAttribute attribute = sortedSrc[i]; - assert attribute.key.id() != id; - if (attribute.key.id() < id) { + int attributeKeyId = attribute.key.id(); + assert attributeKeyId != id; + if (attributeKeyId < id) { break; } - copy[i + 1] = sortedSrc[i]; + copy[i + 1] = attribute; } copy[i + 1] = toInsert; final int toCopy = i + 1; @@ -153,7 +154,6 @@ private void removeAttributeIfMatch(AttributeKey key, DefaultAttribute } } - @SuppressWarnings("serial") private static final class DefaultAttribute extends AtomicReference implements Attribute { private static final AtomicReferenceFieldUpdater MAP_UPDATER = diff --git a/common/src/main/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimap.java b/common/src/main/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimap.java new file mode 100644 index 00000000000..b6a36770704 --- /dev/null +++ b/common/src/main/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimap.java @@ -0,0 +1,1550 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/* + * Written by Doug Lea with assistance from members of JCP JSR-166 + * Expert Group and released to the public domain, as explained at + * https://creativecommons.org/publicdomain/zero/1.0/ + * + * With substantial modifications by The Netty Project team. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.LongCounter; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.ThreadLocalRandom; + +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A scalable concurrent multimap implementation. + * The map is sorted according to the natural ordering of its {@code int} keys. + * + *

This class implements a concurrent variant of SkipLists + * providing expected average log(n) time cost for the + * {@code containsKey}, {@code get}, {@code put} and + * {@code remove} operations and their variants. Insertion, removal, + * update, and access operations safely execute concurrently by + * multiple threads. + * + *

This class is a multimap, which means the same key can be associated with + * multiple values. Each such instance will be represented by a separate + * {@code IntEntry}. There is no defined ordering for the values mapped to + * the same key. + * + *

As a multimap, certain atomic operations like {@code putIfPresent}, + * {@code compute}, or {@code computeIfPresent}, cannot be supported. + * Likewise, some get-like operations cannot be supported. + * + *

Iterators and spliterators are + * weakly consistent. + * + *

All {@code IntEntry} pairs returned by methods in this class + * represent snapshots of mappings at the time they were + * produced. They do not support the {@code Entry.setValue} + * method. (Note however that it is possible to change mappings in the + * associated map using {@code put}, {@code putIfAbsent}, or + * {@code replace}, depending on exactly which effect you need.) + * + *

Beware that bulk operations {@code putAll}, {@code equals}, + * {@code toArray}, {@code containsValue}, and {@code clear} are + * not guaranteed to be performed atomically. For example, an + * iterator operating concurrently with a {@code putAll} operation + * might view only some of the added elements. + * + *

This class does not permit the use of {@code null} values + * because some null return values cannot be reliably distinguished from + * the absence of elements. + * + * @param the type of mapped values + */ +public class ConcurrentSkipListIntObjMultimap implements Iterable> { + /* + * This class implements a tree-like two-dimensionally linked skip + * list in which the index levels are represented in separate + * nodes from the base nodes holding data. There are two reasons + * for taking this approach instead of the usual array-based + * structure: 1) Array based implementations seem to encounter + * more complexity and overhead 2) We can use cheaper algorithms + * for the heavily-traversed index lists than can be used for the + * base lists. Here's a picture of some of the basics for a + * possible list with 2 levels of index: + * + * Head nodes Index nodes + * +-+ right +-+ +-+ + * |2|---------------->| |--------------------->| |->null + * +-+ +-+ +-+ + * | down | | + * v v v + * +-+ +-+ +-+ +-+ +-+ +-+ + * |1|----------->| |->| |------>| |----------->| |------>| |->null + * +-+ +-+ +-+ +-+ +-+ +-+ + * v | | | | | + * Nodes next v v v v v + * +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ + * | |->|A|->|B|->|C|->|D|->|E|->|F|->|G|->|H|->|I|->|J|->|K|->null + * +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ +-+ + * + * The base lists use a variant of the HM linked ordered set + * algorithm. See Tim Harris, "A pragmatic implementation of + * non-blocking linked lists" + * https://www.cl.cam.ac.uk/~tlh20/publications.html and Maged + * Michael "High Performance Dynamic Lock-Free Hash Tables and + * List-Based Sets" + * https://www.research.ibm.com/people/m/michael/pubs.htm. The + * basic idea in these lists is to mark the "next" pointers of + * deleted nodes when deleting to avoid conflicts with concurrent + * insertions, and when traversing to keep track of triples + * (predecessor, node, successor) in order to detect when and how + * to unlink these deleted nodes. + * + * Rather than using mark-bits to mark list deletions (which can + * be slow and space-intensive using AtomicMarkedReference), nodes + * use direct CAS'able next pointers. On deletion, instead of + * marking a pointer, they splice in another node that can be + * thought of as standing for a marked pointer (see method + * unlinkNode). Using plain nodes acts roughly like "boxed" + * implementations of marked pointers, but uses new nodes only + * when nodes are deleted, not for every link. This requires less + * space and supports faster traversal. Even if marked references + * were better supported by JVMs, traversal using this technique + * might still be faster because any search need only read ahead + * one more node than otherwise required (to check for trailing + * marker) rather than unmasking mark bits or whatever on each + * read. + * + * This approach maintains the essential property needed in the HM + * algorithm of changing the next-pointer of a deleted node so + * that any other CAS of it will fail, but implements the idea by + * changing the pointer to point to a different node (with + * otherwise illegal null fields), not by marking it. While it + * would be possible to further squeeze space by defining marker + * nodes not to have key/value fields, it isn't worth the extra + * type-testing overhead. The deletion markers are rarely + * encountered during traversal, are easily detected via null + * checks that are needed anyway, and are normally quickly garbage + * collected. (Note that this technique would not work well in + * systems without garbage collection.) + * + * In addition to using deletion markers, the lists also use + * nullness of value fields to indicate deletion, in a style + * similar to typical lazy-deletion schemes. If a node's value is + * null, then it is considered logically deleted and ignored even + * though it is still reachable. + * + * Here's the sequence of events for a deletion of node n with + * predecessor b and successor f, initially: + * + * +------+ +------+ +------+ + * ... | b |------>| n |----->| f | ... + * +------+ +------+ +------+ + * + * 1. CAS n's value field from non-null to null. + * Traversals encountering a node with null value ignore it. + * However, ongoing insertions and deletions might still modify + * n's next pointer. + * + * 2. CAS n's next pointer to point to a new marker node. + * From this point on, no other nodes can be appended to n. + * which avoids deletion errors in CAS-based linked lists. + * + * +------+ +------+ +------+ +------+ + * ... | b |------>| n |----->|marker|------>| f | ... + * +------+ +------+ +------+ +------+ + * + * 3. CAS b's next pointer over both n and its marker. + * From this point on, no new traversals will encounter n, + * and it can eventually be GCed. + * +------+ +------+ + * ... | b |----------------------------------->| f | ... + * +------+ +------+ + * + * A failure at step 1 leads to simple retry due to a lost race + * with another operation. Steps 2-3 can fail because some other + * thread noticed during a traversal a node with null value and + * helped out by marking and/or unlinking. This helping-out + * ensures that no thread can become stuck waiting for progress of + * the deleting thread. + * + * Skip lists add indexing to this scheme, so that the base-level + * traversals start close to the locations being found, inserted + * or deleted -- usually base level traversals only traverse a few + * nodes. This doesn't change the basic algorithm except for the + * need to make sure base traversals start at predecessors (here, + * b) that are not (structurally) deleted, otherwise retrying + * after processing the deletion. + * + * Index levels are maintained using CAS to link and unlink + * successors ("right" fields). Races are allowed in index-list + * operations that can (rarely) fail to link in a new index node. + * (We can't do this of course for data nodes.) However, even + * when this happens, the index lists correctly guide search. + * This can impact performance, but since skip lists are + * probabilistic anyway, the net result is that under contention, + * the effective "p" value may be lower than its nominal value. + * + * Index insertion and deletion sometimes require a separate + * traversal pass occurring after the base-level action, to add or + * remove index nodes. This adds to single-threaded overhead, but + * improves contended multithreaded performance by narrowing + * interference windows, and allows deletion to ensure that all + * index nodes will be made unreachable upon return from a public + * remove operation, thus avoiding unwanted garbage retention. + * + * Indexing uses skip list parameters that maintain good search + * performance while using sparser-than-usual indices: The + * hardwired parameters k=1, p=0.5 (see method doPut) mean that + * about one-quarter of the nodes have indices. Of those that do, + * half have one level, a quarter have two, and so on (see Pugh's + * Skip List Cookbook, sec 3.4), up to a maximum of 62 levels + * (appropriate for up to 2^63 elements). The expected total + * space requirement for a map is slightly less than for the + * current implementation of java.util.TreeMap. + * + * Changing the level of the index (i.e, the height of the + * tree-like structure) also uses CAS. Creation of an index with + * height greater than the current level adds a level to the head + * index by CAS'ing on a new top-most head. To maintain good + * performance after a lot of removals, deletion methods + * heuristically try to reduce the height if the topmost levels + * appear to be empty. This may encounter races in which it is + * possible (but rare) to reduce and "lose" a level just as it is + * about to contain an index (that will then never be + * encountered). This does no structural harm, and in practice + * appears to be a better option than allowing unrestrained growth + * of levels. + * + * This class provides concurrent-reader-style memory consistency, + * ensuring that read-only methods report status and/or values no + * staler than those holding at method entry. This is done by + * performing all publication and structural updates using + * (volatile) CAS, placing an acquireFence in a few access + * methods, and ensuring that linked objects are transitively + * acquired via dependent reads (normally once) unless performing + * a volatile-mode CAS operation (that also acts as an acquire and + * release). This form of fence-hoisting is similar to RCU and + * related techniques (see McKenney's online book + * https://www.kernel.org/pub/linux/kernel/people/paulmck/perfbook/perfbook.html) + * It minimizes overhead that may otherwise occur when using so + * many volatile-mode reads. Using explicit acquireFences is + * logistically easier than targeting particular fields to be read + * in acquire mode: fences are just hoisted up as far as possible, + * to the entry points or loop headers of a few methods. A + * potential disadvantage is that these few remaining fences are + * not easily optimized away by compilers under exclusively + * single-thread use. It requires some care to avoid volatile + * mode reads of other fields. (Note that the memory semantics of + * a reference dependently read in plain mode exactly once are + * equivalent to those for atomic opaque mode.) Iterators and + * other traversals encounter each node and value exactly once. + * Other operations locate an element (or position to insert an + * element) via a sequence of dereferences. This search is broken + * into two parts. Method findPredecessor (and its specialized + * embeddings) searches index nodes only, returning a base-level + * predecessor of the key. Callers carry out the base-level + * search, restarting if encountering a marker preventing link + * modification. In some cases, it is possible to encounter a + * node multiple times while descending levels. For mutative + * operations, the reported value is validated using CAS (else + * retrying), preserving linearizability with respect to each + * other. Others may return any (non-null) value holding in the + * course of the method call. (Search-based methods also include + * some useless-looking explicit null checks designed to allow + * more fields to be nulled out upon removal, to reduce floating + * garbage, but which is not currently done, pending discovery of + * a way to do this with less impact on other operations.) + * + * To produce random values without interference across threads, + * we use within-JDK thread local random support (via the + * "secondary seed", to avoid interference with user-level + * ThreadLocalRandom.) + * + * For explanation of algorithms sharing at least a couple of + * features with this one, see Mikhail Fomitchev's thesis + * (https://www.cs.yorku.ca/~mikhail/), Keir Fraser's thesis + * (https://www.cl.cam.ac.uk/users/kaf24/), and Hakan Sundell's + * thesis (https://www.cs.chalmers.se/~phs/). + * + * Notation guide for local variables + * Node: b, n, f, p for predecessor, node, successor, aux + * Index: q, r, d for index node, right, down. + * Head: h + * Keys: k, key + * Values: v, value + * Comparisons: c + */ + + /** No-key sentinel value */ + private final int noKey; + /** Lazily initialized topmost index of the skiplist. */ + private volatile /*XXX: Volatile only required for ARFU; remove if we can use VarHandle*/ Index head; + /** Element count */ + private final LongCounter adder; + + /** + * Nodes hold keys and values, and are singly linked in sorted + * order, possibly with some intervening marker nodes. The list is + * headed by a header node accessible as head.node. Headers and + * marker nodes have null keys. The val field (but currently not + * the key field) is nulled out upon deletion. + */ + static final class Node { + final int key; // currently, never detached + volatile /*XXX: Volatile only required for ARFU; remove if we can use VarHandle*/ V val; + volatile /*XXX: Volatile only required for ARFU; remove if we can use VarHandle*/ Node next; + Node(int key, V value, Node next) { + this.key = key; + val = value; + this.next = next; + } + } + + /** + * Index nodes represent the levels of the skip list. + */ + static final class Index { + final Node node; // currently, never detached + final Index down; + volatile /*XXX: Volatile only required for ARFU; remove if we can use VarHandle*/ Index right; + Index(Node node, Index down, Index right) { + this.node = node; + this.down = down; + this.right = right; + } + } + + /** + * The multimap entry type with primitive {@code int} keys. + */ + public static final class IntEntry implements Comparable> { + private final int key; + private final V value; + + public IntEntry(int key, V value) { + this.key = key; + this.value = value; + } + + /** + * Get the corresponding key. + */ + public int getKey() { + return key; + } + + /** + * Get the corresponding value. + */ + public V getValue() { + return value; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof IntEntry)) { + return false; + } + + IntEntry intEntry = (IntEntry) o; + return key == intEntry.key && (value == intEntry.value || (value != null && value.equals(intEntry.value))); + } + + @Override + public int hashCode() { + int result = key; + result = 31 * result + (value == null ? 0 : value.hashCode()); + return result; + } + + @Override + public String toString() { + return "IntEntry[" + key + " => " + value + ']'; + } + + @Override + public int compareTo(IntEntry o) { + return cpr(key, o.key); + } + } + + /* ---------------- Utilities -------------- */ + + /** + * Compares using comparator or natural ordering if null. + * Called only by methods that have performed required type checks. + */ + static int cpr(int x, int y) { + return (x < y) ? -1 : x == y ? 0 : 1; + } + + /** + * Returns the header for base node list, or null if uninitialized + */ + final Node baseHead() { + Index h; + acquireFence(); + return (h = head) == null ? null : h.node; + } + + /** + * Tries to unlink deleted node n from predecessor b (if both + * exist), by first splicing in a marker if not already present. + * Upon return, node n is sure to be unlinked from b, possibly + * via the actions of some other thread. + * + * @param b if nonnull, predecessor + * @param n if nonnull, node known to be deleted + */ + static void unlinkNode(Node b, Node n, int noKey) { + if (b != null && n != null) { + Node f, p; + for (;;) { + if ((f = n.next) != null && f.key == noKey) { + p = f.next; // already marked + break; + } else if (NEXT.compareAndSet(n, f, + new Node(noKey, null, f))) { + p = f; // add marker + break; + } + } + NEXT.compareAndSet(b, n, p); + } + } + + /** + * Adds to element count, initializing adder if necessary + * + * @param c count to add + */ + private void addCount(long c) { + adder.add(c); + } + + /** + * Returns element count, initializing adder if necessary. + */ + final long getAdderCount() { + long c; + return (c = adder.value()) <= 0L ? 0L : c; // ignore transient negatives + } + + /* ---------------- Traversal -------------- */ + + /** + * Returns an index node with key strictly less than given key. + * Also unlinks indexes to deleted nodes found along the way. + * Callers rely on this side-effect of clearing indices to deleted + * nodes. + * + * @param key if nonnull the key + * @return a predecessor node of key, or null if uninitialized or null key + */ + private Node findPredecessor(int key) { + Index q; + acquireFence(); + if ((q = head) == null || key == noKey) { + return null; + } else { + for (Index r, d;;) { + while ((r = q.right) != null) { + Node p; int k; + if ((p = r.node) == null || (k = p.key) == noKey || + p.val == null) { // unlink index to deleted node + RIGHT.compareAndSet(q, r, r.right); + } else if (cpr(key, k) > 0) { + q = r; + } else { + break; + } + } + if ((d = q.down) != null) { + q = d; + } else { + return q.node; + } + } + } + } + + /** + * Returns node holding key or null if no such, clearing out any + * deleted nodes seen along the way. Repeatedly traverses at + * base-level looking for key starting at predecessor returned + * from findPredecessor, processing base-level deletions as + * encountered. Restarts occur, at traversal step encountering + * node n, if n's key field is null, indicating it is a marker, so + * its predecessor is deleted before continuing, which we help do + * by re-finding a valid predecessor. The traversal loops in + * doPut, doRemove, and findNear all include the same checks. + * + * @param key the key + * @return node holding key, or null if no such + */ + private Node findNode(int key) { + if (key == noKey) { + throw new IllegalArgumentException(); // don't postpone errors + } + Node b; + outer: while ((b = findPredecessor(key)) != null) { + for (;;) { + Node n; int k; int c; + if ((n = b.next) == null) { + break outer; // empty + } else if ((k = n.key) == noKey) { + break; // b is deleted + } else if (n.val == null) { + unlinkNode(b, n, noKey); // n is deleted + } else if ((c = cpr(key, k)) > 0) { + b = n; + } else if (c == 0) { + return n; + } else { + break outer; + } + } + } + return null; + } + + /** + * Gets value for key. Same idea as findNode, except skips over + * deletions and markers, and returns first encountered value to + * avoid possibly inconsistent rereads. + * + * @param key the key + * @return the value, or null if absent + */ + private V doGet(int key) { + Index q; + acquireFence(); + if (key == noKey) { + throw new IllegalArgumentException(); + } + V result = null; + if ((q = head) != null) { + outer: for (Index r, d;;) { + while ((r = q.right) != null) { + Node p; int k; V v; int c; + if ((p = r.node) == null || (k = p.key) == noKey || + (v = p.val) == null) { + RIGHT.compareAndSet(q, r, r.right); + } else if ((c = cpr(key, k)) > 0) { + q = r; + } else if (c == 0) { + result = v; + break outer; + } else { + break; + } + } + if ((d = q.down) != null) { + q = d; + } else { + Node b, n; + if ((b = q.node) != null) { + while ((n = b.next) != null) { + V v; int c; + int k = n.key; + if ((v = n.val) == null || k == noKey || + (c = cpr(key, k)) > 0) { + b = n; + } else { + if (c == 0) { + result = v; + } + break; + } + } + } + break; + } + } + } + return result; + } + + /* ---------------- Insertion -------------- */ + + /** + * Main insertion method. Adds element if not present, or + * replaces value if present and onlyIfAbsent is false. + * + * @param key the key + * @param value the value that must be associated with key + * @param onlyIfAbsent if should not insert if already present + */ + private V doPut(int key, V value, boolean onlyIfAbsent) { + if (key == noKey) { + throw new IllegalArgumentException(); + } + for (;;) { + Index h; Node b; + acquireFence(); + int levels = 0; // number of levels descended + if ((h = head) == null) { // try to initialize + Node base = new Node(noKey, null, null); + h = new Index(base, null, null); + b = HEAD.compareAndSet(this, null, h) ? base : null; + } else { + for (Index q = h, r, d;;) { // count while descending + while ((r = q.right) != null) { + Node p; int k; + if ((p = r.node) == null || (k = p.key) == noKey || + p.val == null) { + RIGHT.compareAndSet(q, r, r.right); + } else if (cpr(key, k) > 0) { + q = r; + } else { + break; + } + } + if ((d = q.down) != null) { + ++levels; + q = d; + } else { + b = q.node; + break; + } + } + } + if (b != null) { + Node z = null; // new node, if inserted + for (;;) { // find insertion point + Node n, p; int k; V v; int c; + if ((n = b.next) == null) { + if (b.key == noKey) { // if empty, type check key now TODO: remove? + cpr(key, key); + } + c = -1; + } else if ((k = n.key) == noKey) { + break; // can't append; restart + } else if ((v = n.val) == null) { + unlinkNode(b, n, noKey); + c = 1; + } else if ((c = cpr(key, k)) > 0) { + b = n; // Multimap +// } else if (c == 0 && +// (onlyIfAbsent || VAL.compareAndSet(n, v, value))) { +// return v; + } + + if (c <= 0 && + NEXT.compareAndSet(b, n, + p = new Node(key, value, n))) { + z = p; + break; + } + } + + if (z != null) { + int lr = ThreadLocalRandom.current().nextInt(); + if ((lr & 0x3) == 0) { // add indices with 1/4 prob + int hr = ThreadLocalRandom.current().nextInt(); + long rnd = ((long) hr << 32) | ((long) lr & 0xffffffffL); + int skips = levels; // levels to descend before add + Index x = null; + for (;;) { // create at most 62 indices + x = new Index(z, x, null); + if (rnd >= 0L || --skips < 0) { + break; + } else { + rnd <<= 1; + } + } + if (addIndices(h, skips, x, noKey) && skips < 0 && + head == h) { // try to add new level + Index hx = new Index(z, x, null); + Index nh = new Index(h.node, h, hx); + HEAD.compareAndSet(this, h, nh); + } + if (z.val == null) { // deleted while adding indices + findPredecessor(key); // clean + } + } + addCount(1L); + return null; + } + } + } + } + + /** + * Add indices after an insertion. Descends iteratively to the + * highest level of insertion, then recursively, to chain index + * nodes to lower ones. Returns null on (staleness) failure, + * disabling higher-level insertions. Recursion depths are + * exponentially less probable. + * + * @param q starting index for current level + * @param skips levels to skip before inserting + * @param x index for this insertion + */ + static boolean addIndices(Index q, int skips, Index x, int noKey) { + Node z; int key; + if (x != null && (z = x.node) != null && (key = z.key) != noKey && + q != null) { // hoist checks + boolean retrying = false; + for (;;) { // find splice point + Index r, d; int c; + if ((r = q.right) != null) { + Node p; int k; + if ((p = r.node) == null || (k = p.key) == noKey || + p.val == null) { + RIGHT.compareAndSet(q, r, r.right); + c = 0; + } else if ((c = cpr(key, k)) > 0) { + q = r; + } else if (c == 0) { + break; // stale + } + } else { + c = -1; + } + + if (c < 0) { + if ((d = q.down) != null && skips > 0) { + --skips; + q = d; + } else if (d != null && !retrying && + !addIndices(d, 0, x.down, noKey)) { + break; + } else { + x.right = r; + if (RIGHT.compareAndSet(q, r, x)) { + return true; + } else { + retrying = true; // re-find splice point + } + } + } + } + } + return false; + } + + /* ---------------- Deletion -------------- */ + + /** + * Main deletion method. Locates node, nulls value, appends a + * deletion marker, unlinks predecessor, removes associated index + * nodes, and possibly reduces head index level. + * + * @param key the key + * @param value if non-null, the value that must be + * associated with key + * @return the node, or null if not found + */ + final V doRemove(int key, Object value) { + if (key == noKey) { + throw new IllegalArgumentException(); + } + V result = null; + Node b; + outer: while ((b = findPredecessor(key)) != null && + result == null) { + for (;;) { + Node n; int k; V v; int c; + if ((n = b.next) == null) { + break outer; + } else if ((k = n.key) == noKey) { + break; + } else if ((v = n.val) == null) { + unlinkNode(b, n, noKey); + } else if ((c = cpr(key, k)) > 0) { + b = n; + } else if (c < 0) { + break outer; + } else if (value != null && !value.equals(v)) { +// break outer; + b = n; // Multimap. + } else if (VAL.compareAndSet(n, v, null)) { + result = v; + unlinkNode(b, n, noKey); + break; // loop to clean up + } + } + } + if (result != null) { + tryReduceLevel(); + addCount(-1L); + } + return result; + } + + /** + * Possibly reduce head level if it has no nodes. This method can + * (rarely) make mistakes, in which case levels can disappear even + * though they are about to contain index nodes. This impacts + * performance, not correctness. To minimize mistakes as well as + * to reduce hysteresis, the level is reduced by one only if the + * topmost three levels look empty. Also, if the removed level + * looks non-empty after CAS, we try to change it back quick + * before anyone notices our mistake! (This trick works pretty + * well because this method will practically never make mistakes + * unless current thread stalls immediately before first CAS, in + * which case it is very unlikely to stall again immediately + * afterwards, so will recover.) + *

+ * We put up with all this rather than just let levels grow + * because otherwise, even a small map that has undergone a large + * number of insertions and removals will have a lot of levels, + * slowing down access more than would an occasional unwanted + * reduction. + */ + private void tryReduceLevel() { + Index h, d, e; + if ((h = head) != null && h.right == null && + (d = h.down) != null && d.right == null && + (e = d.down) != null && e.right == null && + HEAD.compareAndSet(this, h, d) && + h.right != null) { // recheck + HEAD.compareAndSet(this, d, h); // try to backout + } + } + + /* ---------------- Finding and removing first element -------------- */ + + /** + * Gets first valid node, unlinking deleted nodes if encountered. + * @return first node or null if empty + */ + final Node findFirst() { + Node b, n; + if ((b = baseHead()) != null) { + while ((n = b.next) != null) { + if (n.val == null) { + unlinkNode(b, n, noKey); + } else { + return n; + } + } + } + return null; + } + + /** + * Entry snapshot version of findFirst + */ + final IntEntry findFirstEntry() { + Node b, n; V v; + if ((b = baseHead()) != null) { + while ((n = b.next) != null) { + if ((v = n.val) == null) { + unlinkNode(b, n, noKey); + } else { + return new IntEntry(n.key, v); + } + } + } + return null; + } + + /** + * Removes first entry; returns its snapshot. + * @return null if empty, else snapshot of first entry + */ + private IntEntry doRemoveFirstEntry() { + Node b, n; V v; + if ((b = baseHead()) != null) { + while ((n = b.next) != null) { + if ((v = n.val) == null || VAL.compareAndSet(n, v, null)) { + int k = n.key; + unlinkNode(b, n, noKey); + if (v != null) { + tryReduceLevel(); + findPredecessor(k); // clean index + addCount(-1L); + return new IntEntry(k, v); + } + } + } + } + return null; + } + + /* ---------------- Finding and removing last element -------------- */ + + /** + * Specialized version of find to get last valid node. + * @return last node or null if empty + */ + final Node findLast() { + outer: for (;;) { + Index q; Node b; + acquireFence(); + if ((q = head) == null) { + break; + } + for (Index r, d;;) { + while ((r = q.right) != null) { + Node p; + if ((p = r.node) == null || p.val == null) { + RIGHT.compareAndSet(q, r, r.right); + } else { + q = r; + } + } + if ((d = q.down) != null) { + q = d; + } else { + b = q.node; + break; + } + } + if (b != null) { + for (;;) { + Node n; + if ((n = b.next) == null) { + if (b.key == noKey) { // empty + break outer; + } else { + return b; + } + } else if (n.key == noKey) { + break; + } else if (n.val == null) { + unlinkNode(b, n, noKey); + } else { + b = n; + } + } + } + } + return null; + } + + /** + * Entry version of findLast + * @return Entry for last node or null if empty + */ + final IntEntry findLastEntry() { + for (;;) { + Node n; V v; + if ((n = findLast()) == null) { + return null; + } + if ((v = n.val) != null) { + return new IntEntry(n.key, v); + } + } + } + + /** + * Removes last entry; returns its snapshot. + * Specialized variant of doRemove. + * @return null if empty, else snapshot of last entry + */ + private IntEntry doRemoveLastEntry() { + outer: for (;;) { + Index q; Node b; + acquireFence(); + if ((q = head) == null) { + break; + } + for (;;) { + Index d, r; Node p; + while ((r = q.right) != null) { + if ((p = r.node) == null || p.val == null) { + RIGHT.compareAndSet(q, r, r.right); + } else if (p.next != null) { + q = r; // continue only if a successor + } else { + break; + } + } + if ((d = q.down) != null) { + q = d; + } else { + b = q.node; + break; + } + } + if (b != null) { + for (;;) { + Node n; int k; V v; + if ((n = b.next) == null) { + if (b.key == noKey) { // empty + break outer; + } else { + break; // retry + } + } else if ((k = n.key) == noKey) { + break; + } else if ((v = n.val) == null) { + unlinkNode(b, n, noKey); + } else if (n.next != null) { + b = n; + } else if (VAL.compareAndSet(n, v, null)) { + unlinkNode(b, n, noKey); + tryReduceLevel(); + findPredecessor(k); // clean index + addCount(-1L); + return new IntEntry(k, v); + } + } + } + } + return null; + } + + /* ---------------- Relational operations -------------- */ + + // Control values OR'ed as arguments to findNear + + private static final int EQ = 1; + private static final int LT = 2; + private static final int GT = 0; // Actually checked as !LT + + /** + * Variant of findNear returning IntEntry + * @param key the key + * @param rel the relation -- OR'ed combination of EQ, LT, GT + * @return Entry fitting relation, or null if no such + */ + final IntEntry findNearEntry(int key, int rel) { + for (;;) { + Node n; V v; + if ((n = findNear(key, rel)) == null) { + return null; + } + if ((v = n.val) != null) { + return new IntEntry(n.key, v); + } + } + } + + /** + * Utility for ceiling, floor, lower, higher methods. + * @param key the key + * @param rel the relation -- OR'ed combination of EQ, LT, GT + * @return nearest node fitting relation, or null if no such + */ + final Node findNear(int key, int rel) { + if (key == noKey) { + throw new IllegalArgumentException(); + } + Node result; + outer: for (Node b;;) { + if ((b = findPredecessor(key)) == null) { + result = null; + break; // empty + } + for (;;) { + Node n; int k; int c; + if ((n = b.next) == null) { + result = (rel & LT) != 0 && b.key != noKey ? b : null; + break outer; + } else if ((k = n.key) == noKey) { + break; + } else if (n.val == null) { + unlinkNode(b, n, noKey); + } else if (((c = cpr(key, k)) == 0 && (rel & EQ) != 0) || + (c < 0 && (rel & LT) == 0)) { + result = n; + break outer; + } else if (c <= 0 && (rel & LT) != 0) { + result = b.key != noKey ? b : null; + break outer; + } else { + b = n; + } + } + } + return result; + } + + /* ---------------- Constructors -------------- */ + + /** + * Constructs a new, empty map, sorted according to the + * {@linkplain Comparable natural ordering} of the keys. + * @param noKey The value to use as a sentinel for signaling the absence of a key. + */ + public ConcurrentSkipListIntObjMultimap(int noKey) { + this.noKey = noKey; + adder = PlatformDependent.newLongCounter(); + } + + /* ------ Map API methods ------ */ + + /** + * Returns {@code true} if this map contains a mapping for the specified + * key. + * + * @param key key whose presence in this map is to be tested + * @return {@code true} if this map contains a mapping for the specified key + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if the specified key is null + */ + public boolean containsKey(int key) { + return doGet(key) != null; + } + + /** + * Returns the value to which the specified key is mapped, + * or {@code null} if this map contains no mapping for the key. + * + *

More formally, if this map contains a mapping from a key + * {@code k} to a value {@code v} such that {@code key} compares + * equal to {@code k} according to the map's ordering, then this + * method returns {@code v}; otherwise it returns {@code null}. + * (There can be at most one such mapping.) + * + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if the specified key is null + */ + public V get(int key) { + return doGet(key); + } + + /** + * Returns the value to which the specified key is mapped, + * or the given defaultValue if this map contains no mapping for the key. + * + * @param key the key + * @param defaultValue the value to return if this map contains + * no mapping for the given key + * @return the mapping for the key, if present; else the defaultValue + * @throws NullPointerException if the specified key is null + * @since 1.8 + */ + public V getOrDefault(int key, V defaultValue) { + V v; + return (v = doGet(key)) == null ? defaultValue : v; + } + + /** + * Associates the specified value with the specified key in this map. + * If the map previously contained a mapping for the key, the old + * value is replaced. + * + * @param key key with which the specified value is to be associated + * @param value value to be associated with the specified key + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if the specified key or value is null + */ + public void put(int key, V value) { + checkNotNull(value, "value"); + doPut(key, value, false); + } + + /** + * Removes the mapping for the specified key from this map if present. + * + * @param key key for which mapping should be removed + * @return the previous value associated with the specified key, or + * {@code null} if there was no mapping for the key + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if the specified key is null + */ + public V remove(int key) { + return doRemove(key, null); + } + + /** + * Returns {@code true} if this map maps one or more keys to the + * specified value. This operation requires time linear in the + * map size. Additionally, it is possible for the map to change + * during execution of this method, in which case the returned + * result may be inaccurate. + * + * @param value value whose presence in this map is to be tested + * @return {@code true} if a mapping to {@code value} exists; + * {@code false} otherwise + * @throws NullPointerException if the specified value is null + */ + public boolean containsValue(Object value) { + checkNotNull(value, "value"); + Node b, n; V v; + if ((b = baseHead()) != null) { + while ((n = b.next) != null) { + if ((v = n.val) != null && value.equals(v)) { + return true; + } else { + b = n; + } + } + } + return false; + } + + /** + * Get the approximate size of the collection. + */ + public int size() { + long c; + return baseHead() == null ? 0 : + (c = getAdderCount()) >= Integer.MAX_VALUE ? + Integer.MAX_VALUE : (int) c; + } + + /** + * Check if the collection is empty. + */ + public boolean isEmpty() { + return findFirst() == null; + } + + /** + * Removes all of the mappings from this map. + */ + public void clear() { + Index h, r, d; Node b; + acquireFence(); + while ((h = head) != null) { + if ((r = h.right) != null) { // remove indices + RIGHT.compareAndSet(h, r, null); + } else if ((d = h.down) != null) { // remove levels + HEAD.compareAndSet(this, h, d); + } else { + long count = 0L; + if ((b = h.node) != null) { // remove nodes + Node n; V v; + while ((n = b.next) != null) { + if ((v = n.val) != null && + VAL.compareAndSet(n, v, null)) { + --count; + v = null; + } + if (v == null) { + unlinkNode(b, n, noKey); + } + } + } + if (count != 0L) { + addCount(count); + } else { + break; + } + } + } + } + + /* ------ ConcurrentMap API methods ------ */ + + /** + * Remove the specific entry with the given key and value, if it exist. + * + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if the specified key is null + */ + public boolean remove(int key, Object value) { + if (key == noKey) { + throw new IllegalArgumentException(); + } + return value != null && doRemove(key, value) != null; + } + + /** + * Replace the specific entry with the given key and value, with the given replacement value, + * if such an entry exist. + * + * @throws ClassCastException if the specified key cannot be compared + * with the keys currently in the map + * @throws NullPointerException if any of the arguments are null + */ + public boolean replace(int key, V oldValue, V newValue) { + if (key == noKey) { + throw new IllegalArgumentException(); + } + checkNotNull(oldValue, "oldValue"); + checkNotNull(newValue, "newValue"); + for (;;) { + Node n; V v; + if ((n = findNode(key)) == null) { + return false; + } + if ((v = n.val) != null) { + if (!oldValue.equals(v)) { + return false; + } + if (VAL.compareAndSet(n, v, newValue)) { + return true; + } + } + } + } + + /* ------ SortedMap API methods ------ */ + + public int firstKey() { + Node n = findFirst(); + if (n == null) { + return noKey; + } + return n.key; + } + + public int lastKey() { + Node n = findLast(); + if (n == null) { + return noKey; + } + return n.key; + } + + /* ---------------- Relational operations -------------- */ + + /** + * Returns a key-value mapping associated with the greatest key + * strictly less than the given key, or {@code null} if there is + * no such key. The returned entry does not support the + * {@code Entry.setValue} method. + * + * @throws NullPointerException if the specified key is null + */ + public IntEntry lowerEntry(int key) { + return findNearEntry(key, LT); + } + + /** + * @throws NullPointerException if the specified key is null + */ + public int lowerKey(int key) { + Node n = findNear(key, LT); + return n == null ? noKey : n.key; + } + + /** + * Returns a key-value mapping associated with the greatest key + * less than or equal to the given key, or {@code null} if there + * is no such key. The returned entry does not support + * the {@code Entry.setValue} method. + * + * @param key the key + * @throws NullPointerException if the specified key is null + */ + public IntEntry floorEntry(int key) { + return findNearEntry(key, LT | EQ); + } + + /** + * @param key the key + * @throws NullPointerException if the specified key is null + */ + public int floorKey(int key) { + Node n = findNear(key, LT | EQ); + return n == null ? noKey : n.key; + } + + /** + * Returns a key-value mapping associated with the least key + * greater than or equal to the given key, or {@code null} if + * there is no such entry. The returned entry does not + * support the {@code Entry.setValue} method. + * + * @throws NullPointerException if the specified key is null + */ + public IntEntry ceilingEntry(int key) { + return findNearEntry(key, GT | EQ); + } + + /** + * @throws NullPointerException if the specified key is null + */ + public int ceilingKey(int key) { + Node n = findNear(key, GT | EQ); + return n == null ? noKey : n.key; + } + + /** + * Returns a key-value mapping associated with the least key + * strictly greater than the given key, or {@code null} if there + * is no such key. The returned entry does not support + * the {@code Entry.setValue} method. + * + * @param key the key + * @throws NullPointerException if the specified key is null + */ + public IntEntry higherEntry(int key) { + return findNearEntry(key, GT); + } + + /** + * @param key the key + * @throws NullPointerException if the specified key is null + */ + public int higherKey(int key) { + Node n = findNear(key, GT); + return n == null ? noKey : n.key; + } + + /** + * Returns a key-value mapping associated with the least + * key in this map, or {@code null} if the map is empty. + * The returned entry does not support + * the {@code Entry.setValue} method. + */ + public IntEntry firstEntry() { + return findFirstEntry(); + } + + /** + * Returns a key-value mapping associated with the greatest + * key in this map, or {@code null} if the map is empty. + * The returned entry does not support + * the {@code Entry.setValue} method. + */ + public IntEntry lastEntry() { + return findLastEntry(); + } + + /** + * Removes and returns a key-value mapping associated with + * the least key in this map, or {@code null} if the map is empty. + * The returned entry does not support + * the {@code Entry.setValue} method. + */ + public IntEntry pollFirstEntry() { + return doRemoveFirstEntry(); + } + + /** + * Removes and returns a key-value mapping associated with + * the greatest key in this map, or {@code null} if the map is empty. + * The returned entry does not support + * the {@code Entry.setValue} method. + */ + public IntEntry pollLastEntry() { + return doRemoveLastEntry(); + } + + public IntEntry pollCeilingEntry(int key) { + // TODO optimize this + Node node; + V val; + do { + node = findNear(key, GT | EQ); + if (node == null) { + return null; + } + val = node.val; + } while (val == null || !remove(node.key, val)); + return new IntEntry(node.key, val); + } + + /* ---------------- Iterators -------------- */ + + /** + * Base of iterator classes + */ + abstract class Iter implements Iterator { + /** the last node returned by next() */ + Node lastReturned; + /** the next node to return from next(); */ + Node next; + /** Cache of next value field to maintain weak consistency */ + V nextValue; + + /** Initializes ascending iterator for entire range. */ + Iter() { + advance(baseHead()); + } + + @Override + public final boolean hasNext() { + return next != null; + } + + /** Advances next to higher entry. */ + final void advance(Node b) { + Node n = null; + V v = null; + if ((lastReturned = b) != null) { + while ((n = b.next) != null && (v = n.val) == null) { + b = n; + } + } + nextValue = v; + next = n; + } + + @Override + public final void remove() { + Node n; int k; + if ((n = lastReturned) == null || (k = n.key) == noKey) { + throw new IllegalStateException(); + } + // It would not be worth all of the overhead to directly + // unlink from here. Using remove is fast enough. + ConcurrentSkipListIntObjMultimap.this.remove(k, n.val); // TODO: inline and optimize this + lastReturned = null; + } + } + + final class EntryIterator extends Iter> { + @Override + public IntEntry next() { + Node n; + if ((n = next) == null) { + throw new NoSuchElementException(); + } + int k = n.key; + V v = nextValue; + advance(n); + return new IntEntry(k, v); + } + } + + @Override + public Iterator> iterator() { + return new EntryIterator(); + } + + // VarHandle mechanics + private static final AtomicReferenceFieldUpdater, Index> HEAD; + private static final AtomicReferenceFieldUpdater, Node> NEXT; + private static final AtomicReferenceFieldUpdater, Object> VAL; + private static final AtomicReferenceFieldUpdater, Index> RIGHT; + private static volatile int acquireFenceVariable; + static { + Class> mapCls = cls(ConcurrentSkipListIntObjMultimap.class); + Class> indexCls = cls(Index.class); + Class> nodeCls = cls(Node.class); + + HEAD = AtomicReferenceFieldUpdater.newUpdater(mapCls, indexCls, "head"); + NEXT = AtomicReferenceFieldUpdater.newUpdater(nodeCls, nodeCls, "next"); + VAL = AtomicReferenceFieldUpdater.newUpdater(nodeCls, Object.class, "val"); + RIGHT = AtomicReferenceFieldUpdater.newUpdater(indexCls, indexCls, "right"); + } + + @SuppressWarnings("unchecked") + private static Class cls(Class cls) { + return (Class) cls; + } + + /** + * Orders LOADS before the fence, with LOADS and STORES after the fence. + */ + private static void acquireFence() { + // Volatile store prevent prior loads from ordering down. + acquireFenceVariable = 1; + // Volatile load prevent following loads and stores from ordering up. + int ignore = acquireFenceVariable; + // Note: Putting the volatile store before the volatile load ensures + // surrounding loads and stores don't order "into" the fence. + } +} diff --git a/common/src/main/java/io/netty/util/concurrent/MpscAtomicIntegerArrayQueue.java b/common/src/main/java/io/netty/util/concurrent/MpscAtomicIntegerArrayQueue.java index 4cf804888af..7384d37535a 100644 --- a/common/src/main/java/io/netty/util/concurrent/MpscAtomicIntegerArrayQueue.java +++ b/common/src/main/java/io/netty/util/concurrent/MpscAtomicIntegerArrayQueue.java @@ -56,7 +56,7 @@ public MpscAtomicIntegerArrayQueue(int capacity, int emptyValue) { super(MathUtil.safeFindNextPositivePowerOfTwo(capacity)); if (emptyValue != 0) { this.emptyValue = emptyValue; - int end = capacity - 1; + int end = length() - 1; for (int i = 0; i < end; i++) { lazySet(i, emptyValue); } @@ -199,6 +199,37 @@ public int fill(int limit, IntSupplier supplier) { return actualLimit; } + /** + * Peek at all available elements and compute a reduction. + * The elements are not removed, and the iteration is weakly consistent. + * @param limit The maximum number of elements to process. + * @param initial The initial value to the reduction operation. + * @param op The reduction operation, taking a prior result and an element, and producing a new result. + * @return The last result of the reduction operation. + */ + public int weakPeekReduce(int limit, int initial, IntBinaryOperator op) { + ObjectUtil.checkNotNull(op, "op"); + ObjectUtil.checkPositiveOrZero(limit, "limit"); + if (limit == 0) { + return 0; + } + int result = initial; + + final int mask = this.mask; + final long cIndex = consumerIndex; // Note: could be weakened to plain-load. + for (int i = 0; i < limit; i++) { + final long index = cIndex + i; + final int offset = (int) (index & mask); + final int value = get(offset); + if (emptyValue == value) { + return result; + } + // Do not remove the element or advance the consumer index. + result = op.applyAsInt(result, value); + } + return result; + } + @Override public boolean isEmpty() { // Load consumer index before producer index, so our check is conservative. @@ -223,4 +254,8 @@ public int size() { } return size < 0 ? 0 : size > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) size; } + + public interface IntBinaryOperator { + int applyAsInt(int a, int b); + } } diff --git a/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java b/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java index afdb4d5e7c6..b316ab621ed 100644 --- a/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java +++ b/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java @@ -258,6 +258,8 @@ public void run() { executor.execute(this); return; // done } catch (Throwable ignore) { + // Restore executingThread since we're continuing to execute tasks. + executingThread.set(current); // Reset the state back to running as we will keep on executing tasks. state.set(RUNNING); // if an error happened we should just ignore it and let the loop run again as there is not diff --git a/common/src/main/java/io/netty/util/internal/ObjectCleaner.java b/common/src/main/java/io/netty/util/internal/ObjectCleaner.java index 0eb7b3495f3..5a23bcd0539 100644 --- a/common/src/main/java/io/netty/util/internal/ObjectCleaner.java +++ b/common/src/main/java/io/netty/util/internal/ObjectCleaner.java @@ -30,7 +30,10 @@ /** * Allows a way to register some {@link Runnable} that will executed once there are no references to an {@link Object} * anymore. + * + * @deprecated The object cleaner is deprecated for removal. */ +@Deprecated public final class ObjectCleaner { private static final int REFERENCE_QUEUE_POLL_TIMEOUT_MS = max(500, getInt("io.netty.util.internal.ObjectCleaner.refQueuePollTimeout", 10000)); diff --git a/common/src/main/java/io/netty/util/internal/PlatformDependent.java b/common/src/main/java/io/netty/util/internal/PlatformDependent.java index 13421fdb240..32405d3f6ea 100644 --- a/common/src/main/java/io/netty/util/internal/PlatformDependent.java +++ b/common/src/main/java/io/netty/util/internal/PlatformDependent.java @@ -997,7 +997,7 @@ public static int equalsConstantTime(byte[] bytes1, int startPos1, byte[] bytes2 * The resulting hash code will be case insensitive. */ public static int hashCodeAscii(byte[] bytes, int startPos, int length) { - return !hasUnsafe() || !unalignedAccess() ? + return !hasUnsafe() || !unalignedAccess() || BIG_ENDIAN_NATIVE_ORDER ? hashCodeAsciiSafe(bytes, startPos, length) : PlatformDependent0.hashCodeAscii(bytes, startPos, length); } diff --git a/common/src/main/java/io/netty/util/internal/PlatformDependent0.java b/common/src/main/java/io/netty/util/internal/PlatformDependent0.java index 62a1ee0f539..950b93bc959 100644 --- a/common/src/main/java/io/netty/util/internal/PlatformDependent0.java +++ b/common/src/main/java/io/netty/util/internal/PlatformDependent0.java @@ -393,7 +393,7 @@ public Object run() { Class bitsClass = Class.forName("java.nio.Bits", false, getSystemClassLoader()); int version = javaVersion(); - if (unsafeStaticFieldOffsetSupported() && version >= 9) { + if (version >= 9) { // Java9/10 use all lowercase and later versions all uppercase. String fieldName = version >= 11? "MAX_MEMORY" : "maxMemory"; // On Java9 and later we try to directly access the field as we can do this without @@ -607,10 +607,6 @@ static boolean isVirtualThread(Thread thread) { } } - private static boolean unsafeStaticFieldOffsetSupported() { - return !RUNNING_IN_NATIVE_IMAGE; - } - static boolean isExplicitNoUnsafe() { return EXPLICIT_NO_UNSAFE_CAUSE != null; } diff --git a/common/src/main/java/io/netty/util/internal/ThrowableUtil.java b/common/src/main/java/io/netty/util/internal/ThrowableUtil.java index c33a19e5591..5af0c7ba883 100644 --- a/common/src/main/java/io/netty/util/internal/ThrowableUtil.java +++ b/common/src/main/java/io/netty/util/internal/ThrowableUtil.java @@ -84,4 +84,19 @@ public static Throwable[] getSuppressed(Throwable source) { } return source.getSuppressed(); } + + /** + * Capture the stack trace of the given thread, interrupt it, and attach the stack trace as a suppressed exception + * to the given cause. + * @param thread The thread to interrupt. + * @param cause The cause to attach a stack trace to. + */ + public static void interruptAndAttachAsyncStackTrace(Thread thread, Throwable cause) { + StackTraceElement[] stackTrace = thread.getStackTrace(); + InterruptedException asyncIE = new InterruptedException( + "Asynchronous interruption: " + thread); + thread.interrupt(); + asyncIE.setStackTrace(stackTrace); + addSuppressed(cause, asyncIE); + } } diff --git a/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java b/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java index 5445b83d581..63ff46d80a5 100644 --- a/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java +++ b/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java @@ -16,6 +16,7 @@ package io.netty.util; import io.netty.util.concurrent.FastThreadLocalThread; +import org.junit.jupiter.api.extension.DynamicTestInvocationContext; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.InvocationInterceptor; import org.junit.jupiter.api.extension.ReflectiveInvocationContext; @@ -37,6 +38,26 @@ public void interceptTestMethod( final Invocation invocation, final ReflectiveInvocationContext invocationContext, final ExtensionContext extensionContext) throws Throwable { + proceed(invocation); + } + + @Override + public void interceptTestTemplateMethod( + Invocation invocation, + ReflectiveInvocationContext invocationContext, + ExtensionContext extensionContext) throws Throwable { + proceed(invocation); + } + + @Override + public void interceptDynamicTest( + Invocation invocation, + DynamicTestInvocationContext invocationContext, + ExtensionContext extensionContext) throws Throwable { + proceed(invocation); + } + + private static void proceed(final Invocation invocation) throws Throwable { final AtomicReference throwable = new AtomicReference(); Thread thread = new FastThreadLocalThread(new Runnable() { @Override diff --git a/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtensionTest.java b/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtensionTest.java new file mode 100644 index 00000000000..a5f4b4fa9cb --- /dev/null +++ b/common/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtensionTest.java @@ -0,0 +1,45 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.concurrent.FastThreadLocalThread; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@ExtendWith(RunInFastThreadLocalThreadExtension.class) +public class RunInFastThreadLocalThreadExtensionTest { + @Test + void normalTest() { + assertInstanceOf(FastThreadLocalThread.class, Thread.currentThread()); + } + + @RepeatedTest(1) + void repeatedTest() { + assertInstanceOf(FastThreadLocalThread.class, Thread.currentThread()); + } + + @ParameterizedTest + @ValueSource(ints = 1) + void parameterizedTest(int ignoreParameter) { + assertInstanceOf(FastThreadLocalThread.class, Thread.currentThread()); + } +} diff --git a/common/src/test/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimapTest.java b/common/src/test/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimapTest.java new file mode 100644 index 00000000000..e3ffb84f785 --- /dev/null +++ b/common/src/test/java/io/netty/util/concurrent/ConcurrentSkipListIntObjMultimapTest.java @@ -0,0 +1,442 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.concurrent.ConcurrentSkipListIntObjMultimap.IntEntry; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.concurrent.ThreadLocalRandom; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class ConcurrentSkipListIntObjMultimapTest { + private ConcurrentSkipListIntObjMultimap map; + private int noKey; + + @BeforeEach + void setUp() { + noKey = -1; + map = new ConcurrentSkipListIntObjMultimap(noKey); + } + + @Test + void addIterateAndRemoveEntries() throws Exception { + assertFalse(map.iterator().hasNext()); + map.put(1, "a"); + map.put(2, "b"); + assertFalse(map.isEmpty()); + assertEquals(2, map.size()); + IntEntry entry; + Iterator> itr = map.iterator(); + assertTrue(itr.hasNext()); + entry = itr.next(); + itr.remove(); + assertEquals(new IntEntry(1, "a"), entry); + assertTrue(itr.hasNext()); + entry = itr.next(); + itr.remove(); + assertEquals(new IntEntry(2, "b"), entry); + assertFalse(itr.hasNext()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + } + + @Test + void clearMustRemoveAllEntries() throws Exception { + map.put(2, "b"); + map.put(1, "a"); + map.put(3, "c"); + assertEquals(3, map.size()); + map.clear(); + assertEquals(0, map.size()); + assertFalse(map.iterator().hasNext()); + assertTrue(map.isEmpty()); + } + + @Test + void pollingFirstEntryOfUniqueKeys() throws Exception { + map.put(2, "b"); + map.put(1, "a"); + map.put(3, "c"); + assertEquals(new IntEntry(1, "a"), map.pollFirstEntry()); + assertEquals(new IntEntry(2, "b"), map.pollFirstEntry()); + assertEquals(new IntEntry(3, "c"), map.pollFirstEntry()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + assertFalse(map.iterator().hasNext()); + } + + @Test + void pollingLastEntryOfUniqueKeys() throws Exception { + map.put(2, "b"); + map.put(1, "a"); + map.put(3, "c"); + assertEquals(new IntEntry(3, "c"), map.pollLastEntry()); + assertEquals(new IntEntry(2, "b"), map.pollLastEntry()); + assertEquals(new IntEntry(1, "a"), map.pollLastEntry()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + assertFalse(map.iterator().hasNext()); + } + + @Test + void addMultipleEntriesForSameKey() throws Exception { + map.put(2, "b1"); + map.put(1, "a"); + map.put(2, "b2"); // second entry for the 2 key + map.put(3, "c"); + assertEquals(4, map.size()); + + IntEntry entry; + Iterator> itr = map.iterator(); + assertTrue(itr.hasNext()); + entry = itr.next(); + itr.remove(); + assertEquals(new IntEntry(1, "a"), entry); + assertTrue(itr.hasNext()); + entry = itr.next(); + IntEntry otherB = entry; + itr.remove(); + assertThat(entry).isIn(new IntEntry(2, "b1"), new IntEntry(2, "b2")); + assertTrue(itr.hasNext()); + entry = itr.next(); + itr.remove(); + assertThat(entry).isIn(new IntEntry(2, "b1"), new IntEntry(2, "b2")); + assertNotEquals(otherB, entry); + assertTrue(itr.hasNext()); + entry = itr.next(); + itr.remove(); + assertEquals(new IntEntry(3, "c"), entry); + assertFalse(itr.hasNext()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void iteratorRemoveSecondOfMultiMappedEntry(boolean withPriorRemoval) throws Exception { + map.put(1, "a"); + map.put(1, "b"); + + Iterator> itr = map.iterator(); + itr.next(); + IntEntry entry = itr.next(); + if (withPriorRemoval) { + map.remove(entry.getKey(), entry.getValue()); + } + itr.remove(); + assertEquals(1, map.size()); + if (entry.equals(new IntEntry(1, "a"))) { + assertEquals(new IntEntry(1, "b"), map.pollFirstEntry()); + } else { + assertEquals(new IntEntry(1, "a"), map.pollFirstEntry()); + } + } + + @Test + void firstKeyOrEntry() throws Exception { + assertEquals(noKey, map.firstKey()); + assertNull(map.firstEntry()); + map.put(2, "b"); + assertEquals(2, map.firstKey()); + assertEquals(new IntEntry(2, "b"), map.firstEntry()); + map.put(3, "c"); + assertEquals(2, map.firstKey()); + assertEquals(new IntEntry(2, "b"), map.firstEntry()); + map.put(2, "b2"); + assertEquals(2, map.firstKey()); + assertThat(map.firstEntry()).isIn(new IntEntry(2, "b"), new IntEntry(2, "b2")); + map.put(1, "a"); + assertEquals(1, map.firstKey()); + assertEquals(new IntEntry(1, "a"), map.firstEntry()); + map.put(2, "b3"); + assertEquals(1, map.firstKey()); + assertEquals(new IntEntry(1, "a"), map.firstEntry()); + map.pollFirstEntry(); + assertEquals(2, map.firstKey()); + assertThat(map.firstEntry()).isIn( + new IntEntry(2, "b"), new IntEntry(2, "b2"), new IntEntry(2, "b3")); + } + + @Test + void lastKeyOrEntry() throws Exception { + assertEquals(noKey, map.lastKey()); + assertNull(map.lastEntry()); + map.put(2, "b"); + assertEquals(2, map.lastKey()); + assertEquals(new IntEntry(2, "b"), map.lastEntry()); + map.put(1, "a"); + assertEquals(2, map.lastKey()); + assertEquals(new IntEntry(2, "b"), map.lastEntry()); + map.put(2, "b2"); + assertEquals(2, map.lastKey()); + assertThat(map.lastEntry()).isIn(new IntEntry(2, "b"), new IntEntry(2, "b2")); + map.put(3, "c"); + assertEquals(3, map.lastKey()); + assertEquals(new IntEntry(3, "c"), map.lastEntry()); + map.put(2, "b3"); + assertEquals(3, map.lastKey()); + assertEquals(new IntEntry(3, "c"), map.lastEntry()); + map.pollLastEntry(); + assertEquals(2, map.lastKey()); + assertThat(map.lastEntry()).isIn( + new IntEntry(2, "b"), new IntEntry(2, "b2"), new IntEntry(2, "b3")); + } + + @RepeatedTest(100) + void firstLastKeyOrEntry() throws Exception { + int[] xs = new int[50]; + for (int i = 0; i < xs.length; i++) { + int key = ThreadLocalRandom.current().nextInt(50); + map.put(key, "a"); + xs[i] = key; + } + Arrays.sort(xs); + assertEquals(xs[0], map.firstKey()); + assertEquals(new IntEntry(xs[0], "a"), map.firstEntry()); + assertEquals(xs[xs.length - 1], map.lastKey()); + assertEquals(new IntEntry(xs[xs.length - 1], "a"), map.lastEntry()); + } + + @SuppressWarnings("unchecked") + @RepeatedTest(100) + void lowerEntryOrKey() { + IntEntry[] xs = new IntEntry[50]; + for (int i = 0; i < xs.length; i++) { + int key = ThreadLocalRandom.current().nextInt(50); + xs[i] = new IntEntry(key, String.valueOf(key)); + map.put(key, xs[i].getValue()); + } + Arrays.sort(xs); + for (int i = 0; i < 10; i++) { + IntEntry target = xs[ThreadLocalRandom.current().nextInt(xs.length)]; + IntEntry expected = null; + for (IntEntry x : xs) { + if (x.compareTo(target) < 0) { + expected = x; + } else { + break; + } + } + assertEquals(expected, map.lowerEntry(target.getKey())); + assertEquals(expected == null ? noKey : expected.getKey(), map.lowerKey(target.getKey())); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void lowerEntryOrKeyMismatch(boolean multiMapped) throws Exception { + map.put(1, "a"); + map.put(3, "b"); + map.put(4, "c"); + if (multiMapped) { + map.put(1, "a"); + map.put(3, "b"); + map.put(4, "c"); + } + assertEquals(1, map.lowerKey(3)); + assertEquals(new IntEntry(1, "a"), map.lowerEntry(3)); + assertEquals(3, map.lowerKey(4)); + assertEquals(new IntEntry(3, "b"), map.lowerEntry(4)); + assertEquals(noKey, map.lowerKey(1)); + assertNull(map.lowerEntry(1)); + } + + @SuppressWarnings("unchecked") + @RepeatedTest(100) + void floorEntryOrKey() { + IntEntry[] xs = new IntEntry[50]; + for (int i = 0; i < xs.length; i++) { + int key = ThreadLocalRandom.current().nextInt(50); + xs[i] = new IntEntry(key, String.valueOf(key)); + map.put(key, xs[i].getValue()); + } + Arrays.sort(xs); + for (int i = 0; i < 10; i++) { + IntEntry target = xs[ThreadLocalRandom.current().nextInt(xs.length)]; + IntEntry expected = null; + for (IntEntry x : xs) { + if (x.compareTo(target) <= 0) { + expected = x; + } else { + break; + } + } + assertEquals(expected, map.floorEntry(target.getKey())); + assertEquals(expected == null ? noKey : expected.getKey(), map.floorKey(target.getKey())); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void floorEntryOrKeyMismatch(boolean multiMapped) throws Exception { + map.put(1, "a"); + map.put(3, "b"); + map.put(4, "c"); + if (multiMapped) { + map.put(1, "a"); + map.put(3, "b"); + map.put(4, "c"); + } + assertEquals(1, map.floorKey(2)); + assertEquals(new IntEntry(1, "a"), map.floorEntry(2)); + assertEquals(3, map.floorKey(3)); + assertEquals(new IntEntry(3, "b"), map.floorEntry(3)); + } + + @SuppressWarnings("unchecked") + @RepeatedTest(100) + void ceilEntryOrKey() { + IntEntry[] xs = new IntEntry[50]; + for (int i = 0; i < xs.length; i++) { + int key = ThreadLocalRandom.current().nextInt(50); + xs[i] = new IntEntry(key, String.valueOf(key)); + map.put(key, xs[i].getValue()); + } + Arrays.sort(xs); + for (int i = 0; i < 10; i++) { + IntEntry target = xs[ThreadLocalRandom.current().nextInt(xs.length)]; + IntEntry expected = null; + for (IntEntry x : xs) { + if (x.compareTo(target) >= 0) { + expected = x; + break; + } + } + assertEquals(expected, map.ceilingEntry(target.getKey())); + assertEquals(expected == null ? noKey : expected.getKey(), map.ceilingKey(target.getKey())); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void ceilEntryOrKeyMismatch(boolean multiMapped) throws Exception { + map.put(1, "a"); + map.put(2, "b"); + map.put(4, "c"); + if (multiMapped) { + map.put(1, "a"); + map.put(2, "b"); + map.put(4, "c"); + } + assertEquals(2, map.ceilingKey(2)); + assertEquals(new IntEntry(2, "b"), map.ceilingEntry(2)); + assertEquals(4, map.ceilingKey(3)); + assertEquals(new IntEntry(4, "c"), map.ceilingEntry(3)); + } + + @SuppressWarnings("unchecked") + @RepeatedTest(100) + void higherEntryOrKey() { + IntEntry[] xs = new IntEntry[50]; + for (int i = 0; i < xs.length; i++) { + int key = ThreadLocalRandom.current().nextInt(50); + xs[i] = new IntEntry(key, String.valueOf(key)); + map.put(key, xs[i].getValue()); + } + Arrays.sort(xs); + for (int i = 0; i < 10; i++) { + IntEntry target = xs[ThreadLocalRandom.current().nextInt(xs.length)]; + IntEntry expected = null; + for (IntEntry x : xs) { + if (x.compareTo(target) > 0) { + expected = x; + break; + } + } + assertEquals(expected, map.higherEntry(target.getKey())); + assertEquals(expected == null ? noKey : expected.getKey(), map.higherKey(target.getKey())); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void higherEntryOrKeyMismatch(boolean multiMapped) throws Exception { + map.put(1, "a"); + map.put(2, "b"); + map.put(4, "c"); + if (multiMapped) { + map.put(1, "a"); + map.put(2, "b"); + map.put(4, "c"); + } + assertEquals(4, map.higherKey(2)); + assertEquals(new IntEntry(4, "c"), map.higherEntry(2)); + assertEquals(4, map.higherKey(3)); + assertEquals(new IntEntry(4, "c"), map.higherEntry(3)); + assertEquals(noKey, map.higherKey(4)); + assertNull(map.higherEntry(4)); + } + + @Test + void pollingFirstEntryOfMultiMappedKeys() throws Exception { + map.put(2, "b"); + map.put(1, "a"); + map.put(2, "b"); + map.put(3, "c"); + assertEquals(new IntEntry(1, "a"), map.pollFirstEntry()); + assertEquals(new IntEntry(2, "b"), map.pollFirstEntry()); + assertEquals(new IntEntry(2, "b"), map.pollFirstEntry()); + assertEquals(new IntEntry(3, "c"), map.pollFirstEntry()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + assertFalse(map.iterator().hasNext()); + } + + @Test + void pollingLastEntryOfMultiMappedKeys() throws Exception { + map.put(2, "b"); + map.put(1, "a"); + map.put(2, "b"); + map.put(3, "c"); + assertEquals(new IntEntry(3, "c"), map.pollLastEntry()); + assertEquals(new IntEntry(2, "b"), map.pollLastEntry()); + assertEquals(new IntEntry(2, "b"), map.pollLastEntry()); + assertEquals(new IntEntry(1, "a"), map.pollLastEntry()); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + assertFalse(map.iterator().hasNext()); + } + + @Test + void pollCeilingEntry() throws Exception { + map.put(1, "a"); + map.put(2, "b"); + map.put(2, "b"); + map.put(3, "c"); + map.put(4, "d"); + map.put(4, "d"); + assertEquals(new IntEntry(2, "b"), map.pollCeilingEntry(2)); + assertEquals(new IntEntry(2, "b"), map.pollCeilingEntry(2)); + assertEquals(new IntEntry(3, "c"), map.pollCeilingEntry(2)); + assertEquals(new IntEntry(4, "d"), map.pollCeilingEntry(2)); + assertEquals(new IntEntry(4, "d"), map.pollCeilingEntry(2)); + assertNull(map.pollCeilingEntry(2)); + assertFalse(map.isEmpty()); + assertEquals(1, map.size()); + } +} diff --git a/common/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java b/common/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java index 13327117e0d..35f9af75806 100644 --- a/common/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java +++ b/common/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java @@ -121,65 +121,6 @@ public void run() { } } - @Test - public void testMultipleSetRemove() throws Exception { - final FastThreadLocal threadLocal = new FastThreadLocal(); - final Runnable runnable = new Runnable() { - @Override - public void run() { - threadLocal.set("1"); - threadLocal.remove(); - threadLocal.set("2"); - threadLocal.remove(); - } - }; - - final int sizeWhenStart = ObjectCleaner.getLiveSetCount(); - Thread thread = new Thread(runnable); - thread.start(); - thread.join(); - - assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); - - Thread thread2 = new Thread(runnable); - thread2.start(); - thread2.join(); - - assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); - } - - @Test - public void testMultipleSetRemove_multipleThreadLocal() throws Exception { - final FastThreadLocal threadLocal = new FastThreadLocal(); - final FastThreadLocal threadLocal2 = new FastThreadLocal(); - final Runnable runnable = new Runnable() { - @Override - public void run() { - threadLocal.set("1"); - threadLocal.remove(); - threadLocal.set("2"); - threadLocal.remove(); - threadLocal2.set("1"); - threadLocal2.remove(); - threadLocal2.set("2"); - threadLocal2.remove(); - } - }; - - final int sizeWhenStart = ObjectCleaner.getLiveSetCount(); - Thread thread = new Thread(runnable); - thread.start(); - thread.join(); - - assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); - - Thread thread2 = new Thread(runnable); - thread2.start(); - thread2.join(); - - assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); - } - @Test @Timeout(value = 4000, unit = TimeUnit.MILLISECONDS) public void testOnRemoveCalledForFastThreadLocalGet() throws Exception { diff --git a/common/src/test/java/io/netty/util/concurrent/MpscIntQueueTest.java b/common/src/test/java/io/netty/util/concurrent/MpscIntQueueTest.java new file mode 100644 index 00000000000..f11003be929 --- /dev/null +++ b/common/src/test/java/io/netty/util/concurrent/MpscIntQueueTest.java @@ -0,0 +1,43 @@ +/* + * Copyright 2025 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.IntSupplier; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class MpscIntQueueTest { + @ParameterizedTest + @ValueSource(ints = {1, 7, 8, 15, 16, 17}) + void mustFillWithSpecifiedEmptyEntry(int size) throws Exception { + MpscIntQueue queue = new MpscAtomicIntegerArrayQueue(size, -1); + int filled = queue.fill(size, new IntSupplier() { + @Override + public int get() throws Exception { + return 42; + } + }); + assertEquals(size, filled); + for (int i = 0; i < size; i++) { + assertEquals(42, queue.poll()); + } + assertEquals(-1, queue.poll()); + assertTrue(queue.isEmpty()); + } +} diff --git a/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java b/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java index aedd77e9d25..7e3e4784920 100644 --- a/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java +++ b/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java @@ -24,13 +24,17 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -129,6 +133,153 @@ public void run() { } } + @Test + public void testInEventLoopAfterReschedulingFailure() throws Exception { + final UnorderedThreadPoolEventExecutor underlying = new UnorderedThreadPoolEventExecutor(1); + final AtomicInteger executeCount = new AtomicInteger(); + + final EventExecutorGroup wrapper = new AbstractEventExecutorGroup() { + @Override + public void shutdown() { + shutdownGracefully(); + } + + private final EventExecutor executor = new AbstractEventExecutor(this) { + @Override + public boolean inEventLoop(Thread thread) { + return underlying.inEventLoop(thread); + } + + @Override + public void shutdown() { + shutdownGracefully(); + } + + @Override + public void execute(Runnable command) { + // Reject the 2nd execute() call (the reschedule attempt) + // 1st call: initial task submission + // 2nd call: reschedule after maxTaskExecutePerRun + if (executeCount.incrementAndGet() == 2) { + throw new RejectedExecutionException("Simulated queue full"); + } + underlying.execute(command); + } + + @Override + public boolean isShuttingDown() { + return underlying.isShuttingDown(); + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + return underlying.shutdownGracefully(quietPeriod, timeout, unit); + } + + @Override + public Future terminationFuture() { + return underlying.terminationFuture(); + } + + @Override + public boolean isShutdown() { + return underlying.isShutdown(); + } + + @Override + public boolean isTerminated() { + return underlying.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return underlying.awaitTermination(timeout, unit); + } + }; + + @Override + public EventExecutor next() { + return executor; + } + + @Override + public Iterator iterator() { + return Collections.singletonList(executor).iterator(); + } + + @Override + public boolean isShuttingDown() { + return underlying.isShuttingDown(); + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + return underlying.shutdownGracefully(quietPeriod, timeout, unit); + } + + @Override + public Future terminationFuture() { + return underlying.terminationFuture(); + } + + @Override + public boolean isShutdown() { + return underlying.isShutdown(); + } + + @Override + public boolean isTerminated() { + return underlying.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return underlying.awaitTermination(timeout, unit); + } + }; + + // Use maxTaskExecutePerRun=1 so reschedule happens after first task + NonStickyEventExecutorGroup nonStickyGroup = new NonStickyEventExecutorGroup(wrapper, 1); + + try { + final EventExecutor executor = nonStickyGroup.next(); + + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference inEventLoopResult = new AtomicReference(); + + // Submit 2 tasks: + // Task 1: completes, triggers reschedule which will be rejected + // Task 2: verifies inEventLoop() still works after failed reschedule + executor.execute(new Runnable() { + @Override + public void run() { + // First task - will trigger reschedule attempt that fails + } + }); + + executor.execute(new Runnable() { + @Override + public void run() { + // This runs AFTER the failed rescheduling + // WITHOUT line 262 fix: executingThread is null, inEventLoop() returns false + // WITH line 262 fix: executingThread restored, inEventLoop() returns true + inEventLoopResult.set(executor.inEventLoop()); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS), "Tasks should complete"); + Boolean result = inEventLoopResult.get(); + assertNotNull(result, "inEventLoop() should have been called"); + assertTrue(result, + "inEventLoop() should return true even after failed reschedule attempt. " + + "This indicates executingThread was properly restored in the exception handler."); + } finally { + nonStickyGroup.shutdownGracefully(); + underlying.shutdownGracefully(); + } + } + private static void execute(EventExecutorGroup group, CountDownLatch startLatch) throws Throwable { final EventExecutor executor = group.next(); assertTrue(executor instanceof OrderedEventExecutor); diff --git a/dev-tools/pom.xml b/dev-tools/pom.xml index f587cce722c..aa6975bd069 100644 --- a/dev-tools/pom.xml +++ b/dev-tools/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-dev-tools diff --git a/docker-datastax-release.sh b/docker-datastax-release.sh index 77e8938c251..61cbbee8b5c 100755 --- a/docker-datastax-release.sh +++ b/docker-datastax-release.sh @@ -12,4 +12,4 @@ if ! which docker > /dev/null ; then fi sudo docker build -f docker/Dockerfile-netty-centos6 -t netty-centos6 . -sudo docker run -t --network host -v ~/.m2:/root/.m2:Z -v ~/.ssh:/root/.ssh:Z -v ~/.gnupg:/root/.gnupg:Z -v `pwd`:/code:Z -w /code --entrypoint="" netty-centos6 bash -ic "./mvnw -B clean deploy -Partifactory -DskipTests -DaltDeploymentRepository=\"artifactory::default::https://repo.aws.dsinternal.org/artifactory/datastax-releases-local\"" +sudo docker run -t --network host -v ~/.m2:/root/.m2:Z -v ~/.ssh:/root/.ssh:Z -v ~/.gnupg:/root/.gnupg:Z -v `pwd`:/code:Z -w /code --entrypoint="" netty-centos6 bash -ic "./mvnw -B clean deploy -Partifactory -DskipTests -DaltDeploymentRepository=\"artifactory::default::https://maven.pkg.github.com/riptano/netty\"" diff --git a/docker/Dockerfile-netty-centos6 b/docker/Dockerfile-netty-centos6 index e330266e5fc..cb02f13d282 100644 --- a/docker/Dockerfile-netty-centos6 +++ b/docker/Dockerfile-netty-centos6 @@ -26,7 +26,7 @@ RUN yum install -y \ # Downloading and installing SDKMAN! RUN curl -s "https://get.sdkman.io" | bash -ARG java_version="8.0.302-zulu" +ARG java_version="8.0.482-zulu" ENV JAVA_VERSION $java_version # Installing Java removing some unnecessary SDKMAN files diff --git a/docker/Dockerfile.al2023 b/docker/Dockerfile.al2023 new file mode 100644 index 00000000000..06c3eedb81f --- /dev/null +++ b/docker/Dockerfile.al2023 @@ -0,0 +1,70 @@ +FROM --platform=linux/amd64 amazonlinux:2023 + +ARG java_version=11.0.30-amzn +ARG aws_lc_version=v1.54.0 +ARG maven_version=3.9.10 +ENV JAVA_VERSION $java_version +ENV AWS_LC_VERSION $aws_lc_version +ENV MAVEN_VERSION $maven_version + +# install dependencies +RUN dnf install -y \ + apr-devel \ + autoconf \ + automake \ + bzip2 \ + cmake \ + gcc \ + gcc-c++ \ + git \ + glibc-devel \ + golang \ + libgcc \ + libstdc++ \ + libstdc++-devel \ + libstdc++-static \ + libtool \ + make \ + ninja-build \ + patch \ + perl \ + perl-parent \ + perl-devel \ + tar \ + unzip \ + wget \ + which \ + zip + +# Downloading and installing SDKMAN! +RUN curl -s "https://get.sdkman.io" | bash + +# Installing Java removing some unnecessary SDKMAN files +RUN bash -c "source $HOME/.sdkman/bin/sdkman-init.sh && \ + yes | sdk install java $JAVA_VERSION && \ + yes | sdk install maven $MAVEN_VERSION && \ + rm -rf $HOME/.sdkman/archives/* && \ + rm -rf $HOME/.sdkman/tmp/*" + +RUN echo 'export JAVA_HOME="/root/.sdkman/candidates/java/current"' >> ~/.bashrc +RUN echo 'export PATH=$JAVA_HOME/bin:$PATH' >> ~/.bashrc + +ENV PATH /root/.sdkman/candidates/java/current/bin:/root/.sdkman/candidates/maven/current/bin:$PATH +ENV JAVA_HOME=/root/.sdkman/candidates/java/current + +# install rust and setup PATH +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y +RUN echo 'PATH=$PATH:$HOME/.cargo/bin' >> ~/.bashrc + +RUN mkdir "$HOME/sources" && \ + git clone https://github.com/aws/aws-lc.git "$HOME/sources/aws-lc" && \ + cd "$HOME/sources/aws-lc" && \ + git checkout $AWS_LC_VERSION && \ + cmake -B build -S . -DCMAKE_INSTALL_PREFIX=/opt/aws-lc -DBUILD_SHARED_LIBS=1 -DBUILD_TESTING=0 && \ + cmake --build build -- -j && \ + cmake --install build + +# Cleanup +RUN dnf clean all && \ + rm -rf /var/cache/dnf && \ + rm -rf "$HOME/sources" diff --git a/docker/Dockerfile.cross_compile_aarch64 b/docker/Dockerfile.cross_compile_aarch64 index 8c1077c3f14..a5e21d20982 100644 --- a/docker/Dockerfile.cross_compile_aarch64 +++ b/docker/Dockerfile.cross_compile_aarch64 @@ -1,10 +1,9 @@ FROM --platform=linux/amd64 centos:7.6.1810 -ARG gcc_version=10.2-2020.11 +ARG gcc_version=10.3-2021.07 ENV GCC_VERSION $gcc_version ENV SOURCE_DIR /root/source - # Update to use the vault RUN sed -i -e 's/^mirrorlist/#mirrorlist/g' -e 's/^#baseurl=http:\/\/mirror.centos.org\/centos\/$releasever\//baseurl=https:\/\/linuxsoft.cern.ch\/centos-vault\/\/7.6.1810\//g' /etc/yum.repos.d/CentOS-Base.repo diff --git a/docker/docker-compose.al2023.yaml b/docker/docker-compose.al2023.yaml new file mode 100644 index 00000000000..ae8e9b4106d --- /dev/null +++ b/docker/docker-compose.al2023.yaml @@ -0,0 +1,65 @@ +services: + + runtime-setup: + image: netty-al2023:x86_64 + build: + context: ../ + dockerfile: docker/Dockerfile.al2023 + + common: &common + image: netty-al2023:x86_64 + depends_on: [runtime-setup] + environment: + LD_LIBRARY_PATH: /opt/aws-lc/lib64 + volumes: + # Use a separate directory for the AL2023 Maven repository + - ~/.m2-al2023:/root/.m2 + - ..:/netty + - ../../netty-tcnative:/netty-tcnative + working_dir: /netty + + common-tcnative: &common-tcnative + <<: *common + environment: + MAVEN_OPTS: + LD_LIBRARY_PATH: /opt/aws-lc/lib64 + LDFLAGS: -L/opt/aws-lc/lib64 -lssl -lcrypto + CFLAGS: -I/opt/aws-lc/include -DHAVE_OPENSSL -lssl -lcrypto + CXXFLAGS: -I/opt/aws-lc/include -DHAVE_OPENSSL -lssl -lcrypto + + install-tcnative: + <<: *common-tcnative + command: '/bin/bash -cl " + ./mvnw -am -pl openssl-dynamic clean install && + env -u LDFLAGS -u CFLAGS -u CXXFLAGS -u LD_LIBRARY_PATH ./mvnw -am -pl boringssl-static clean install + "' + working_dir: /netty-tcnative + + update-tcnative-version: + <<: *common + command: '/bin/bash -cl " + ./mvnw versions:update-property -Dproperty=tcnative.version -DnewVersion=$(cd /netty-tcnative && ./mvnw help:evaluate -Dexpression=project.version -q -DforceStdout) -DallowSnapshots=true -DprocessParent=true -DgenerateBackupPoms=false + "' + + build: + <<: *common + command: '/bin/bash -cl " + ./mvnw -B -ntp clean install -Dio.netty.testsuite.badHost=netty.io -Dtcnative.classifier=linux-x86_64-fedora -Drevapi.skip=true -Dcheckstyle.skip=true -Dforbiddenapis.skip=true + "' + + build-leak: + <<: *common + command: '/bin/bash -cl " + ./mvnw -B -ntp -Pleak clean install -Dio.netty.testsuite.badHost=netty.io -Dtcnative.classifier=linux-x86_64-fedora -Drevapi.skip=true -Dcheckstyle.skip=true -Dforbiddenapis.skip=true + "' + + shell: + <<: *common + volumes: + - ~/.m2-al2023:/root/.m2 + - ~/.gitconfig:/root/.gitconfig + - ~/.gitignore:/root/.gitignore + - ..:/netty + - ../../netty-tcnative:/netty-tcnative + working_dir: /netty + entrypoint: /bin/bash -l diff --git a/docker/docker-compose.centos-6.111.yaml b/docker/docker-compose.centos-6.111.yaml index 5ef7aecc48d..28b01f82ea5 100644 --- a/docker/docker-compose.centos-6.111.yaml +++ b/docker/docker-compose.centos-6.111.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-6-1.11 build: args: - java_version : "11.0.28-zulu" + java_version : "11.0.30-zulu" build: image: netty:centos-6-1.11 diff --git a/docker/docker-compose.centos-6.18.yaml b/docker/docker-compose.centos-6.18.yaml index ee132f5ca28..ecf9eaae22b 100644 --- a/docker/docker-compose.centos-6.18.yaml +++ b/docker/docker-compose.centos-6.18.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-6-1.8 build: args: - java_version : "8.0.462-zulu" + java_version : "8.0.482-zulu" build: image: netty:centos-6-1.8 diff --git a/docker/docker-compose.centos-6.21.yaml b/docker/docker-compose.centos-6.21.yaml index 35a8f3b7707..dc0bf62a6a4 100644 --- a/docker/docker-compose.centos-6.21.yaml +++ b/docker/docker-compose.centos-6.21.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-6-21 build: args: - java_version : "21.0.8-zulu" + java_version : "21.0.10-zulu" build: image: netty:centos-6-21 diff --git a/docker/docker-compose.centos-6.24.yaml b/docker/docker-compose.centos-6.24.yaml index 8646af72da3..de70f61a624 100644 --- a/docker/docker-compose.centos-6.24.yaml +++ b/docker/docker-compose.centos-6.24.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-6-24 build: args: - java_version : "24.0.1-zulu" + java_version : "24.0.2-zulu" build: image: netty:centos-6-24 diff --git a/docker/docker-compose.centos-6.25.yaml b/docker/docker-compose.centos-6.25.yaml index 07ee2ba8ed1..e7b2cff3cf4 100644 --- a/docker/docker-compose.centos-6.25.yaml +++ b/docker/docker-compose.centos-6.25.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-6-25 build: args: - java_version : "25-zulu" + java_version : "25.0.2-zulu" build: image: netty:centos-6-25 diff --git a/docker/docker-compose.centos-7.117.yaml b/docker/docker-compose.centos-7.117.yaml index 411ef802512..464e7082fb8 100644 --- a/docker/docker-compose.centos-7.117.yaml +++ b/docker/docker-compose.centos-7.117.yaml @@ -6,7 +6,7 @@ services: image: netty:centos-7-1.17 build: args: - java_version : "17.0.16-zulu" + java_version : "17.0.18-zulu" build: image: netty:centos-7-1.17 diff --git a/docker/docker-compose.centos-7.yaml b/docker/docker-compose.centos-7.yaml index 6c0facb0652..14437428a34 100644 --- a/docker/docker-compose.centos-7.yaml +++ b/docker/docker-compose.centos-7.yaml @@ -8,8 +8,8 @@ services: context: ../ dockerfile: docker/Dockerfile.cross_compile_aarch64 args: - gcc_version: "10.2-2020.11" - java_version: "8.0.462-zulu" + gcc_version: "10.3-2021.07" + java_version: "8.0.482-zulu" cross-compile-aarch64-common: &cross-compile-aarch64-common depends_on: [ cross-compile-aarch64-runtime-setup ] diff --git a/example/pom.xml b/example/pom.xml index 682c3cb3848..2602786bf1c 100644 --- a/example/pom.xml +++ b/example/pom.xml @@ -21,7 +21,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-example @@ -32,6 +32,7 @@ true io.netty.example + true diff --git a/example/src/main/java/io/netty/example/discard/DiscardClient.java b/example/src/main/java/io/netty/example/discard/DiscardClient.java index 32562f89ae5..160c317aa4e 100644 --- a/example/src/main/java/io/netty/example/discard/DiscardClient.java +++ b/example/src/main/java/io/netty/example/discard/DiscardClient.java @@ -23,21 +23,29 @@ import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; -import io.netty.example.util.ServerUtil; import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; /** * Keeps sending random data to the specified address. */ public final class DiscardClient { + static final boolean SSL = System.getProperty("ssl") != null; static final String HOST = System.getProperty("host", "127.0.0.1"); static final int PORT = Integer.parseInt(System.getProperty("port", "8009")); static final int SIZE = Integer.parseInt(System.getProperty("size", "256")); public static void main(String[] args) throws Exception { // Configure SSL. - final SslContext sslCtx = ServerUtil.buildSslContext(); + final SslContext sslCtx; + if (SSL) { + sslCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE).build(); + } else { + sslCtx = null; + } EventLoopGroup group = new NioEventLoopGroup(); try { diff --git a/example/src/main/java/io/netty/example/redis/RedisClient.java b/example/src/main/java/io/netty/example/redis/RedisClient.java index 50718b11243..3a116778e01 100644 --- a/example/src/main/java/io/netty/example/redis/RedisClient.java +++ b/example/src/main/java/io/netty/example/redis/RedisClient.java @@ -52,7 +52,7 @@ protected void initChannel(SocketChannel ch) throws Exception { ChannelPipeline p = ch.pipeline(); p.addLast(new RedisDecoder()); p.addLast(new RedisBulkStringAggregator()); - p.addLast(new RedisArrayAggregator()); + p.addLast(new RedisArrayAggregator(1000000, 1024)); p.addLast(new RedisEncoder()); p.addLast(new RedisClientHandler()); } diff --git a/handler-proxy/pom.xml b/handler-proxy/pom.xml index ebc56a14add..5ae4d04da33 100644 --- a/handler-proxy/pom.xml +++ b/handler-proxy/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-handler-proxy diff --git a/handler-proxy/src/main/java/io/netty/handler/proxy/HttpProxyHandler.java b/handler-proxy/src/main/java/io/netty/handler/proxy/HttpProxyHandler.java index f22abe4232b..e36fe3d7142 100644 --- a/handler-proxy/src/main/java/io/netty/handler/proxy/HttpProxyHandler.java +++ b/handler-proxy/src/main/java/io/netty/handler/proxy/HttpProxyHandler.java @@ -70,6 +70,7 @@ public final class HttpProxyHandler extends ProxyHandler { private final CharSequence authorization; private final HttpHeaders outboundHeaders; private final boolean ignoreDefaultPortsInConnectHostHeader; + private final boolean validateInitialHeaders; private HttpResponseStatus status; private HttpHeaders inboundHeaders; @@ -84,12 +85,20 @@ public HttpProxyHandler(SocketAddress proxyAddress, HttpHeaders headers) { public HttpProxyHandler(SocketAddress proxyAddress, HttpHeaders headers, boolean ignoreDefaultPortsInConnectHostHeader) { + this(proxyAddress, headers, ignoreDefaultPortsInConnectHostHeader, true); + } + + public HttpProxyHandler(SocketAddress proxyAddress, + HttpHeaders headers, + boolean ignoreDefaultPortsInConnectHostHeader, + boolean validateInitialHeaders) { super(proxyAddress); username = null; password = null; authorization = null; this.outboundHeaders = headers; this.ignoreDefaultPortsInConnectHostHeader = ignoreDefaultPortsInConnectHostHeader; + this.validateInitialHeaders = validateInitialHeaders; } public HttpProxyHandler(SocketAddress proxyAddress, String username, String password) { @@ -98,7 +107,7 @@ public HttpProxyHandler(SocketAddress proxyAddress, String username, String pass public HttpProxyHandler(SocketAddress proxyAddress, String username, String password, HttpHeaders headers) { - this(proxyAddress, username, password, headers, false); + this(proxyAddress, username, password, headers, false, true); } public HttpProxyHandler(SocketAddress proxyAddress, @@ -106,6 +115,15 @@ public HttpProxyHandler(SocketAddress proxyAddress, String password, HttpHeaders headers, boolean ignoreDefaultPortsInConnectHostHeader) { + this(proxyAddress, username, password, headers, ignoreDefaultPortsInConnectHostHeader, true); + } + + public HttpProxyHandler(SocketAddress proxyAddress, + String username, + String password, + HttpHeaders headers, + boolean ignoreDefaultPortsInConnectHostHeader, + boolean validateInitialHeaders) { super(proxyAddress); this.username = ObjectUtil.checkNotNull(username, "username"); this.password = ObjectUtil.checkNotNull(password, "password"); @@ -125,6 +143,7 @@ public HttpProxyHandler(SocketAddress proxyAddress, this.outboundHeaders = headers; this.ignoreDefaultPortsInConnectHostHeader = ignoreDefaultPortsInConnectHostHeader; + this.validateInitialHeaders = validateInitialHeaders; } @Override @@ -173,7 +192,8 @@ protected Object newInitialMessage(ChannelHandlerContext ctx) throws Exception { hostString : url; - HttpHeadersFactory headersFactory = DefaultHttpHeadersFactory.headersFactory().withValidation(false); + HttpHeadersFactory headersFactory = DefaultHttpHeadersFactory.headersFactory() + .withValidation(validateInitialHeaders); FullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.CONNECT, url, diff --git a/handler-proxy/src/test/java/io/netty/handler/proxy/HttpProxyHandlerTest.java b/handler-proxy/src/test/java/io/netty/handler/proxy/HttpProxyHandlerTest.java index bb9571768fe..0346f1b53da 100644 --- a/handler-proxy/src/test/java/io/netty/handler/proxy/HttpProxyHandlerTest.java +++ b/handler-proxy/src/test/java/io/netty/handler/proxy/HttpProxyHandlerTest.java @@ -43,6 +43,8 @@ import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import java.net.InetAddress; import java.net.InetSocketAddress; @@ -51,6 +53,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.*; public class HttpProxyHandlerTest { @@ -175,6 +178,29 @@ public void testCustomHeaders() throws Exception { true); } + @ParameterizedTest + @ValueSource(booleans = { true, false }) + public void testInvalidHeaders(boolean validation) throws Exception { + InetSocketAddress socketAddress = InetSocketAddress.createUnresolved("10.0.0.1", 8080); + try { + testInitialMessage( + socketAddress, + "10.0.0.1:8080", + "10.0.0.1:8080", + new DefaultHttpHeaders(false) + .add("CUSTOM_HEADER", "CUSTOM_VALUE1\r\nInvalid: true") + .add("CUSTOM_HEADER", "CUSTOM_VALUE2"), + true, validation); + if (validation) { + fail("Validation should have failed for the provided headers"); + } + } catch (IllegalArgumentException e) { + if (!validation) { + throw e; + } + } + } + @Test public void testExceptionDuringConnect() throws Exception { EventLoopGroup group = null; @@ -239,6 +265,16 @@ private static void testInitialMessage(InetSocketAddress socketAddress, String expectedHostHeader, HttpHeaders headers, boolean ignoreDefaultPortsInConnectHostHeader) throws Exception { + testInitialMessage(socketAddress, expectedUrl, expectedHostHeader, headers, + ignoreDefaultPortsInConnectHostHeader, true); + } + + private static void testInitialMessage(InetSocketAddress socketAddress, + String expectedUrl, + String expectedHostHeader, + HttpHeaders headers, + boolean ignoreDefaultPortsInConnectHostHeader, + boolean validateInitialHeaders) throws Exception { InetSocketAddress proxyAddress = new InetSocketAddress(NetUtil.LOCALHOST, 8080); ChannelPromise promise = mock(ChannelPromise.class); @@ -250,7 +286,7 @@ private static void testInitialMessage(InetSocketAddress socketAddress, HttpProxyHandler handler = new HttpProxyHandler( new InetSocketAddress(NetUtil.LOCALHOST, 8080), headers, - ignoreDefaultPortsInConnectHostHeader); + ignoreDefaultPortsInConnectHostHeader, validateInitialHeaders); handler.connect(ctx, socketAddress, null, promise); FullHttpRequest request = (FullHttpRequest) handler.newInitialMessage(ctx); diff --git a/handler-ssl-ocsp/pom.xml b/handler-ssl-ocsp/pom.xml index d7e0056466e..49d5cd0556d 100644 --- a/handler-ssl-ocsp/pom.xml +++ b/handler-ssl-ocsp/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-handler-ssl-ocsp diff --git a/handler/pom.xml b/handler/pom.xml index e0011af86f8..39beace8d23 100644 --- a/handler/pom.xml +++ b/handler/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-handler diff --git a/handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilter.java b/handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilter.java index 4a29abc6197..982fdf347db 100644 --- a/handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilter.java +++ b/handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilter.java @@ -21,6 +21,7 @@ import io.netty.util.internal.ObjectUtil; import java.net.Inet4Address; +import java.net.Inet6Address; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.ArrayList; @@ -175,7 +176,7 @@ protected boolean accept(ChannelHandlerContext ctx, InetSocketAddress remoteAddr return ipFilterRuleTypeIPv4 == IpFilterRuleType.ACCEPT; } } - } else if (ipv6Rules != null) { + } else if (ipv6Rules != null && remoteAddress.getAddress() instanceof Inet6Address) { int indexOf = Arrays.binarySearch(ipv6Rules, remoteAddress, IpSubnetFilterRuleComparator.INSTANCE); if (indexOf >= 0) { if (ipFilterRuleTypeIPv6 == null) { diff --git a/handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRule.java b/handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRule.java index 377222d8b70..428dec64149 100644 --- a/handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRule.java +++ b/handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRule.java @@ -149,7 +149,7 @@ int compareTo(InetSocketAddress inetSocketAddress) { Ip6SubnetFilterRule ip6SubnetFilterRule = (Ip6SubnetFilterRule) filterRule; return ip6SubnetFilterRule.networkAddress .compareTo(Ip6SubnetFilterRule.ipToInt((Inet6Address) inetSocketAddress.getAddress()) - .and(ip6SubnetFilterRule.networkAddress)); + .and(ip6SubnetFilterRule.subnetMask)); } } @@ -245,7 +245,7 @@ private static BigInteger ipToInt(Inet6Address ipAddress) { byte[] octets = ipAddress.getAddress(); assert octets.length == 16; - return new BigInteger(octets); + return new BigInteger(1, octets); } private static BigInteger prefixToSubnetMask(int cidrPrefix) { diff --git a/handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java b/handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java index 6265dc27fe1..32f624ccb62 100644 --- a/handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java +++ b/handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java @@ -28,6 +28,7 @@ import io.netty.channel.socket.ServerSocketChannel; import io.netty.channel.socket.SocketChannel; import io.netty.util.NetUtil; +import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -277,7 +278,12 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { // Initialize if needed if (state.get() == State.INIT) { - initializeIfNecessary(ctx); + try { + initializeIfNecessary(ctx); + } catch (Exception ex) { + ReferenceCountUtil.release(msg); + throw ex; + } } // Only write if State is STARTED @@ -297,7 +303,13 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { // Initialize if needed if (state.get() == State.INIT) { - initializeIfNecessary(ctx); + try { + initializeIfNecessary(ctx); + } catch (Exception ex) { + ReferenceCountUtil.release(msg); + promise.setFailure(ex); + return; + } } // Only write if State is STARTED diff --git a/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java b/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java index 5abf8a1b107..3b8f0e0713d 100644 --- a/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java @@ -121,6 +121,7 @@ private static String extractSniHostname(ByteBuf in) { return null; } + static final long DEFAULT_HANDSHAKE_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(10); protected final long handshakeTimeoutMillis; private ScheduledFuture timeoutFuture; private String hostname; @@ -129,7 +130,7 @@ private static String extractSniHostname(ByteBuf in) { * @param handshakeTimeoutMillis the handshake timeout in milliseconds */ protected AbstractSniHandler(long handshakeTimeoutMillis) { - this(0, handshakeTimeoutMillis); + this(DEFAULT_MAX_CLIENT_HELLO_LENGTH, handshakeTimeoutMillis); } /** @@ -142,7 +143,7 @@ protected AbstractSniHandler(int maxClientHelloLength, long handshakeTimeoutMill } public AbstractSniHandler() { - this(0, 0L); + this(DEFAULT_MAX_CLIENT_HELLO_LENGTH, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); } @Override diff --git a/handler/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java b/handler/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java index c2c3e9032a9..0807daea621 100644 --- a/handler/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java +++ b/handler/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java @@ -18,15 +18,22 @@ import io.netty.util.internal.SuppressJava6Requirement; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.X509ExtendedTrustManager; -import javax.net.ssl.X509TrustManager; import java.net.Socket; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Collection; import java.util.List; - +import javax.naming.ldap.LdapName; +import javax.naming.ldap.Rdn; +import javax.net.ssl.ExtendedSSLSession; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; /** * Wraps an existing {@link X509ExtendedTrustManager} and enhances the {@link CertificateException} that is thrown @@ -34,6 +41,13 @@ */ @SuppressJava6Requirement(reason = "Usage guarded by java version check") final class EnhancingX509ExtendedTrustManager extends X509ExtendedTrustManager { + + // Constants for subject alt names of type DNS and IP. See X509Certificate#getSubjectAlternativeNames() javadocs. + static final int ALTNAME_DNS = 2; + static final int ALTNAME_URI = 6; + static final int ALTNAME_IP = 7; + private static final String SEPARATOR = ", "; + private final X509ExtendedTrustManager wrapped; EnhancingX509ExtendedTrustManager(X509TrustManager wrapped) { @@ -52,7 +66,8 @@ public void checkServerTrusted(X509Certificate[] chain, String authType, Socket try { wrapped.checkServerTrusted(chain, authType, socket); } catch (CertificateException e) { - throwEnhancedCertificateException(chain, e); + throwEnhancedCertificateException(e, chain, + socket instanceof SSLSocket ? ((SSLSocket) socket).getHandshakeSession() : null); } } @@ -68,7 +83,7 @@ public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngi try { wrapped.checkServerTrusted(chain, authType, engine); } catch (CertificateException e) { - throwEnhancedCertificateException(chain, e); + throwEnhancedCertificateException(e, chain, engine != null ? engine.getHandshakeSession() : null); } } @@ -84,7 +99,7 @@ public void checkServerTrusted(X509Certificate[] chain, String authType) try { wrapped.checkServerTrusted(chain, authType); } catch (CertificateException e) { - throwEnhancedCertificateException(chain, e); + throwEnhancedCertificateException(e, chain, null); } } @@ -93,32 +108,91 @@ public X509Certificate[] getAcceptedIssuers() { return wrapped.getAcceptedIssuers(); } - private static void throwEnhancedCertificateException(X509Certificate[] chain, CertificateException e) - throws CertificateException { + private static void throwEnhancedCertificateException(CertificateException e, X509Certificate[] chain, + SSLSession session) throws CertificateException { // Matching the message is the best we can do sadly. String message = e.getMessage(); - if (message != null && e.getMessage().startsWith("No subject alternative DNS name matching")) { - StringBuilder names = new StringBuilder(64); + if (message != null && + (message.startsWith("No subject alternative") || message.startsWith("No name matching"))) { + StringBuilder sb = new StringBuilder(128); + sb.append(message); + // Some exception messages from sun.security.util.HostnameChecker may end with a dot that we don't need + if (message.charAt(message.length() - 1) == '.') { + sb.setLength(sb.length() - 1); + } + if (session != null) { + sb.append(" for SNIHostName=").append(getSNIHostName(session)) + .append(" and peerHost=").append(session.getPeerHost()); + } + sb.append(" in the chain of ").append(chain.length).append(" certificate(s):"); for (int i = 0; i < chain.length; i++) { X509Certificate cert = chain[i]; Collection> collection = cert.getSubjectAlternativeNames(); + sb.append(' ').append(i + 1).append(". subjectAlternativeNames=["); if (collection != null) { + boolean hasNames = false; for (List altNames : collection) { - // 2 is dNSName. See X509Certificate javadocs. - if (altNames.size() >= 2 && ((Integer) altNames.get(0)).intValue() == 2) { - names.append((String) altNames.get(1)).append(","); + if (altNames.size() < 2) { + // We expect at least a pair of 'nameType:value' in that list. + continue; + } + final int nameType = ((Integer) altNames.get(0)).intValue(); + if (nameType == ALTNAME_DNS) { + sb.append("DNS"); + } else if (nameType == ALTNAME_IP) { + sb.append("IP"); + } else if (nameType == ALTNAME_URI) { + // URI names are common in some environments with gRPC services that use SPIFFEs. + // Though the hostname matcher won't be looking at them, having them there can help + // debugging cases where hostname verification was enabled when it shouldn't be. + sb.append("URI"); + } else { + continue; } + sb.append(':').append((String) altNames.get(1)).append(SEPARATOR); + hasNames = true; + } + if (hasNames) { + // Strip of the last separator + sb.setLength(sb.length() - SEPARATOR.length()); } } + sb.append("], CN=").append(getCommonName(cert)).append('.'); } - if (names.length() != 0) { - // Strip of , - names.setLength(names.length() - 1); - throw new CertificateException(message + - " Subject alternative DNS names in the certificate chain of " + chain.length + - " certificate(s): " + names, e); - } + throw new CertificateException(sb.toString(), e); } throw e; } + + private static String getSNIHostName(SSLSession session) { + if (!(session instanceof ExtendedSSLSession)) { + return null; + } + List names = ((ExtendedSSLSession) session).getRequestedServerNames(); + for (SNIServerName sni : names) { + if (sni instanceof SNIHostName) { + SNIHostName hostName = (SNIHostName) sni; + return hostName.getAsciiName(); + } + } + return null; + } + + private static String getCommonName(X509Certificate cert) { + try { + // 1. Get the X500Principal (better than getSubjectDN which is implementation dependent and deprecated) + X500Principal principal = cert.getSubjectX500Principal(); + // 2. Parse the DN using LdapName + LdapName ldapName = new LdapName(principal.getName()); + // 3. Iterate over the Relative Distinguished Names (RDNs) to find CN + for (Rdn rdn : ldapName.getRdns()) { + if (rdn.getType().equalsIgnoreCase("CN")) { + return rdn.getValue().toString(); + } + } + } catch (Exception ignore) { + // ignore + } + return "null"; + } } diff --git a/handler/src/main/java/io/netty/handler/ssl/JdkSslClientContext.java b/handler/src/main/java/io/netty/handler/ssl/JdkSslClientContext.java index fde06a6023b..f389743ce6a 100644 --- a/handler/src/main/java/io/netty/handler/ssl/JdkSslClientContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/JdkSslClientContext.java @@ -318,10 +318,11 @@ private static SSLContext newSSLContext(Provider sslContextProvider, } private static TrustManager[] wrapIfNeeded(TrustManager[] tms, ResumptionController resumptionController) { - if (resumptionController != null) { - for (int i = 0; i < tms.length; i++) { - tms[i] = resumptionController.wrapIfNeeded(tms[i]); - } + if (tms == null || resumptionController == null) { + return tms; + } + for (int i = 0; i < tms.length; i++) { + tms[i] = resumptionController.wrapIfNeeded(tms[i]); } return tms; } diff --git a/handler/src/main/java/io/netty/handler/ssl/JdkSslServerContext.java b/handler/src/main/java/io/netty/handler/ssl/JdkSslServerContext.java index 9867581038e..f58bab3641b 100644 --- a/handler/src/main/java/io/netty/handler/ssl/JdkSslServerContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/JdkSslServerContext.java @@ -360,6 +360,9 @@ private static SSLContext newSSLContext(Provider sslContextProvider, X509Certifi @SuppressJava6Requirement(reason = "Guarded by java version check") private static TrustManager[] wrapTrustManagerIfNeeded( TrustManager[] trustManagers, ResumptionController resumptionController) { + if (trustManagers == null) { + return null; + } if (WRAP_TRUST_MANAGER && PlatformDependent.javaVersion() >= 7) { for (int i = 0; i < trustManagers.length; i++) { TrustManager tm = trustManagers[i]; diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java b/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java index e1eabf71a7e..af4bcf6779a 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java @@ -67,6 +67,7 @@ public final class OpenSsl { private static final boolean SUPPORTS_OCSP; private static final boolean TLSV13_SUPPORTED; private static final boolean IS_BORINGSSL; + private static final boolean IS_AWSLC; private static final Set CLIENT_DEFAULT_PROTOCOLS; private static final Set SERVER_DEFAULT_PROTOCOLS; static final Set SUPPORTED_PROTOCOLS_SET; @@ -161,6 +162,7 @@ public final class OpenSsl { } IS_BORINGSSL = "BoringSSL".equals(versionString()); + IS_AWSLC = versionString().startsWith("AWS-LC"); if (IS_BORINGSSL) { EXTRA_SUPPORTED_TLS_1_3_CIPHERS = new String [] { "TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384" , @@ -268,7 +270,7 @@ public final class OpenSsl { try { boolean propertySet = SystemPropertyUtil.contains( "io.netty.handler.ssl.openssl.useKeyManagerFactory"); - if (!IS_BORINGSSL) { + if (!(IS_BORINGSSL || IS_AWSLC)) { useKeyManagerFactory = SystemPropertyUtil.getBoolean( "io.netty.handler.ssl.openssl.useKeyManagerFactory", true); @@ -282,7 +284,7 @@ public final class OpenSsl { if (propertySet) { logger.info("System property " + "'io.netty.handler.ssl.openssl.useKeyManagerFactory'" + - " is deprecated and will be ignored when using BoringSSL"); + " is deprecated and will be ignored when using BoringSSL or AWS-LC"); } } } catch (Throwable ignore) { @@ -453,6 +455,7 @@ public final class OpenSsl { SUPPORTS_OCSP = false; TLSV13_SUPPORTED = false; IS_BORINGSSL = false; + IS_AWSLC = false; EXTRA_SUPPORTED_TLS_1_3_CIPHERS = EmptyArrays.EMPTY_STRINGS; EXTRA_SUPPORTED_TLS_1_3_CIPHERS_STRING = StringUtil.EMPTY_STRING; NAMED_GROUPS = DEFAULT_NAMED_GROUPS; @@ -738,7 +741,7 @@ static boolean isOptionSupported(SslContextOption option) { return true; } // Check for options that are only supported by BoringSSL atm. - if (isBoringSSL()) { + if (isBoringSSL() || isAWSLC()) { return option == OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD || option == OpenSslContextOption.PRIVATE_KEY_METHOD || option == OpenSslContextOption.CERTIFICATE_COMPRESSION_ALGORITHMS || @@ -779,4 +782,8 @@ static String[] defaultProtocols(boolean isClient) { static boolean isBoringSSL() { return IS_BORINGSSL; } + + static boolean isAWSLC() { + return IS_AWSLC; + } } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProvider.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProvider.java index a55007dfdb8..801478468eb 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProvider.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProvider.java @@ -67,13 +67,17 @@ OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) t @Override void destroy() { - // Remove and release all entries. - do { - Iterator iterator = cache.values().iterator(); - while (iterator.hasNext()) { - iterator.next().release(); - iterator.remove(); - } - } while (!cache.isEmpty()); + try { + // Remove and release all entries. + do { + Iterator iterator = cache.values().iterator(); + while (iterator.hasNext()) { + iterator.next().release(); + iterator.remove(); + } + } while (!cache.isEmpty()); + } finally { + super.destroy(); + } } } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java index adf545fb61c..f314b90a65c 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java @@ -18,11 +18,13 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.internal.tcnative.SSL; +import io.netty.util.IllegalReferenceCountException; import javax.net.ssl.SSLException; import javax.net.ssl.X509KeyManager; import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.concurrent.atomic.AtomicReference; import static io.netty.handler.ssl.ReferenceCountedOpenSslContext.toBIO; @@ -30,13 +32,16 @@ * Provides {@link OpenSslKeyMaterial} for a given alias. */ class OpenSslKeyMaterialProvider { + private static final MaterialCache SENTINEL_DESTROYED = new MaterialCache(null, null, null); private final X509KeyManager keyManager; private final String password; + private final AtomicReference cache; OpenSslKeyMaterialProvider(X509KeyManager keyManager, String password) { this.keyManager = keyManager; this.password = password; + cache = new AtomicReference(); } static void validateKeyMaterialSupported(X509Certificate[] keyCertChain, PrivateKey key, String keyPassword) @@ -109,6 +114,36 @@ OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) t } PrivateKey key = keyManager.getPrivateKey(alias); + MaterialCache materialCache = cache.get(); + if (materialCache != null && materialCache != SENTINEL_DESTROYED && materialCache.retain()) { + if (materialCache.sameInstances(key, certificates)) { + return materialCache.material(); // We already called `retain()` + } else { + // No match on this cache. Release and build a new one from scratch. + materialCache.release(); + } + } + + OpenSslKeyMaterial keyMaterial = createKeyMaterial(allocator, certificates, key); + materialCache = new MaterialCache(key, certificates, keyMaterial); + + // Retain the new material to put in the cache, then replace and release the old material. + materialCache.retain(); + MaterialCache oldMaterial = cache.getAndSet(materialCache); + if (oldMaterial != null) { + if (oldMaterial == SENTINEL_DESTROYED) { + destroyCache(); // Call `destroyCache()` instead of `destroy()` to avoid duplicating other effects. + } else { + oldMaterial.release(); + } + } + + return keyMaterial; + } + + private OpenSslKeyMaterial createKeyMaterial( + ByteBufAllocator allocator, X509Certificate[] certificates, PrivateKey key) + throws Exception { PemEncoded encoded = PemX509Certificate.toPEM(allocator, true, certificates); long chainBio = 0; long pkeyBio = 0; @@ -149,6 +184,61 @@ OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) t * Will be invoked once the provider should be destroyed. */ void destroy() { - // NOOP. + destroyCache(); + } + + private void destroyCache() { + MaterialCache oldMaterial; + while ((oldMaterial = cache.getAndSet(SENTINEL_DESTROYED)) != SENTINEL_DESTROYED) { + if (oldMaterial != null) { + oldMaterial.release(); + } + } + } + + private static final class MaterialCache { + private final PrivateKey key; + private final X509Certificate[] certs; + private final OpenSslKeyMaterial material; + + private MaterialCache(PrivateKey key, X509Certificate[] certs, OpenSslKeyMaterial material) { + this.key = key; + this.certs = certs; + this.material = material; + } + + OpenSslKeyMaterial material() { + return material; + } + + boolean sameInstances(PrivateKey key, X509Certificate[] certs) { + X509Certificate[] existingCerts = this.certs; + int length = existingCerts.length; + if (this.key != key || length != certs.length) { + return false; + } + for (int i = 0; i < length; i++) { + if (certs[i] != existingCerts[i]) { + return false; + } + } + return true; + } + + boolean retain() { + if (material.refCnt() != 0) { + try { + material.retain(); + return true; + } catch (IllegalReferenceCountException ignore) { + // Fall through to the `return false` below. + } + } + return false; + } + + void release() { + material.release(); + } } } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactory.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactory.java index df711a0bede..2cde2d9e471 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactory.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactory.java @@ -225,10 +225,14 @@ OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) t @Override void destroy() { - for (Object material: materialMap.values()) { - ReferenceCountUtil.release(material); + try { + for (Object material: materialMap.values()) { + ReferenceCountUtil.release(material); + } + materialMap.clear(); + } finally { + super.destroy(); } - materialMap.clear(); } } } diff --git a/handler/src/main/java/io/netty/handler/ssl/PemReader.java b/handler/src/main/java/io/netty/handler/ssl/PemReader.java index 33328ea672d..529b59d05a2 100644 --- a/handler/src/main/java/io/netty/handler/ssl/PemReader.java +++ b/handler/src/main/java/io/netty/handler/ssl/PemReader.java @@ -20,6 +20,7 @@ import io.netty.buffer.Unpooled; import io.netty.handler.codec.base64.Base64; import io.netty.util.CharsetUtil; +import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -80,34 +81,44 @@ static ByteBuf[] readCertificates(InputStream in) throws CertificateException { List certs = new ArrayList(); Matcher m = CERT_HEADER.matcher(content); int start = 0; - for (;;) { - if (!m.find(start)) { - break; - } + try { + for (;;) { + if (!m.find(start)) { + break; + } - // Here and below it's necessary to save the position as it is reset - // after calling usePattern() on Android due to a bug. - // - // See https://issuetracker.google.com/issues/293206296 - start = m.end(); - m.usePattern(BODY); - if (!m.find(start)) { - break; - } + // Here and below it's necessary to save the position as it is reset + // after calling usePattern() on Android due to a bug. + // + // See https://issuetracker.google.com/issues/293206296 + start = m.end(); + m.usePattern(BODY); + if (!m.find(start)) { + break; + } - ByteBuf base64 = Unpooled.copiedBuffer(m.group(0), CharsetUtil.US_ASCII); - start = m.end(); - m.usePattern(CERT_FOOTER); - if (!m.find(start)) { - // Certificate is incomplete. - break; - } - ByteBuf der = Base64.decode(base64); - base64.release(); - certs.add(der); + ByteBuf base64 = Unpooled.copiedBuffer(m.group(0), CharsetUtil.US_ASCII); + try { + start = m.end(); + m.usePattern(CERT_FOOTER); + if (!m.find(start)) { + // Certificate is incomplete. + break; + } + ByteBuf der = Base64.decode(base64); + certs.add(der); + } finally { + base64.release(); + } - start = m.end(); - m.usePattern(CERT_HEADER); + start = m.end(); + m.usePattern(CERT_HEADER); + } + } catch (Throwable e) { + for (ByteBuf cert : certs) { + cert.release(); + } + PlatformDependent.throwException(e); } if (certs.isEmpty()) { @@ -150,15 +161,17 @@ static ByteBuf readPrivateKey(InputStream in) throws KeyException { } ByteBuf base64 = Unpooled.copiedBuffer(m.group(0), CharsetUtil.US_ASCII); - start = m.end(); - m.usePattern(KEY_FOOTER); - if (!m.find(start)) { - // Key is incomplete. - throw keyNotFoundException(); + try { + start = m.end(); + m.usePattern(KEY_FOOTER); + if (!m.find(start)) { + // Key is incomplete. + throw keyNotFoundException(); + } + return Base64.decode(base64); + } finally { + base64.release(); } - ByteBuf der = Base64.decode(base64); - base64.release(); - return der; } private static KeyException keyNotFoundException() { diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java index 4de373a6231..9886543f97e 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java @@ -146,10 +146,13 @@ public ReferenceCounted touch(Object hint) { @Override protected void deallocate() { - destroy(); - if (leak != null) { - boolean closed = leak.close(ReferenceCountedOpenSslContext.this); - assert closed; + try { + destroy(); + } finally { + if (leak != null) { + boolean closed = leak.close(ReferenceCountedOpenSslContext.this); + assert closed; + } } } }; @@ -326,7 +329,8 @@ public ApplicationProtocolConfig.SelectedListenerFailureBehavior selectedListene } } else { CipherSuiteConverter.convertToCipherStrings( - unmodifiableCiphers, cipherBuilder, cipherTLSv13Builder, OpenSsl.isBoringSSL()); + unmodifiableCiphers, cipherBuilder, cipherTLSv13Builder, + OpenSsl.isBoringSSL()); // Set non TLSv1.3 ciphers. SSLContext.setCipherSuite(ctx, cipherBuilder.toString(), false); diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java index 8ed9324c0f1..faee3f098ab 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -386,9 +386,9 @@ public List getStatusResponses() { } } - if (OpenSsl.isBoringSSL() && clientMode) { - // If in client-mode and BoringSSL let's allow to renegotiate once as the server may use this - // for client auth. + if ((OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()) && clientMode) { + // If in client-mode and provider is BoringSSL or AWS-LC let's allow to renegotiate once as the + // server may use this for client auth. // // See https://github.com/netty/netty/issues/11529 SSL.setRenegotiateMode(ssl, SSL.SSL_RENEGOTIATE_ONCE); @@ -1704,7 +1704,8 @@ public final void setEnabledCipherSuites(String[] cipherSuites) { final StringBuilder buf = new StringBuilder(); final StringBuilder bufTLSv13 = new StringBuilder(); - CipherSuiteConverter.convertToCipherStrings(Arrays.asList(cipherSuites), buf, bufTLSv13, OpenSsl.isBoringSSL()); + CipherSuiteConverter.convertToCipherStrings(Arrays.asList(cipherSuites), buf, bufTLSv13, + OpenSsl.isBoringSSL()); final String cipherSuiteSpec = buf.toString(); final String cipherSuiteSpecTLSv13 = bufTLSv13.toString(); diff --git a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java index 0f1d069c62a..946ff91aeaf 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java @@ -82,7 +82,7 @@ public SniHandler(DomainNameMapping mapping) { */ @SuppressWarnings("unchecked") public SniHandler(AsyncMapping mapping) { - this(mapping, 0, 0L); + this(mapping, DEFAULT_MAX_CLIENT_HELLO_LENGTH, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); } /** @@ -119,7 +119,7 @@ public SniHandler(Mapping mapping, long ha * @param handshakeTimeoutMillis the handshake timeout in milliseconds */ public SniHandler(AsyncMapping mapping, long handshakeTimeoutMillis) { - this(mapping, 0, handshakeTimeoutMillis); + this(mapping, DEFAULT_MAX_CLIENT_HELLO_LENGTH, handshakeTimeoutMillis); } /** diff --git a/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java index 052beda8ae1..46eee9512bd 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java @@ -44,6 +44,10 @@ public abstract class SslClientHelloHandler extends ByteToMessageDecoder impl */ public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF; + // Let's use a default limit of 64kb which should be big enough for almost everything in practice but still + // small enough to not allocate to much memory. + static final int DEFAULT_MAX_CLIENT_HELLO_LENGTH = 64 * 1024; + private static final InternalLogger logger = InternalLoggerFactory.getInstance(SslClientHelloHandler.class); @@ -54,7 +58,7 @@ public abstract class SslClientHelloHandler extends ByteToMessageDecoder impl private ByteBuf handshakeBuffer; public SslClientHelloHandler() { - this(MAX_CLIENT_HELLO_LENGTH); + this(DEFAULT_MAX_CLIENT_HELLO_LENGTH); } protected SslClientHelloHandler(int maxClientHelloLength) { @@ -219,7 +223,15 @@ private void select(final ChannelHandlerContext ctx, ByteBuf clientHello) throws try { future = lookup(ctx, clientHello); if (future.isDone()) { - onLookupComplete(ctx, future); + try { + onLookupComplete(ctx, future); + } catch (DecoderException err) { + ctx.fireExceptionCaught(err); + } catch (Exception cause) { + ctx.fireExceptionCaught(new DecoderException(cause)); + } catch (Throwable cause) { + ctx.fireExceptionCaught(cause); + } } else { suppressRead = true; final ByteBuf finalClientHello = clientHello; diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index f80b3004a8a..8b8b88d4da7 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -1903,6 +1903,10 @@ private void resumeOnEventExecutor() { void runComplete() { EventExecutor executor = ctx.executor(); + if (executor.isShuttingDown()) { + // The executor is already shutting down, just return. + return; + } // Jump back on the EventExecutor. We do this even if we are already on the EventLoop to guard against // reentrancy issues. Failing to do so could lead to the situation of tryDecode(...) be called and so // channelRead(...) while still in the decode loop. In this case channelRead(...) might release the input diff --git a/handler/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java b/handler/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java index 6efe959cd7e..53b7e62f004 100644 --- a/handler/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java +++ b/handler/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java @@ -17,6 +17,8 @@ package io.netty.handler.ssl.util; import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -41,7 +43,8 @@ public final class InsecureTrustManagerFactory extends SimpleTrustManagerFactory public static final TrustManagerFactory INSTANCE = new InsecureTrustManagerFactory(); - private static final TrustManager tm = new X509TrustManager() { + private static final TrustManager tm = wrapIfNeeded(new X509TrustManager() { + @Override public void checkClientTrusted(X509Certificate[] chain, String s) { if (logger.isDebugEnabled()) { @@ -60,7 +63,18 @@ public void checkServerTrusted(X509Certificate[] chain, String s) { public X509Certificate[] getAcceptedIssuers() { return EmptyArrays.EMPTY_X509_CERTIFICATES; } - }; + }); + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + static X509TrustManager wrapIfNeeded(X509TrustManager tm) { + if (PlatformDependent.javaVersion() >= 7) { + // This needs to be X509ExtendedTrustManager so hostname verification is skipped as well. + // Otherwise the JDK will internally wrap it with AbstractTrustManagerWrapper and add hostname verification + // by itself. + return new X509TrustManagerWrapper(tm); + } + return tm; + } private InsecureTrustManagerFactory() { } diff --git a/handler/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java b/handler/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java index b502f8cc3e2..d34f038949e 100644 --- a/handler/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java +++ b/handler/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java @@ -15,6 +15,7 @@ */ package io.netty.handler.ssl.util; +import io.netty.util.Recycler; import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.SuppressJava6Requirement; @@ -41,18 +42,37 @@ import java.util.Set; public final class LazyX509Certificate extends X509Certificate { + private static final Recycler CERT_FACTORIES = new Recycler() { + @Override + protected CertFactoryHandle newObject(Handle handle) { + try { + return new CertFactoryHandle(CertificateFactory.getInstance("X.509"), handle); + } catch (CertificateException e) { + throw new IllegalStateException(e); + } + } + }; + + private static final class CertFactoryHandle { + private final CertificateFactory factory; + private final Recycler.EnhancedHandle handle; + + private CertFactoryHandle(CertificateFactory factory, Recycler.Handle handle) { + this.factory = factory; + this.handle = (Recycler.EnhancedHandle) handle; + } + + public X509Certificate generateCertificate(byte[] bytes) throws CertificateException { + return (X509Certificate) factory.generateCertificate(new ByteArrayInputStream(bytes)); + } - static final CertificateFactory X509_CERT_FACTORY; - static { - try { - X509_CERT_FACTORY = CertificateFactory.getInstance("X.509"); - } catch (CertificateException e) { - throw new ExceptionInInitializerError(e); + public void recycle() { + handle.unguardedRecycle(this); } } private final byte[] bytes; - private X509Certificate wrapped; + private volatile X509Certificate wrapped; /** * Creates a new instance which will lazy parse the given bytes. Be aware that the bytes will not be cloned. @@ -230,11 +250,16 @@ public byte[] getExtensionValue(String oid) { private X509Certificate unwrap() { X509Certificate wrapped = this.wrapped; if (wrapped == null) { + CertFactoryHandle factory = null; try { - wrapped = this.wrapped = (X509Certificate) X509_CERT_FACTORY.generateCertificate( - new ByteArrayInputStream(bytes)); + factory = CERT_FACTORIES.get(); + wrapped = this.wrapped = factory.generateCertificate(bytes); } catch (CertificateException e) { throw new IllegalStateException(e); + } finally { + if (factory != null) { + factory.recycle(); + } } } return wrapped; diff --git a/handler/src/main/java/io/netty/handler/ssl/util/SimpleTrustManagerFactory.java b/handler/src/main/java/io/netty/handler/ssl/util/SimpleTrustManagerFactory.java index c6d4b8bc1c0..ed50ad69b9d 100644 --- a/handler/src/main/java/io/netty/handler/ssl/util/SimpleTrustManagerFactory.java +++ b/handler/src/main/java/io/netty/handler/ssl/util/SimpleTrustManagerFactory.java @@ -18,15 +18,11 @@ import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.internal.ObjectUtil; -import io.netty.util.internal.PlatformDependent; -import io.netty.util.internal.SuppressJava6Requirement; import javax.net.ssl.ManagerFactoryParameters; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.TrustManagerFactorySpi; -import javax.net.ssl.X509ExtendedTrustManager; -import javax.net.ssl.X509TrustManager; import java.security.InvalidAlgorithmParameterException; import java.security.KeyStore; import java.security.KeyStoreException; @@ -135,22 +131,9 @@ protected TrustManager[] engineGetTrustManagers() { TrustManager[] trustManagers = this.trustManagers; if (trustManagers == null) { trustManagers = parent.engineGetTrustManagers(); - if (PlatformDependent.javaVersion() >= 7) { - wrapIfNeeded(trustManagers); - } this.trustManagers = trustManagers; } return trustManagers.clone(); } - - @SuppressJava6Requirement(reason = "Usage guarded by java version check") - private static void wrapIfNeeded(TrustManager[] trustManagers) { - for (int i = 0; i < trustManagers.length; i++) { - final TrustManager tm = trustManagers[i]; - if (tm instanceof X509TrustManager && !(tm instanceof X509ExtendedTrustManager)) { - trustManagers[i] = new X509TrustManagerWrapper((X509TrustManager) tm); - } - } - } } } diff --git a/handler/src/test/java/io/netty/handler/ipfilter/IpSubnetFilterTest.java b/handler/src/test/java/io/netty/handler/ipfilter/IpSubnetFilterTest.java index 6566c493432..22cdbefe78e 100644 --- a/handler/src/test/java/io/netty/handler/ipfilter/IpSubnetFilterTest.java +++ b/handler/src/test/java/io/netty/handler/ipfilter/IpSubnetFilterTest.java @@ -24,9 +24,12 @@ import io.netty.util.internal.SocketUtils; import org.junit.jupiter.api.Test; +import java.net.Inet4Address; +import java.net.Inet6Address; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -36,6 +39,22 @@ public class IpSubnetFilterTest { + @Test + void noClassCastExceptionIpv4RuleOnly() { + IpSubnetFilterRule rule = new IpSubnetFilterRule("10.10.0.0/16", IpFilterRuleType.ACCEPT); + IpSubnetFilter filter = new IpSubnetFilter(false, rule); + assertFalse(filter.accept(null, new InetSocketAddress(Inet4Address.getLoopbackAddress(), 80))); + assertFalse(filter.accept(null, new InetSocketAddress(Inet6Address.getLoopbackAddress(), 80))); + } + + @Test + void noClassCastExceptionIpv6RuleOnly() { + IpSubnetFilterRule rule = new IpSubnetFilterRule("::1/16", IpFilterRuleType.ACCEPT); + IpSubnetFilter filter = new IpSubnetFilter(false, rule); + assertFalse(filter.accept(null, new InetSocketAddress(Inet4Address.getLoopbackAddress(), 80))); + assertFalse(filter.accept(null, new InetSocketAddress(Inet6Address.getLoopbackAddress(), 80))); + } + @Test public void testIpv4DefaultRoute() { IpSubnetFilterRule rule = new IpSubnetFilterRule("0.0.0.0", 0, IpFilterRuleType.ACCEPT); @@ -195,12 +214,29 @@ public void testBinarySearch() { assertTrue(ch6.close().isSuccess()); //2001:db8:abcd:0000::/52 - EmbeddedChannel ch7 = newEmbeddedInetChannel("2001:db8:abcd:1000::", + EmbeddedChannel ch7 = newEmbeddedInetChannel("2001:db8:abcd:0000::1", new IpSubnetFilter(ipSubnetFilterRuleList)); assertFalse(ch7.isActive()); assertTrue(ch7.close().isSuccess()); } + @Test + public void testIpv6MaskCorrectlyApplied() { + IpSubnetFilterRule rule = new IpSubnetFilterRule("2001:db8:abcd:0000::", 52, IpFilterRuleType.ACCEPT); + + EmbeddedChannel ch = newEmbeddedInetChannel("2001:db8:ffff:0000::", + new IpSubnetFilter(false, Collections.singletonList(rule))); + assertFalse(ch.isActive()); + assertTrue(ch.close().isSuccess()); + } + + @Test + public void testIpv6MatchesNoFalsePositiveForAllOnesNetworkBits() { + // FFFF:FFFF::1 is NOT in 2001:db8::/32, which will be the case if the comparison is made unsigned. + IpSubnetFilterRule rule = new IpSubnetFilterRule("2001:db8::", 32, IpFilterRuleType.ACCEPT); + assertFalse(rule.matches(newSockAddress("FFFF:FFFF::1"))); + } + private static IpSubnetFilterRule buildRejectIP(String ipAddress, int mask) { return new IpSubnetFilterRule(ipAddress, mask, IpFilterRuleType.REJECT); } diff --git a/handler/src/test/java/io/netty/handler/logging/LoggingHandlerTest.java b/handler/src/test/java/io/netty/handler/logging/LoggingHandlerTest.java index 5d1aa8f60db..78e24f46d9d 100644 --- a/handler/src/test/java/io/netty/handler/logging/LoggingHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/logging/LoggingHandlerTest.java @@ -33,6 +33,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.api.parallel.Isolated; import org.mockito.ArgumentMatcher; import org.slf4j.LoggerFactory; @@ -54,6 +55,7 @@ /** * Verifies the correct functionality of the {@link LoggingHandler}. */ +@Isolated public class LoggingHandlerTest { private static final String LOGGER_NAME = LoggingHandler.class.getName(); diff --git a/handler/src/test/java/io/netty/handler/ssl/AmazonCorrettoSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/AmazonCorrettoSslEngineTest.java index 498e734f0d2..25a5f4a563d 100644 --- a/handler/src/test/java/io/netty/handler/ssl/AmazonCorrettoSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/AmazonCorrettoSslEngineTest.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.condition.DisabledIf; +import org.junit.jupiter.api.parallel.Isolated; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; @@ -30,7 +31,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; - +@Isolated("Adds and removes Security providers") @DisabledIf("checkIfAccpIsDisabled") public class AmazonCorrettoSslEngineTest extends SSLEngineTest { diff --git a/handler/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java b/handler/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java index 9b7398ba5ef..56fa9c4074a 100644 --- a/handler/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java @@ -223,16 +223,17 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) } } finally { server.close().sync(); + + if (executorService != null) { + executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); + } } } finally { ReferenceCountUtil.release(sslClientContext); } } finally { ReferenceCountUtil.release(sslServerContext); - - if (executorService != null) { - executorService.shutdown(); - } } } diff --git a/handler/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java b/handler/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java index 427b82d2953..d0ca2165fc2 100644 --- a/handler/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java @@ -68,8 +68,6 @@ public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEn throws Exception { } - @MethodSource("newTestParams") - @ParameterizedTest @Override protected boolean mySetupMutualAuthServerIsValidServerException(Throwable cause) { // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. diff --git a/handler/src/test/java/io/netty/handler/ssl/DelayingExecutor.java b/handler/src/test/java/io/netty/handler/ssl/DelayingExecutor.java index e3c39cbc7d9..65cbc448aa8 100644 --- a/handler/src/test/java/io/netty/handler/ssl/DelayingExecutor.java +++ b/handler/src/test/java/io/netty/handler/ssl/DelayingExecutor.java @@ -42,7 +42,8 @@ public void execute(Runnable command) { PlatformDependent.threadLocalRandom().nextInt(100), TimeUnit.MILLISECONDS); } - void shutdown() { + boolean shutdownAndAwaitTermination(long timeout, TimeUnit unit) throws InterruptedException { service.shutdown(); + return service.awaitTermination(timeout, unit); } } diff --git a/handler/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java b/handler/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java index 60976127579..9e9a689c878 100644 --- a/handler/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java @@ -17,13 +17,13 @@ package io.netty.handler.ssl; import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.function.Executable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLSocket; -import javax.net.ssl.X509ExtendedTrustManager; import java.math.BigInteger; import java.net.Socket; import java.security.Principal; @@ -32,23 +32,48 @@ import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.Date; import java.util.List; import java.util.Set; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.security.auth.x500.X500Principal; +import static io.netty.handler.ssl.EnhancingX509ExtendedTrustManager.ALTNAME_DNS; +import static io.netty.handler.ssl.EnhancingX509ExtendedTrustManager.ALTNAME_IP; +import static io.netty.handler.ssl.EnhancingX509ExtendedTrustManager.ALTNAME_URI; +import static io.netty.handler.ssl.SniClientJava8TestUtil.mockSSLSessionWithSNIHostNameAndPeerHost; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assumptions.assumeTrue; public class EnhancedX509ExtendedTrustManagerTest { + private static final String HOSTNAME = "netty.io"; + private static final String SAN_ENTRY_DNS = "some.netty.io"; + private static final String SAN_ENTRY_IP = "127.0.0.1"; + private static final String SAN_ENTRY_URI = "URI:https://uri.netty.io/profile"; + private static final String SAN_ENTRY_RFC822 = "info@netty.io"; + private static final String COMMON_NAME = "leaf.netty.io"; + private static final X509Certificate TEST_CERT = new X509Certificate() { @Override public Collection> getSubjectAlternativeNames() { - return Arrays.asList(Arrays.asList(1, new Object()), Arrays.asList(2, "some.netty.io")); + return Arrays.asList(Arrays.asList(1, new Object()), + Arrays.asList(ALTNAME_DNS, SAN_ENTRY_DNS), Arrays.asList(ALTNAME_IP, SAN_ENTRY_IP), + Arrays.asList(ALTNAME_URI, SAN_ENTRY_URI), Arrays.asList(1 /* rfc822Name */, SAN_ENTRY_RFC822)); + } + + @Override + public X500Principal getSubjectX500Principal() { + return new X500Principal("CN=" + COMMON_NAME + ", O=Netty"); } @Override @@ -192,7 +217,7 @@ public void checkClientTrusted(X509Certificate[] chain, String authType, Socket @Override public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException { - throw new CertificateException("No subject alternative DNS name matching netty.io."); + throw newCertificateExceptionWithMatchingMessage(); } @Override @@ -203,7 +228,7 @@ public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngi @Override public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) throws CertificateException { - throw new CertificateException("No subject alternative DNS name matching netty.io."); + throw newCertificateExceptionWithMatchingMessage(); } @Override @@ -214,16 +239,23 @@ public void checkClientTrusted(X509Certificate[] chain, String authType) { @Override public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { - throw new CertificateException("No subject alternative DNS name matching netty.io."); + throw newCertificateExceptionWithMatchingMessage(); } @Override public X509Certificate[] getAcceptedIssuers() { return new X509Certificate[0]; } + + private CertificateException newCertificateExceptionWithMatchingMessage() { + return new CertificateException("No subject alternative DNS name matching " + HOSTNAME + " found."); + } }); static List throwingMatchingExecutables() { + if (PlatformDependent.javaVersion() < 8) { + return Collections.emptyList(); + } return Arrays.asList(new Executable() { @Override public void execute() throws Throwable { @@ -232,12 +264,18 @@ public void execute() throws Throwable { }, new Executable() { @Override public void execute() throws Throwable { - MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, (SSLEngine) null); + SSLSession session = mockSSLSessionWithSNIHostNameAndPeerHost(HOSTNAME); + SSLEngine engine = Mockito.mock(SSLEngine.class); + Mockito.when(engine.getHandshakeSession()).thenReturn(session); + MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, engine); } }, new Executable() { @Override public void execute() throws Throwable { - MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, (SSLSocket) null); + SSLSession session = mockSSLSessionWithSNIHostNameAndPeerHost(HOSTNAME); + SSLSocket socket = Mockito.mock(SSLSocket.class); + Mockito.when(socket.getHandshakeSession()).thenReturn(session); + MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, socket); } }); } @@ -307,16 +345,28 @@ public void execute() throws Throwable { @ParameterizedTest @MethodSource("throwingMatchingExecutables") - void testEnhanceException(Executable executable) { + void testEnhanceException(Executable executable, TestInfo testInfo) { + assumeTrue(PlatformDependent.javaVersion() >= 8); CertificateException exception = assertThrows(CertificateException.class, executable); // We should wrap the original cause with our own. assertInstanceOf(CertificateException.class, exception.getCause()); - assertThat(exception.getMessage()).contains("some.netty.io"); + String message = exception.getMessage(); + if (testInfo.getDisplayName().contains("with")) { + // The following data can be extracted only when we run the test with SSLEngine or SSLSocket: + assertThat(message).contains("SNIHostName=" + HOSTNAME); + assertThat(message).contains("peerHost=" + HOSTNAME); + } + assertThat(message).contains("DNS:" + SAN_ENTRY_DNS); + assertThat(message).contains("IP:" + SAN_ENTRY_IP); + assertThat(message).contains("URI:" + SAN_ENTRY_URI); + assertThat(message).contains("CN=" + COMMON_NAME); + assertThat(message).doesNotContain(SAN_ENTRY_RFC822); } @ParameterizedTest @MethodSource("throwingNonMatchingExecutables") void testNotEnhanceException(Executable executable) { + assumeTrue(PlatformDependent.javaVersion() >= 8); CertificateException exception = assertThrows(CertificateException.class, executable); // We should not wrap the original cause with our own. assertNull(exception.getCause()); diff --git a/handler/src/test/java/io/netty/handler/ssl/JdkSslClientContextTest.java b/handler/src/test/java/io/netty/handler/ssl/JdkSslClientContextTest.java index e5e18c13872..65eb965ec8e 100644 --- a/handler/src/test/java/io/netty/handler/ssl/JdkSslClientContextTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/JdkSslClientContextTest.java @@ -16,9 +16,18 @@ package io.netty.handler.ssl; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import org.junit.jupiter.api.Test; +import javax.net.ssl.ManagerFactoryParameters; import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.TrustManagerFactorySpi; import java.io.File; +import java.security.KeyStore; +import java.security.Provider; + +import static org.junit.jupiter.api.Assertions.assertNotNull; public class JdkSslClientContextTest extends SslContextTest { @Override @@ -26,4 +35,54 @@ protected SslContext newSslContext(File crtFile, File keyFile, String pass) thro return new JdkSslClientContext(crtFile, InsecureTrustManagerFactory.INSTANCE, crtFile, keyFile, pass, null, null, IdentityCipherSuiteFilter.INSTANCE, ApplicationProtocolConfig.DISABLED, 0, 0); } + + // Reproduces https://github.com/netty/netty/issues/14488 + // Before the fix, wrapIfNeeded did tms.length without a null check, so a + // TrustManagerFactory whose getTrustManagers() returns null surfaced as an + // opaque NullPointerException wrapped in SSLException. After the fix the + // context is built successfully, matching pre-4.1.114 behavior that relied + // on SSLContext.init accepting a null TrustManager[] to use platform defaults. + @Test + public void testTrustManagerFactoryReturningNullDoesNotThrowNpe() throws Exception { + TrustManagerFactory tmf = new TrustManagerFactory( + new NullReturningTrustManagerFactorySpi(), NullReturningTrustManagerFactorySpi.PROVIDER, "null") { + }; + // TrustManagerFactory must be initialized before SslContextBuilder will accept it; + // without this call the builder throws before reaching the code path under test. + tmf.init((KeyStore) null); + + SslContext ctx = SslContextBuilder.forClient() + .sslProvider(SslProvider.JDK) + .trustManager(tmf) + .build(); + // Success is "did not throw". Before the fix this call produced + // SSLException("failed to initialize the client-side SSL context") + // caused by NullPointerException from wrapIfNeeded. + assertNotNull(ctx); + } + + private static final class NullReturningTrustManagerFactorySpi extends TrustManagerFactorySpi { + // The Provider(String, double, String) constructor is deprecated since JDK 9 in + // favor of (String, String, String), but the replacement is unavailable on the + // JDK 8 source level this module targets, so we suppress the deprecation warning. + @SuppressWarnings("deprecation") + static final Provider PROVIDER = new Provider("NullReturningProvider", 1.0, "test-only") { + private static final long serialVersionUID = 1L; + }; + + @Override + protected void engineInit(KeyStore ks) { + // no-op + } + + @Override + protected void engineInit(ManagerFactoryParameters spec) { + // no-op + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return null; + } + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/JdkSslServerContextTest.java b/handler/src/test/java/io/netty/handler/ssl/JdkSslServerContextTest.java index 10052770a8f..c2effab48f1 100644 --- a/handler/src/test/java/io/netty/handler/ssl/JdkSslServerContextTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/JdkSslServerContextTest.java @@ -15,12 +15,21 @@ */ package io.netty.handler.ssl; +import io.netty.handler.ssl.util.SelfSignedCertificate; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; +import javax.net.ssl.ManagerFactoryParameters; import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.TrustManagerFactorySpi; import java.io.File; +import java.security.KeyStore; +import java.security.Provider; + +import static org.junit.jupiter.api.Assertions.assertNotNull; public class JdkSslServerContextTest extends SslContextTest { @@ -38,4 +47,58 @@ public void execute() throws Throwable { } }); } + + // A TrustManagerFactory whose getTrustManagers() legitimately returns null + // (e.g., asking SSLContext.init to fall back to the JDK default trust store) + // previously NPE'd inside wrapTrustManagerIfNeeded. Verify the server context + // now builds without throwing. + @Test + void testTrustManagerFactoryReturningNullDoesNotThrowNpe() throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + try { + TrustManagerFactory tmf = new TrustManagerFactory( + new NullReturningTrustManagerFactorySpi(), + NullReturningTrustManagerFactorySpi.PROVIDER, "null") { + }; + // TrustManagerFactory must be initialized before SslContextBuilder will accept it; + // without this call the builder throws before reaching the code path under test. + tmf.init((KeyStore) null); + + SslContext ctx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(SslProvider.JDK) + .trustManager(tmf) + .build(); + // Success is "did not throw". Before the fix this call produced + // SSLException("failed to initialize the server-side SSL context") + // caused by NullPointerException from wrapTrustManagerIfNeeded. + assertNotNull(ctx); + } finally { + ssc.delete(); + } + } + + private static final class NullReturningTrustManagerFactorySpi extends TrustManagerFactorySpi { + // The Provider(String, double, String) constructor is deprecated since JDK 9 in + // favor of (String, String, String), but the replacement is unavailable on the + // JDK 8 source level this module targets, so we suppress the deprecation warning. + @SuppressWarnings("deprecation") + static final Provider PROVIDER = new Provider("NullReturningProvider", 1.0, "test-only") { + private static final long serialVersionUID = 1L; + }; + + @Override + protected void engineInit(KeyStore ks) { + // no-op + } + + @Override + protected void engineInit(ManagerFactoryParameters spec) { + // no-op + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return null; + } + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProviderTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProviderTest.java index 4e916f4927f..f278ff8a760 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProviderTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProviderTest.java @@ -40,11 +40,6 @@ protected OpenSslKeyMaterialProvider newMaterialProvider(KeyManagerFactory facto factory.getKeyManagers()), password, Integer.MAX_VALUE); } - @Override - protected void assertRelease(OpenSslKeyMaterial material) { - assertFalse(material.release()); - } - @Test public void testMaterialCached() throws Exception { OpenSslKeyMaterialProvider provider = newMaterialProvider(newKeyManagerFactory(), PASSWORD); @@ -53,14 +48,14 @@ public void testMaterialCached() throws Exception { assertNotNull(material); assertNotEquals(0, material.certificateChainAddress()); assertNotEquals(0, material.privateKeyAddress()); - assertEquals(2, material.refCnt()); + assertEquals(3, material.refCnt()); OpenSslKeyMaterial material2 = provider.chooseKeyMaterial(UnpooledByteBufAllocator.DEFAULT, EXISTING_ALIAS); assertNotNull(material2); assertEquals(material.certificateChainAddress(), material2.certificateChainAddress()); assertEquals(material.privateKeyAddress(), material2.privateKeyAddress()); - assertEquals(3, material.refCnt()); - assertEquals(3, material2.refCnt()); + assertEquals(4, material.refCnt()); + assertEquals(4, material2.refCnt()); assertFalse(material.release()); assertFalse(material2.release()); diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java index 9882c722aad..b5422d1b9ac 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java @@ -30,6 +30,8 @@ import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.internal.tcnative.CertificateCompressionAlgo; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Promise; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; @@ -71,7 +73,7 @@ public void refreshAlgos() { @Test public void testSimple() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); final SslContext clientSslContext = buildClientContext( OpenSslCertificateCompressionConfig.newBuilder() .addAlgorithm(testBrotliAlgoClient, @@ -92,7 +94,7 @@ public void testSimple() throws Throwable { @Test public void testServerPriority() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); final SslContext clientSslContext = buildClientContext( OpenSslCertificateCompressionConfig.newBuilder() .addAlgorithm(testBrotliAlgoClient, @@ -116,7 +118,7 @@ public void testServerPriority() throws Throwable { @Test public void testServerPriorityReverse() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); final SslContext clientSslContext = buildClientContext( OpenSslCertificateCompressionConfig.newBuilder() .addAlgorithm(testBrotliAlgoClient, @@ -141,7 +143,7 @@ public void testServerPriorityReverse() throws Throwable { @Test public void testFailedNegotiation() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); final SslContext clientSslContext = buildClientContext( OpenSslCertificateCompressionConfig.newBuilder() .addAlgorithm(testBrotliAlgoClient, @@ -162,7 +164,7 @@ public void testFailedNegotiation() throws Throwable { @Test public void testAlgoFailure() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); TestCertCompressionAlgo badZlibAlgoClient = new TestCertCompressionAlgo(CertificateCompressionAlgo.TLS_EXT_CERT_COMPRESSION_ZLIB) { @Override @@ -191,7 +193,7 @@ public void execute() throws Throwable { @Test public void testAlgoException() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); TestCertCompressionAlgo badZlibAlgoClient = new TestCertCompressionAlgo(CertificateCompressionAlgo.TLS_EXT_CERT_COMPRESSION_ZLIB) { @Override @@ -220,7 +222,7 @@ public void execute() throws Throwable { @Test public void testTlsLessThan13() throws Throwable { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); final SslContext clientSslContext = SslContextBuilder.forClient() .sslProvider(SslProvider.OPENSSL) .protocols(SslProtocols.TLS_v1_2) @@ -251,7 +253,7 @@ public void testTlsLessThan13() throws Throwable { @Test public void testDuplicateAdd() throws Throwable { // Fails with "Failed trying to add certificate compression algorithm" - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); Assertions.assertThrows(Exception.class, new Executable() { @Override public void execute() throws Throwable { @@ -283,7 +285,7 @@ public void execute() throws Throwable { @Test public void testNotBoringAdd() throws Throwable { // Fails with "TLS Cert Compression only supported by BoringSSL" - assumeTrue(!OpenSsl.isBoringSSL()); + assumeTrue(!OpenSsl.isBoringSSL() && !OpenSsl.isAWSLC()); Assertions.assertThrows(Exception.class, new Executable() { @Override public void execute() throws Throwable { @@ -333,7 +335,10 @@ public void runCertCompressionTest(SslContext clientSslContext, SslContext serve clientChannel.close().syncUninterruptibly(); serverChannel.close().syncUninterruptibly(); } finally { - group.shutdownGracefully(); + Future future = group.shutdownGracefully(0, 10, TimeUnit.SECONDS); + ReferenceCountUtil.release(clientSslContext); + ReferenceCountUtil.release(serverSslContext); + future.sync(); } } diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java index 896a21583d0..defabca75d0 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java @@ -25,7 +25,7 @@ static void expandCombinations(SSLEngineTest.SSLEngineTestParam param, List output) { output.add(new OpenSslEngineTestParam(true, false, param)); output.add(new OpenSslEngineTestParam(false, false, param)); - if (OpenSsl.isBoringSSL()) { + if (OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()) { output.add(new OpenSslEngineTestParam(true, true, param)); output.add(new OpenSslEngineTestParam(false, true, param)); } diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java index 09976ee4aa6..14d70cb2c71 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java @@ -32,6 +32,7 @@ import java.security.cert.X509Certificate; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; @@ -69,7 +70,7 @@ protected OpenSslKeyMaterialProvider newMaterialProvider(KeyManagerFactory facto } protected void assertRelease(OpenSslKeyMaterial material) { - assertTrue(material.release()); + assertFalse(material.release()); } @Test @@ -86,6 +87,8 @@ public void testChooseKeyMaterial() throws Exception { assertRelease(material); provider.destroy(); + + assertEquals(0, material.refCnt()); } /** @@ -164,13 +167,14 @@ public void testChooseOpenSslPrivateKeyMaterial() throws Exception { OpenSslKeyMaterial material = provider.chooseKeyMaterial(ByteBufAllocator.DEFAULT, keyAlias); assertNotNull(material); assertEquals(2, sslPrivateKey.refCnt()); - assertEquals(1, material.refCnt()); - assertTrue(material.release()); - assertEquals(1, sslPrivateKey.refCnt()); + assertEquals(2, material.refCnt()); + assertFalse(material.release()); + assertEquals(2, sslPrivateKey.refCnt()); // Can get material multiple times from the same key material = provider.chooseKeyMaterial(ByteBufAllocator.DEFAULT, keyAlias); assertNotNull(material); assertEquals(2, sslPrivateKey.refCnt()); + provider.destroy(); // Destroy single-entry cache. assertTrue(material.release()); assertTrue(sslPrivateKey.release()); assertEquals(0, sslPrivateKey.refCnt()); diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java index ef268a8a253..3a3f049bd0b 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java @@ -58,8 +58,6 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.Executor; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -93,7 +91,7 @@ static Collection parameters() { public static void init() throws Exception { checkShouldUseKeyManagerFactory(); - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); // Check if the cipher is supported at all which may not be the case for various JDK versions and OpenSSL API // implementations. assumeCipherAvailable(SslProvider.OPENSSL); @@ -110,11 +108,11 @@ public Thread newThread(Runnable r) { } @AfterAll - public static void destroy() { - if (OpenSsl.isBoringSSL()) { + public static void destroy() throws InterruptedException { + if (OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()) { GROUP.shutdownGracefully(); + assertTrue(EXECUTOR.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); CERT.delete(); - EXECUTOR.shutdown(); } } diff --git a/handler/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java index 36ff3de72e3..573f74bb221 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java @@ -28,7 +28,7 @@ import org.mockito.MockitoAnnotations; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; public class OptionalSslHandlerTest { @@ -115,7 +115,7 @@ public void decodeBuffered() throws Exception { final ByteBuf payload = Unpooled.wrappedBuffer(new byte[] { 22, 3 }); try { handler.decode(context, payload, null); - verifyZeroInteractions(pipeline); + verifyNoInteractions(pipeline); } finally { payload.release(); } diff --git a/handler/src/test/java/io/netty/handler/ssl/RenegotiateTest.java b/handler/src/test/java/io/netty/handler/ssl/RenegotiateTest.java index 174243d9a01..1128d8e4cd3 100644 --- a/handler/src/test/java/io/netty/handler/ssl/RenegotiateTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/RenegotiateTest.java @@ -101,7 +101,7 @@ public void operationComplete(Future future) throws Exception { }); } }); - Channel channel = sb.bind(new LocalAddress("RenegotiateTest")).syncUninterruptibly().channel(); + Channel channel = sb.bind(new LocalAddress(getClass())).syncUninterruptibly().channel(); final SslContext clientContext = SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE) diff --git a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java index da3ca84b444..2bb66673a48 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java @@ -50,7 +50,6 @@ import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; -import io.netty.util.internal.SystemPropertyUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import org.conscrypt.OpenSSLProvider; @@ -547,7 +546,7 @@ public void tearDown() throws InterruptedException { if (clientGroupShutdownFuture != null) { clientGroupShutdownFuture.sync(); } - delegatingExecutor.shutdown(); + assertTrue(delegatingExecutor.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); serverException = null; clientException = null; } @@ -629,11 +628,7 @@ public void testIncompatibleCiphers(final SSLEngineTestParam param) throws Excep serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); // Set the server to only support a single TLSv1.2 cipher - final String serverCipher = - // JDK24+ does not support TLS_RSA_* ciphers by default anymore: - // See https://www.java.com/en/configure_crypto.html - PlatformDependent.javaVersion() >= 24 ? "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" : - "TLS_RSA_WITH_AES_128_CBC_SHA"; + final String serverCipher = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; serverEngine.setEnabledCipherSuites(new String[] { serverCipher }); // Set the client to only support a single TLSv1.3 cipher @@ -1381,11 +1376,12 @@ public void testSessionInvalidate(SSLEngineTestParam param) throws Exception { clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); handshake(param.type(), param.delegate(), clientEngine, serverEngine); + pingPongPacketsUntilSessionAllocation(param, clientEngine, serverEngine); SSLSession session = serverEngine.getSession(); - assertTrue(session.isValid()); + assertTrue(session.isValid(), "session should be valid: " + session); session.invalidate(); - assertFalse(session.isValid()); + assertFalse(session.isValid(), "session should be invalid: " + session); } finally { cleanupClientSslEngine(clientEngine); cleanupServerSslEngine(serverEngine); @@ -1422,43 +1418,7 @@ public void testSSLSessionId(SSLEngineTestParam param) throws Exception { handshake(param.type(), param.delegate(), clientEngine, serverEngine); if (param.protocolCipherCombo == ProtocolCipherCombo.TLSV13) { - // Allocate something which is big enough for sure - ByteBuffer packetBuffer = allocateBuffer(param.type(), 32 * 1024); - ByteBuffer appBuffer = allocateBuffer(param.type(), 32 * 1024); - - appBuffer.clear().position(4).flip(); - packetBuffer.clear(); - - do { - SSLEngineResult result; - - do { - result = serverEngine.wrap(appBuffer, packetBuffer); - } while (appBuffer.hasRemaining() || result.bytesProduced() > 0); - - appBuffer.clear(); - packetBuffer.flip(); - do { - result = clientEngine.unwrap(packetBuffer, appBuffer); - } while (packetBuffer.hasRemaining() || result.bytesProduced() > 0); - - packetBuffer.clear(); - appBuffer.clear().position(4).flip(); - - do { - result = clientEngine.wrap(appBuffer, packetBuffer); - } while (appBuffer.hasRemaining() || result.bytesProduced() > 0); - - appBuffer.clear(); - packetBuffer.flip(); - - do { - result = serverEngine.unwrap(packetBuffer, appBuffer); - } while (packetBuffer.hasRemaining() || result.bytesProduced() > 0); - - packetBuffer.clear(); - appBuffer.clear().position(4).flip(); - } while (clientEngine.getSession().getId().length == 0); + pingPongPacketsUntilSessionAllocation(param, clientEngine, serverEngine); // With TLS1.3 we should see pseudo IDs and so these should never match. assertFalse(Arrays.equals(clientEngine.getSession().getId(), serverEngine.getSession().getId())); @@ -1477,6 +1437,47 @@ public void testSSLSessionId(SSLEngineTestParam param) throws Exception { } } + private void pingPongPacketsUntilSessionAllocation( + SSLEngineTestParam param, SSLEngine clientEngine, SSLEngine serverEngine) throws SSLException { + // Allocate something which is big enough for sure + ByteBuffer packetBuffer = allocateBuffer(param.type(), 32 * 1024); + ByteBuffer appBuffer = allocateBuffer(param.type(), 32 * 1024); + + appBuffer.clear().position(4).flip(); + packetBuffer.clear(); + + do { + SSLEngineResult result; + + do { + result = serverEngine.wrap(appBuffer, packetBuffer); + } while (appBuffer.hasRemaining() || result.bytesProduced() > 0); + + appBuffer.clear(); + packetBuffer.flip(); + do { + result = clientEngine.unwrap(packetBuffer, appBuffer); + } while (packetBuffer.hasRemaining() || result.bytesProduced() > 0); + + packetBuffer.clear(); + appBuffer.clear().position(4).flip(); + + do { + result = clientEngine.wrap(appBuffer, packetBuffer); + } while (appBuffer.hasRemaining() || result.bytesProduced() > 0); + + appBuffer.clear(); + packetBuffer.flip(); + + do { + result = serverEngine.unwrap(packetBuffer, appBuffer); + } while (packetBuffer.hasRemaining() || result.bytesProduced() > 0); + + packetBuffer.clear(); + appBuffer.clear().position(4).flip(); + } while (clientEngine.getSession().getId().length == 0); + } + @MethodSource("newTestParams") @ParameterizedTest @Timeout(30) @@ -2247,11 +2248,7 @@ public void testHandshakeCompletesWithNonContiguousProtocolsTLSv1_2CipherOnly(SS SelfSignedCertificate ssc = CachedSelfSignedCertificate.getCachedCertificate(); // Select a mandatory cipher from the TLSv1.2 RFC https://www.ietf.org/rfc/rfc5246.txt so handshakes won't fail // due to no shared/supported cipher. - final String sharedCipher = - // JDK24+ does not support TLS_RSA_* ciphers by default anymore: - // See https://www.java.com/en/configure_crypto.html - PlatformDependent.javaVersion() >= 24 ? "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" : - "TLS_RSA_WITH_AES_128_CBC_SHA"; + final String sharedCipher = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; clientSslCtx = wrapContext(param, SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE) .ciphers(Collections.singletonList(sharedCipher)) @@ -2284,11 +2281,7 @@ public void testHandshakeCompletesWithoutFilteringSupportedCipher(SSLEngineTestP SelfSignedCertificate ssc = CachedSelfSignedCertificate.getCachedCertificate(); // Select a mandatory cipher from the TLSv1.2 RFC https://www.ietf.org/rfc/rfc5246.txt so handshakes won't fail // due to no shared/supported cipher. - final String sharedCipher = - // JDK24+ does not support TLS_RSA_* ciphers by default anymore: - // See https://www.java.com/en/configure_crypto.html - PlatformDependent.javaVersion() >= 24 ? "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" : - "TLS_RSA_WITH_AES_128_CBC_SHA"; + final String sharedCipher = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; clientSslCtx = wrapContext(param, SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE) .ciphers(Collections.singletonList(sharedCipher), SupportedCipherSuiteFilter.INSTANCE) @@ -4511,11 +4504,9 @@ public void testMasterKeyLogging(final SSLEngineTestParam param) throws Exceptio * The JDK SSL engine master key retrieval relies on being able to set field access to true. * That is not available in JDK9+ */ - assumeFalse(sslServerProvider() == SslProvider.JDK && PlatformDependent.javaVersion() > 8); - - String originalSystemPropertyValue = SystemPropertyUtil.get(SslMasterKeyHandler.SYSTEM_PROP_KEY); - System.setProperty(SslMasterKeyHandler.SYSTEM_PROP_KEY, Boolean.TRUE.toString()); - + if (sslServerProvider() == SslProvider.JDK) { + assumeTrue(SslMasterKeyHandler.isSunSslEngineAvailable()); + } SelfSignedCertificate ssc = CachedSelfSignedCertificate.getCachedCertificate(); serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) .sslProvider(sslServerProvider()) @@ -4542,6 +4533,12 @@ protected void initChannel(Channel ch) { ch.pipeline().addLast(sslHandler); ch.pipeline().addLast(new SslMasterKeyHandler() { + + @Override + protected boolean masterKeyHandlerEnabled() { + return true; + } + @Override protected void accept(SecretKey masterKey, SSLSession session) { promise.setSuccess(masterKey); @@ -4565,11 +4562,6 @@ protected void accept(SecretKey masterKey, SSLSession session) { assertEquals(48, key.getEncoded().length, "AES secret key must be 48 bytes"); } finally { closeQuietly(socket); - if (originalSystemPropertyValue != null) { - System.setProperty(SslMasterKeyHandler.SYSTEM_PROP_KEY, originalSystemPropertyValue); - } else { - System.clearProperty(SslMasterKeyHandler.SYSTEM_PROP_KEY); - } } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java b/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java index 3554e5ae46b..2e67ac87279 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java @@ -35,6 +35,7 @@ import io.netty.util.concurrent.Promise; import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.ThrowableUtil; +import org.mockito.Mockito; import javax.net.ssl.ExtendedSSLSession; import javax.net.ssl.KeyManager; @@ -64,6 +65,7 @@ import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -345,4 +347,12 @@ public String chooseEngineServerAlias(String s, Principal[] principals, }, factory.getProvider(), factory.getAlgorithm()); } } + + static SSLSession mockSSLSessionWithSNIHostNameAndPeerHost(String hostname) { + ExtendedSSLSession session = Mockito.mock(ExtendedSSLSession.class); + SNIServerName sniName = new SNIHostName(hostname); + Mockito.when(session.getRequestedServerNames()).thenReturn(Arrays.asList(sniName)); + Mockito.when(session.getPeerHost()).thenReturn(hostname); + return session; + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java index bd0223cd1f6..3a85f24fdc6 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java @@ -31,7 +31,6 @@ import io.netty.handler.codec.TooLongFrameException; import io.netty.handler.ssl.util.CachedSelfSignedCertificate; import io.netty.util.concurrent.Future; - import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; @@ -70,14 +69,14 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assumptions.assumeTrue; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; import static org.mockito.Mockito.mock; public class SniHandlerTest { @@ -265,8 +264,14 @@ public void testNonAsciiServerNameParsing(SslProvider provider) throws Exception .add("chat4.leancloud.cn", leanContext2) .build(); + final AtomicReference exceptionRef = new AtomicReference(); SniHandler handler = new SniHandler(mapping); - final EmbeddedChannel ch = new EmbeddedChannel(handler); + final EmbeddedChannel ch = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + exceptionRef.compareAndSet(null, cause); + } + }); try { // hex dump of a client hello packet, which contains an invalid hostname "CHAT4。LEANCLOUD。CN" @@ -283,13 +288,11 @@ public void testNonAsciiServerNameParsing(SslProvider provider) throws Exception // Decode should fail because of the badly encoded "HostName" string in the SNI extension // that isn't ASCII as per RFC 6066 - https://tools.ietf.org/html/rfc6066#page-6 ch.writeInbound(Unpooled.wrappedBuffer(StringUtil.decodeHexDump(tlsHandshakeMessageHex1))); + ch.writeInbound(Unpooled.wrappedBuffer(StringUtil.decodeHexDump(tlsHandshakeMessageHex))); - assertThrows(DecoderException.class, new Executable() { - @Override - public void execute() throws Throwable { - ch.writeInbound(Unpooled.wrappedBuffer(StringUtil.decodeHexDump(tlsHandshakeMessageHex))); - } - }); + Throwable cause = exceptionRef.get(); + assertNotNull(cause); + assertInstanceOf(DecoderException.class, cause); } finally { ch.finishAndReleaseAll(); } @@ -394,10 +397,19 @@ public void testMajorVersionNot3(SslProvider provider) throws Exception { @ParameterizedTest(name = "{index}: sslProvider={0}") @MethodSource("data") - public void testSniWithApnHandler(SslProvider provider) throws Exception { - SslContext nettyContext = makeSslContext(provider, true); - SslContext sniContext = makeSslContext(provider, true); - final SslContext clientContext = makeSslClientContext(provider, true); + public void testSniWithAlpnHandler(SslProvider provider) throws Exception { + SslContext nettyContext = null; + SslContext sniContext = null; + final SslContext clientContext; + try { + nettyContext = makeSslContext(provider, true); + sniContext = makeSslContext(provider, true); + clientContext = makeSslClientContext(provider, true); + } catch (Exception e) { + ReferenceCountUtil.safeRelease(nettyContext); + ReferenceCountUtil.safeRelease(sniContext); + throw e; + } try { final AtomicBoolean serverApnCtx = new AtomicBoolean(false); final AtomicBoolean clientApnCtx = new AtomicBoolean(false); @@ -455,8 +467,7 @@ protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); - ChannelFuture ccf = cb.connect(serverChannel.localAddress()); - assertTrue(ccf.awaitUninterruptibly().isSuccess()); + ChannelFuture ccf = cb.connect(serverChannel.localAddress()).sync(); clientChannel = ccf.channel(); assertTrue(serverApnDoneLatch.await(5, TimeUnit.SECONDS)); @@ -472,7 +483,7 @@ protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { if (clientChannel != null) { clientChannel.close().sync(); } - group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS); + group.shutdownGracefully(100, 5000, TimeUnit.MILLISECONDS).sync(); } } finally { releaseAll(clientContext, nettyContext, sniContext); diff --git a/handler/src/test/java/io/netty/handler/ssl/SslClientHelloHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslClientHelloHandlerTest.java new file mode 100644 index 00000000000..f638cd62fd5 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/SslClientHelloHandlerTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.DecoderException; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.ImmediateEventExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.StringUtil; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class SslClientHelloHandlerTest { + + // ClientHello carrying SNI hostname "chat4.leancloud.cn", borrowed from SniHandlerTest. + private static final String TLS_CLIENT_HELLO_HEX_PART1 = "16030100"; + private static final String TLS_CLIENT_HELLO_HEX_PART2 = + "c6010000c20303bb0855d66532c05a0ef784f7c384feeafa68b3" + + "b655ac7288650d5eed4aa3fb52000038c02cc030009fcca9cca8ccaac02b" + + "c02f009ec024c028006bc023c0270067c00ac0140039c009c0130033009d" + + "009c003d003c0035002f00ff010000610000001700150000124348415434" + + "2e4c45414e434c4f55442e434e000b000403000102000a000a0008001d00" + + "170019001800230000000d0020001e060106020603050105020503040104" + + "0204030301030203030201020202030016000000170000"; + + @Test + public void testSyncLookupCallbackExceptionFiredOnPipeline() { + final AtomicBoolean nullRetryOccurred = new AtomicBoolean(); + + AbstractSniHandler handler = new AbstractSniHandler() { + @Override + protected Future lookup(ChannelHandlerContext ctx, String hostname) { + if (hostname == null) { + nullRetryOccurred.set(true); + } + Promise promise = ImmediateEventExecutor.INSTANCE.newPromise(); + promise.setSuccess(new Object()); + return promise; + } + + @Override + protected void onLookupComplete(ChannelHandlerContext ctx, String hostname, + Future future) { + throw new RuntimeException("simulated user callback failure"); + } + }; + + final AtomicReference exceptionRef = new AtomicReference(); + EmbeddedChannel ch = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + exceptionRef.compareAndSet(null, cause); + } + }); + + try { + ch.writeInbound(Unpooled.wrappedBuffer(StringUtil.decodeHexDump(TLS_CLIENT_HELLO_HEX_PART1))); + ch.writeInbound(Unpooled.wrappedBuffer(StringUtil.decodeHexDump(TLS_CLIENT_HELLO_HEX_PART2))); + } finally { + ch.finishAndReleaseAll(); + } + + Throwable cause = exceptionRef.get(); + assertNotNull(cause); + assertInstanceOf(DecoderException.class, cause); + assertFalse(nullRetryOccurred.get(), "Expected no select(ctx, null) retry"); + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java b/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java index be0785195fa..17ee9686a6a 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java @@ -126,18 +126,18 @@ public void testContextFromManagersOpenssl() throws Exception { @Test public void testUnsupportedPrivateKeyFailsFastForServer() { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); testUnsupportedPrivateKeyFailsFast(true); } @Test public void testUnsupportedPrivateKeyFailsFastForClient() { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); testUnsupportedPrivateKeyFailsFast(false); } private static void testUnsupportedPrivateKeyFailsFast(boolean server) { - assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC()); String cert = "-----BEGIN CERTIFICATE-----\n" + "MIICODCCAY2gAwIBAgIEXKTrajAKBggqhkjOPQQDBDBUMQswCQYDVQQGEwJVUzEM\n" + "MAoGA1UECAwDTi9hMQwwCgYDVQQHDANOL2ExDDAKBgNVBAoMA04vYTEMMAoGA1UE\n" + diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index 3a5d16d3a46..226bcebfaee 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -1000,7 +1000,7 @@ public void testHandshakeWithExecutorJDK() throws Throwable { try { testHandshakeWithExecutor(executorService, SslProvider.JDK, false); } finally { - executorService.shutdown(); + assertTrue(executorService.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); } } @@ -1029,7 +1029,7 @@ public void testHandshakeWithExecutorOpenSsl() throws Throwable { try { testHandshakeWithExecutor(executorService, SslProvider.OPENSSL, false); } finally { - executorService.shutdown(); + assertTrue(executorService.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); } } @@ -1054,7 +1054,7 @@ public void testHandshakeMTLSWithExecutorJDK() throws Throwable { try { testHandshakeWithExecutor(executorService, SslProvider.JDK, true); } finally { - executorService.shutdown(); + assertTrue(executorService.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); } } @@ -1083,7 +1083,7 @@ public void testHandshakeMTLSWithExecutorOpenSsl() throws Throwable { try { testHandshakeWithExecutor(executorService, SslProvider.OPENSSL, true); } finally { - executorService.shutdown(); + assertTrue(executorService.shutdownAndAwaitTermination(5, TimeUnit.SECONDS)); } } @@ -1574,7 +1574,8 @@ public void testHandshakeFailureCipherMissmatchTLSv12OpenSsl() throws Exception public void testHandshakeFailureCipherMissmatchTLSv13OpenSsl() throws Exception { OpenSsl.ensureAvailability(); assumeTrue(SslProvider.isTlsv13Supported(SslProvider.OPENSSL)); - assumeFalse(OpenSsl.isBoringSSL(), "BoringSSL does not support setting ciphers for TLSv1.3 explicit"); + assumeFalse(OpenSsl.isBoringSSL() || OpenSsl.isAWSLC(), + "Provider does not support setting ciphers for TLSv1.3 explicitly"); testHandshakeFailureCipherMissmatch(SslProvider.OPENSSL, true); } diff --git a/handler/src/test/java/io/netty/handler/ssl/util/LazyX509CertificateTest.java b/handler/src/test/java/io/netty/handler/ssl/util/LazyX509CertificateTest.java index ca678598577..59a28c69aba 100644 --- a/handler/src/test/java/io/netty/handler/ssl/util/LazyX509CertificateTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/util/LazyX509CertificateTest.java @@ -21,9 +21,15 @@ import java.io.ByteArrayInputStream; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.function.Supplier; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; public class LazyX509CertificateTest { @@ -79,7 +85,29 @@ public void testLazyX509Certificate() throws Exception { assertArrayEquals(x509Certificate.getKeyUsage(), lazyX509Certificate.getKeyUsage()); assertEquals(x509Certificate.getExtendedKeyUsage(), lazyX509Certificate.getExtendedKeyUsage()); assertEquals(x509Certificate.getBasicConstraints(), lazyX509Certificate.getBasicConstraints()); - assertEquals(x509Certificate.getSubjectAlternativeNames(), lazyX509Certificate.getSubjectAlternativeNames()); - assertEquals(x509Certificate.getIssuerAlternativeNames(), lazyX509Certificate.getIssuerAlternativeNames()); + assertEqualSans(x509Certificate.getSubjectAlternativeNames(), lazyX509Certificate.getSubjectAlternativeNames()); + assertEqualSans(x509Certificate.getIssuerAlternativeNames(), lazyX509Certificate.getIssuerAlternativeNames()); + } + + private static void assertEqualSans(Collection> expectedSans, Collection> actualSans) { + String errMsgSans = expectedSans + " != " + actualSans; + if (expectedSans == null) { + assertNull(actualSans, errMsgSans); + return; + } + assertEquals(expectedSans.size(), actualSans.size(), errMsgSans); + Iterator> expectItr = expectedSans.iterator(); + Iterator> actualItr = actualSans.iterator(); + while (expectItr.hasNext() && actualItr.hasNext()) { + List expectedSan = expectItr.next(); + List actualSan = actualItr.next(); + String errMsgSan = expectedSan + " != " + actualSan; + assertEquals(2, expectedSan.size(), errMsgSan); + assertEquals(2, actualSan.size(), errMsgSan); + assertEquals(expectedSan.get(0), actualSan.get(0), errMsgSan); + assertEquals(expectedSan.get(1), actualSan.get(1), errMsgSan); + } + assertFalse(expectItr.hasNext(), errMsgSans); + assertFalse(actualItr.hasNext(), errMsgSans); } } diff --git a/handler/src/test/java/io/netty/handler/ssl/util/SimpleTrustManagerFactoryTest.java b/handler/src/test/java/io/netty/handler/ssl/util/SimpleTrustManagerFactoryTest.java new file mode 100644 index 00000000000..01999460904 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/util/SimpleTrustManagerFactoryTest.java @@ -0,0 +1,71 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl.util; + +import io.netty.util.internal.EmptyArrays; +import org.junit.jupiter.api.Test; + +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; +import java.security.KeyStore; +import java.security.cert.X509Certificate; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; + +public class SimpleTrustManagerFactoryTest { + + @Test + public void testNotWrap() { + final X509TrustManager tm = new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) { + // NOOP + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) { + // NOOP + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + }; + SimpleTrustManagerFactory factory = new SimpleTrustManagerFactory() { + @Override + protected void engineInit(KeyStore keyStore) { + // NOOP + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { + // NOOP + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] { tm }; + } + }; + + TrustManager[] tms = factory.getTrustManagers(); + assertEquals(1, tms.length); + assertSame(tm, tms[0]); + } +} diff --git a/microbench/pom.xml b/microbench/pom.xml index 8850eb00316..c066b426aa3 100644 --- a/microbench/pom.xml +++ b/microbench/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-microbench @@ -223,6 +223,13 @@ **/Http2FrameWriterBenchmark.java + + + org.openjdk.jmh + jmh-generator-annprocess + ${jmh.version} + + diff --git a/microbench/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBufBenchmark.java b/microbench/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBufBenchmark.java index 7ae7f56c0e5..6eb10982ed0 100644 --- a/microbench/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBufBenchmark.java +++ b/microbench/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBufBenchmark.java @@ -35,6 +35,7 @@ public class AbstractReferenceCountedByteBufBenchmark extends AbstractMicrobenchmark { @Param({ + "0", "1", "10", "100", @@ -60,10 +61,16 @@ public void tearDown() { @OutputTimeUnit(TimeUnit.NANOSECONDS) public boolean retainReleaseUncontended() { buf.retain(); - Blackhole.consumeCPU(delay); + delay(); return buf.release(); } + private void delay() { + if (delay > 0) { + Blackhole.consumeCPU(delay); + } + } + @Benchmark @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) @@ -71,7 +78,7 @@ public boolean retainReleaseUncontended() { public boolean createUseAndRelease(Blackhole useBuffer) { ByteBuf unpooled = Unpooled.buffer(1); useBuffer.consume(unpooled); - Blackhole.consumeCPU(delay); + delay(); return unpooled.release(); } @@ -81,7 +88,7 @@ public boolean createUseAndRelease(Blackhole useBuffer) { @GroupThreads(4) public boolean retainReleaseContended() { buf.retain(); - Blackhole.consumeCPU(delay); + delay(); return buf.release(); } } diff --git a/microbench/src/main/java/io/netty/handler/codec/http/HttpRequestEncoderInsertBenchmark.java b/microbench/src/main/java/io/netty/handler/codec/http/HttpRequestEncoderInsertBenchmark.java index 7d7df184c60..18ab53b24c4 100644 --- a/microbench/src/main/java/io/netty/handler/codec/http/HttpRequestEncoderInsertBenchmark.java +++ b/microbench/src/main/java/io/netty/handler/codec/http/HttpRequestEncoderInsertBenchmark.java @@ -23,27 +23,98 @@ import io.netty.util.CharsetUtil; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Param; import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.annotations.Warmup; -import static io.netty.handler.codec.http.HttpConstants.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +import static io.netty.handler.codec.http.HttpConstants.CR; +import static io.netty.handler.codec.http.HttpConstants.LF; +import static io.netty.handler.codec.http.HttpConstants.SP; @State(Scope.Benchmark) @Warmup(iterations = 10) @Measurement(iterations = 20) public class HttpRequestEncoderInsertBenchmark extends AbstractMicrobenchmark { - private final String uri = "http://localhost?eventType=CRITICAL&from=0&to=1497437160327&limit=10&offset=0"; + private static final String[] PARAMS = { + "eventType=CRITICAL", + "from=0", + "to=1497437160327", + "limit=10", + "offset=0" + }; + @Param({"1024", "128000"}) + private int samples; + + private String[] uris; + private int index; private final OldHttpRequestEncoder encoderOld = new OldHttpRequestEncoder(); private final HttpRequestEncoder encoderNew = new HttpRequestEncoder(); + @Setup + public void setup() { + List permutations = new ArrayList(); + permute(PARAMS.clone(), 0, permutations); + + String[] allCombinations = new String[permutations.size()]; + String base = "http://localhost?"; + for (int i = 0; i < permutations.size(); i++) { + StringBuilder sb = new StringBuilder(base); + String[] p = permutations.get(i); + for (int j = 0; j < p.length; j++) { + if (j != 0) { + sb.append('&'); + } + sb.append(p[j]); + } + allCombinations[i] = sb.toString(); + } + uris = new String[samples]; + Random rand = new Random(42); + for (int i = 0; i < uris.length; i++) { + uris[i] = allCombinations[rand.nextInt(allCombinations.length)]; + } + index = 0; + } + + private static void permute(String[] arr, int start, List out) { + if (start == arr.length - 1) { + out.add(Arrays.copyOf(arr, arr.length)); + return; + } + for (int i = start; i < arr.length; i++) { + swap(arr, start, i); + permute(arr, start + 1, out); + swap(arr, start, i); + } + } + + private static void swap(String[] a, int i, int j) { + String t = a[i]; + a[i] = a[j]; + a[j] = t; + } + + private String nextUri() { + if (index >= uris.length) { + index = 0; + } + return uris[index++]; + } + @Benchmark public ByteBuf oldEncoder() throws Exception { ByteBuf buffer = Unpooled.buffer(100); try { encoderOld.encodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.HTTP_1_1, - HttpMethod.GET, uri)); + HttpMethod.GET, nextUri())); return buffer; } finally { buffer.release(); @@ -55,7 +126,7 @@ public ByteBuf newEncoder() throws Exception { ByteBuf buffer = Unpooled.buffer(100); try { encoderNew.encodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.HTTP_1_1, - HttpMethod.GET, uri)); + HttpMethod.GET, nextUri())); return buffer; } finally { buffer.release(); diff --git a/microbench/src/main/java/io/netty/handler/codec/http2/Http2RequestTargetConversionBenchmark.java b/microbench/src/main/java/io/netty/handler/codec/http2/Http2RequestTargetConversionBenchmark.java new file mode 100644 index 00000000000..f07a8fba280 --- /dev/null +++ b/microbench/src/main/java/io/netty/handler/codec/http2/Http2RequestTargetConversionBenchmark.java @@ -0,0 +1,151 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpScheme; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.microbench.util.AbstractMicrobenchmark; +import io.netty.util.AsciiString; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.net.URI; +import java.util.concurrent.TimeUnit; + +import static io.netty.util.internal.StringUtil.isNullOrEmpty; + +@State(Scope.Benchmark) +@Warmup(iterations = 5, time = 200, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 5, time = 200, timeUnit = TimeUnit.MILLISECONDS) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class Http2RequestTargetConversionBenchmark extends AbstractMicrobenchmark { + + @Param + public RequestTargetType requestTargetType; + + private HttpRequest request; + + @Setup + public void setup() { + request = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, + HttpMethod.GET, + requestTargetType.requestTarget, + new DefaultHttpHeaders(), + false); + request.headers().set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), HttpScheme.HTTP.name()); + } + + @Benchmark + public void newConversion(Blackhole bh) { + bh.consume(HttpConversionUtil.toHttp2Headers(request, false)); + } + + @Benchmark + public void oldUriConversion(Blackhole bh) { + bh.consume(oldToHttp2Headers(request)); + } + + public enum RequestTargetType { + ORIGIN("/orders/123/items?expand=details"), + ABSOLUTE("http://example.com/orders/123/items?expand=details#section"), + ABSOLUTE_NO_PATH("http://example.com?next=/home#section"), + ABSOLUTE_NO_AUTHORITY("http://?x=1#frag"), + SCHEME_ONLY_ABSOLUTE_PATH("http:/orders/123/items?expand=details"); + + final String requestTarget; + + RequestTargetType(String requestTarget) { + this.requestTarget = requestTarget; + } + } + + private static Http2Headers oldToHttp2Headers(final HttpRequest request) { + HttpHeaders inHeaders = request.headers(); + Http2Headers out = new DefaultHttp2Headers(false, inHeaders.size()); + String host = inHeaders.getAsString(HttpHeaderNames.HOST); + if (HttpUtil.isOriginForm(request.uri()) || HttpUtil.isAsteriskForm(request.uri())) { + out.path(new AsciiString(request.uri())); + oldSetHttp2Scheme(inHeaders, URI.create(""), out); + } else { + URI requestTargetUri = URI.create(request.uri()); + out.path(oldToHttp2Path(requestTargetUri)); + host = isNullOrEmpty(host) ? requestTargetUri.getAuthority() : host; + oldSetHttp2Scheme(inHeaders, requestTargetUri, out); + } + HttpConversionUtil.setHttp2Authority(host, out); + out.method(request.method().asciiName()); + HttpConversionUtil.toHttp2Headers(inHeaders, out); + return out; + } + + private static AsciiString oldToHttp2Path(final URI uri) { + StringBuilder pathBuilder = new StringBuilder(); + if (!isNullOrEmpty(uri.getRawPath())) { + pathBuilder.append(uri.getRawPath()); + } + if (!isNullOrEmpty(uri.getRawQuery())) { + pathBuilder.append('?'); + pathBuilder.append(uri.getRawQuery()); + } + if (!isNullOrEmpty(uri.getRawFragment())) { + pathBuilder.append('#'); + pathBuilder.append(uri.getRawFragment()); + } + String path = pathBuilder.toString(); + return path.isEmpty() ? new AsciiString("/") : new AsciiString(path); + } + + private static void oldSetHttp2Scheme(final HttpHeaders in, final URI uri, final Http2Headers out) { + String value = uri.getScheme(); + if (!isNullOrEmpty(value)) { + out.scheme(new AsciiString(value)); + return; + } + + CharSequence cValue = in.get(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text()); + if (cValue != null) { + out.scheme(AsciiString.of(cValue)); + return; + } + + if (uri.getPort() == HttpScheme.HTTPS.port()) { + out.scheme(HttpScheme.HTTPS.name()); + } else if (uri.getPort() == HttpScheme.HTTP.port()) { + out.scheme(HttpScheme.HTTP.name()); + } else { + throw new IllegalArgumentException( + ":scheme must be specified. see https://tools.ietf.org/html/rfc7540#section-8.1.2.3"); + } + } +} diff --git a/microbench/src/main/java/io/netty/microbench/http/HttpChunkedRequestResponseBenchmark.java b/microbench/src/main/java/io/netty/microbench/http/HttpChunkedRequestResponseBenchmark.java new file mode 100644 index 00000000000..365decd1d7f --- /dev/null +++ b/microbench/src/main/java/io/netty/microbench/http/HttpChunkedRequestResponseBenchmark.java @@ -0,0 +1,114 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.microbench.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.microbench.util.AbstractMicrobenchmark; +import io.netty.util.ReferenceCountUtil; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import static io.netty.handler.codec.http.HttpConstants.CR; +import static io.netty.handler.codec.http.HttpConstants.LF; + +@State(Scope.Thread) +@Warmup(iterations = 10, time = 1) +@Measurement(iterations = 10, time = 1) +public class HttpChunkedRequestResponseBenchmark extends AbstractMicrobenchmark { + private static final int CRLF_SHORT = (CR << 8) + LF; + + ByteBuf POST; + int readerIndex; + int writeIndex; + EmbeddedChannel nettyChannel; + + @Setup + public void setup() { + HttpRequestDecoder httpRequestDecoder = new HttpRequestDecoder( + HttpRequestDecoder.DEFAULT_MAX_INITIAL_LINE_LENGTH, HttpRequestDecoder.DEFAULT_MAX_HEADER_SIZE, + HttpRequestDecoder.DEFAULT_MAX_CHUNK_SIZE, false); + ChannelInboundHandlerAdapter inboundHandlerAdapter = new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object o) { + // this is saving a slow type check on LastHttpContent vs HttpRequest + try { + if (o == LastHttpContent.EMPTY_LAST_CONTENT) { + writeResponse(ctx); + } + } finally { + ReferenceCountUtil.release(o); + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + ctx.flush(); + } + + private void writeResponse(ChannelHandlerContext ctx) { + ByteBuf buffer = ctx.alloc().buffer(); + // Build the response object. + ByteBufUtil.writeAscii(buffer, "HTTP/1.1 200 OK\r\n"); + ByteBufUtil.writeAscii(buffer, "Content-Length: 0\r\n\r\n"); + ctx.write(buffer, ctx.voidPromise()); + } + }; + nettyChannel = new EmbeddedChannel(httpRequestDecoder, inboundHandlerAdapter); + + ByteBuf buffer = Unpooled.buffer(); + ByteBufUtil.writeAscii(buffer, "POST / HTTP/1.1\r\n"); + ByteBufUtil.writeAscii(buffer, "Content-Type: text/plain\r\n"); + ByteBufUtil.writeAscii(buffer, "Transfer-Encoding: chunked\r\n\r\n"); + ByteBufUtil.writeAscii(buffer, Integer.toHexString(43) + "\r\n"); + buffer.writeZero(43); + buffer.writeShort(CRLF_SHORT); + ByteBufUtil.writeAscii(buffer, Integer.toHexString(18) + + ";extension=kjhkasdhfiushdksjfnskdjfbskdjfbskjdfb\r\n"); + buffer.writeZero(18); + buffer.writeShort(CRLF_SHORT); + ByteBufUtil.writeAscii(buffer, Integer.toHexString(29) + + ";a=12938746238;b=\"lkjkjhskdfhsdkjh\\\"kjshdflkjhdskjhifuwehwi\";c=lkjdshfkjshdiufh\r\n"); + buffer.writeZero(29); + buffer.writeShort(CRLF_SHORT); + ByteBufUtil.writeAscii(buffer, Integer.toHexString(9) + + ";A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A;A\r\n"); + buffer.writeZero(9); + buffer.writeShort(CRLF_SHORT); + ByteBufUtil.writeAscii(buffer, "0\r\n\r\n"); // Last empty chunk + POST = Unpooled.unreleasableBuffer(buffer); + readerIndex = POST.readerIndex(); + writeIndex = POST.writerIndex(); + } + + @Benchmark + public Object netty() { + POST.setIndex(readerIndex, writeIndex); + ByteBuf byteBuf = POST.retainedDuplicate(); + nettyChannel.writeInbound(byteBuf); + return nettyChannel.outboundMessages().poll(); + } +} diff --git a/microbench/src/main/java/io/netty/microbench/http/HttpRequestResponseBenchmark.java b/microbench/src/main/java/io/netty/microbench/http/HttpRequestResponseBenchmark.java index 0716a8fe2f7..54dba8e3ddd 100644 --- a/microbench/src/main/java/io/netty/microbench/http/HttpRequestResponseBenchmark.java +++ b/microbench/src/main/java/io/netty/microbench/http/HttpRequestResponseBenchmark.java @@ -68,7 +68,7 @@ public class HttpRequestResponseBenchmark extends AbstractMicrobenchmark { static class Alloc implements ByteBufAllocator { - private final ByteBuf buf = Unpooled.buffer(); + private final ByteBuf buf = Unpooled.buffer(512); private final int capacity = buf.capacity(); @Override @@ -82,7 +82,8 @@ public ByteBuf buffer(int initialCapacity) { if (initialCapacity <= capacity) { return buffer(); } else { - throw new IllegalArgumentException(); + throw new IllegalArgumentException( + "initialCapacity " + initialCapacity + " is greater than capacity " + capacity); } } @@ -91,7 +92,8 @@ public ByteBuf buffer(int initialCapacity, int maxCapacity) { if (initialCapacity <= capacity) { return buffer(); } else { - throw new IllegalArgumentException(); + throw new IllegalArgumentException( + "initialCapacity " + initialCapacity + " is greater than capacity " + capacity); } } diff --git a/microbench/src/main/java/io/netty/microbench/http/HttpUtilBenchmark.java b/microbench/src/main/java/io/netty/microbench/http/HttpUtilBenchmark.java new file mode 100644 index 00000000000..80b2f1ec736 --- /dev/null +++ b/microbench/src/main/java/io/netty/microbench/http/HttpUtilBenchmark.java @@ -0,0 +1,41 @@ +/* + * Copyright 2025 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.microbench.http; + +import io.netty.handler.codec.http.HttpUtil; +import io.netty.microbench.util.AbstractMicrobenchmark; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.concurrent.TimeUnit; + +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@BenchmarkMode(Mode.AverageTime) +@Warmup(iterations = 10, time = 1) +@Measurement(iterations = 10, time = 1) +public class HttpUtilBenchmark extends AbstractMicrobenchmark { + private static final String uri = "https://github.com/netty/netty/blob/893508ce62a7f90464f8e4bf2ac28ecc73ce6608/" + + "handler/src/main/java/io/netty/handler/ssl/util/BouncyCastleSelfSignedCertGenerator.java"; + + @Benchmark + public boolean checkIsEncodingSafeUri() { + return HttpUtil.isEncodingSafeStartLineToken(uri); + } +} diff --git a/pom.xml b/pom.xml index d32129b8bd5..4dfabece678 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ io.netty netty-parent pom - 4.1.128.1.dse + 4.1.135.1.dse Netty https://netty.io/ @@ -53,7 +53,7 @@ https://github.com/netty/netty scm:git:git://github.com/netty/netty.git scm:git:ssh://git@github.com/netty/netty.git - netty-4.1.128.Final + netty-4.1.135.Final @@ -680,7 +680,7 @@ boringssl-snapshot netty-tcnative-boringssl-static - 2.0.75.Final-SNAPSHOT + 2.0.78.Final-SNAPSHOT ${os.detected.classifier} @@ -707,6 +707,9 @@ noPrintGC + + true + -D_ @@ -828,7 +831,7 @@ fedora,suse,arch netty-tcnative - 2.0.74.Final + 2.0.77.Final ${os.detected.classifier} org.conscrypt conscrypt-openjdk-uber @@ -844,9 +847,10 @@ ${os.detected.name}-${os.detected.arch} ${project.basedir}/../common/src/test/resources/logback-test.xml warn - 2.17.2 + 2.25.3 3.0.0 5.12.1 + 0.30.0 false ${java.home} ${testJavaHome}/bin/java @@ -854,7 +858,7 @@ false false 19.3.6 - 1.16.0 + 1.23.0 true false @@ -1014,7 +1018,7 @@ org.bouncycastle bcpkix-jdk15on - 1.69 + 1.70 compile true @@ -1026,7 +1030,7 @@ org.bouncycastle bcprov-jdk15on - 1.69 + 1.70 compile true @@ -1037,7 +1041,7 @@ org.bouncycastle bctls-jdk15on - 1.69 + 1.70 compile true @@ -1056,12 +1060,12 @@ com.ning compress-lzf - 1.0.3 + 1.2.0 - org.lz4 + at.yawk.lz4 lz4-java - 1.8.0 + 1.10.1 com.github.jponge @@ -1225,6 +1229,12 @@ ${junit.version} test + + com.code-intelligence + jazzer-junit + ${jazzer.version} + test + ${project.groupId} netty-build-common @@ -1234,13 +1244,13 @@ org.assertj assertj-core - 3.18.0 + 3.27.7 test org.mockito mockito-core - 2.18.3 + 4.11.0 test @@ -1288,7 +1298,7 @@ org.apache.commons commons-compress - 1.26.0 + 1.28.0 test @@ -1296,7 +1306,7 @@ commons-io commons-io - 2.14.0 + 2.20.0 test @@ -1335,7 +1345,7 @@ io.projectreactor.tools blockhound - 1.0.14.RELEASE + 1.0.16.RELEASE @@ -1589,6 +1599,22 @@ 7 SETTINGS_ENABLE_CONNECT_PROTOCOL was added to the standard HTTP/2 settings. + + true + java.annotation.removed + method void io.netty.channel.ChannelInboundHandlerAdapter::channelInactive(io.netty.channel.ChannelHandlerContext) throws java.lang.Exception @ io.netty.handler.codec.redis.RedisArrayAggregator + method void io.netty.handler.codec.redis.RedisArrayAggregator::channelInactive(io.netty.channel.ChannelHandlerContext) throws java.lang.Exception + @io.netty.channel.ChannelHandlerMask.Skip + Change is harmless for compatibility. Needed for a security fix. + + + true + java.annotation.removed + method void io.netty.channel.ChannelInboundHandlerAdapter::exceptionCaught(io.netty.channel.ChannelHandlerContext, java.lang.Throwable) throws java.lang.Exception @ io.netty.handler.codec.sctp.SctpMessageCompletionHandler + method void io.netty.handler.codec.sctp.SctpMessageCompletionHandler::exceptionCaught(io.netty.channel.ChannelHandlerContext, java.lang.Throwable) throws java.lang.Exception + @io.netty.channel.ChannelHandlerMask.Skip + Change is harmless for compatibility. Needed for a security fix. + diff --git a/resolver-dns-classes-macos/pom.xml b/resolver-dns-classes-macos/pom.xml index 230036b95b1..e5dde986095 100644 --- a/resolver-dns-classes-macos/pom.xml +++ b/resolver-dns-classes-macos/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-resolver-dns-classes-macos diff --git a/resolver-dns-native-macos/pom.xml b/resolver-dns-native-macos/pom.xml index dfa467c1f1e..b628d9de465 100644 --- a/resolver-dns-native-macos/pom.xml +++ b/resolver-dns-native-macos/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-resolver-dns-native-macos diff --git a/resolver-dns-native-macos/src/main/c/netty_resolver_dns_macos.c b/resolver-dns-native-macos/src/main/c/netty_resolver_dns_macos.c index 8e72ca9e898..4a72d87ac88 100644 --- a/resolver-dns-native-macos/src/main/c/netty_resolver_dns_macos.c +++ b/resolver-dns-native-macos/src/main/c/netty_resolver_dns_macos.c @@ -144,13 +144,23 @@ static jobjectArray netty_resolver_dns_macos_resolvers(JNIEnv* env, jclass clazz static JNINativeMethod* createDynamicMethodsTable(const char* packagePrefix) { JNINativeMethod* dynamicMethods = malloc(sizeof(JNINativeMethod) * 1); - + if (dynamicMethods == NULL) { + return NULL; + } char* dynamicTypeName = netty_jni_util_prepend(packagePrefix, "io/netty/resolver/dns/macos/DnsResolver;"); + if (dynamicTypeName == NULL) { + free(dynamicMethods); + return NULL; + } JNINativeMethod* dynamicMethod = &dynamicMethods[0]; dynamicMethod->name = "resolvers"; dynamicMethod->signature = netty_jni_util_prepend("()[L", dynamicTypeName); - dynamicMethod->fnPtr = (void *) netty_resolver_dns_macos_resolvers; free(dynamicTypeName); + if (dynamicMethod->signature == NULL) { + free(dynamicMethods); + return NULL; + } + dynamicMethod->fnPtr = (void *) netty_resolver_dns_macos_resolvers; return dynamicMethods; } diff --git a/resolver-dns/pom.xml b/resolver-dns/pom.xml index addbb1dae09..088dec73c5f 100644 --- a/resolver-dns/pom.xml +++ b/resolver-dns/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-resolver-dns diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index a5cd04de147..5df293112cd 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -83,6 +83,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -102,8 +103,8 @@ public class DnsNameResolver extends InetNameResolver { private static final InternalLogger logger = InternalLoggerFactory.getInstance(DnsNameResolver.class); private static final String LOCALHOST = "localhost"; + private static final String DOT_LOCALHOST = '.' + LOCALHOST; private static final String WINDOWS_HOST_NAME; - private static final InetAddress LOCALHOST_ADDRESS; private static final DnsRecord[] EMPTY_ADDITIONALS = new DnsRecord[0]; private static final DnsRecordType[] IPV4_ONLY_RESOLVED_RECORD_TYPES = {DnsRecordType.A}; @@ -136,18 +137,14 @@ public boolean isSharable() { static { if (NetUtil.isIpV4StackPreferred() || !anyInterfaceSupportsIpV6()) { DEFAULT_RESOLVE_ADDRESS_TYPES = ResolvedAddressTypes.IPV4_ONLY; - LOCALHOST_ADDRESS = NetUtil.LOCALHOST4; } else { if (NetUtil.isIpV6AddressesPreferred()) { DEFAULT_RESOLVE_ADDRESS_TYPES = ResolvedAddressTypes.IPV6_PREFERRED; - LOCALHOST_ADDRESS = NetUtil.LOCALHOST6; } else { DEFAULT_RESOLVE_ADDRESS_TYPES = ResolvedAddressTypes.IPV4_PREFERRED; - LOCALHOST_ADDRESS = NetUtil.LOCALHOST4; } } logger.debug("Default ResolvedAddressTypes: {}", DEFAULT_RESOLVE_ADDRESS_TYPES); - logger.debug("Localhost address: {}", LOCALHOST_ADDRESS); String hostName; try { @@ -701,7 +698,7 @@ private InetAddress resolveHostsFileEntry(String hostname) { return null; } InetAddress address = hostsFileEntriesResolver.address(hostname, resolvedAddressTypes); - return address == null && isLocalWindowsHost(hostname) ? LOCALHOST_ADDRESS : address; + return address == null && isLocalHostAddress(hostname)? getLocalHostAddress() : address; } private List resolveHostsFileEntries(String hostname) { @@ -714,23 +711,54 @@ private List resolveHostsFileEntries(String hostname) { .addresses(hostname, resolvedAddressTypes); } else { InetAddress address = hostsFileEntriesResolver.address(hostname, resolvedAddressTypes); - addresses = address != null ? Collections.singletonList(address) : null; + addresses = address != null? Collections.singletonList(address) : null; } - return addresses == null && isLocalWindowsHost(hostname) ? - Collections.singletonList(LOCALHOST_ADDRESS) : addresses; + return addresses == null && isLocalHostAddress(hostname)? + Collections.singletonList(getLocalHostAddress()) : addresses; } /** - * Checks whether the given hostname is the localhost/host (computer) name on Windows OS. - * Windows OS removed the localhost/host (computer) name information from the hosts file in the later versions - * and such hostname cannot be resolved from hosts file. - * See https://github.com/netty/netty/issues/5386 - * See https://github.com/netty/netty/issues/11142 + * Checks whether the given hostname refers to the current computer. This is the case for: + *
    + *
  • localhost.
  • + *
  • any domain within .localhost.
  • + *
  • the hostname of the local computer on Windows
  • + *
+ *

+ * According to RFC 6761 Section 6.3, localhost and subdomains of localhost should be resolved to the loopback + * address by name resolution libraries without querying DNS servers. The hostname of the local machine can usually + * be resolved from the hosts file, but on Windows, this is no longer possible. + * + * @param hostname the hostname that's being looked up + * @return true if the hostname should point to the loopback adress. False otherwise. + * @see Issue 5386 + * @see Issue 11142 + * @see Issue 16744 + * @see RFC 6761 */ - private static boolean isLocalWindowsHost(String hostname) { - return PlatformDependent.isWindows() && - (LOCALHOST.equalsIgnoreCase(hostname) || - (WINDOWS_HOST_NAME != null && WINDOWS_HOST_NAME.equalsIgnoreCase(hostname))); + private static boolean isLocalHostAddress(String hostname) { + if (PlatformDependent.isWindows() && WINDOWS_HOST_NAME != null && + WINDOWS_HOST_NAME.equalsIgnoreCase(hostname)) { + return true; + } + + if (hostname.endsWith(".")) { + hostname = hostname.substring(0, hostname.length() - 1); + } + return hostname.equalsIgnoreCase(LOCALHOST) || hostname.toLowerCase(Locale.US).endsWith(DOT_LOCALHOST); + } + + private InetAddress getLocalHostAddress() { + switch (resolvedAddressTypes) { + case IPV4_ONLY: + case IPV4_PREFERRED: + return NetUtil.LOCALHOST4; + case IPV6_ONLY: + case IPV6_PREFERRED: + return NetUtil.LOCALHOST6; + default: + throw new IllegalStateException("Unknown ResolvedAddressTypes " + resolvedAddressTypes); + } } /** diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverChannelStrategy.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverChannelStrategy.java index cf4a6fd366f..2c1964d7ba5 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverChannelStrategy.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverChannelStrategy.java @@ -22,6 +22,11 @@ public enum DnsNameResolverChannelStrategy { /** * Use the same underlying {@link io.netty.channel.Channel} for all queries produced by a single {@link DnsNameResolver} instance. + *

+ * As the same {@link io.netty.channel.Channel} is used for all queries we will also use the same source port + * for all of these. To minimize the risk of spoofing integrators should ideally use multiple resolvers randomly, + * so that there is source port randomization following the recommendations of + * RFC5452 Section 9.2. */ ChannelPerResolver, /** diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryIdSpace.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryIdSpace.java index 5cad6970ac4..64446f1fcb0 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryIdSpace.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryIdSpace.java @@ -16,9 +16,8 @@ package io.netty.resolver.dns; import io.netty.util.internal.MathUtil; -import io.netty.util.internal.PlatformDependent; -import java.util.Random; +import java.security.SecureRandom; /** * Special data-structure that will allow to retrieve the next query id to use, while still guarantee some sort @@ -34,15 +33,16 @@ final class DnsQueryIdSpace { // If there are other buckets left that have at least 500 usable ids we will drop an unused bucket. private static final int BUCKET_DROP_THRESHOLD = 500; private final DnsQueryIdRange[] idBuckets = new DnsQueryIdRange[BUCKETS]; + private final SecureRandom random = new SecureRandom(); DnsQueryIdSpace() { assert idBuckets.length == MathUtil.findNextPositivePowerOfTwo(idBuckets.length); // We start with 1 bucket. - idBuckets[0] = newBucket(0); + idBuckets[0] = newBucket(0, random); } - private static DnsQueryIdRange newBucket(int idBucketsIdx) { - return new DnsQueryIdRange(BUCKET_SIZE, idBucketsIdx * BUCKET_SIZE); + private static DnsQueryIdRange newBucket(int idBucketsIdx, SecureRandom random) { + return new DnsQueryIdRange(BUCKET_SIZE, idBucketsIdx * BUCKET_SIZE, random); } /** @@ -61,7 +61,7 @@ int nextId() { } } else if (freeIdx == -1 || // Let's make it somehow random which free slot is used. - PlatformDependent.threadLocalRandom().nextBoolean()) { + random.nextBoolean()) { // We have a slot that we can use to create a new bucket if we need to. freeIdx = bucketIdx; } @@ -72,7 +72,7 @@ int nextId() { } // We still have some slots free to store a new bucket. Let's do this now and use it to generate the next id. - DnsQueryIdRange bucket = newBucket(freeIdx); + DnsQueryIdRange bucket = newBucket(freeIdx, random); idBuckets[freeIdx] = bucket; int id = bucket.nextId(); assert id >= 0; @@ -141,11 +141,13 @@ private static final class DnsQueryIdRange { // Holds all possible ids which are stored as unsigned shorts private final short[] ids; private final int startId; + private final SecureRandom random; private int count; - DnsQueryIdRange(int bucketSize, int startId) { + DnsQueryIdRange(int bucketSize, int startId, SecureRandom random) { this.ids = new short[bucketSize]; this.startId = startId; + this.random = random; for (int v = startId; v < bucketSize + startId; v++) { pushId(v); } @@ -178,7 +180,6 @@ void pushId(int id) { } assert id <= startId + ids.length && id >= startId; // pick a slot for our index, and whatever was in that slot before will get moved to the tail. - Random random = PlatformDependent.threadLocalRandom(); int insertionPosition = random.nextInt(count + 1); short moveId = ids[insertionPosition]; short insertId = (short) id; diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java index f3b193f2e1a..844d28b23b0 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java @@ -641,8 +641,9 @@ private void onResponse(final DnsServerAddressStream nameServerAddrStream, final final DnsRecordType type = question.type(); if (type == DnsRecordType.CNAME) { - onResponseCNAME(question, buildAliasMap(envelope.content(), cnameCache(), parent.executor()), - queryLifecycleObserver, promise); + onResponseCNAME(question, + buildAliasMap(question.name(), envelope.content(), cnameCache(), parent.executor()), + queryLifecycleObserver, promise); return; } @@ -832,7 +833,7 @@ private void onExpectedResponse( // We often get a bunch of CNAMES as well when we asked for A/AAAA. final DnsResponse response = envelope.content(); - final Map cnames = buildAliasMap(response, cnameCache(), parent.executor()); + final Map cnames = buildAliasMap(question.name(), response, cnameCache(), parent.executor()); final int answerCount = response.count(DnsSection.ANSWER); boolean found = false; @@ -991,7 +992,8 @@ private void onResponseCNAME( } } - private static Map buildAliasMap(DnsResponse response, DnsCnameCache cache, EventLoop loop) { + private static Map buildAliasMap( + String queryName, DnsResponse response, DnsCnameCache cache, EventLoop loop) { final int answerCount = response.count(DnsSection.ANSWER); Map cnames = null; for (int i = 0; i < answerCount; i ++) { @@ -1022,7 +1024,13 @@ private static Map buildAliasMap(DnsResponse response, DnsCnameC String nameWithDot = hostnameWithDot(name); String mappingWithDot = hostnameWithDot(mapping); if (!nameWithDot.equalsIgnoreCase(mappingWithDot)) { - cache.cache(nameWithDot, mappingWithDot, r.timeToLive(), loop); + String queryNameWithDot = hostnameWithDot(queryName.toLowerCase(Locale.US)); + // Only cache the CNAME if the owner is in the bailiwick of the original query name. + boolean inBailiwick = nameWithDot.equals(queryNameWithDot) || + nameWithDot.endsWith("." + queryNameWithDot); + if (inBailiwick) { + cache.cache(nameWithDot, mappingWithDot, r.timeToLive(), loop); + } cnames.put(name, mapping); } } @@ -1409,7 +1417,7 @@ void handleWithoutAdditionals( } } - private static void cacheUnresolved( + private void cacheUnresolved( AuthoritativeNameServer server, AuthoritativeDnsServerCache authoritativeCache, EventLoop loop) { // We still want to cached the unresolved address server.address = InetSocketAddress.createUnresolved( @@ -1419,11 +1427,20 @@ private static void cacheUnresolved( cache(server, authoritativeCache, loop); } - private static void cache(AuthoritativeNameServer server, AuthoritativeDnsServerCache cache, EventLoop loop) { + private void cache(AuthoritativeNameServer server, AuthoritativeDnsServerCache cache, EventLoop loop) { // Cache NS record if not for a root server as we should never cache for root servers. - if (!server.isRootServer()) { - cache.cache(server.domainName, server.address, server.ttl, loop); + if (server.isRootServer()) { + return; + } + // Bailiwick check (RFC 2181 §5.4.1): only cache a nameserver entry when its zone + // equals the question name or is a subdomain of it. A server that is authoritative + // for a child zone must not be trusted to supply authoritative NS records for a + // parent zone, which would allow cache poisoning of the parent. + if (!server.domainName.equals(questionName) && + !server.domainName.endsWith("." + questionName)) { + return; } + cache.cache(server.domainName, server.address, server.ttl, loop); } /** diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java index cc7ed6a7180..5852f9f8e6f 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java @@ -68,6 +68,7 @@ import org.apache.directory.server.dns.store.RecordStore; import org.apache.mina.core.buffer.IoBuffer; import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -111,7 +112,9 @@ import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.function.Executable; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; import static io.netty.handler.codec.dns.DnsRecordType.A; import static io.netty.handler.codec.dns.DnsRecordType.AAAA; @@ -860,6 +863,46 @@ public void testResolveHostNameIpv6(DnsNameResolverChannelStrategy strategy) { testResolve0(strategy, ResolvedAddressTypes.IPV6_ONLY, NetUtil.LOCALHOST6, WINDOWS_HOST_NAME); } + private static List testResolveLocalhostWithoutDNSArgs() { + DnsNameResolverChannelStrategy[] strategies = DnsNameResolverChannelStrategy.values(); + List names = asList("localhost", "localhost.", "test.localhost", "TEsT.LOCalhost", "test.localhost."); + + List output = new ArrayList(); + for (DnsNameResolverChannelStrategy strategy : strategies) { + for (String name : names) { + output.add(new Object[] { strategy, ResolvedAddressTypes.IPV4_ONLY, NetUtil.LOCALHOST4, name }); + output.add(new Object[] { strategy, ResolvedAddressTypes.IPV6_ONLY, NetUtil.LOCALHOST6, name }); + } + } + + return output; + } + + @ParameterizedTest + @MethodSource("testResolveLocalhostWithoutDNSArgs") + public void testResolveLocalhostWithoutDNSOrHostsFile(DnsNameResolverChannelStrategy strategy, + ResolvedAddressTypes addressTypes, InetAddress expectedAddr, + String name) { + DnsNameResolver resolver = newResolver(strategy, addressTypes) + .hostsFileEntriesResolver(new HostsFileEntriesResolver() { + @Override + public InetAddress address(String inetHost, ResolvedAddressTypes resolvedAddressTypes) { + // The hosts file should not be required to resolve localhost addresses. + return null; + } + }) + .build(); + try { + InetAddress address = resolver.resolve(name).syncUninterruptibly().getNow(); + assertEquals(expectedAddr, address); + + // We are resolving the local address, so we shouldn't make any queries. + assertNoQueriesMade(resolver); + } finally { + resolver.close(); + } + } + @ParameterizedTest @EnumSource(DnsNameResolverChannelStrategy.class) public void testResolveNullIpv4(DnsNameResolverChannelStrategy strategy) { @@ -1520,10 +1563,8 @@ InetSocketAddress newRedirectServerAddress(InetAddress server) { assertNull(nsCache.cache.get("netty.io.")); DnsServerAddressStream entries = nsCache.cache.get("record.netty.io."); - // First address should be resolved (as we received a matching additional record), second is unresolved. - assertEquals(2, entries.size()); - assertFalse(entries.next().isUnresolved()); - assertTrue(entries.next().isUnresolved()); + // Should be null because of bailiwick check. + assertNull(entries); assertNull(nsCache.cache.get(hostname)); @@ -1533,27 +1574,15 @@ InetSocketAddress newRedirectServerAddress(InetAddress server) { observer = lifecycleObserverFactory.observers.poll(); assertNotNull(observer); assertTrue(lifecycleObserverFactory.observers.isEmpty()); - assertEquals(2, observer.events.size()); + assertEquals(4, observer.events.size()); writtenEvent1 = (QueryWrittenEvent) observer.events.poll(); - assertEquals(expectedDnsName, writtenEvent1.dnsServerAddress.getHostName()); - assertEquals(dnsServerAuthority.localAddress(), writtenEvent1.dnsServerAddress); - succeededEvent = (QuerySucceededEvent) observer.events.poll(); + QueryRedirectedEvent ev = (QueryRedirectedEvent) observer.events.poll(); - resolver.resolveAll(hostname2).syncUninterruptibly(); + assertInstanceOf(UnknownHostException.class, resolver.resolveAll(hostname2).await().cause()); - observer = lifecycleObserverFactory.observers.poll(); - assertNotNull(observer); - assertTrue(lifecycleObserverFactory.observers.isEmpty()); - assertEquals(2, observer.events.size()); - writtenEvent1 = (QueryWrittenEvent) observer.events.poll(); - assertEquals(expectedDnsName, writtenEvent1.dnsServerAddress.getHostName()); - assertEquals(dnsServerAuthority.localAddress(), writtenEvent1.dnsServerAddress); - succeededEvent = (QuerySucceededEvent) observer.events.poll(); - - // Check that it only queried the cache for record.netty.io. assertNull(nsCache.cacheHits.get("io.")); assertNull(nsCache.cacheHits.get("netty.io.")); - assertNotNull(nsCache.cacheHits.get("record.netty.io.")); + assertNull(nsCache.cacheHits.get("record.netty.io.")); assertNull(nsCache.cacheHits.get("some.record.netty.io.")); } } finally { @@ -1688,19 +1717,8 @@ InetSocketAddress newRedirectServerAddress(InetAddress server) { if (authoritativeDnsServerCache != NoopAuthoritativeDnsServerCache.INSTANCE) { DnsServerAddressStream cached = authoritativeDnsServerCache.get(domain + '.'); - assertEquals(2, cached.size()); - InetSocketAddress ns1Address = InetSocketAddress.createUnresolved( - ns1Name + '.', DefaultDnsServerAddressStreamProvider.DNS_PORT); - InetSocketAddress ns2Address = InetSocketAddress.createUnresolved( - ns2Name + '.', DefaultDnsServerAddressStreamProvider.DNS_PORT); - - if (invalidNsFirst) { - assertEquals(ns2Address, cached.next()); - assertEquals(ns1Address, cached.next()); - } else { - assertEquals(ns1Address, cached.next()); - assertEquals(ns2Address, cached.next()); - } + // We should not cache anything because of bailiwick check + assertNull(cached); } if (cache != NoopDnsCache.INSTANCE) { List ns1Cached = cache.get(ns1Name + '.', null); @@ -1844,7 +1862,9 @@ protected DnsServerAddressStream newRedirectDnsServerStream( DnsServerAddressStream redirected = redirectedRef.get(); assertNotNull(redirected); assertEquals(4, redirected.size()); - assertEquals(4, cached.size()); + + // We should not cache anything because of bailiwick check + assertEquals(0, cached.size()); if (reversed) { assertEquals(ns4Address, redirected.next()); @@ -1857,12 +1877,6 @@ protected DnsServerAddressStream newRedirectDnsServerStream( assertEquals(ns3Address, redirected.next()); assertEquals(ns4Address, redirected.next()); } - - // We should always have the same order in the cache. - assertEquals(ns1Address, cached.get(0)); - assertEquals(ns2Address, cached.get(1)); - assertEquals(ns3Address, cached.get(2)); - assertEquals(ns4Address, cached.get(3)); } finally { resolver.close(); group.shutdownGracefully(0, 0, TimeUnit.SECONDS); @@ -1981,7 +1995,8 @@ protected DnsServerAddressStream newRedirectDnsServerStream( DnsServerAddressStream redirected = redirectedRef.get(); assertNotNull(redirected); assertEquals(6, redirected.size()); - assertEquals(3, cached.size()); + // We should not cache because of bailiwick check. + assertEquals(0, cached.size()); // The redirected addresses should have been retrieven from the DnsCache if not resolved, so these are // fully resolved. @@ -1991,13 +2006,6 @@ protected DnsServerAddressStream newRedirectDnsServerStream( assertEquals(ns3Address, redirected.next()); assertEquals(ns4Address, redirected.next()); assertEquals(ns5Address, redirected.next()); - - // As this address was supplied as ADDITIONAL we should put it resolved into the cache. - assertEquals(ns0Address, cached.get(0)); - assertEquals(ns5Address, cached.get(1)); - - // We should have put the unresolved address in the AuthoritativeDnsServerCache (but only 1 time) - assertEquals(unresolved(ns1Address), cached.get(2)); } finally { resolver.close(); group.shutdownGracefully(0, 0, TimeUnit.SECONDS); @@ -2354,14 +2362,18 @@ public DnsCacheEntry cache(String hostname, DnsRecord[] additionals, Throwable c private static class RedirectingTestDnsServer extends TestDnsServer { private final String dnsAddress; - private final String domain; + private final Set domains; - RedirectingTestDnsServer(String domain, String dnsAddress) { - super(Collections.singleton(domain)); - this.domain = domain; + RedirectingTestDnsServer(Set domains, String dnsAddress) { + super(domains); + this.domains = domains; this.dnsAddress = dnsAddress; } + RedirectingTestDnsServer(String domain, String dnsAddress) { + this(Collections.singleton(domain), dnsAddress); + } + @Override protected DnsMessage filterMessage(DnsMessage message) { // Clear the answers as we want to add our own stuff to test dns redirects. @@ -2369,21 +2381,22 @@ protected DnsMessage filterMessage(DnsMessage message) { message.getAuthorityRecords().clear(); message.getAdditionalRecords().clear(); - String name = domain; - for (int i = 0 ;; i++) { - int idx = name.indexOf('.'); - if (idx <= 0) { - break; - } - name = name.substring(idx + 1); // skip the '.' as well. - String dnsName = "dns" + idx + '.' + domain; - message.getAuthorityRecords().add(newNsRecord(name, dnsName)); - message.getAdditionalRecords().add(newARecord(dnsName, i == 0 ? dnsAddress : "1.2.3." + idx)); + for (String domain : domains) { + String name = domain; + for (int i = 0 ;; i++) { + int idx = name.indexOf('.'); + if (idx <= 0) { + break; + } + name = name.substring(idx + 1); // skip the '.' as well. + String dnsName = "dns" + idx + '.' + domain; + message.getAuthorityRecords().add(newNsRecord(name, dnsName)); + message.getAdditionalRecords().add(newARecord(dnsName, i == 0 ? dnsAddress : "1.2.3." + idx)); - // Add an unresolved NS record (with no additionals as well) - message.getAuthorityRecords().add(newNsRecord(name, "unresolved." + dnsName)); + // Add an unresolved NS record (with no additionals as well) + message.getAuthorityRecords().add(newNsRecord(name, "unresolved." + dnsName)); + } } - return message; } } @@ -3000,6 +3013,80 @@ public boolean clear(String hostname) { } } + @ParameterizedTest + @EnumSource(DnsNameResolverChannelStrategy.class) + public void testCnameCacheBailiwick(DnsNameResolverChannelStrategy strategy) throws Exception { + final Map cache = new ConcurrentHashMap(); + + TestDnsServer dnsServer = new TestDnsServer(new RecordStore() { + @Override + public Set getRecords(QuestionRecord question) throws DnsException { + if ("x.netty.io".equals(question.getDomainName())) { + Set records = new HashSet(); + // Valid CNAME (in bailiwick of query) + records.add(new TestDnsServer.TestResourceRecord( + "x.netty.io", RecordType.CNAME, + Collections.singletonMap( + DnsAttribute.DOMAIN_NAME.toLowerCase(), "cname.netty.io"))); + // Invalid CNAME (out of bailiwick of query) + records.add(new TestDnsServer.TestResourceRecord( + "cname.netty.io", RecordType.CNAME, + Collections.singletonMap( + DnsAttribute.DOMAIN_NAME.toLowerCase(), "evil.com"))); + // Provide an A record to satisfy the resolution + records.add(new TestDnsServer.TestResourceRecord( + "evil.com", RecordType.A, + Collections.singletonMap( + DnsAttribute.IP_ADDRESS.toLowerCase(), "10.0.0.99"))); + return records; + } + return Collections.emptySet(); + } + }); + dnsServer.start(); + DnsNameResolver resolver = null; + try { + DnsNameResolverBuilder builder = newResolver(strategy) + .recursionDesired(true) + .resolvedAddressTypes(ResolvedAddressTypes.IPV4_ONLY) + .maxQueriesPerResolve(16) + .nameServerProvider(new SingletonDnsServerAddressStreamProvider(dnsServer.localAddress())) + .resolveCache(NoopDnsCache.INSTANCE) + .cnameCache(new DnsCnameCache() { + @Override + public String get(String hostname) { + return cache.get(hostname); + } + + @Override + public void cache(String hostname, String cname, long originalTtl, EventLoop loop) { + cache.put(hostname, cname); + } + + @Override + public void clear() { + } + + @Override + public boolean clear(String hostname) { + return false; + } + }); + resolver = builder.build(); + resolver.resolveAll("x.netty.io").syncUninterruptibly(); + + // The CNAME for x.netty.io should be cached because it was the queried name + assertEquals("cname.netty.io.", cache.get("x.netty.io.")); + // The CNAME for cname.netty.io should NOT be cached because it is out of bailiwick for x.netty.io + assertNull(cache.get("cname.netty.io.")); + } finally { + dnsServer.stop(); + if (resolver != null) { + resolver.close(); + } + } + } + @Test public void testInstanceWithNullPreferredAddressType() { new DnsNameResolver( @@ -3492,8 +3579,8 @@ private static ServerSocket startDnsServerAndCreateServerSocket(TestDnsServer dn serverSocket.close(); if (i == 10) { // We tried 10 times without success - throw new IllegalStateException( - "Unable to bind TestDnsServer and ServerSocket to the same address", e); + Assumptions.abort("Unable to bind TestDnsServer and ServerSocket to the same address: " + + e.getMessage()); } // We could not start the DnsServer which is most likely because the localAddress was already used, // let's retry diff --git a/resolver/pom.xml b/resolver/pom.xml index 4fca568e0f9..3d9df8464c3 100644 --- a/resolver/pom.xml +++ b/resolver/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-resolver diff --git a/testsuite-autobahn/pom.xml b/testsuite-autobahn/pom.xml index af5acaf68b3..94ef81e4d59 100644 --- a/testsuite-autobahn/pom.xml +++ b/testsuite-autobahn/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-testsuite-autobahn diff --git a/testsuite-http2/pom.xml b/testsuite-http2/pom.xml index cec0164e73e..8b812d817eb 100644 --- a/testsuite-http2/pom.xml +++ b/testsuite-http2/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-testsuite-http2 diff --git a/testsuite-native-image-client-runtime-init/pom.xml b/testsuite-native-image-client-runtime-init/pom.xml index 8fa230f0e6a..236ad055402 100644 --- a/testsuite-native-image-client-runtime-init/pom.xml +++ b/testsuite-native-image-client-runtime-init/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-testsuite-native-image-client-runtime-init diff --git a/testsuite-native-image-client/pom.xml b/testsuite-native-image-client/pom.xml index 563d24b1d6f..bbaa1c2142d 100644 --- a/testsuite-native-image-client/pom.xml +++ b/testsuite-native-image-client/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-testsuite-native-image-client diff --git a/testsuite-native-image/pom.xml b/testsuite-native-image/pom.xml index 2266f6eac51..d99726a9d93 100644 --- a/testsuite-native-image/pom.xml +++ b/testsuite-native-image/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-testsuite-native-image diff --git a/testsuite-native/pom.xml b/testsuite-native/pom.xml index 3b5d008617a..6ea7bd734a9 100644 --- a/testsuite-native/pom.xml +++ b/testsuite-native/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-testsuite-native diff --git a/testsuite-osgi/pom.xml b/testsuite-osgi/pom.xml index 0ce12fe0e03..4a02eab4385 100644 --- a/testsuite-osgi/pom.xml +++ b/testsuite-osgi/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-testsuite-osgi diff --git a/testsuite-shading/pom.xml b/testsuite-shading/pom.xml index 43f0355a568..d3f2b1553c1 100644 --- a/testsuite-shading/pom.xml +++ b/testsuite-shading/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-testsuite-shading diff --git a/testsuite/pom.xml b/testsuite/pom.xml index 289cccfee8f..0242f4c6bdf 100644 --- a/testsuite/pom.xml +++ b/testsuite/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-testsuite diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/CompositeBufferGatheringWriteTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/CompositeBufferGatheringWriteTest.java index 67bac1a1361..ce34256da82 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/CompositeBufferGatheringWriteTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/CompositeBufferGatheringWriteTest.java @@ -99,6 +99,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E if (!(cause instanceof IOException)) { clientReceived.set(cause); latch.countDown(); + } else if (!cause.getMessage().contains("reset")) { + logger.warn("{} client got weird exception", + CompositeBufferGatheringWriteTest.this.getClass(), cause); } } diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramConnectedWriteExceptionTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramConnectedWriteExceptionTest.java new file mode 100644 index 00000000000..c26e3a0fb7c --- /dev/null +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramConnectedWriteExceptionTest.java @@ -0,0 +1,141 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.testsuite.transport.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOption; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.socket.DatagramPacket; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; + +import java.net.InetSocketAddress; +import java.net.PortUnreachableException; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DatagramConnectedWriteExceptionTest extends AbstractClientSocketTest { + + @Override + protected List> newFactories() { + return SocketTestPermutation.INSTANCE.datagramSocket(); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + @DisabledOnOs(OS.WINDOWS) + public void testWriteThrowsPortUnreachableException(TestInfo testInfo) throws Throwable { + run(testInfo, new Runner() { + @Override + public void run(Bootstrap bootstrap) throws Throwable { + testWriteExceptionAfterServerStop(bootstrap); + } + }); + } + + protected void testWriteExceptionAfterServerStop(Bootstrap clientBootstrap) throws Throwable { + final CountDownLatch serverReceivedLatch = new CountDownLatch(1); + Bootstrap serverBootstrap = clientBootstrap.clone() + .option(ChannelOption.SO_BROADCAST, false) + .handler(new SimpleChannelInboundHandler() { + + @Override + protected void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) { + serverReceivedLatch.countDown(); + } + }); + + Channel serverChannel = serverBootstrap.bind(new InetSocketAddress(NetUtil.LOCALHOST, 0)).sync().channel(); + InetSocketAddress serverAddress = (InetSocketAddress) serverChannel.localAddress(); + + clientBootstrap.option(ChannelOption.AUTO_READ, false) + .handler(new SimpleChannelInboundHandler() { + + @Override + protected void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) { + // no-op + } + }); + + Channel clientChannel = clientBootstrap.connect(serverAddress).sync().channel(); + + final CountDownLatch clientFirstSendLatch = new CountDownLatch(1); + try { + ByteBuf firstMessage = Unpooled.wrappedBuffer("First message".getBytes(CharsetUtil.UTF_8)); + clientChannel.writeAndFlush(firstMessage) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (future.isSuccess()) { + clientFirstSendLatch.countDown(); + } + } + }); + + assertTrue(serverReceivedLatch.await(5, TimeUnit.SECONDS), "Server should receive first message"); + assertTrue(clientFirstSendLatch.await(5, TimeUnit.SECONDS), "Client should send first message"); + + serverChannel.close().sync(); + + final AtomicReference writeException = new AtomicReference(); + final CountDownLatch writesCompleteLatch = new CountDownLatch(10); + + for (int i = 0; i < 10; i++) { + ByteBuf message = Unpooled.wrappedBuffer(("Message " + i).getBytes(CharsetUtil.UTF_8)); + clientChannel.writeAndFlush(message) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (!future.isSuccess()) { + writeException.compareAndSet(null, future.cause()); + } + writesCompleteLatch.countDown(); + } + }); + Thread.sleep(50); + } + + assertTrue(writesCompleteLatch.await(5, TimeUnit.SECONDS), "All writes should complete"); + + assertNotNull(writeException.get(), "Should have captured a write exception"); + + assertInstanceOf(PortUnreachableException.class, writeException.get(), "Expected " + + "PortUnreachableException but got: " + writeException.get().getClass().getName()); + } finally { + if (clientChannel != null) { + clientChannel.close().sync(); + } + } + } +} diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java index ed91927ad91..4a4ae3a1eb9 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java @@ -128,7 +128,7 @@ public static Collection data() throws Exception { public void testSslRenegotiationRejected(final SslContext serverCtx, final SslContext clientCtx, final boolean delegate, TestInfo testInfo) throws Throwable { // BoringSSL does not support renegotiation intentionally. - assumeFalse("BoringSSL".equals(OpenSsl.versionString())); + assumeFalse("BoringSSL".equals(OpenSsl.versionString()) || OpenSsl.versionString().startsWith("AWS-LC")); assumeTrue(OpenSsl.isAvailable()); run(testInfo, new Runner() { @Override @@ -206,6 +206,7 @@ public void initChannel(Channel sch) throws Exception { } finally { if (executorService != null) { executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); } } } diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java index a8fab7c4806..63b8a9b0003 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java @@ -381,6 +381,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { clientChannel.close().awaitUninterruptibly(); sc.close().awaitUninterruptibly(); delegatedTaskExecutor.shutdown(); + assertTrue(delegatedTaskExecutor.awaitTermination(5, TimeUnit.SECONDS)); if (serverException.get() != null && !(serverException.get() instanceof IOException)) { throw serverException.get(); diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java index 7cd3ece19df..70a20487950 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java @@ -58,6 +58,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; public class SocketSslGreetingTest extends AbstractSocketTest { @@ -179,6 +180,7 @@ public void initChannel(Channel sch) throws Exception { } finally { if (executorService != null) { executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); } } } diff --git a/transport-blockhound-tests/pom.xml b/transport-blockhound-tests/pom.xml index c109f9ba179..6823c99e56a 100644 --- a/transport-blockhound-tests/pom.xml +++ b/transport-blockhound-tests/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-transport-blockhound-tests diff --git a/transport-blockhound-tests/src/test/java/io/netty/util/internal/NettyBlockHoundIntegrationTest.java b/transport-blockhound-tests/src/test/java/io/netty/util/internal/NettyBlockHoundIntegrationTest.java index 403b97a4189..ed0845ac0ee 100644 --- a/transport-blockhound-tests/src/test/java/io/netty/util/internal/NettyBlockHoundIntegrationTest.java +++ b/transport-blockhound-tests/src/test/java/io/netty/util/internal/NettyBlockHoundIntegrationTest.java @@ -105,9 +105,12 @@ public void testServiceLoader() { @Test public void testBlockingCallsInNettyThreads() throws Exception { - final FutureTask future = new FutureTask<>(() -> { - Thread.sleep(0); - return null; + final FutureTask future = new FutureTask<>(new Callable() { + @Override + public Void call() throws Exception { + Thread.sleep(0); + return null; + } }); GlobalEventExecutor.INSTANCE.execute(future); @@ -173,9 +176,16 @@ protected void run() { }; taskQueue.emulateContention(); CountDownLatch latch = new CountDownLatch(1); - executor.submit(() -> { - executor.execute(() -> { }); // calls addTask - latch.countDown(); + executor.submit(new Runnable() { + @Override + public void run() { + executor.execute(new Runnable() { + @Override + public void run() { + } + }); // calls addTask + latch.countDown(); + } }); taskQueue.waitUntilContented(); taskQueue.removeContention(); @@ -184,9 +194,12 @@ protected void run() { @Test void permittingBlockingCallsInFastThreadLocalThreadSubclass() throws Exception { - final FutureTask future = new FutureTask<>(() -> { - Thread.sleep(0); - return null; + final FutureTask future = new FutureTask<>(new Callable() { + @Override + public Void call() throws Exception { + Thread.sleep(0); + return null; + } }); FastThreadLocalThread thread = new FastThreadLocalThread(future) { @Override @@ -250,6 +263,7 @@ public void testHandshakeWithExecutor() throws Exception { testHandshakeWithExecutor(executorService, "TLSv1.2"); } finally { executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); } } @@ -261,6 +275,7 @@ public void testHandshakeWithExecutorTLSv13() throws Exception { testHandshakeWithExecutor(executorService, "TLSv1.3"); } finally { executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); } } @@ -338,8 +353,12 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { } }) .connect(sc.localAddress()) - .addListener((ChannelFutureListener) future -> - future.channel().writeAndFlush(wrappedBuffer(new byte [] { 1, 2, 3, 4 }))) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + future.channel().writeAndFlush(wrappedBuffer(new byte[]{1, 2, 3, 4})); + } + }) .syncUninterruptibly() .channel(); @@ -362,20 +381,23 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { public void pooledBufferAllocation() throws Exception { AtomicLong iterationCounter = new AtomicLong(); PooledByteBufAllocator allocator = PooledByteBufAllocator.DEFAULT; - FutureTask task = new FutureTask<>(() -> { - List buffers = new ArrayList<>(); - long count; - do { - count = iterationCounter.get(); - } while (count == 0); - for (int i = 0; i < 13; i++) { - int size = 8 << i; - buffers.add(allocator.ioBuffer(size, size)); - } - for (ByteBuf buffer : buffers) { - buffer.release(); + FutureTask task = new FutureTask<>(new Callable() { + @Override + public Void call() throws Exception { + List buffers = new ArrayList<>(); + long count; + do { + count = iterationCounter.get(); + } while (count == 0); + for (int i = 0; i < 13; i++) { + int size = 8 << i; + buffers.add(allocator.ioBuffer(size, size)); + } + for (ByteBuf buffer : buffers) { + buffer.release(); + } + return null; } - return null; }); FastThreadLocalThread thread = new FastThreadLocalThread(task); thread.start(); @@ -431,13 +453,16 @@ protected void run() { CountDownLatch latch = new CountDownLatch(1); List result = new ArrayList<>(); List error = new ArrayList<>(); - executor.execute(() -> { - try { - result.add(callable.call()); - } catch (Throwable t) { - error.add(t); + executor.execute(new Runnable() { + @Override + public void run() { + try { + result.add(callable.call()); + } catch (Throwable t) { + error.add(t); + } + latch.countDown(); } - latch.countDown(); }); latch.await(); assertEquals(0, error.size()); diff --git a/transport-classes-epoll/pom.xml b/transport-classes-epoll/pom.xml index 593265cc3ce..46153d738ef 100644 --- a/transport-classes-epoll/pom.xml +++ b/transport-classes-epoll/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-transport-classes-epoll diff --git a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java index c4ea86f452d..bf52134a1ac 100644 --- a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java +++ b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java @@ -77,8 +77,7 @@ protected Object filterOutboundMessage(Object msg) throws Exception { final class EpollServerSocketUnsafe extends AbstractEpollUnsafe { // Will hold the remote address after accept(...) was successful. // We need 24 bytes for the address as maximum + 1 byte for storing the length. - // So use 26 bytes as it's a power of two. - private final byte[] acceptedAddress = new byte[26]; + private final byte[] acceptedAddress = new byte[25]; @Override public void connect(SocketAddress socketAddress, SocketAddress socketAddress2, ChannelPromise channelPromise) { diff --git a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java index 69da8a1e3b8..e606eca38be 100644 --- a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java +++ b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java @@ -129,8 +129,9 @@ public ChannelMetadata metadata() { *
  • {@link EpollChannelConfig#getEpollMode()} must be {@link EpollMode#LEVEL_TRIGGERED} for this and the * target {@link AbstractEpollStreamChannel}
  • * - * + * @deprecated Will be removed in the future. */ + @Deprecated public final ChannelFuture spliceTo(final AbstractEpollStreamChannel ch, final int len) { return spliceTo(ch, len, newPromise()); } @@ -147,8 +148,9 @@ public final ChannelFuture spliceTo(final AbstractEpollStreamChannel ch, final i *
  • {@link EpollChannelConfig#getEpollMode()} must be {@link EpollMode#LEVEL_TRIGGERED} for this and the * target {@link AbstractEpollStreamChannel}
  • * - * + * @deprecated will be removed in the future. */ + @Deprecated public final ChannelFuture spliceTo(final AbstractEpollStreamChannel ch, final int len, final ChannelPromise promise) { if (ch.eventLoop() != eventLoop()) { @@ -182,7 +184,9 @@ public final ChannelFuture spliceTo(final AbstractEpollStreamChannel ch, final i *
  • the {@link FileDescriptor} will not be closed after the {@link ChannelFuture} is notified
  • *
  • this channel must be registered to an event loop or {@link IllegalStateException} will be thrown.
  • * + * @deprecated Will be removed in the future. */ + @Deprecated public final ChannelFuture spliceTo(final FileDescriptor ch, final int offset, final int len) { return spliceTo(ch, offset, len, newPromise()); } @@ -200,7 +204,9 @@ public final ChannelFuture spliceTo(final FileDescriptor ch, final int offset, f *
  • the {@link FileDescriptor} will not be closed after the {@link ChannelPromise} is notified
  • *
  • this channel must be registered to an event loop or {@link IllegalStateException} will be thrown.
  • * + * @deprecated Will be removed in the future. */ + @Deprecated public final ChannelFuture spliceTo(final FileDescriptor ch, final int offset, final int len, final ChannelPromise promise) { checkPositiveOrZero(len, "len"); diff --git a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java index c42ac048467..613e2c2f274 100644 --- a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java +++ b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java @@ -415,7 +415,14 @@ private boolean doWriteMessage(Object msg) throws Exception { return true; } - return doWriteOrSendBytes(data, remoteAddress, false) > 0; + try { + return doWriteOrSendBytes(data, remoteAddress, false) > 0; + } catch (NativeIoException e) { + if (remoteAddress == null) { + throw translateForConnected(e); + } + throw e; + } } private static void checkUnresolved(AddressedEnvelope envelope) { diff --git a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/NativeStaticallyReferencedJniMethods.java b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/NativeStaticallyReferencedJniMethods.java index 62fb003103b..c2aa0a93b2d 100644 --- a/transport-classes-epoll/src/main/java/io/netty/channel/epoll/NativeStaticallyReferencedJniMethods.java +++ b/transport-classes-epoll/src/main/java/io/netty/channel/epoll/NativeStaticallyReferencedJniMethods.java @@ -39,8 +39,6 @@ private NativeStaticallyReferencedJniMethods() { } static native int eagain(); static native long ssizeMax(); static native int tcpMd5SigMaxKeyLen(); - static native int iovMax(); - static native int uioMaxIov(); static native boolean isSupportingSendmmsg(); static native boolean isSupportingRecvmmsg(); static native int tcpFastopenMode(); diff --git a/transport-classes-kqueue/pom.xml b/transport-classes-kqueue/pom.xml index 4242dd74bc3..6aab758e433 100644 --- a/transport-classes-kqueue/pom.xml +++ b/transport-classes-kqueue/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-transport-classes-kqueue diff --git a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java index 8a4c56cd191..93a26b7a85d 100644 --- a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java +++ b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java @@ -77,8 +77,7 @@ protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddr final class KQueueServerSocketUnsafe extends AbstractKQueueUnsafe { // Will hold the remote address after accept(...) was successful. // We need 24 bytes for the address as maximum + 1 byte for storing the capacity. - // So use 26 bytes as it's a power of two. - private final byte[] acceptedAddress = new byte[26]; + private final byte[] acceptedAddress = new byte[25]; @Override void readReady(KQueueRecvByteAllocatorHandle allocHandle) { diff --git a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java index aa9cd90e959..0a114569ce8 100644 --- a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java +++ b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java @@ -562,7 +562,7 @@ void readReady(final KQueueRecvByteAllocatorHandle allocHandle) { allocHandle.readComplete(); pipeline.fireChannelReadComplete(); - if (close) { + if (close || allocHandle.isReadEOF()) { shutdownInput(false); } } catch (Throwable t) { diff --git a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java index 70f848a3e90..52aa4d4fb27 100644 --- a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java +++ b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java @@ -34,6 +34,7 @@ import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.StringUtil; +import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.NetworkInterface; @@ -265,7 +266,11 @@ protected boolean doWriteMessage(Object msg) throws Exception { if (data.hasMemoryAddress()) { long memoryAddress = data.memoryAddress(); if (remoteAddress == null) { - writtenBytes = socket.writeAddress(memoryAddress, data.readerIndex(), data.writerIndex()); + try { + writtenBytes = socket.writeAddress(memoryAddress, data.readerIndex(), data.writerIndex()); + } catch (Errors.NativeIoException e) { + throw translateForConnected(e); + } } else { writtenBytes = socket.sendToAddress(memoryAddress, data.readerIndex(), data.writerIndex(), remoteAddress.getAddress(), remoteAddress.getPort()); @@ -295,6 +300,16 @@ protected boolean doWriteMessage(Object msg) throws Exception { return writtenBytes > 0; } + private static IOException translateForConnected(Errors.NativeIoException e) { + // We need to correctly translate connect errors to match NIO behaviour. + if (e.expectedErr() == Errors.ERROR_ECONNREFUSED_NEGATIVE) { + PortUnreachableException error = new PortUnreachableException(e.getMessage()); + error.initCause(e); + return error; + } + return e; + } + private static void checkUnresolved(AddressedEnvelope envelope) { if (envelope.recipient() instanceof InetSocketAddress && (((InetSocketAddress) envelope.recipient()).isUnresolved())) { diff --git a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java index 87081a82c88..99ec7e620fd 100644 --- a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java +++ b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java @@ -97,7 +97,7 @@ private void reallocIfNeeded() { */ void realloc(boolean throwIfFail) { // Double the capacity while it is "sufficiently small", and otherwise increase by 50%. - int newLength = capacity <= 65536 ? capacity << 1 : capacity + capacity >> 1; + int newLength = capacity <= 65536 ? capacity << 1 : capacity + (capacity >> 1); try { ByteBuffer buffer = Buffer.allocateDirectWithNativeOrder(calculateBufferCapacity(newLength)); diff --git a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueRecvByteAllocatorHandle.java b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueRecvByteAllocatorHandle.java index acd67f40d92..8d6955848a0 100644 --- a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueRecvByteAllocatorHandle.java +++ b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/KQueueRecvByteAllocatorHandle.java @@ -98,7 +98,7 @@ void numberBytesPending(long numberBytesPending) { } boolean maybeMoreDataToRead() { - /** + /* * kqueue with EV_CLEAR flag set requires that we read until we consume "data" bytes * (see kqueue man). However in order to * respect auto read we supporting reading to stop if auto read is off. If auto read is on we force reading to diff --git a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java index 5c44c57ca45..42ccb20a16b 100644 --- a/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java +++ b/transport-classes-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java @@ -85,7 +85,7 @@ private long memoryOffset(int index) { private void reallocIfNeeded() { if (size == capacity) { // Double the capacity while it is "sufficiently small", and otherwise increase by 50%. - int newLength = capacity <= 65536 ? capacity << 1 : capacity + capacity >> 1; + int newLength = capacity <= 65536 ? capacity << 1 : capacity + (capacity >> 1); ByteBuffer buffer = Buffer.allocateDirectWithNativeOrder(calculateBufferCapacity(newLength)); // Copy over the old content of the memory and reset the position as we always act on the buffer as if // the position was never increased. diff --git a/transport-native-epoll/pom.xml b/transport-native-epoll/pom.xml index 7ac617deb2a..4148553c31e 100644 --- a/transport-native-epoll/pom.xml +++ b/transport-native-epoll/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-transport-native-epoll diff --git a/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c b/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c index cd1e6abfb14..7528c679749 100644 --- a/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c +++ b/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c @@ -507,8 +507,13 @@ static void netty_epoll_linuxsocket_setTcpMd5Sig(JNIEnv* env, jclass clazz, jint } if (key != NULL) { - md5sig.tcpm_keylen = (*env)->GetArrayLength(env, key); - (*env)->GetByteArrayRegion(env, key, 0, md5sig.tcpm_keylen, (void *) &md5sig.tcpm_key); + jint keylen = (*env)->GetArrayLength(env, key); + if (keylen > TCP_MD5SIG_MAXKEYLEN) { + netty_unix_errors_throwIOException(env, "key is too long"); + return; + } + md5sig.tcpm_keylen = (u_int16_t) keylen; + (*env)->GetByteArrayRegion(env, key, 0, keylen, (void *) &md5sig.tcpm_key); if ((*env)->ExceptionCheck(env) == JNI_TRUE) { return; } diff --git a/transport-native-epoll/src/main/c/netty_epoll_native.c b/transport-native-epoll/src/main/c/netty_epoll_native.c index eda94b3992b..7b83549b1b4 100644 --- a/transport-native-epoll/src/main/c/netty_epoll_native.c +++ b/transport-native-epoll/src/main/c/netty_epoll_native.c @@ -251,7 +251,7 @@ static jint netty_epoll_native_epollCreate(JNIEnv* env, jclass clazz) { int err = errno; close(efd); netty_unix_errors_throwChannelExceptionErrorNo(env, "fcntl() failed: ", err); - return err; + return -err; } } return efd; @@ -277,7 +277,7 @@ static inline jint netty_epoll_wait(JNIEnv* env, jint efd, struct epoll_event *e netty_unix_errors_throwRuntimeExceptionErrorNo(env, "clock_gettime() failed: ", errno); return -1; } - deadline = ts.tv_sec * 1000 + ts.tv_nsec / 1000 + timeout; + deadline = ts.tv_sec * 1000 + ts.tv_nsec / 1000000 + timeout; while ((rc = epoll_wait(efd, ev, len, timeout)) < 0) { if (errno != EINTR) { @@ -289,7 +289,7 @@ static inline jint netty_epoll_wait(JNIEnv* env, jint efd, struct epoll_event *e return -1; } - now = ts.tv_sec * 1000 + ts.tv_nsec / 1000; + now = ts.tv_sec * 1000 + ts.tv_nsec / 1000000; if (now >= deadline) { return 0; } @@ -495,6 +495,11 @@ static jint netty_epoll_native_sendmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo for (i = 0; i < len; i++) { jobject packet = (*env)->GetObjectArrayElement(env, packets, i + offset); + if (packet == NULL) { + // This should never happen but just handle it and return early. This way if GetObjectArrayElement(...) + // did put an exception on the stack we will see it and not crash. + return -1; + } jbyteArray address = (jbyteArray) (*env)->GetObjectField(env, packet, packetRecipientAddrFieldId); jint addrLen = (*env)->GetIntField(env, packet, packetRecipientAddrLenFieldId); jint packetSegmentSize = (*env)->GetIntField(env, packet, packetSegmentSizeFieldId); @@ -623,7 +628,7 @@ static jint netty_epoll_native_recvmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo #ifdef IP_RECVORIGDSTADDR int readLocalAddr = 0; if (netty_unix_socket_getOption(env, fd, IPPROTO_IP, IP_RECVORIGDSTADDR, - &readLocalAddr, sizeof(readLocalAddr)) < 0) { + &readLocalAddr, sizeof(readLocalAddr)) != -1 && readLocalAddr != 0) { cntrlbuf = malloc(sizeof(char) * storageSize * len); } #endif // IP_RECVORIGDSTADDR @@ -632,11 +637,16 @@ static jint netty_epoll_native_recvmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo for (i = 0; i < len; i++) { jobject packet = (*env)->GetObjectArrayElement(env, packets, i + offset); + if (packet == NULL) { + // This should never happen but just handle it and return early. This way if GetObjectArrayElement(...) + // did put an exception on the stack we will see it and not crash. + return -1; + } msg[i].msg_hdr.msg_iov = (struct iovec*) (intptr_t) (*env)->GetLongField(env, packet, packetMemoryAddressFieldId); msg[i].msg_hdr.msg_iovlen = (*env)->GetIntField(env, packet, packetCountFieldId); msg[i].msg_hdr.msg_name = addr + i; - msg[i].msg_hdr.msg_namelen = (socklen_t) addrSize; + msg[i].msg_hdr.msg_namelen = (socklen_t) storageSize; if (cntrlbuf != NULL) { msg[i].msg_hdr.msg_control = cntrlbuf + i * storageSize; diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTcpMd5Test.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTcpMd5Test.java index 6bbcb2e660b..17cce3760b9 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTcpMd5Test.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTcpMd5Test.java @@ -16,6 +16,7 @@ package io.netty.channel.epoll; import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelOption; import io.netty.channel.ConnectTimeoutException; @@ -53,10 +54,10 @@ public static void afterClass() { @BeforeEach public void setup() { - Bootstrap bootstrap = new Bootstrap(); + ServerBootstrap bootstrap = new ServerBootstrap(); server = (EpollServerSocketChannel) bootstrap.group(GROUP) .channel(EpollServerSocketChannel.class) - .handler(new ChannelInboundHandlerAdapter()) + .childHandler(new ChannelInboundHandlerAdapter()) .bind(new InetSocketAddress(NetUtil.LOCALHOST4, 0)).syncUninterruptibly().channel(); } @@ -74,10 +75,10 @@ public void testServerSocketChannelOption() throws Exception { @Test public void testServerOption() throws Exception { - Bootstrap bootstrap = new Bootstrap(); + ServerBootstrap bootstrap = new ServerBootstrap(); EpollServerSocketChannel ch = (EpollServerSocketChannel) bootstrap.group(GROUP) .channel(EpollServerSocketChannel.class) - .handler(new ChannelInboundHandlerAdapter()) + .childHandler(new ChannelInboundHandlerAdapter()) .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); ch.config().setOption(EpollChannelOption.TCP_MD5SIG, diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java index 3e09949053e..bfc5fc2fce3 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java @@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -33,6 +34,29 @@ public void testIsAvailable() { assertTrue(Epoll.isAvailable()); } + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testEpollWaitTimeoutAccuracy() throws Exception { + final int timeoutMs = 200; + final FileDescriptor epoll = Native.newEpollCreate(); + final EpollEventArray eventArray = new EpollEventArray(8); + try { + long startNs = System.nanoTime(); + // No fds registered, so this will just wait for the timeout. + int ready = Native.epollWait(epoll, eventArray, timeoutMs); + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs); + + assertEquals(0, ready); + // Should have waited at least close to the timeout + assertThat(elapsedMs).isGreaterThanOrEqualTo(timeoutMs - 20); + // Should not have waited vastly longer than the timeout + assertThat(elapsedMs).isLessThan(timeoutMs + 200); + } finally { + eventArray.free(); + epoll.close(); + } + } + // Testcase for https://github.com/netty/netty/issues/8444 @Test @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/LinuxSocketTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/LinuxSocketTest.java index 4fe962e8575..2154f0f28d7 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/LinuxSocketTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/LinuxSocketTest.java @@ -28,6 +28,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; public class LinuxSocketTest { @@ -94,4 +95,21 @@ public void execute() throws Throwable { socket.close(); } } + + @Test + public void testUnixAbstractDomainSocket() throws IOException { + String address = "\0" + UUID.randomUUID(); + + final DomainSocketAddress domainSocketAddress = new DomainSocketAddress(address); + final Socket socket = Socket.newSocketDomain(); + try { + socket.bind(domainSocketAddress); + DomainSocketAddress local = socket.localDomainSocketAddress(); + assertEquals(domainSocketAddress, local); + assertEquals(address, domainSocketAddress.path()); + assertEquals(address, local.path()); + } finally { + socket.close(); + } + } } diff --git a/transport-native-kqueue/pom.xml b/transport-native-kqueue/pom.xml index f6fcc3b87d0..87520671bf4 100644 --- a/transport-native-kqueue/pom.xml +++ b/transport-native-kqueue/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-transport-native-kqueue diff --git a/transport-native-kqueue/src/main/c/netty_kqueue_bsdsocket.c b/transport-native-kqueue/src/main/c/netty_kqueue_bsdsocket.c index 8e13979ee0a..19ecdf4d0ab 100644 --- a/transport-native-kqueue/src/main/c/netty_kqueue_bsdsocket.c +++ b/transport-native-kqueue/src/main/c/netty_kqueue_bsdsocket.c @@ -73,12 +73,17 @@ static jlong netty_kqueue_bsdsocket_sendFile(JNIEnv* env, jclass clazz, jint soc sbytes = 0; res = sendfile(srcFd, socketFd, base_off + off, len, NULL, &sbytes, 0); #endif + // BSD/macOS sendfile passes the offset by value (unlike Linux which takes off_t*). + // When interrupted (EINTR), sbytes reports how many bytes were sent before the signal. + // Advance off so the next iteration resumes from where we left off, not from the start. + off += sbytes; len -= sbytes; } while (res < 0 && ((err = errno) == EINTR)); sbytes = lenBefore - len; if (sbytes > 0) { // update the transferred field in DefaultFileRegion - (*env)->SetLongField(env, fileRegion, transferredFieldId, off + sbytes); + // off has already been advanced by sbytes inside the loop, so it equals the new total. + (*env)->SetLongField(env, fileRegion, transferredFieldId, off); return sbytes; } return res < 0 ? -err : 0; @@ -144,12 +149,36 @@ static void netty_kqueue_bsdsocket_setAcceptFilter(JNIEnv* env, jclass clazz, ji const char* tmpString = NULL; af.af_name[0] = af.af_arg[0] ='\0'; + jsize len = (*env)->GetStringUTFLength(env, afName); + if (len > sizeof(af.af_name)) { + // Too large and so can't be stored + netty_unix_errors_throwChannelExceptionErrorNo(env, "setsockopt() failed: ", EOVERFLOW); + return; + } tmpString = (*env)->GetStringUTFChars(env, afName, NULL); - strncat(af.af_name, tmpString, sizeof(af.af_name) / sizeof(af.af_name[0])); + if (tmpString == NULL) { + // if NULL is returned it failed due OOME + netty_unix_errors_throwChannelExceptionErrorNo(env, "setsockopt() failed: ", ENOMEM); + return; + } + + strlcat(af.af_name, tmpString, sizeof(af.af_name)); (*env)->ReleaseStringUTFChars(env, afName, tmpString); + len = (*env)->GetStringUTFLength(env, afArg); + if (len > sizeof(af.af_arg)) { + // Too large and so can't be stored + netty_unix_errors_throwChannelExceptionErrorNo(env, "setsockopt() failed: ", EOVERFLOW); + return; + } + tmpString = (*env)->GetStringUTFChars(env, afArg, NULL); - strncat(af.af_arg, tmpString, sizeof(af.af_arg) / sizeof(af.af_arg[0])); + if (tmpString == NULL) { + // if NULL is returned it failed due OOME + netty_unix_errors_throwChannelExceptionErrorNo(env, "setsockopt() failed: ", ENOMEM); + return; + } + strlcat(af.af_arg, tmpString, sizeof(af.af_arg)); (*env)->ReleaseStringUTFChars(env, afArg, tmpString); netty_unix_socket_setOption(env, fd, SOL_SOCKET, SO_ACCEPTFILTER, &af, sizeof(af)); @@ -250,7 +279,7 @@ static jobject netty_kqueue_bsdsocket_getPeerCredentials(JNIEnv *env, jclass cla #ifdef LOCAL_PEERPID socklen_t len = sizeof(pid); // Getting the LOCAL_PEERPID is expected to return error in some cases (e.g. server socket FDs) - just return 0. - if (netty_unix_socket_getOption0(fd, SOCK_STREAM, LOCAL_PEERPID, &pid, len) < 0) { + if (netty_unix_socket_getOption0(fd, SOL_LOCAL, LOCAL_PEERPID, &pid, len) < 0) { pid = 0; } #endif diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueCompositeBufferGatheringWriteTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueCompositeBufferGatheringWriteTest.java index f0cd700bc57..0237ac4519f 100644 --- a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueCompositeBufferGatheringWriteTest.java +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueCompositeBufferGatheringWriteTest.java @@ -26,7 +26,7 @@ public class KQueueCompositeBufferGatheringWriteTest extends CompositeBufferGatheringWriteTest { @Override protected List> newFactories() { - return KQueueSocketTestPermutation.INSTANCE.socket(); + return KQueueSocketTestPermutation.INSTANCE.socketWithoutFastOpen(); } @Override diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDatagramConnectedWriteExceptionTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDatagramConnectedWriteExceptionTest.java new file mode 100644 index 00000000000..c964ca8acb3 --- /dev/null +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDatagramConnectedWriteExceptionTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2026 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.DatagramConnectedWriteExceptionTest; + +import java.util.List; + +public class KQueueDatagramConnectedWriteExceptionTest extends DatagramConnectedWriteExceptionTest { + + @Override + protected List> newFactories() { + return KQueueSocketTestPermutation.INSTANCE.datagramSocket(); + } +} diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketTestPermutation.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketTestPermutation.java index ec5aabca7a5..c7be7595cc7 100644 --- a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketTestPermutation.java +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketTestPermutation.java @@ -45,7 +45,6 @@ class KQueueSocketTestPermutation extends SocketTestPermutation { @Override public List> socket() { - List> list = combo(serverSocket(), clientSocketWithFastOpen()); @@ -54,6 +53,15 @@ public List> socketWithoutFastOpen() { + List> list = + combo(serverSocket(), clientSocket()); + + list.remove(list.size() - 1); // Exclude NIO x NIO test + + return list; + } + @Override public List> serverSocket() { List> toReturn = new ArrayList>(); diff --git a/transport-native-unix-common-tests/pom.xml b/transport-native-unix-common-tests/pom.xml index 4be08a13265..0524c4bc76b 100644 --- a/transport-native-unix-common-tests/pom.xml +++ b/transport-native-unix-common-tests/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-transport-native-unix-common-tests diff --git a/transport-native-unix-common/pom.xml b/transport-native-unix-common/pom.xml index dd6880e6f51..b51a8cb6793 100644 --- a/transport-native-unix-common/pom.xml +++ b/transport-native-unix-common/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-transport-native-unix-common diff --git a/transport-native-unix-common/src/main/c/netty_unix_errors.c b/transport-native-unix-common/src/main/c/netty_unix_errors.c index 1dcac708e22..e54eba6a60f 100644 --- a/transport-native-unix-common/src/main/c/netty_unix_errors.c +++ b/transport-native-unix-common/src/main/c/netty_unix_errors.c @@ -37,13 +37,20 @@ static jmethodID closedChannelExceptionMethodId = NULL; even on platforms where the GNU variant is exposed. Note: `strerrbuf` must be initialized to all zeros prior to calling this function. XSI or GNU functions do not have such a requirement, but our wrappers do. + + Android exposes the XSI variant by default, see + https://cs.android.com/android/platform/superproject/+/android16-release:bionic/libc/include/string.h;l=145?q=string.h */ -#if (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600 || __APPLE__) && ! _GNU_SOURCE +#if (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600 || __APPLE__ || __ANDROID__) && ! _GNU_SOURCE static inline int strerror_r_xsi(int errnum, char *strerrbuf, size_t buflen) { return strerror_r(errnum, strerrbuf, buflen); } #else static inline int strerror_r_xsi(int errnum, char *strerrbuf, size_t buflen) { + // Clear errno before calling the GNU variant so we can reliably detect failure. + // The GNU strerror_r only sets errno on error; it does not clear a pre-existing value, + // so a stale non-zero errno would otherwise cause a false negative here. + errno = 0; char* tmp = strerror_r(errnum, strerrbuf, buflen); if (strerrbuf[0] == '\0') { // Our output buffer was not used. Copy from tmp. @@ -200,7 +207,9 @@ static jint netty_unix_errors_errorEHOSTUNREACH(JNIEnv* env, jclass clazz) { } static jstring netty_unix_errors_strError(JNIEnv* env, jclass clazz, jint error) { - return (*env)->NewStringUTF(env, strerror(error)); + char strerrbuf[256] = {0}; + strerror_r_xsi(error, strerrbuf, sizeof(strerrbuf)); + return (*env)->NewStringUTF(env, strerrbuf); } // JNI Registered Methods End diff --git a/transport-native-unix-common/src/main/c/netty_unix_filedescriptor.c b/transport-native-unix-common/src/main/c/netty_unix_filedescriptor.c index dcb4b34a015..13030740641 100644 --- a/transport-native-unix-common/src/main/c/netty_unix_filedescriptor.c +++ b/transport-native-unix-common/src/main/c/netty_unix_filedescriptor.c @@ -100,6 +100,10 @@ static jint netty_unix_filedescriptor_close(JNIEnv* env, jclass clazz, jint fd) static jint netty_unix_filedescriptor_open(JNIEnv* env, jclass clazz, jstring path, jint flags) { const char* f_path = (*env)->GetStringUTFChars(env, path, 0); + if (f_path == NULL) { + return -ENOMEM; + } + int res = open(f_path, flags, 0666); (*env)->ReleaseStringUTFChars(env, path, f_path); @@ -243,13 +247,16 @@ static jlong netty_unix_filedescriptor_newPipe(JNIEnv* env, jclass clazz) { } } else { if (pipe(fd) == 0) { - if (fcntl(fd[0], F_SETFD, O_NONBLOCK) < 0) { + // Read current flags and OR-ing in O_NONBLOCK to preserve old flags as well. + int flags0 = fcntl(fd[0], F_GETFL, 0); + if (flags0 < 0 || fcntl(fd[0], F_SETFL, flags0 | O_NONBLOCK) < 0) { int err = errno; close(fd[0]); close(fd[1]); return -err; } - if (fcntl(fd[1], F_SETFD, O_NONBLOCK) < 0) { + int flags1 = fcntl(fd[1], F_GETFL, 0); + if (flags1 < 0 || fcntl(fd[1], F_SETFL, flags1 | O_NONBLOCK) < 0) { int err = errno; close(fd[0]); close(fd[1]); diff --git a/transport-native-unix-common/src/main/c/netty_unix_socket.c b/transport-native-unix-common/src/main/c/netty_unix_socket.c index 0efea5d65a0..1d63de25b3c 100644 --- a/transport-native-unix-common/src/main/c/netty_unix_socket.c +++ b/transport-native-unix-common/src/main/c/netty_unix_socket.c @@ -133,11 +133,23 @@ static jobject createDatagramSocketAddress(JNIEnv* env, const struct sockaddr_st return obj; } -static jobject createDomainDatagramSocketAddress(JNIEnv* env, const struct sockaddr_storage* addr, int len, jobject local) { +static int domainSocketPathLength(const struct sockaddr_un* s, const socklen_t addrlen) { +#ifdef __linux__ + // Linux supports abstract domain sockets so we need to handle it. + // https://man7.org/linux/man-pages/man7/unix.7.html + if (addrlen >= sizeof(sa_family_t) && s->sun_path[0] == '\0') { + // This is an abstract domain socket address + return (addrlen - sizeof(sa_family_t)); + } +#endif + return strlen(s->sun_path); +} + +static jobject createDomainDatagramSocketAddress(JNIEnv* env, const struct sockaddr_storage* addr, const socklen_t addrlen, int len, jobject local) { jclass domainDatagramSocketAddressClass = NULL; jobject obj = NULL; struct sockaddr_un* s = (struct sockaddr_un*) addr; - int pathLength = strlen(s->sun_path); + int pathLength = domainSocketPathLength(s, addrlen); jbyteArray pathBytes = (*env)->NewByteArray(env, pathLength); if (pathBytes == NULL) { return NULL; @@ -157,9 +169,9 @@ static jobject createDomainDatagramSocketAddress(JNIEnv* env, const struct socka return obj; } -static jbyteArray netty_unix_socket_createDomainSocketAddressArray(JNIEnv* env, const struct sockaddr_storage* addr) { +static jbyteArray netty_unix_socket_createDomainSocketAddressArray(JNIEnv* env, const struct sockaddr_storage* addr, const socklen_t addrlen) { struct sockaddr_un* s = (struct sockaddr_un*) addr; - int pathLength = strlen(s->sun_path); + int pathLength = domainSocketPathLength(s, addrlen); jbyteArray pathBytes = (*env)->NewByteArray(env, pathLength); if (pathBytes == NULL) { return NULL; @@ -446,7 +458,7 @@ static jobject _recvFromDomainSocket(JNIEnv* env, jint fd, void* buffer, jint po int err; do { - bzero(&addr, sizeof(addr)); // Zap addr so we can strlen(addr.sun_path) later. See unix(4). + memset(&addr, 0, sizeof(addr)); // Zap addr so we can strlen(addr.sun_path) later. See unix(4). res = recvfrom(fd, buffer + pos, (size_t) (limit - pos), 0, (struct sockaddr*) &addr, &addrlen); // Keep on reading if it was interrupted } while (res == -1 && ((err = errno) == EINTR)); @@ -464,7 +476,7 @@ static jobject _recvFromDomainSocket(JNIEnv* env, jint fd, void* buffer, jint po return NULL; } - return createDomainDatagramSocketAddress(env, &addr, res, NULL); + return createDomainDatagramSocketAddress(env, &addr, addrlen, res, NULL); } static jint _send(JNIEnv* env, jclass clazz, jint fd, void* buffer, jint pos, jint limit) { @@ -687,8 +699,10 @@ static jint netty_unix_socket_accept(JNIEnv* env, jclass clazz, jint fd, jbyteAr if (accept4) { return socketFd; } + // accept4 was not present so need two more sys-calls ... if (fcntl(socketFd, F_SETFD, FD_CLOEXEC) == -1 || fcntl(socketFd, F_SETFL, O_NONBLOCK) == -1) { - // accept4 was not present so need two more sys-calls ... + // close the fd before report the error so we don't leak it. + close(socketFd); return -errno; } return socketFd; @@ -709,7 +723,7 @@ static jbyteArray netty_unix_socket_remoteDomainSocketAddress(JNIEnv* env, jclas if (getpeername(fd, (struct sockaddr*) &addr, &len) == -1) { return NULL; } - return netty_unix_socket_createDomainSocketAddressArray(env, &addr); + return netty_unix_socket_createDomainSocketAddressArray(env, &addr, len); } static jbyteArray netty_unix_socket_localAddress(JNIEnv* env, jclass clazz, jint fd) { @@ -727,7 +741,7 @@ static jbyteArray netty_unix_socket_localDomainSocketAddress(JNIEnv* env, jclass if (getsockname(fd, (struct sockaddr*) &addr, &len) == -1) { return NULL; } - return netty_unix_socket_createDomainSocketAddressArray(env, &addr); + return netty_unix_socket_createDomainSocketAddressArray(env, &addr, len); } static jint netty_unix_socket_newSocketDgramFd(JNIEnv* env, jclass clazz, jboolean ipv6) { @@ -926,10 +940,6 @@ static jint netty_unix_socket_recvFd(JNIEnv* env, jclass clazz, jint fd) { char control[CMSG_SPACE(sizeof(int))] = { 0 }; char iovecData[1]; - descriptorMessage.msg_control = control; - descriptorMessage.msg_controllen = sizeof(control); - descriptorMessage.msg_iov = iov; - descriptorMessage.msg_iovlen = 1; iov[0].iov_base = iovecData; iov[0].iov_len = sizeof(iovecData); @@ -937,6 +947,14 @@ static jint netty_unix_socket_recvFd(JNIEnv* env, jclass clazz, jint fd) { int err; for (;;) { + // Reset descriptorMessage to an initial start at the beginning of the loop as we might run it multiple + // times. + memset(&descriptorMessage, 0, sizeof(descriptorMessage)); + descriptorMessage.msg_control = control; + descriptorMessage.msg_controllen = sizeof(control); + descriptorMessage.msg_iov = iov; + descriptorMessage.msg_iovlen = 1; + do { res = recvmsg(fd, &descriptorMessage, 0); // Keep on reading if we was interrupted @@ -950,21 +968,61 @@ static jint netty_unix_socket_recvFd(JNIEnv* env, jclass clazz, jint fd) { return -err; } - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&descriptorMessage); - if (!cmsg) { - return -errno; + // Walk every cmsg; close any SCM_RIGHTS fds we cannot use so they + // are never silently leaked (e.g. peer sent more than one fd). + jint result = -1; + err = 0; + + // If ancillary data was truncated the kernel auto-closes fds that + // did not fit but it is still an error we must not retry and so should report it back to the caller. + // Beside this we also need to ensure we close all other fds so they not leak. + if (descriptorMessage.msg_flags & MSG_CTRUNC) { + err = EMSGSIZE; } - if ((cmsg->cmsg_len == CMSG_LEN(sizeof(int))) && (cmsg->cmsg_level == SOL_SOCKET) && (cmsg->cmsg_type == SCM_RIGHTS)) { - socketFd = *((int *) CMSG_DATA(cmsg)); - // set as non blocking as we want to use it with kqueue/epoll - if (fcntl(socketFd, F_SETFL, O_NONBLOCK) == -1) { - err = errno; - close(socketFd); - return -err; + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&descriptorMessage); + while (cmsg != NULL) { + if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { + int nfds = (int) ((cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int)); + int* fds = (int*) CMSG_DATA(cmsg); + + if (nfds == 1 && err == 0) { + socketFd = fds[0]; + + // set as non blocking as we want to use it with kqueue/epoll + if (fcntl(socketFd, F_SETFL, O_NONBLOCK) == -1) { + err = errno; + close(socketFd); + } else { + result = socketFd; + } + } else { + int i = 0; + // Peer sent an unexpected number of fds; close them all + // and signal an error so the caller does not retry blindly. + for (i = 0; i < nfds; i++) { + close(fds[i]); + } + if (result >= 0) { + // Already accepted one fd above; undo it. + close(result); + result = -1; + } + // check if we need to update the err or if we already did set it to an error. + if (err == 0) { + err = EINVAL; + } + } } - return socketFd; + cmsg = CMSG_NXTHDR(&descriptorMessage, cmsg); + } + if (result != -1) { + return result; + } + if (err != 0) { + return -err; } + // No SCM_RIGHTS cmsg found and no error; try again. } } diff --git a/transport-rxtx/pom.xml b/transport-rxtx/pom.xml index df1cc7672c6..e2d266cd5d9 100644 --- a/transport-rxtx/pom.xml +++ b/transport-rxtx/pom.xml @@ -21,7 +21,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-transport-rxtx diff --git a/transport-sctp/pom.xml b/transport-sctp/pom.xml index 5ce10d35157..1c2ba00c368 100644 --- a/transport-sctp/pom.xml +++ b/transport-sctp/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-transport-sctp diff --git a/transport-sctp/src/main/java/io/netty/channel/sctp/DefaultSctpServerChannelConfig.java b/transport-sctp/src/main/java/io/netty/channel/sctp/DefaultSctpServerChannelConfig.java index 2860ba74b54..cde4db83a56 100644 --- a/transport-sctp/src/main/java/io/netty/channel/sctp/DefaultSctpServerChannelConfig.java +++ b/transport-sctp/src/main/java/io/netty/channel/sctp/DefaultSctpServerChannelConfig.java @@ -54,7 +54,8 @@ public DefaultSctpServerChannelConfig( public Map, Object> getOptions() { return getOptions( super.getOptions(), - ChannelOption.SO_RCVBUF, ChannelOption.SO_SNDBUF, SctpChannelOption.SCTP_INIT_MAXSTREAMS); + ChannelOption.SO_RCVBUF, ChannelOption.SO_SNDBUF, ChannelOption.SO_BACKLOG, + SctpChannelOption.SCTP_INIT_MAXSTREAMS); } @SuppressWarnings("unchecked") @@ -66,6 +67,9 @@ public T getOption(ChannelOption option) { if (option == ChannelOption.SO_SNDBUF) { return (T) Integer.valueOf(getSendBufferSize()); } + if (option == ChannelOption.SO_BACKLOG) { + return (T) Integer.valueOf(getBacklog()); + } if (option == SctpChannelOption.SCTP_INIT_MAXSTREAMS) { return (T) getInitMaxStreams(); } @@ -80,6 +84,8 @@ public boolean setOption(ChannelOption option, T value) { setReceiveBufferSize((Integer) value); } else if (option == ChannelOption.SO_SNDBUF) { setSendBufferSize((Integer) value); + } else if (option == ChannelOption.SO_BACKLOG) { + setBacklog((Integer) value); } else if (option == SctpChannelOption.SCTP_INIT_MAXSTREAMS) { setInitMaxStreams((SctpStandardSocketOptions.InitMaxStreams) value); } else { diff --git a/transport-sctp/src/main/java/io/netty/handler/codec/sctp/SctpMessageCompletionHandler.java b/transport-sctp/src/main/java/io/netty/handler/codec/sctp/SctpMessageCompletionHandler.java index 13fe290c488..f6c6669f4cf 100644 --- a/transport-sctp/src/main/java/io/netty/handler/codec/sctp/SctpMessageCompletionHandler.java +++ b/transport-sctp/src/main/java/io/netty/handler/codec/sctp/SctpMessageCompletionHandler.java @@ -17,23 +17,45 @@ package io.netty.handler.codec.sctp; import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; +import io.netty.buffer.CompositeByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandler; import io.netty.channel.sctp.SctpMessage; +import io.netty.handler.codec.CodecException; import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; +import java.util.ArrayList; import java.util.List; +import static io.netty.util.internal.ObjectUtil.checkPositive; + /** * {@link MessageToMessageDecoder} which will take care of handle fragmented {@link SctpMessage}s, so * only complete {@link SctpMessage}s will be forwarded to the next * {@link ChannelInboundHandler}. */ public class SctpMessageCompletionHandler extends MessageToMessageDecoder { - private final IntObjectMap fragments = new IntObjectHashMap(); + private final IntObjectMap> incompleteSctpMessages = new IntObjectHashMap>(); + private final int maxIncompleteSctpMessages; + private final int maxFragments; + + public SctpMessageCompletionHandler() { + this(128, 128); + } + + /** + * Create a new instance. + * + * @param maxIncompleteSctpMessages the maximum number of incomplete sctp message inflight. + * @param maxFragments the maximum number of fragments per sctp message. + */ + public SctpMessageCompletionHandler(int maxIncompleteSctpMessages, int maxFragments) { + super(SctpMessage.class); + this.maxIncompleteSctpMessages = checkPositive(maxIncompleteSctpMessages, "maxIncompleteSctpMessages"); + this.maxFragments = checkPositive(maxFragments, "maxFragments"); + } @Override protected void decode(ChannelHandlerContext ctx, SctpMessage msg, List out) throws Exception { @@ -43,38 +65,59 @@ protected void decode(ChannelHandlerContext ctx, SctpMessage msg, List o final boolean isComplete = msg.isComplete(); final boolean isUnordered = msg.isUnordered(); - ByteBuf frag = fragments.remove(streamIdentifier); + List frag = incompleteSctpMessages.get(streamIdentifier); if (frag == null) { - frag = Unpooled.EMPTY_BUFFER; - } - - if (isComplete && !frag.isReadable()) { - //data chunk is not fragmented - out.add(msg); - } else if (!isComplete && frag.isReadable()) { - //more message to complete - fragments.put(streamIdentifier, Unpooled.wrappedBuffer(frag, byteBuf)); - } else if (isComplete && frag.isReadable()) { - //last message to complete - SctpMessage assembledMsg = new SctpMessage( - protocolIdentifier, - streamIdentifier, - isUnordered, - Unpooled.wrappedBuffer(frag, byteBuf)); - out.add(assembledMsg); + // No previous fragments. + if (isComplete) { + out.add(msg.retain()); + } else { + if (maxIncompleteSctpMessages <= incompleteSctpMessages.size()) { + throw new CodecException( + "Too many incomplete sctp messages in flight: " + maxIncompleteSctpMessages); + } + //first incomplete message + frag = new ArrayList(); + frag.add(byteBuf.retain()); + incompleteSctpMessages.put(streamIdentifier, frag); + } } else { - //first incomplete message - fragments.put(streamIdentifier, byteBuf); + if (maxFragments <= frag.size()) { + throw new CodecException("Too many fragments for sctp message: " + maxFragments); + } + frag.add(byteBuf.retain()); + if (isComplete) { + // Is complete so remove it. + incompleteSctpMessages.remove(streamIdentifier); + CompositeByteBuf composite = ctx.alloc().compositeBuffer(); + + for (int i = 0; i < frag.size(); i++) { + composite.addComponent(true, frag.get(i)); + } + // last message to complete + SctpMessage assembledMsg = new SctpMessage( + protocolIdentifier, + streamIdentifier, + isUnordered, + composite); + out.add(assembledMsg); + } } - byteBuf.retain(); } @Override public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - for (ByteBuf buffer: fragments.values()) { - buffer.release(); + for (List buffers: incompleteSctpMessages.values()) { + for (ByteBuf buffer: buffers) { + buffer.release(); + } } - fragments.clear(); + incompleteSctpMessages.clear(); super.handlerRemoved(ctx); } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + super.exceptionCaught(ctx, cause); + ctx.close(); + } } diff --git a/transport-sctp/src/test/java/io/netty/handler/codec/sctp/SctpMessageCompletionHandlerTest.java b/transport-sctp/src/test/java/io/netty/handler/codec/sctp/SctpMessageCompletionHandlerTest.java index 99ebe94f8ee..299811c3a39 100644 --- a/transport-sctp/src/test/java/io/netty/handler/codec/sctp/SctpMessageCompletionHandlerTest.java +++ b/transport-sctp/src/test/java/io/netty/handler/codec/sctp/SctpMessageCompletionHandlerTest.java @@ -21,13 +21,18 @@ import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.sctp.SctpMessage; +import io.netty.handler.codec.CodecException; import io.netty.util.SuppressForbidden; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; import java.net.SocketAddress; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; public class SctpMessageCompletionHandlerTest { @@ -47,8 +52,68 @@ public void testFragmentsReleased() { assertEquals(0, buffer2.refCnt()); } + @Test + public void testIncompleteMessagesLimited() { + final EmbeddedChannel channel = new EmbeddedChannel(new SctpMessageCompletionHandler(1, 2)); + ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] { 1, 2, 3, 4 }); + ByteBuf buffer2 = Unpooled.wrappedBuffer(new byte[] { 1, 2, 3, 4 }); + SctpMessage message = new SctpMessage(new TestMessageInfo(false, 1), buffer); + assertFalse(channel.writeInbound(message)); + assertEquals(1, buffer.refCnt()); + final SctpMessage message2 = new SctpMessage(new TestMessageInfo(false, 2), buffer2); + assertThrows(CodecException.class, new Executable() { + @Override + public void execute() throws Throwable { + channel.writeInbound(message2); + } + }); + // exceptionCaught closes the channel, triggering handlerRemoved which releases all buffered fragments + assertEquals(0, buffer.refCnt()); + assertEquals(0, buffer2.refCnt()); + assertFalse(channel.finish()); + } + + @Test + public void testFragmentsLimited() { + final EmbeddedChannel channel = new EmbeddedChannel(new SctpMessageCompletionHandler(1, 2)); + ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] { 1, 2, 3, 4 }); + ByteBuf buffer2 = Unpooled.wrappedBuffer(new byte[] { 1, 2, 3, 4 }); + final ByteBuf buffer3 = Unpooled.wrappedBuffer(new byte[] { 1, 2, 3, 4 }); + + assertFalse(channel.writeInbound(new SctpMessage(new TestMessageInfo(false, 1), buffer))); + assertEquals(1, buffer.refCnt()); + + assertFalse(channel.writeInbound(new SctpMessage(new TestMessageInfo(false, 1), buffer2))); + assertEquals(1, buffer2.refCnt()); + + assertThrows(CodecException.class, new Executable() { + @Override + public void execute() throws Throwable { + channel.writeInbound(new SctpMessage(new TestMessageInfo(true, 1), buffer3)); + } + }); + assertEquals(0, buffer3.refCnt()); + + assertFalse(channel.finish()); + assertEquals(0, buffer.refCnt()); + assertEquals(0, buffer2.refCnt()); + } + + @Test + public void testNotFragmented() { + EmbeddedChannel channel = new EmbeddedChannel(new SctpMessageCompletionHandler()); + ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] { 1, 2, 3, 4 }); + SctpMessage message = new SctpMessage(new TestMessageInfo(true, 1), buffer); + assertTrue(channel.writeInbound(message)); + SctpMessage read = channel.readInbound(); + assertSame(message, read); + assertEquals(1, read.refCnt()); + read.release(); + assertFalse(channel.finish()); + } + @SuppressForbidden(reason = "test-only") - private final class TestMessageInfo extends MessageInfo { + private static final class TestMessageInfo extends MessageInfo { private final boolean complete; private final int streamNumber; diff --git a/transport-udt/pom.xml b/transport-udt/pom.xml index f2dcfb7498e..0945fddd843 100644 --- a/transport-udt/pom.xml +++ b/transport-udt/pom.xml @@ -21,7 +21,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-transport-udt diff --git a/transport/pom.xml b/transport/pom.xml index f9e2a81e2df..06695dd7da2 100644 --- a/transport/pom.xml +++ b/transport/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.128.1.dse + 4.1.135.1.dse netty-transport diff --git a/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java b/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java index f319944318d..d82024fe976 100644 --- a/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java +++ b/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java @@ -32,6 +32,7 @@ import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.SocketUtils; import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SystemPropertyUtil; import io.netty.util.internal.logging.InternalLogger; import java.net.InetAddress; @@ -52,6 +53,9 @@ * transports such as datagram (UDP).

    */ public abstract class AbstractBootstrap, C extends Channel> implements Cloneable { + + private static final boolean CLOSE_ON_SET_OPTION_FAILURE = SystemPropertyUtil.getBoolean( + "io.netty.bootstrap.closeOnSetOptionFailure", true); @SuppressWarnings("unchecked") private static final Map.Entry, Object>[] EMPTY_OPTION_ARRAY = new Map.Entry[0]; @SuppressWarnings("unchecked") @@ -357,7 +361,7 @@ final ChannelFuture initAndRegister() { return regFuture; } - abstract void init(Channel channel) throws Exception; + abstract void init(Channel channel) throws Throwable; Collection getInitializerExtensions() { ClassLoader loader = extensionsClassLoader; @@ -474,7 +478,7 @@ static void setAttributes(Channel channel, Map.Entry, Object>[] } static void setChannelOptions( - Channel channel, Map.Entry, Object>[] options, InternalLogger logger) { + Channel channel, Map.Entry, Object>[] options, InternalLogger logger) throws Throwable { for (Map.Entry, Object> e: options) { setChannelOption(channel, e.getKey(), e.getValue(), logger); } @@ -482,7 +486,7 @@ static void setChannelOptions( @SuppressWarnings("unchecked") private static void setChannelOption( - Channel channel, ChannelOption option, Object value, InternalLogger logger) { + Channel channel, ChannelOption option, Object value, InternalLogger logger) throws Throwable { try { if (!channel.config().setOption((ChannelOption) option, value)) { logger.warn("Unknown channel option '{}' for channel '{}' of type '{}'", @@ -492,6 +496,10 @@ private static void setChannelOption( logger.warn( "Failed to set channel option '{}' with value '{}' for channel '{}' of type '{}'", option, value, channel, channel.getClass(), t); + if (CLOSE_ON_SET_OPTION_FAILURE) { + // Only rethrow if we want to close the channel in case of a failure. + throw t; + } } } diff --git a/transport/src/main/java/io/netty/bootstrap/Bootstrap.java b/transport/src/main/java/io/netty/bootstrap/Bootstrap.java index cfba85fe31c..5c71a02d8a1 100644 --- a/transport/src/main/java/io/netty/bootstrap/Bootstrap.java +++ b/transport/src/main/java/io/netty/bootstrap/Bootstrap.java @@ -271,11 +271,12 @@ public void run() { } @Override - void init(Channel channel) { + void init(Channel channel) throws Throwable { ChannelPipeline p = channel.pipeline(); p.addLast(config.handler()); setChannelOptions(channel, newOptionsArray(), logger); + setAttributes(channel, newAttributesArray()); Collection extensions = getInitializerExtensions(); if (!extensions.isEmpty()) { diff --git a/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java b/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java index c8a17fc06f4..b3e14c4e715 100644 --- a/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java +++ b/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java @@ -132,7 +132,7 @@ public ServerBootstrap childHandler(ChannelHandler childHandler) { } @Override - void init(Channel channel) { + void init(Channel channel) throws Throwable { setChannelOptions(channel, newOptionsArray(), logger); setAttributes(channel, newAttributesArray()); @@ -227,7 +227,12 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { child.pipeline().addLast(childHandler); - setChannelOptions(child, childOptions, logger); + try { + setChannelOptions(child, childOptions, logger); + } catch (Throwable cause) { + forceClose(child, cause); + return; + } setAttributes(child, childAttrs); if (!extensions.isEmpty()) { diff --git a/transport/src/main/java/io/netty/channel/ChannelInitializer.java b/transport/src/main/java/io/netty/channel/ChannelInitializer.java index 61d91124cc5..0681b329c70 100644 --- a/transport/src/main/java/io/netty/channel/ChannelInitializer.java +++ b/transport/src/main/java/io/netty/channel/ChannelInitializer.java @@ -128,8 +128,8 @@ private boolean initChannel(ChannelHandlerContext ctx) throws Exception { try { initChannel((C) ctx.channel()); } catch (Throwable cause) { - // Explicitly call exceptionCaught(...) as we removed the handler before calling initChannel(...). - // We do so to prevent multiple calls to initChannel(...). + // Explicitly route the failure into the pipeline. Re-entrance is guarded by + // the initMap.add(ctx) check above; the finally block below removes the handler. exceptionCaught(ctx, cause); } finally { if (!ctx.isRemoved()) { diff --git a/transport/src/main/java/io/netty/channel/local/LocalChannel.java b/transport/src/main/java/io/netty/channel/local/LocalChannel.java index 8e5f9c50c96..82b50da57c3 100644 --- a/transport/src/main/java/io/netty/channel/local/LocalChannel.java +++ b/transport/src/main/java/io/netty/channel/local/LocalChannel.java @@ -78,6 +78,13 @@ public void run() { } }; + private final Runnable finishReadTask = new Runnable() { + @Override + public void run() { + finishPeerRead0(LocalChannel.this); + } + }; + private volatile State state; private volatile LocalChannel peer; private volatile LocalAddress localAddress; @@ -418,21 +425,19 @@ private void finishPeerRead(final LocalChannel peer) { } } - private void runFinishPeerReadTask(final LocalChannel peer) { + private void runFinishTask0() { // If the peer is writing, we must wait until after reads are completed for that peer before we can read. So // we keep track of the task, and coordinate later that our read can't happen until the peer is done. - final Runnable finishPeerReadTask = new Runnable() { - @Override - public void run() { - finishPeerRead0(peer); - } - }; + if (writeInProgress) { + finishReadFuture = eventLoop().submit(finishReadTask); + } else { + eventLoop().execute(finishReadTask); + } + } + + private void runFinishPeerReadTask(final LocalChannel peer) { try { - if (peer.writeInProgress) { - peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask); - } else { - peer.eventLoop().execute(finishPeerReadTask); - } + peer.runFinishTask0(); } catch (Throwable cause) { logger.warn("Closing Local channels {}-{} because exception occurred!", this, peer, cause); close(); @@ -482,7 +487,6 @@ public void connect(final SocketAddress remoteAddress, if (state == State.CONNECTED) { Exception cause = new AlreadyConnectedException(); safeSetFailure(promise, cause); - pipeline().fireExceptionCaught(cause); return; } diff --git a/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java b/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java index bfa2bae8459..4927a337494 100644 --- a/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java +++ b/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java @@ -30,6 +30,7 @@ import io.netty.channel.socket.ChannelInputShutdownReadComplete; import io.netty.channel.socket.SocketChannelConfig; import io.netty.util.internal.StringUtil; +import io.netty.util.internal.ThrowableUtil; import java.io.IOException; import java.nio.channels.SelectableChannel; @@ -115,7 +116,11 @@ private void handleReadException(ChannelPipeline pipeline, ByteBuf byteBuf, Thro if (byteBuf != null) { if (byteBuf.isReadable()) { readPending = false; - pipeline.fireChannelRead(byteBuf); + try { + pipeline.fireChannelRead(byteBuf); + } catch (Exception e) { + ThrowableUtil.addSuppressed(cause, e); + } } else { byteBuf.release(); } diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java index a361f3dc208..9f103b8bf35 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java @@ -555,7 +555,7 @@ public ChannelFuture block( try { key.block(sourceToBlock); } catch (IOException e) { - promise.setFailure(e); + return promise.setFailure(e); } } } diff --git a/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java b/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java index f788899493f..177538bf352 100644 --- a/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java +++ b/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java @@ -85,6 +85,29 @@ public static void destroy() { groupB.terminationFuture().syncUninterruptibly(); } + @Test + public void testSetOptionsThrow() { + final ChannelFuture cf = new Bootstrap() + .group(groupA) + .channelFactory(new ChannelFactory() { + @Override + public Channel newChannel() { + return new TestChannel(); + } + }) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 4242) + .handler(new ChannelInboundHandlerAdapter()) + .register(); + + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() throws Throwable { + cf.syncUninterruptibly(); + } + }); + assertFalse(cf.channel().isActive()); + } + @Test public void testOptionsCopied() { final Bootstrap bootstrapA = new Bootstrap(); @@ -578,4 +601,5 @@ public void run() { }; } } + } diff --git a/transport/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java b/transport/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java index 36ed66cbc5a..c2376f7f159 100644 --- a/transport/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java +++ b/transport/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java @@ -16,6 +16,8 @@ package io.netty.bootstrap; import io.netty.channel.Channel; +import io.netty.channel.ChannelFactory; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -24,6 +26,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalEventLoopGroup; @@ -31,6 +34,7 @@ import io.netty.util.AttributeKey; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; import java.util.UUID; import java.util.concurrent.Callable; @@ -40,12 +44,43 @@ import java.util.concurrent.atomic.AtomicReference; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class ServerBootstrapTest { + @Test + public void testSetOptionsThrow() { + LocalEventLoopGroup group = new LocalEventLoopGroup(1); + try { + final ChannelFuture cf = new ServerBootstrap() + .group(group) + .channelFactory(new ChannelFactory() { + @Override + public ServerChannel newChannel() { + return new TestServerChannel(); + } + }) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 4242) + .handler(new ChannelInboundHandlerAdapter()) + .childHandler(new ChannelInboundHandlerAdapter()) + .register(); + + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() throws Throwable { + cf.syncUninterruptibly(); + } + }); + assertFalse(cf.channel().isActive()); + } finally { + group.shutdownGracefully(); + } + } + @Test @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) public void testHandlerRegister() throws Exception { @@ -240,4 +275,6 @@ public Object call() throws Exception { clientChannel.close().syncUninterruptibly(); group.shutdownGracefully(); } + + private static final class TestServerChannel extends TestChannel implements ServerChannel { } } diff --git a/transport/src/test/java/io/netty/bootstrap/TestChannel.java b/transport/src/test/java/io/netty/bootstrap/TestChannel.java new file mode 100644 index 00000000000..d654d36fffe --- /dev/null +++ b/transport/src/test/java/io/netty/bootstrap/TestChannel.java @@ -0,0 +1,124 @@ +/* + * Copyright 2025 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.AbstractChannel; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.EventLoop; + +import java.net.SocketAddress; + +class TestChannel extends AbstractChannel { + private static final ChannelMetadata METADATA = new ChannelMetadata(false); + private final ChannelConfig config; + private volatile boolean closed; + + TestChannel() { + this(null); + } + + TestChannel(Channel parent) { + super(parent); + config = new TestConfig(this); + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new AbstractUnsafe() { + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + promise.setSuccess(); + } + }; + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return true; + } + + @Override + protected SocketAddress localAddress0() { + return null; + } + + @Override + protected SocketAddress remoteAddress0() { + return null; + } + + @Override + protected void doBind(SocketAddress localAddress) { + // NOOP + } + + @Override + protected void doDisconnect() { + closed = true; + } + + @Override + protected void doClose() { + closed = true; + } + + @Override + protected void doBeginRead() { + // NOOP + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) { + // NOOP + } + + @Override + public ChannelConfig config() { + return config; + } + + @Override + public boolean isOpen() { + return !closed; + } + + @Override + public boolean isActive() { + return !closed; + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + private static final class TestConfig extends DefaultChannelConfig { + TestConfig(Channel channel) { + super(channel); + } + + @Override + public boolean setOption(ChannelOption option, T value) { + throw new UnsupportedOperationException("Unsupported channel option: " + option); + } + } +} diff --git a/transport/src/test/java/io/netty/channel/CompleteChannelFutureTest.java b/transport/src/test/java/io/netty/channel/CompleteChannelFutureTest.java index 3c0378b849c..9df99697701 100644 --- a/transport/src/test/java/io/netty/channel/CompleteChannelFutureTest.java +++ b/transport/src/test/java/io/netty/channel/CompleteChannelFutureTest.java @@ -44,7 +44,7 @@ public void shouldNotDoAnythingOnRemove() { ChannelFutureListener l = Mockito.mock(ChannelFutureListener.class); future.removeListener(l); Mockito.verifyNoMoreInteractions(l); - Mockito.verifyZeroInteractions(channel); + Mockito.verifyNoInteractions(channel); } @Test @@ -60,7 +60,7 @@ public void testConstantProperties() throws InterruptedException { assertSame(future, future.awaitUninterruptibly()); assertTrue(future.awaitUninterruptibly(1)); assertTrue(future.awaitUninterruptibly(1, TimeUnit.NANOSECONDS)); - Mockito.verifyZeroInteractions(channel); + Mockito.verifyNoInteractions(channel); } private static class CompleteChannelFutureImpl extends CompleteChannelFuture { diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index fa86dd10a92..6182325dc82 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -47,6 +47,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import java.net.SocketAddress; import java.util.ArrayDeque; @@ -449,6 +451,183 @@ public void channelRegistered(ChannelHandlerContext ctx) { assertTrue(latch.await(2, TimeUnit.SECONDS)); } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testInboundOperationsViaContext(boolean inEventLoop) throws Exception { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + final ChannelHandler handler = new ChannelHandlerAdapter() { }; + pipeline.addLast(handler); + group.register(pipeline.channel()).syncUninterruptibly(); + final BlockingQueue events = new LinkedBlockingQueue(); + pipeline.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + events.add("channelRegistered"); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) { + events.add("channelUnregistered"); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + events.add("channelActive"); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + events.add("channelInactive"); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + events.add("channelRead"); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + events.add("channelReadComplete"); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + events.add("userEventTriggered"); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + events.add("channelWritabilityChanged"); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + events.add("exceptionCaught"); + } + }); + final ChannelHandlerContext ctx = pipeline.context(handler); + if (inEventLoop) { + pipeline.channel().eventLoop().execute(new Runnable() { + @Override + public void run() { + executeInboundOperations(ctx); + } + }); + } else { + executeInboundOperations(ctx); + } + + assertEquals("channelRegistered", events.take()); + assertEquals("channelUnregistered", events.take()); + assertEquals("channelActive", events.take()); + assertEquals("channelInactive", events.take()); + assertEquals("channelRead", events.take()); + assertEquals("channelReadComplete", events.take()); + assertEquals("userEventTriggered", events.take()); + assertEquals("channelWritabilityChanged", events.take()); + assertEquals("exceptionCaught", events.take()); + assertTrue(events.isEmpty()); + pipeline.removeLast(); + pipeline.channel().close().syncUninterruptibly(); + } + + private static void executeInboundOperations(ChannelHandlerContext ctx) { + ctx.fireChannelRegistered(); + ctx.fireChannelUnregistered(); + ctx.fireChannelActive(); + ctx.fireChannelInactive(); + ctx.fireChannelRead(""); + ctx.fireChannelReadComplete(); + ctx.fireUserEventTriggered(""); + ctx.fireChannelWritabilityChanged(); + ctx.fireExceptionCaught(new Exception()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testOutboundOperationsViaContext(boolean inEventLoop) throws Exception { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + final ChannelHandler handler = new ChannelHandlerAdapter() { }; + pipeline.addLast(handler); + group.register(pipeline.channel()).syncUninterruptibly(); + final BlockingQueue events = new LinkedBlockingQueue(); + pipeline.addFirst(new ChannelOutboundHandlerAdapter() { + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { + events.add("bind"); + promise.setSuccess(); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) { + events.add("connect"); + promise.setSuccess(); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) { + events.add("close"); + promise.setSuccess(); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { + events.add("deregister"); + promise.setSuccess(); + } + + @Override + public void read(ChannelHandlerContext ctx) { + events.add("read"); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + events.add("write"); + promise.setSuccess(); + } + + @Override + public void flush(ChannelHandlerContext ctx) { + events.add("flush"); + ctx.flush(); + } + }); + final ChannelHandlerContext ctx = pipeline.context(handler); + if (inEventLoop) { + pipeline.channel().eventLoop().execute(new Runnable() { + @Override + public void run() { + executeOutboundOperations(ctx); + } + }); + } else { + executeOutboundOperations(ctx); + } + + assertEquals("bind", events.take()); + assertEquals("connect", events.take()); + assertEquals("close", events.take()); + assertEquals("deregister", events.take()); + assertEquals("read", events.take()); + assertEquals("write", events.take()); + assertEquals("flush", events.take()); + assertTrue(events.isEmpty()); + pipeline.removeFirst(); + pipeline.channel().close().syncUninterruptibly(); + } + + private static void executeOutboundOperations(ChannelHandlerContext ctx) { + ctx.bind(new SocketAddress() { }); + ctx.connect(new SocketAddress() { }); + ctx.close(); + ctx.deregister(); + ctx.read(); + ctx.write(""); + ctx.flush(); + } + @Test public void testPipelineOperation() { ChannelPipeline pipeline = new LocalChannel().pipeline(); diff --git a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java index fdb10e5c0c7..f733748809e 100644 --- a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java +++ b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java @@ -46,6 +46,7 @@ import org.junit.jupiter.api.function.Executable; import java.net.ConnectException; +import java.nio.channels.AlreadyConnectedException; import java.nio.channels.ClosedChannelException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; @@ -881,6 +882,48 @@ public void execute() { }); } + @Test + public void testConnectedAlready() throws Exception { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final AtomicReference causeRef = new AtomicReference(); + cb.group(group1) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + causeRef.set(cause); + } + }); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new TestHandler()); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).sync().channel(); + + // Connect to the server + cc = cb.connect(sc.localAddress()).sync().channel(); + + ChannelFuture f = cc.connect(sc.localAddress()).awaitUninterruptibly(); + assertInstanceOf(AlreadyConnectedException.class, f.cause()); + cc.close().syncUninterruptibly(); + assertNull(causeRef.get()); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } + private static final class LatchChannelFutureListener extends CountDownLatch implements ChannelFutureListener { private LatchChannelFutureListener(int count) { super(count); diff --git a/transport/src/test/java/io/netty/channel/socket/nio/NioSocketChannelTest.java b/transport/src/test/java/io/netty/channel/socket/nio/NioSocketChannelTest.java index b94618f6518..59b26590b99 100644 --- a/transport/src/test/java/io/netty/channel/socket/nio/NioSocketChannelTest.java +++ b/transport/src/test/java/io/netty/channel/socket/nio/NioSocketChannelTest.java @@ -35,6 +35,8 @@ import io.netty.util.CharsetUtil; import io.netty.util.NetUtil; import io.netty.util.internal.PlatformDependent; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.ThrowableUtil; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -165,13 +167,13 @@ public void operationComplete(ChannelFuture future) throws Exception { // Test for https://github.com/netty/netty/issues/4805 @Test - @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + @Timeout(30) public void testChannelReRegisterReadSameEventLoop() throws Exception { testChannelReRegisterRead(true); } @Test - @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + @Timeout(30) public void testChannelReRegisterReadDifferentEventLoop() throws Exception { testChannelReRegisterRead(false); } @@ -179,6 +181,7 @@ public void testChannelReRegisterReadDifferentEventLoop() throws Exception { private static void testChannelReRegisterRead(final boolean sameEventLoop) throws Exception { final EventLoopGroup group = new NioEventLoopGroup(2); final CountDownLatch latch = new CountDownLatch(1); + final Promise eventLoopCheck = group.next().newPromise(); // Just some random bytes byte[] bytes = new byte[1024]; @@ -225,8 +228,17 @@ private void deregister(ChannelHandlerContext ctx, final EventLoop loop) { @Override public void operationComplete(ChannelFuture cf) { Channel channel = cf.channel(); - assertNotSame(loop, channel.eventLoop()); - group.next().register(channel); + Throwable cause = cf.cause(); + if (loop == channel.eventLoop()) { + AssertionError err = new AssertionError("Got same event loop: " + loop); + ThrowableUtil.addSuppressed(err, cause); + eventLoopCheck.tryFailure(err); + } else if (cause != null) { + eventLoopCheck.tryFailure(new AssertionError(cause)); + } else { + eventLoopCheck.trySuccess(null); + group.next().register(channel); + } } }); } @@ -242,6 +254,7 @@ public void operationComplete(ChannelFuture cf) { cc = bootstrap.connect(sc.localAddress()).syncUninterruptibly().channel(); cc.writeAndFlush(Unpooled.wrappedBuffer(bytes)).syncUninterruptibly(); latch.await(); + eventLoopCheck.sync(); } finally { if (cc != null) { cc.close(); @@ -249,13 +262,13 @@ public void operationComplete(ChannelFuture cf) { if (sc != null) { sc.close(); } - group.shutdownGracefully(); + group.shutdownGracefully().sync(); } } @Test - @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) - public void testShutdownOutputAndClose() throws IOException { + @Timeout(30) + public void testShutdownOutputAndClose() throws Exception { NioEventLoopGroup group = new NioEventLoopGroup(1); ServerSocket socket = new ServerSocket(); socket.bind(new InetSocketAddress(0)); @@ -285,7 +298,7 @@ public void testShutdownOutputAndClose() throws IOException { } catch (IOException ignore) { // ignore } - group.shutdownGracefully(); + group.shutdownGracefully().sync(); } }