Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 32 additions & 33 deletions src/winml/modelkit/build/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,44 +244,43 @@ def _name(base: str) -> str:
# config.quant=None. This is_quantized_onnx() check is redundant in that
# path but kept for backward compatibility when build_hf_model()
# is called directly with a hand-built config.
is_pre_quantized = is_quantized_onnx(current_path) or skip_optimize
is_qdq = is_quantized_onnx(current_path)
is_pre_quantized = is_qdq or skip_optimize

if is_pre_quantized:
if is_qdq:
logger.info(
"Pre-quantized model detected (QDQ nodes present). "
"Skipping optimize + quantize, running analyze-only."
"Skipping quantize to preserve QDQ structure."
)

# Build kwargs for run_optimize_analyze_loop. When skip_optimize is set,
# we still run optimize_onnx (applies config.optim flags) but skip the
# autoconf re-optimization loop to preserve the original behavior.
run_kwargs: dict = dict(
model_path=current_path,
optimized_path=optimized_path,
config=config,
ep=ep,
device=device,
**onnx_kwargs,
)
if not skip_optimize:
run_kwargs["max_optim_iterations"] = hack_max_optim_iterations
run_kwargs["allow_unsupported_nodes"] = allow_unsupported_nodes
run_kwargs["analyze_output_path"] = analyze_result_path

logger.info("Optimizing ONNX model...")
(
current_path,
opt_elapsed,
analyze_iterations,
analyze_unsupported_nodes,
analyze_details,
) = run_optimize_analyze_loop(**run_kwargs)

