[train] Support streaming mini-batch (non-blocking async training)#1607
[train] Support streaming mini-batch (non-blocking async training)#1607rishithayenumula wants to merge 6 commits intoNovaSky-AI:mainfrom
Conversation
- Add is_training flag to compute_prompt_mini_batch_boundaries() - Allow partial batches during evaluation - Use num_prompts instead of train_batch_size in boundary calculations - Keep strict validation during training for distributed correctness - Add 4 comprehensive tests for eval partial batch scenarios - Backward compatible: default is_training=True preserves training behavior
There was a problem hiding this comment.
Code Review
This pull request introduces support for partial batches during evaluation by adding an is_training flag to the mini-batch boundary computation logic, supported by new unit tests. However, several issues were identified: the .gitignore file was accidentally truncated, removing critical exclusion rules. In fully_async_trainer.py, the replacement of await queue.get() with a polling loop is an inefficient anti-pattern that should be reverted. Furthermore, the removal of a try...finally block in the generator worker loop risks crashing the trainer and leaking resources during task cancellation.
| .venv/ | ||
| .venv311/ | ||
| __pycache__/ | ||
| /wandb/ | ||
| **/*.egg-info/ | ||
| # hydra logs | ||
| /outputs/ | ||
| /data/lcb | ||
|
|
||
| # MkDocs build output (generated during build) | ||
| docs/public/api-ref/ | ||
|
|
||
| # Documentation cache | ||
| .doctrees/ | ||
| .cache/ | ||
| .pytest_cache/ | ||
|
|
||
| # NOTE (sumanthrh): Don't add .env to gitignore. .env file when passed to uv is used to set env vars for each ray worker process. | ||
| # If it's in .gitignore then it won't be a part of the working directory shipped by uv and your env vars will not be set. | ||
| # This will just appear as a warning (silent failure) and you're gonna have a bad time. | ||
| # .env | ||
|
|
||
| # .env files inside directories can be ignored | ||
| /skyrl-gym/.env | ||
|
|
||
| /skyrl-gym/.venv | ||
|
|
||
| # build | ||
| /skyrl-gym/build | ||
| /skyrl-gym/dist | ||
|
|
||
| *.log | ||
| nohup.out | ||
| tensorboard_log/ | ||
|
|
||
| # SQLite database files | ||
| *.db | ||
|
|
||
| # Byte-compiled / optimized / DLL files | ||
| __pycache__/ | ||
| *.py[cod] | ||
| *$py.class | ||
|
|
||
| # C extensions | ||
| *.so | ||
|
|
||
| # Distribution / packaging | ||
| .Python | ||
| build/ | ||
| develop-eggs/ | ||
| dist/ | ||
| downloads/ | ||
| eggs/ | ||
| .eggs/ | ||
| lib/ | ||
| !docs/lib/ | ||
| lib64/ | ||
| parts/ | ||
| sdist/ | ||
| var/ | ||
| wheels/ | ||
| pip-wheel-metadata/ | ||
| share/python-wheels/ | ||
| *.egg-info/ | ||
| .installed.cfg | ||
| *.egg | ||
| MANIFEST | ||
|
|
||
| # PyInstaller | ||
| *.manifest | ||
| *.spec | ||
|
|
||
| # Installer logs | ||
| pip-log.txt | ||
| pip-delete-this-directory.txt | ||
|
|
||
| # Unit test / coverage reports | ||
| htmlcov/ | ||
| .tox/ | ||
| .nox/ | ||
| .coverage | ||
| .coverage.* | ||
| .cache | ||
| nosetests.xml | ||
| coverage.xml | ||
| *.cover | ||
| *.py,cover | ||
| .hypothesis/ | ||
| .pytest_cache/ | ||
|
|
||
| # Jupyter Notebook | ||
| .ipynb_checkpoints | ||
|
|
||
| # Environments | ||
| .env | ||
| .venv | ||
| env/ | ||
| venv/ | ||
| ENV/ | ||
| env.bak/ | ||
| venv.bak/ | ||
|
|
||
| # MkDocs build output | ||
| site/ | ||
|
|
||
| # IDEs and editors | ||
| .idea/ | ||
| .vscode/ | ||
|
|
||
| # OS generated files | ||
| .DS_Store | ||
| Thumbs.db | ||
|
|
||
| # Hydra outputs | ||
| outputs/ | ||
|
|
||
| # Local artifacts | ||
| tinker.db | ||
|
|
||
| # Alembic - don't track pycache | ||
| tx/tinker/alembic/__pycache__/ | ||
|
|
||
| # SQLite databases (tracked in git by default, but ignore if created locally) | ||
| *.db | ||
| *.db-journal | ||
| *.db-wal | ||
| *.db-shm | ||
| *.pyc | ||
| *.egg-info/ No newline at end of file |
There was a problem hiding this comment.
The .gitignore file has been significantly truncated, removing over 100 lines of rules. This deletes critical ignores for environment variables (.env), build artifacts (outputs/, dist/), IDE settings (.vscode/, .idea/), and various cache directories. This appears to be an accidental change that should be reverted to prevent committing sensitive information or large binary artifacts to the repository.
| if "slot_acquired" in locals() and slot_acquired: | ||
| raise RuntimeError("Generation workers should only be cancelled when they finish running.") | ||
| return |
There was a problem hiding this comment.
The removal of the try...finally block around the RuntimeError means that this exception will now propagate and crash the trainer at the end of every epoch. When generator_tasks are cancelled (line 471), any worker currently generating will raise this RuntimeError, which is no longer suppressed. Furthermore, if a worker is cancelled after acquiring a slot but before finishing, the running count in the staleness_manager is leaked, which will cause the validate_state_at_epoch_end assertion to fail. The worker should instead release the slot gracefully upon cancellation.
| if "slot_acquired" in locals() and slot_acquired: | |
| raise RuntimeError("Generation workers should only be cancelled when they finish running.") | |
| return | |
| if "slot_acquired" in locals() and slot_acquired: | |
| await self._staleness_manager.on_rollout_rejected() | |
| return |
| while generation_output_group_buffer.qsize() < self.mini_batch_size: | ||
| # Sleep briefly to avoid busy waiting while generation workers keep running. | ||
| await asyncio.sleep(0.01) | ||
| logger.info(f"Buffer size: {generation_output_group_buffer.qsize()}") | ||
| for _ in range(self.mini_batch_size): | ||
| # We do finish-time FIFO here (not schedule-time FIFO) | ||
| cur_generation_group_mini_batch.append(await generation_output_group_buffer.get()) | ||
| buffer_pbar.update(1) | ||
| buffer_pbar.set_postfix({"buffer qsize": generation_output_group_buffer.qsize()}) | ||
| buffer_pbar.close() | ||
| try: | ||
| cur_generation_group_mini_batch.append(generation_output_group_buffer.get_nowait()) | ||
| except asyncio.QueueEmpty as e: | ||
| raise AssertionError( | ||
| "Generation buffer unexpectedly drained while collecting a mini-batch." | ||
| ) from e |
There was a problem hiding this comment.
Replacing await queue.get() with a polling loop using qsize() and asyncio.sleep(0.01) is an anti-pattern in asynchronous programming. It introduces unnecessary latency (up to 10ms per check) and CPU overhead compared to the built-in synchronization of asyncio.Queue. The original implementation using await buffer.get() in a loop was already non-blocking for the event loop and more efficient, as it leverages the queue's internal notification system to wake up the task exactly when data is available. The motivation of 'removing the blocking wait' seems to be a misunderstanding of how await works in this context.
| while generation_output_group_buffer.qsize() < self.mini_batch_size: | |
| # Sleep briefly to avoid busy waiting while generation workers keep running. | |
| await asyncio.sleep(0.01) | |
| logger.info(f"Buffer size: {generation_output_group_buffer.qsize()}") | |
| for _ in range(self.mini_batch_size): | |
| # We do finish-time FIFO here (not schedule-time FIFO) | |
| cur_generation_group_mini_batch.append(await generation_output_group_buffer.get()) | |
| buffer_pbar.update(1) | |
| buffer_pbar.set_postfix({"buffer qsize": generation_output_group_buffer.qsize()}) | |
| buffer_pbar.close() | |
| try: | |
| cur_generation_group_mini_batch.append(generation_output_group_buffer.get_nowait()) | |
| except asyncio.QueueEmpty as e: | |
| raise AssertionError( | |
| "Generation buffer unexpectedly drained while collecting a mini-batch." | |
| ) from e | |
| while len(cur_generation_group_mini_batch) < self.mini_batch_size: | |
| cur_generation_group_mini_batch.append(await generation_output_group_buffer.get()) | |
| logger.info(f"Buffer size: {generation_output_group_buffer.qsize()}") |
Summary
This PR introduces streaming mini-batch training in the fully async trainer by removing the blocking wait on the generation buffer.
Motivation
Currently, the trainer blocks while waiting for a full mini-batch using
await buffer.get(), which limits overlap between generation and training.This change enables training to start as soon as enough data is available, improving async pipeline utilization.
Key Changes
Replaced blocking
await buffer.get()loop with non-blocking polling (qsize()+get_nowait())Training now triggers once
policy_mini_batch_sizeitems are availablePreserved existing logic for:
Behavior
Notes
Related Issue
Closes #1204