Skip to content

Fix AssertionError during eval when val set size is not divisible by train_batch_size#1589

Open
rishithayenumula wants to merge 5 commits intoNovaSky-AI:mainfrom
rishithayenumula:fix-eval-batch-assertion
Open

Fix AssertionError during eval when val set size is not divisible by train_batch_size#1589
rishithayenumula wants to merge 5 commits intoNovaSky-AI:mainfrom
rishithayenumula:fix-eval-batch-assertion

Conversation

@rishithayenumula
Copy link
Copy Markdown

@rishithayenumula rishithayenumula commented Apr 29, 2026

Problem

compute_prompt_mini_batch_boundaries assumes that all batches have size equal to train_batch_size, which holds during training (drop_last=True) but not during evaluation (drop_last=False).

During evaluation, the final batch can be smaller, leading to an AssertionError when convert_to_training_input is invoked on partial batches.


Solution

  • Added an is_training flag to compute_prompt_mini_batch_boundaries
  • Training mode (is_training=True):
    • Retains strict assertion to enforce full batches
  • Evaluation mode (is_training=False):
    • Allows partial batches
    • Uses num_prompts instead of train_batch_size for boundary calculations
    • Logs a warning for partial batches

Changes

  • Added is_training flag to compute_prompt_mini_batch_boundaries
  • Updated boundary logic to use num_prompts
  • Preserved strict validation during training
  • Added tests for evaluation with partial batches
  • Maintained backward compatibility (default is_training=True)

Result

  • Prevents crashes during evaluation when validation set size is not divisible by train_batch_size
  • Keeps training behavior unchanged
  • Improves robustness for custom evaluation pipelines

Open in Devin Review Fixes #1583

- Add is_training flag to compute_prompt_mini_batch_boundaries()
- Allow partial batches during evaluation
- Use num_prompts instead of train_batch_size in boundary calculations
- Keep strict validation during training for distributed correctness
- Add 4 comprehensive tests for eval partial batch scenarios
- Backward compatible: default is_training=True preserves training behavior
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for partial batches during evaluation by adding an is_training flag to the mini-batch boundary computation logic, preventing crashes when validation sets are not perfectly divisible by the batch size. Feedback identifies a NameError in the new test cases where assertions were incorrectly moved, and suggests restoring jaxtyping annotations for better documentation. Additionally, it is recommended to lower the logging level for partial batch detection during evaluation to reduce output noise.

Comment thread tests/train/test_prompt_mini_batch.py Outdated
Comment on lines 281 to 282
# Non-step-wise boundaries should be uniform
assert non_stepwise_bounds == [(0, 640), (640, 1280)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The last two lines of test_same_step_count_as_non_stepwise were accidentally moved into the body of the new test_eval_stepwise_partial_batch method. This will cause a NameError when running the tests because non_stepwise_bounds is not defined within the scope of the new test. These lines should be moved back to the end of their original test function.


import torch
from jaxtyping import Float, Integer
from transformers import AutoTokenizer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The jaxtyping imports and associated type annotations were removed in this file. These annotations provide valuable documentation regarding tensor shapes and dtypes, which is particularly helpful in complex batching logic. Unless there is a specific reason for their removal, they should be retained to maintain code clarity and type safety.

Suggested change
from transformers import AutoTokenizer
from jaxtyping import Float, Integer

Comment thread skyrl/train/dataset/preprocess.py Outdated
Comment on lines +41 to +47
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The detailed shape annotations for the return types were removed. It is recommended to keep these for better maintainability and readability.

Suggested change
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
Float[torch.Tensor, "batch seq_len"],
Float[torch.Tensor, "batch seq_len"],
Float[torch.Tensor, "batch response_len"],
Float[torch.Tensor, "batch response_len"],
Float[torch.Tensor, "batch response_len"],
Optional[Float[torch.Tensor, "batch response_len"]],
Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]],

Comment thread skyrl/train/dataset/preprocess.py Outdated
Comment on lines +258 to +261
logger.warning(
f"Partial batch detected during eval: got {num_prompts} prompts but "
f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Logging a warning for partial batches during evaluation is likely too noisy. When drop_last=False (standard for evaluation), the final batch is expected to be partial if the dataset size is not a multiple of the batch size. Consider using logger.info or removing the log entirely to avoid cluttering the output with expected behavior.

Suggested change
logger.warning(
f"Partial batch detected during eval: got {num_prompts} prompts but "
f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries."
)
logger.info(
f"Partial batch detected during eval: got {num_prompts} prompts but "
f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries."
)

devin-ai-integration[bot]

This comment was marked as resolved.

@CharlieFRuan CharlieFRuan self-assigned this Apr 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AssertionError crash on eval when val-set size is not a multiple of train_batch_size

2 participants