if skip_optimize:
stages_skipped.append("optimize")
# Optimize+analyze only, no autoconf re-optimization
current_path, _, analyze_iterations, analyze_unsupported_nodes, analyze_details = (
run_optimize_analyze_loop(
model_path=current_path,
optimized_path=optimized_path,
config=config,
ep=ep,
device=device,
**onnx_kwargs,
)
)
else:
logger.info("Optimizing ONNX model...")
(
current_path,
opt_elapsed,
analyze_iterations,
analyze_unsupported_nodes,
analyze_details,
) = run_optimize_analyze_loop(
model_path=current_path,
optimized_path=optimized_path,
config=config,
ep=ep,
device=device,
max_optim_iterations=hack_max_optim_iterations,
allow_unsupported_nodes=allow_unsupported_nodes,
analyze_output_path=analyze_result_path,
**onnx_kwargs,
)
stage_timings["optimize"] = opt_elapsed
stages_completed.append("optimize")
logger.info("Optimize done (%.1fs) -> %s", opt_elapsed, optimized_path)
Expand All @@ -303,7 +302,7 @@ def _name(base: str) -> str:
# Defensive fallback: catches the edge case where a direct caller
# provides config.quant != None but the model already has QDQ nodes
# (e.g., hand-built config without running generate_*_build_config).
if is_quantized_onnx(current_path):
if is_qdq:
logger.warning(
"Model already contains QDQ nodes, skipping quantization. "
"Set config.quant=None to silence this warning."
Expand Down
59 changes: 29 additions & 30 deletions src/winml/modelkit/build/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def build_onnx_model(
copy_onnx_model(onnx_path, current_path)

# =========================================================================
# [1] OPTIMIZE + ANALYZE (or ANALYZE-ONLY for pre-quantized)
# [1] OPTIMIZE + ANALYZE
# FIXME: Stages [1]-[4] (optimize, quantize, compile, finalize) are
# duplicated between build_onnx_model() and build_hf_model(). Extract
# into a shared run_build_stages() function in common.py.
Expand All @@ -152,40 +152,39 @@ def build_onnx_model(
# config.quant=None. This is_quantized_onnx() check is redundant in that
# path but kept for backward compatibility when build_onnx_model()
# is called directly with a hand-built config.
is_pre_quantized = is_quantized_onnx(current_path) or skip_optimize
is_qdq = is_quantized_onnx(current_path)
is_pre_quantized = is_qdq or skip_optimize

if is_pre_quantized:
if is_qdq:
logger.info(
"Pre-quantized model detected (QDQ nodes present). "
"Skipping optimize + quantize, running analyze-only."
"Skipping quantize to preserve QDQ structure."
)
Comment thread
DingmaomaoBJTU marked this conversation as resolved.

# Build kwargs for run_optimize_analyze_loop. When skip_optimize is set,
# we still run optimize_onnx (applies config.optim flags) but skip the
# autoconf re-optimization loop to preserve the original behavior.
run_kwargs: dict = dict(
model_path=current_path,
optimized_path=optimized_path,
config=config,
ep=ep,
device=device,
**onnx_kwargs,
)
if not skip_optimize:
run_kwargs["max_optim_iterations"] = hack_max_optim_iterations
run_kwargs["allow_unsupported_nodes"] = allow_unsupported_nodes
run_kwargs["analyze_output_path"] = analyze_result_path

logger.info("Optimizing ONNX model...")
current_path, opt_elapsed, analyze_iters, analyze_unsupported, analyze_details = (
run_optimize_analyze_loop(**run_kwargs)
)

if skip_optimize:
stages_skipped.append("optimize")
# Optimize+analyze only, no autoconf re-optimization
current_path, _, analyze_iters, analyze_unsupported, analyze_details = (
run_optimize_analyze_loop(
model_path=current_path,
optimized_path=optimized_path,
config=config,
ep=ep,
device=device,
**onnx_kwargs,
)
)
else:
logger.info("Optimizing ONNX model...")
current_path, opt_elapsed, analyze_iters, analyze_unsupported, analyze_details = (
run_optimize_analyze_loop(
model_path=current_path,
optimized_path=optimized_path,
config=config,
ep=ep,
device=device,
max_optim_iterations=hack_max_optim_iterations,
allow_unsupported_nodes=allow_unsupported_nodes,
analyze_output_path=analyze_result_path,
**onnx_kwargs,
)
)
stage_timings["optimize"] = opt_elapsed
stages_completed.append("optimize")
logger.info("Optimize done (%.1fs) -> %s", opt_elapsed, optimized_path)
Expand All @@ -207,7 +206,7 @@ def build_onnx_model(
# Defensive fallback: catches the edge case where a direct caller
# provides config.quant != None but the model already has QDQ nodes
# (e.g., hand-built config without running generate_*_build_config).
if is_quantized_onnx(current_path):
if is_qdq:
logger.warning(
"Model already contains QDQ nodes, skipping quantization. "
"Set config.quant=None to silence this warning."
Expand Down
5 changes: 2 additions & 3 deletions src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ def build(
model_id,
trust_remote_code=trust_remote_code,
device=device,
onnx_path=model_id if cli_utils.is_onnx_file_path(model_id) else None,
Comment thread
DingmaomaoBJTU marked this conversation as resolved.
)
if no_quant:
config_or_configs.quant = None
Expand Down Expand Up @@ -643,9 +644,7 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:
# scratch state when the user passes the wrong file or a
# hand-edited config (#P1 UX).
_configs_to_validate: list[WinMLBuildConfig] = (
config_or_configs
if isinstance(config_or_configs, list)
else [config_or_configs]
config_or_configs if isinstance(config_or_configs, list) else [config_or_configs]
)
try:
for _cfg in _configs_to_validate:
Expand Down
22 changes: 11 additions & 11 deletions tests/unit/build/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,10 +795,10 @@ def test_autoconf_merges_config_for_downstream(
class TestBuildHfPreQuantized:
"""Test pre-quantized detection in HF build pipeline."""

def test_post_export_qdq_skips_optimize_and_quantize(
def test_post_export_qdq_skips_quantize_but_runs_optimize(
self, tmp_path: Path, sample_config, mock_pipeline
) -> None:
"""If exported ONNX has QDQ nodes, skip optimize+quantize."""
"""If exported ONNX has QDQ nodes, run optimize but skip quantize."""
mock_pipeline["is_quantized_onnx"].return_value = True

output_dir = tmp_path / "output"
Expand All @@ -807,9 +807,9 @@ def test_post_export_qdq_skips_optimize_and_quantize(
output_dir=output_dir,
pytorch_model=mock_pipeline["model"],
)
assert "optimize" in result.stages_skipped
assert "optimize" in result.stages_completed
assert "optimize" not in result.stages_skipped
assert "quantize" in result.stages_skipped
assert "optimize" not in result.stages_completed
assert "quantize" not in result.stages_completed
mock_pipeline["optimize"].assert_called_once()
mock_pipeline["quantize"].assert_not_called()
Expand Down Expand Up @@ -844,10 +844,10 @@ def test_post_export_qdq_still_compiles(
assert "compile" in result.stages_completed
mock_pipeline["compile"].assert_called_once()

def test_post_export_qdq_runs_analyze_only(
def test_post_export_qdq_runs_optimize_and_analyze(
self, tmp_path: Path, sample_config, mock_pipeline
) -> None:
"""Pre-quantized path runs optimize but skips autoconf (no analyze)."""
"""Pre-quantized path runs full optimize with autoconf (analyze runs)."""
mock_pipeline["is_quantized_onnx"].return_value = True

output_dir = tmp_path / "output"
Expand All @@ -856,9 +856,9 @@ def test_post_export_qdq_runs_analyze_only(
output_dir=output_dir,
pytorch_model=mock_pipeline["model"],
)
# max_optim_iterations=0 means no analyze loop runs
mock_pipeline["analyze"].assert_not_called()
mock_pipeline["optimize"].assert_called_once()
# Full autoconf: analyze is called by run_optimize_analyze_loop
mock_pipeline["analyze"].assert_called()

def test_skip_optimize_kwarg(self, tmp_path: Path, sample_config, mock_pipeline) -> None:
"""skip_optimize=True forces optimize+quantize skip."""
Expand Down Expand Up @@ -913,17 +913,17 @@ def test_analyze_output_path_respects_cache_key(
for call in mock_pipeline["analyze"].call_args_list:
assert call.kwargs["output_path"] == expected

def test_no_output_path_for_prequantized(
def test_analyze_output_path_for_prequantized(
self, tmp_path: Path, sample_config, mock_pipeline
) -> None:
"""Pre-quantized path never calls analyze_onnx (no JSON written)."""
"""Pre-quantized path runs analyze with output_path (full autoconf)."""
mock_pipeline["is_quantized_onnx"].return_value = True
build_hf_model(
config=sample_config,
output_dir=tmp_path / "output",
pytorch_model=mock_pipeline["model"],
)
mock_pipeline["analyze"].assert_not_called()
mock_pipeline["analyze"].assert_called()

def test_analyze_onnx_writes_json_to_disk(self, tmp_path: Path) -> None:
"""analyze_onnx with output_path writes a valid JSON file."""
Expand Down
22 changes: 11 additions & 11 deletions tests/unit/build/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,10 @@ def test_build_onnx_non_quantized_proceeds(
assert "quantize" in result.stages_completed
mock_onnx_pipeline["quantize"].assert_called_once()

def test_pre_quantized_skips_optimize_and_quantize(
def test_pre_quantized_skips_quantize_but_runs_optimize(
self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline
) -> None:
"""QDQ model skips both optimize AND quantize stages."""
"""QDQ model runs optimize (with autoconf) but skips quantize."""
mock_onnx_pipeline["is_quantized_onnx"].return_value = True

output_dir = tmp_path / "output"
Expand All @@ -375,9 +375,9 @@ def test_pre_quantized_skips_optimize_and_quantize(
config=sample_onnx_config,
output_dir=output_dir,
)
assert "optimize" in result.stages_skipped
assert "optimize" in result.stages_completed
assert "optimize" not in result.stages_skipped
assert "quantize" in result.stages_skipped
assert "optimize" not in result.stages_completed
assert "quantize" not in result.stages_completed
mock_onnx_pipeline["optimize"].assert_called_once()
mock_onnx_pipeline["quantize"].assert_not_called()
Expand All @@ -397,10 +397,10 @@ def test_pre_quantized_still_compiles(
assert "compile" in result.stages_completed
mock_onnx_pipeline["compile"].assert_called_once()

def test_pre_quantized_runs_analyze_only(
def test_pre_quantized_runs_optimize_and_analyze(
self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline
) -> None:
"""Pre-quantized path runs optimize but skips autoconf (no analyze)."""
"""Pre-quantized path runs full optimize with autoconf (analyze runs)."""
mock_onnx_pipeline["is_quantized_onnx"].return_value = True

output_dir = tmp_path / "output"
Expand All @@ -409,9 +409,9 @@ def test_pre_quantized_runs_analyze_only(
config=sample_onnx_config,
output_dir=output_dir,
)
# max_optim_iterations=0 means no analyze loop runs
mock_onnx_pipeline["analyze"].assert_not_called()
mock_onnx_pipeline["optimize"].assert_called_once()
# Full autoconf: analyze is called by run_optimize_analyze_loop
mock_onnx_pipeline["analyze"].assert_called()

def test_skip_optimize_kwarg(
self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline
Expand Down Expand Up @@ -574,10 +574,10 @@ def test_analyze_onnx_called_with_output_path(
for call in mock_onnx_pipeline["analyze"].call_args_list:
assert call.kwargs["output_path"] == output_dir / "analyze_result.json"

def test_no_output_path_for_prequantized(
def test_analyze_output_path_for_prequantized(
self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline
) -> None:
"""Pre-quantized path never calls analyze_onnx (no JSON written)."""
"""Pre-quantized path runs analyze with output_path (full autoconf)."""
mock_onnx_pipeline["is_quantized_onnx"].return_value = True
build_onnx_model(fake_onnx, config=sample_onnx_config, output_dir=tmp_path / "output")
mock_onnx_pipeline["analyze"].assert_not_called()
mock_onnx_pipeline["analyze"].assert_called()
Loading