diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index 30ce9194..68287e42 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -618,7 +618,7 @@ def main(): print_with_rank("Loading/building dataset cache...") dataset = Dataset.from_generator( generator=safe_conversations_generator, - gen_kwargs={"file_path": args.data_path}, + gen_kwargs={"file_path": args.data_path, "is_vlm": args.is_vlm}, cache_dir=os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "cache", diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 0bd157b3..ab262979 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -456,7 +456,7 @@ def build_dataloaders( cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() train_dataset = Dataset.from_generator( generator=safe_conversations_generator, - gen_kwargs={"file_path": args.train_data_path}, + gen_kwargs={"file_path": args.train_data_path, "is_vlm": args.is_vlm}, ) is_online = ( args.train_data_path is not None and args.train_hidden_states_path is None @@ -507,7 +507,7 @@ def build_dataloaders( if args.eval_data_path is not None: eval_dataset = Dataset.from_generator( generator=safe_conversations_generator, - gen_kwargs={"file_path": args.eval_data_path}, + gen_kwargs={"file_path": args.eval_data_path, "is_vlm": args.is_vlm}, ) eval_eagle3_dataset = build_eagle3_dataset( eval_dataset, diff --git a/specforge/utils.py b/specforge/utils.py index af4d627c..83f81faa 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -328,12 +328,15 @@ def shard_optimizer_state_with_dtensor(bf16_optimizer, device_mesh): ) -def safe_conversations_generator(file_path): +def safe_conversations_generator(file_path, is_vlm=False): """ Generator that: 1. Extracts the 'conversations' field. 2. Preserves all original fields within each message. 3. [Key step] Converts all list/dict-type field values to strings to resolve mixed-type conflicts (e.g., for Arrow compatibility). + Args: + file_path: Path to the JSONL file. + is_vlm: If True, include 'image' field for vision-language models. Default False. """ with open(file_path, "r", encoding="utf-8") as f: for i, line in enumerate(f): @@ -376,8 +379,14 @@ def safe_conversations_generator(file_path): cleaned_convs.append(new_msg) - # Build result with conversations - result = {"conversations": cleaned_convs} + if is_vlm: + image = row.get("image", "") + result = { + "conversations": cleaned_convs, + "image": image, + } + else: + result = {"conversations": cleaned_convs} # Preserve 'tools' field if present if "tools" in row: