Skip to content

[train] Support streaming mini-batch (non-blocking async training)#1607

Open
rishithayenumula wants to merge 6 commits intoNovaSky-AI:mainfrom
rishithayenumula:feature/streaming-mini-batch
Open

[train] Support streaming mini-batch (non-blocking async training)#1607
rishithayenumula wants to merge 6 commits intoNovaSky-AI:mainfrom
rishithayenumula:feature/streaming-mini-batch

Conversation

@rishithayenumula
Copy link
Copy Markdown

@rishithayenumula rishithayenumula commented May 1, 2026

Summary

This PR introduces streaming mini-batch training in the fully async trainer by removing the blocking wait on the generation buffer.

Motivation

Currently, the trainer blocks while waiting for a full mini-batch using await buffer.get(), which limits overlap between generation and training.

This change enables training to start as soon as enough data is available, improving async pipeline utilization.

Key Changes

  • Replaced blocking await buffer.get() loop with non-blocking polling (qsize() + get_nowait())

  • Training now triggers once policy_mini_batch_size items are available

  • Preserved existing logic for:

    • UID tracking
    • staleness management
    • global step updates
    • weight synchronization

Behavior

  • Trainer no longer idles waiting for full batch
  • Better overlap between generation and training
  • Enables streaming-style execution

Notes

  • May slightly increase policy lag, but remains bounded
  • Full GSM8k run was not executed due to environment constraints, but logic integrates cleanly with existing pipeline

Related Issue

Closes #1204


Open in Devin Review

- Add is_training flag to compute_prompt_mini_batch_boundaries()
- Allow partial batches during evaluation
- Use num_prompts instead of train_batch_size in boundary calculations
- Keep strict validation during training for distributed correctness
- Add 4 comprehensive tests for eval partial batch scenarios
- Backward compatible: default is_training=True preserves training behavior
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for partial batches during evaluation by adding an is_training flag to the mini-batch boundary computation logic, supported by new unit tests. However, several issues were identified: the .gitignore file was accidentally truncated, removing critical exclusion rules. In fully_async_trainer.py, the replacement of await queue.get() with a polling loop is an inefficient anti-pattern that should be reverted. Furthermore, the removal of a try...finally block in the generator worker loop risks crashing the trainer and leaking resources during task cancellation.

Comment thread .gitignore
Comment on lines +1 to +5
.venv/
.venv311/
__pycache__/
/wandb/
**/*.egg-info/
# hydra logs
/outputs/
/data/lcb

# MkDocs build output (generated during build)
docs/public/api-ref/

# Documentation cache
.doctrees/
.cache/
.pytest_cache/

# NOTE (sumanthrh): Don't add .env to gitignore. .env file when passed to uv is used to set env vars for each ray worker process.
# If it's in .gitignore then it won't be a part of the working directory shipped by uv and your env vars will not be set.
# This will just appear as a warning (silent failure) and you're gonna have a bad time.
# .env

# .env files inside directories can be ignored
/skyrl-gym/.env

/skyrl-gym/.venv

# build
/skyrl-gym/build
/skyrl-gym/dist

*.log
nohup.out
tensorboard_log/

# SQLite database files
*.db

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
!docs/lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Jupyter Notebook
.ipynb_checkpoints

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# MkDocs build output
site/

# IDEs and editors
.idea/
.vscode/

# OS generated files
.DS_Store
Thumbs.db

# Hydra outputs
outputs/

# Local artifacts
tinker.db

# Alembic - don't track pycache
tx/tinker/alembic/__pycache__/

# SQLite databases (tracked in git by default, but ignore if created locally)
*.db
*.db-journal
*.db-wal
*.db-shm
*.pyc
*.egg-info/ No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The .gitignore file has been significantly truncated, removing over 100 lines of rules. This deletes critical ignores for environment variables (.env), build artifacts (outputs/, dist/), IDE settings (.vscode/, .idea/), and various cache directories. This appears to be an accidental change that should be reverted to prevent committing sensitive information or large binary artifacts to the repository.

Comment on lines +591 to +593
if "slot_acquired" in locals() and slot_acquired:
raise RuntimeError("Generation workers should only be cancelled when they finish running.")
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try...finally block around the RuntimeError means that this exception will now propagate and crash the trainer at the end of every epoch. When generator_tasks are cancelled (line 471), any worker currently generating will raise this RuntimeError, which is no longer suppressed. Furthermore, if a worker is cancelled after acquiring a slot but before finishing, the running count in the staleness_manager is leaked, which will cause the validate_state_at_epoch_end assertion to fail. The worker should instead release the slot gracefully upon cancellation.

Suggested change
if "slot_acquired" in locals() and slot_acquired:
raise RuntimeError("Generation workers should only be cancelled when they finish running.")
return
if "slot_acquired" in locals() and slot_acquired:
await self._staleness_manager.on_rollout_rejected()
return

Comment on lines +390 to +401
while generation_output_group_buffer.qsize() < self.mini_batch_size:
# Sleep briefly to avoid busy waiting while generation workers keep running.
await asyncio.sleep(0.01)
logger.info(f"Buffer size: {generation_output_group_buffer.qsize()}")
for _ in range(self.mini_batch_size):
# We do finish-time FIFO here (not schedule-time FIFO)
cur_generation_group_mini_batch.append(await generation_output_group_buffer.get())
buffer_pbar.update(1)
buffer_pbar.set_postfix({"buffer qsize": generation_output_group_buffer.qsize()})
buffer_pbar.close()
try:
cur_generation_group_mini_batch.append(generation_output_group_buffer.get_nowait())
except asyncio.QueueEmpty as e:
raise AssertionError(
"Generation buffer unexpectedly drained while collecting a mini-batch."
) from e
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Replacing await queue.get() with a polling loop using qsize() and asyncio.sleep(0.01) is an anti-pattern in asynchronous programming. It introduces unnecessary latency (up to 10ms per check) and CPU overhead compared to the built-in synchronization of asyncio.Queue. The original implementation using await buffer.get() in a loop was already non-blocking for the event loop and more efficient, as it leverages the queue's internal notification system to wake up the task exactly when data is available. The motivation of 'removing the blocking wait' seems to be a misunderstanding of how await works in this context.

Suggested change
while generation_output_group_buffer.qsize() < self.mini_batch_size:
# Sleep briefly to avoid busy waiting while generation workers keep running.
await asyncio.sleep(0.01)
logger.info(f"Buffer size: {generation_output_group_buffer.qsize()}")
for _ in range(self.mini_batch_size):
# We do finish-time FIFO here (not schedule-time FIFO)
cur_generation_group_mini_batch.append(await generation_output_group_buffer.get())
buffer_pbar.update(1)
buffer_pbar.set_postfix({"buffer qsize": generation_output_group_buffer.qsize()})
buffer_pbar.close()
try:
cur_generation_group_mini_batch.append(generation_output_group_buffer.get_nowait())
except asyncio.QueueEmpty as e:
raise AssertionError(
"Generation buffer unexpectedly drained while collecting a mini-batch."
) from e
while len(cur_generation_group_mini_batch) < self.mini_batch_size:
cur_generation_group_mini_batch.append(await generation_output_group_buffer.get())
logger.info(f"Buffer size: {generation_output_group_buffer.qsize()}")

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 4 additional findings.

Open in Devin Review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[train] Support streming mini-batch (i.e. mini-batch-level overlapping async training)

1 participant