From 71eacb78194296378eb277e9897a233a3687c387 Mon Sep 17 00:00:00 2001 From: Qiong Wu Date: Mon, 8 Jun 2026 15:50:56 +0800 Subject: [PATCH] fix: pre-quantized models run optimize, only skip quantize - Pre-quantized (QDQ) models now run full optimize+autoconf+analyze, only skip quantize to preserve QDQ structure. - Cache is_quantized_onnx() result to avoid redundant ONNX file reads. - skip_optimize: restore original semantics (run optimize_onnx without autoconf loop, label as skipped). Labels now match actual work done. - Fix generate_build_config onnx_path bug for 'winml build -m model.onnx'. --- src/winml/modelkit/build/hf.py | 65 ++++++++++++++-------------- src/winml/modelkit/build/onnx.py | 59 +++++++++++++------------ src/winml/modelkit/commands/build.py | 5 +-- tests/unit/build/test_hf.py | 22 +++++----- tests/unit/build/test_onnx.py | 22 +++++----- 5 files changed, 85 insertions(+), 88 deletions(-) diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index eaae70ef4..fa1e4e7e8 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -244,44 +244,43 @@ def _name(base: str) -> str: # config.quant=None. This is_quantized_onnx() check is redundant in that # path but kept for backward compatibility when build_hf_model() # is called directly with a hand-built config. - is_pre_quantized = is_quantized_onnx(current_path) or skip_optimize + is_qdq = is_quantized_onnx(current_path) + is_pre_quantized = is_qdq or skip_optimize - if is_pre_quantized: + if is_qdq: logger.info( "Pre-quantized model detected (QDQ nodes present). " - "Skipping optimize + quantize, running analyze-only." + "Skipping quantize to preserve QDQ structure." ) + + # Build kwargs for run_optimize_analyze_loop. When skip_optimize is set, + # we still run optimize_onnx (applies config.optim flags) but skip the + # autoconf re-optimization loop to preserve the original behavior. + run_kwargs: dict = dict( + model_path=current_path, + optimized_path=optimized_path, + config=config, + ep=ep, + device=device, + **onnx_kwargs, + ) + if not skip_optimize: + run_kwargs["max_optim_iterations"] = hack_max_optim_iterations + run_kwargs["allow_unsupported_nodes"] = allow_unsupported_nodes + run_kwargs["analyze_output_path"] = analyze_result_path + + logger.info("Optimizing ONNX model...") + ( + current_path, + opt_elapsed, + analyze_iterations, + analyze_unsupported_nodes, + analyze_details, + ) = run_optimize_analyze_loop(**run_kwargs) + + if skip_optimize: stages_skipped.append("optimize") - # Optimize+analyze only, no autoconf re-optimization - current_path, _, analyze_iterations, analyze_unsupported_nodes, analyze_details = ( - run_optimize_analyze_loop( - model_path=current_path, - optimized_path=optimized_path, - config=config, - ep=ep, - device=device, - **onnx_kwargs, - ) - ) else: - logger.info("Optimizing ONNX model...") - ( - current_path, - opt_elapsed, - analyze_iterations, - analyze_unsupported_nodes, - analyze_details, - ) = run_optimize_analyze_loop( - model_path=current_path, - optimized_path=optimized_path, - config=config, - ep=ep, - device=device, - max_optim_iterations=hack_max_optim_iterations, - allow_unsupported_nodes=allow_unsupported_nodes, - analyze_output_path=analyze_result_path, - **onnx_kwargs, - ) stage_timings["optimize"] = opt_elapsed stages_completed.append("optimize") logger.info("Optimize done (%.1fs) -> %s", opt_elapsed, optimized_path) @@ -303,7 +302,7 @@ def _name(base: str) -> str: # Defensive fallback: catches the edge case where a direct caller # provides config.quant != None but the model already has QDQ nodes # (e.g., hand-built config without running generate_*_build_config). - if is_quantized_onnx(current_path): + if is_qdq: logger.warning( "Model already contains QDQ nodes, skipping quantization. " "Set config.quant=None to silence this warning." diff --git a/src/winml/modelkit/build/onnx.py b/src/winml/modelkit/build/onnx.py index 2e7424e99..ecc18eb46 100644 --- a/src/winml/modelkit/build/onnx.py +++ b/src/winml/modelkit/build/onnx.py @@ -141,7 +141,7 @@ def build_onnx_model( copy_onnx_model(onnx_path, current_path) # ========================================================================= - # [1] OPTIMIZE + ANALYZE (or ANALYZE-ONLY for pre-quantized) + # [1] OPTIMIZE + ANALYZE # FIXME: Stages [1]-[4] (optimize, quantize, compile, finalize) are # duplicated between build_onnx_model() and build_hf_model(). Extract # into a shared run_build_stages() function in common.py. @@ -152,40 +152,39 @@ def build_onnx_model( # config.quant=None. This is_quantized_onnx() check is redundant in that # path but kept for backward compatibility when build_onnx_model() # is called directly with a hand-built config. - is_pre_quantized = is_quantized_onnx(current_path) or skip_optimize + is_qdq = is_quantized_onnx(current_path) + is_pre_quantized = is_qdq or skip_optimize - if is_pre_quantized: + if is_qdq: logger.info( "Pre-quantized model detected (QDQ nodes present). " - "Skipping optimize + quantize, running analyze-only." + "Skipping quantize to preserve QDQ structure." ) + + # Build kwargs for run_optimize_analyze_loop. When skip_optimize is set, + # we still run optimize_onnx (applies config.optim flags) but skip the + # autoconf re-optimization loop to preserve the original behavior. + run_kwargs: dict = dict( + model_path=current_path, + optimized_path=optimized_path, + config=config, + ep=ep, + device=device, + **onnx_kwargs, + ) + if not skip_optimize: + run_kwargs["max_optim_iterations"] = hack_max_optim_iterations + run_kwargs["allow_unsupported_nodes"] = allow_unsupported_nodes + run_kwargs["analyze_output_path"] = analyze_result_path + + logger.info("Optimizing ONNX model...") + current_path, opt_elapsed, analyze_iters, analyze_unsupported, analyze_details = ( + run_optimize_analyze_loop(**run_kwargs) + ) + + if skip_optimize: stages_skipped.append("optimize") - # Optimize+analyze only, no autoconf re-optimization - current_path, _, analyze_iters, analyze_unsupported, analyze_details = ( - run_optimize_analyze_loop( - model_path=current_path, - optimized_path=optimized_path, - config=config, - ep=ep, - device=device, - **onnx_kwargs, - ) - ) else: - logger.info("Optimizing ONNX model...") - current_path, opt_elapsed, analyze_iters, analyze_unsupported, analyze_details = ( - run_optimize_analyze_loop( - model_path=current_path, - optimized_path=optimized_path, - config=config, - ep=ep, - device=device, - max_optim_iterations=hack_max_optim_iterations, - allow_unsupported_nodes=allow_unsupported_nodes, - analyze_output_path=analyze_result_path, - **onnx_kwargs, - ) - ) stage_timings["optimize"] = opt_elapsed stages_completed.append("optimize") logger.info("Optimize done (%.1fs) -> %s", opt_elapsed, optimized_path) @@ -207,7 +206,7 @@ def build_onnx_model( # Defensive fallback: catches the edge case where a direct caller # provides config.quant != None but the model already has QDQ nodes # (e.g., hand-built config without running generate_*_build_config). - if is_quantized_onnx(current_path): + if is_qdq: logger.warning( "Model already contains QDQ nodes, skipping quantization. " "Set config.quant=None to silence this warning." diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index 7fc735baf..38e88d31a 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -591,6 +591,7 @@ def build( model_id, trust_remote_code=trust_remote_code, device=device, + onnx_path=model_id if cli_utils.is_onnx_file_path(model_id) else None, ) if no_quant: config_or_configs.quant = None @@ -643,9 +644,7 @@ def _patch_device(cfg: WinMLBuildConfig) -> None: # scratch state when the user passes the wrong file or a # hand-edited config (#P1 UX). _configs_to_validate: list[WinMLBuildConfig] = ( - config_or_configs - if isinstance(config_or_configs, list) - else [config_or_configs] + config_or_configs if isinstance(config_or_configs, list) else [config_or_configs] ) try: for _cfg in _configs_to_validate: diff --git a/tests/unit/build/test_hf.py b/tests/unit/build/test_hf.py index 38203af44..77d286c4a 100644 --- a/tests/unit/build/test_hf.py +++ b/tests/unit/build/test_hf.py @@ -795,10 +795,10 @@ def test_autoconf_merges_config_for_downstream( class TestBuildHfPreQuantized: """Test pre-quantized detection in HF build pipeline.""" - def test_post_export_qdq_skips_optimize_and_quantize( + def test_post_export_qdq_skips_quantize_but_runs_optimize( self, tmp_path: Path, sample_config, mock_pipeline ) -> None: - """If exported ONNX has QDQ nodes, skip optimize+quantize.""" + """If exported ONNX has QDQ nodes, run optimize but skip quantize.""" mock_pipeline["is_quantized_onnx"].return_value = True output_dir = tmp_path / "output" @@ -807,9 +807,9 @@ def test_post_export_qdq_skips_optimize_and_quantize( output_dir=output_dir, pytorch_model=mock_pipeline["model"], ) - assert "optimize" in result.stages_skipped + assert "optimize" in result.stages_completed + assert "optimize" not in result.stages_skipped assert "quantize" in result.stages_skipped - assert "optimize" not in result.stages_completed assert "quantize" not in result.stages_completed mock_pipeline["optimize"].assert_called_once() mock_pipeline["quantize"].assert_not_called() @@ -844,10 +844,10 @@ def test_post_export_qdq_still_compiles( assert "compile" in result.stages_completed mock_pipeline["compile"].assert_called_once() - def test_post_export_qdq_runs_analyze_only( + def test_post_export_qdq_runs_optimize_and_analyze( self, tmp_path: Path, sample_config, mock_pipeline ) -> None: - """Pre-quantized path runs optimize but skips autoconf (no analyze).""" + """Pre-quantized path runs full optimize with autoconf (analyze runs).""" mock_pipeline["is_quantized_onnx"].return_value = True output_dir = tmp_path / "output" @@ -856,9 +856,9 @@ def test_post_export_qdq_runs_analyze_only( output_dir=output_dir, pytorch_model=mock_pipeline["model"], ) - # max_optim_iterations=0 means no analyze loop runs - mock_pipeline["analyze"].assert_not_called() mock_pipeline["optimize"].assert_called_once() + # Full autoconf: analyze is called by run_optimize_analyze_loop + mock_pipeline["analyze"].assert_called() def test_skip_optimize_kwarg(self, tmp_path: Path, sample_config, mock_pipeline) -> None: """skip_optimize=True forces optimize+quantize skip.""" @@ -913,17 +913,17 @@ def test_analyze_output_path_respects_cache_key( for call in mock_pipeline["analyze"].call_args_list: assert call.kwargs["output_path"] == expected - def test_no_output_path_for_prequantized( + def test_analyze_output_path_for_prequantized( self, tmp_path: Path, sample_config, mock_pipeline ) -> None: - """Pre-quantized path never calls analyze_onnx (no JSON written).""" + """Pre-quantized path runs analyze with output_path (full autoconf).""" mock_pipeline["is_quantized_onnx"].return_value = True build_hf_model( config=sample_config, output_dir=tmp_path / "output", pytorch_model=mock_pipeline["model"], ) - mock_pipeline["analyze"].assert_not_called() + mock_pipeline["analyze"].assert_called() def test_analyze_onnx_writes_json_to_disk(self, tmp_path: Path) -> None: """analyze_onnx with output_path writes a valid JSON file.""" diff --git a/tests/unit/build/test_onnx.py b/tests/unit/build/test_onnx.py index 1c1322907..ab99183b1 100644 --- a/tests/unit/build/test_onnx.py +++ b/tests/unit/build/test_onnx.py @@ -363,10 +363,10 @@ def test_build_onnx_non_quantized_proceeds( assert "quantize" in result.stages_completed mock_onnx_pipeline["quantize"].assert_called_once() - def test_pre_quantized_skips_optimize_and_quantize( + def test_pre_quantized_skips_quantize_but_runs_optimize( self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline ) -> None: - """QDQ model skips both optimize AND quantize stages.""" + """QDQ model runs optimize (with autoconf) but skips quantize.""" mock_onnx_pipeline["is_quantized_onnx"].return_value = True output_dir = tmp_path / "output" @@ -375,9 +375,9 @@ def test_pre_quantized_skips_optimize_and_quantize( config=sample_onnx_config, output_dir=output_dir, ) - assert "optimize" in result.stages_skipped + assert "optimize" in result.stages_completed + assert "optimize" not in result.stages_skipped assert "quantize" in result.stages_skipped - assert "optimize" not in result.stages_completed assert "quantize" not in result.stages_completed mock_onnx_pipeline["optimize"].assert_called_once() mock_onnx_pipeline["quantize"].assert_not_called() @@ -397,10 +397,10 @@ def test_pre_quantized_still_compiles( assert "compile" in result.stages_completed mock_onnx_pipeline["compile"].assert_called_once() - def test_pre_quantized_runs_analyze_only( + def test_pre_quantized_runs_optimize_and_analyze( self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline ) -> None: - """Pre-quantized path runs optimize but skips autoconf (no analyze).""" + """Pre-quantized path runs full optimize with autoconf (analyze runs).""" mock_onnx_pipeline["is_quantized_onnx"].return_value = True output_dir = tmp_path / "output" @@ -409,9 +409,9 @@ def test_pre_quantized_runs_analyze_only( config=sample_onnx_config, output_dir=output_dir, ) - # max_optim_iterations=0 means no analyze loop runs - mock_onnx_pipeline["analyze"].assert_not_called() mock_onnx_pipeline["optimize"].assert_called_once() + # Full autoconf: analyze is called by run_optimize_analyze_loop + mock_onnx_pipeline["analyze"].assert_called() def test_skip_optimize_kwarg( self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline @@ -574,10 +574,10 @@ def test_analyze_onnx_called_with_output_path( for call in mock_onnx_pipeline["analyze"].call_args_list: assert call.kwargs["output_path"] == output_dir / "analyze_result.json" - def test_no_output_path_for_prequantized( + def test_analyze_output_path_for_prequantized( self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline ) -> None: - """Pre-quantized path never calls analyze_onnx (no JSON written).""" + """Pre-quantized path runs analyze with output_path (full autoconf).""" mock_onnx_pipeline["is_quantized_onnx"].return_value = True build_onnx_model(fake_onnx, config=sample_onnx_config, output_dir=tmp_path / "output") - mock_onnx_pipeline["analyze"].assert_not_called() + mock_onnx_pipeline["analyze"].assert_called()