Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/prepare_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def main():
print_with_rank("Loading/building dataset cache...")
dataset = Dataset.from_generator(
generator=safe_conversations_generator,
gen_kwargs={"file_path": args.data_path},
gen_kwargs={"file_path": args.data_path, "is_vlm": args.is_vlm},
cache_dir=os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"cache",
Expand Down
4 changes: 2 additions & 2 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def build_dataloaders(
cache_key = hashlib.md5(cache_params_string.encode()).hexdigest()
train_dataset = Dataset.from_generator(
generator=safe_conversations_generator,
gen_kwargs={"file_path": args.train_data_path},
gen_kwargs={"file_path": args.train_data_path, "is_vlm": args.is_vlm},
)
is_online = (
args.train_data_path is not None and args.train_hidden_states_path is None
Expand Down Expand Up @@ -507,7 +507,7 @@ def build_dataloaders(
if args.eval_data_path is not None:
eval_dataset = Dataset.from_generator(
generator=safe_conversations_generator,
gen_kwargs={"file_path": args.eval_data_path},
gen_kwargs={"file_path": args.eval_data_path, "is_vlm": args.is_vlm},
)
eval_eagle3_dataset = build_eagle3_dataset(
eval_dataset,
Expand Down
15 changes: 12 additions & 3 deletions specforge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,15 @@ def shard_optimizer_state_with_dtensor(bf16_optimizer, device_mesh):
)


def safe_conversations_generator(file_path):
def safe_conversations_generator(file_path, is_vlm=False):
"""
Generator that:
1. Extracts the 'conversations' field.
2. Preserves all original fields within each message.
3. [Key step] Converts all list/dict-type field values to strings to resolve mixed-type conflicts (e.g., for Arrow compatibility).
Args:
file_path: Path to the JSONL file.
is_vlm: If True, include 'image' field for vision-language models. Default False.
"""
with open(file_path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
Expand Down Expand Up @@ -376,8 +379,14 @@ def safe_conversations_generator(file_path):

cleaned_convs.append(new_msg)

# Build result with conversations
result = {"conversations": cleaned_convs}
if is_vlm:
image = row.get("image", "")
result = {
"conversations": cleaned_convs,
"image": image,
}
else:
result = {"conversations": cleaned_convs}

# Preserve 'tools' field if present
if "tools" in row:
Expand Down