Fix AssertionError during eval when val set size is not divisible by train_batch_size#1589
Fix AssertionError during eval when val set size is not divisible by train_batch_size#1589rishithayenumula wants to merge 5 commits intoNovaSky-AI:mainfrom
Conversation
- Add is_training flag to compute_prompt_mini_batch_boundaries() - Allow partial batches during evaluation - Use num_prompts instead of train_batch_size in boundary calculations - Keep strict validation during training for distributed correctness - Add 4 comprehensive tests for eval partial batch scenarios - Backward compatible: default is_training=True preserves training behavior
There was a problem hiding this comment.
Code Review
This pull request introduces support for partial batches during evaluation by adding an is_training flag to the mini-batch boundary computation logic, preventing crashes when validation sets are not perfectly divisible by the batch size. Feedback identifies a NameError in the new test cases where assertions were incorrectly moved, and suggests restoring jaxtyping annotations for better documentation. Additionally, it is recommended to lower the logging level for partial batch detection during evaluation to reduce output noise.
| # Non-step-wise boundaries should be uniform | ||
| assert non_stepwise_bounds == [(0, 640), (640, 1280)] |
There was a problem hiding this comment.
The last two lines of test_same_step_count_as_non_stepwise were accidentally moved into the body of the new test_eval_stepwise_partial_batch method. This will cause a NameError when running the tests because non_stepwise_bounds is not defined within the scope of the new test. These lines should be moved back to the end of their original test function.
|
|
||
| import torch | ||
| from jaxtyping import Float, Integer | ||
| from transformers import AutoTokenizer |
There was a problem hiding this comment.
The jaxtyping imports and associated type annotations were removed in this file. These annotations provide valuable documentation regarding tensor shapes and dtypes, which is particularly helpful in complex batching logic. Unless there is a specific reason for their removal, they should be retained to maintain code clarity and type safety.
| from transformers import AutoTokenizer | |
| from jaxtyping import Float, Integer |
| torch.Tensor, | ||
| torch.Tensor, | ||
| torch.Tensor, | ||
| torch.Tensor, | ||
| torch.Tensor, | ||
| Optional[torch.Tensor], | ||
| Optional[torch.Tensor], |
There was a problem hiding this comment.
The detailed shape annotations for the return types were removed. It is recommended to keep these for better maintainability and readability.
| torch.Tensor, | |
| torch.Tensor, | |
| torch.Tensor, | |
| torch.Tensor, | |
| torch.Tensor, | |
| Optional[torch.Tensor], | |
| Optional[torch.Tensor], | |
| Float[torch.Tensor, "batch seq_len"], | |
| Float[torch.Tensor, "batch seq_len"], | |
| Float[torch.Tensor, "batch response_len"], | |
| Float[torch.Tensor, "batch response_len"], | |
| Float[torch.Tensor, "batch response_len"], | |
| Optional[Float[torch.Tensor, "batch response_len"]], | |
| Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]], |
| logger.warning( | ||
| f"Partial batch detected during eval: got {num_prompts} prompts but " | ||
| f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries." | ||
| ) |
There was a problem hiding this comment.
Logging a warning for partial batches during evaluation is likely too noisy. When drop_last=False (standard for evaluation), the final batch is expected to be partial if the dataset size is not a multiple of the batch size. Consider using logger.info or removing the log entirely to avoid cluttering the output with expected behavior.
| logger.warning( | |
| f"Partial batch detected during eval: got {num_prompts} prompts but " | |
| f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries." | |
| ) | |
| logger.info( | |
| f"Partial batch detected during eval: got {num_prompts} prompts but " | |
| f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries." | |
| ) |
Problem
compute_prompt_mini_batch_boundariesassumes that all batches have size equal totrain_batch_size, which holds during training (drop_last=True) but not during evaluation (drop_last=False).During evaluation, the final batch can be smaller, leading to an
AssertionErrorwhenconvert_to_training_inputis invoked on partial batches.Solution
is_trainingflag tocompute_prompt_mini_batch_boundariesis_training=True):is_training=False):num_promptsinstead oftrain_batch_sizefor boundary calculationsChanges
is_trainingflag tocompute_prompt_mini_batch_boundariesnum_promptsis_training=True)Result
train_batch_size