diff --git a/src/winml/modelkit/build/onnx.py b/src/winml/modelkit/build/onnx.py index 2e7424e99..5c48c2e7c 100644 --- a/src/winml/modelkit/build/onnx.py +++ b/src/winml/modelkit/build/onnx.py @@ -42,6 +42,7 @@ def build_onnx_model( rebuild: bool = False, ep: EPNameOrAlias | None = None, device: str | None = None, + cache_key: str | None = None, **kwargs: Any, ) -> BuildResult: """Build from a pre-exported ONNX model. @@ -58,6 +59,9 @@ def build_onnx_model( rebuild: Force rebuild even if output exists. ep: Target execution provider for the analyzer (e.g., ``"qnn"``). device: Target device for the analyzer (e.g., ``"NPU"``). + cache_key: Optional prefix for artifact filenames, enabling multiple + task/config variants to coexist in one directory. When set, all + artifact files are prefixed (e.g., ``"{cache_key}_model.onnx"``). **kwargs: Additional options: - ``hack_max_optim_iterations`` (int, default 3): Max analyzer iterations. 0 disables analyzer. @@ -101,15 +105,19 @@ def build_onnx_model( start_time = time.monotonic() output_dir.mkdir(parents=True, exist_ok=True) + # Artifact naming — optionally prefixed when cache_key is set so that + # multiple task/config variants can coexist in one directory. + def _name(base: str) -> str: + return f"{cache_key}_{base}" if cache_key else base + # Define output paths - stem = onnx_path.stem - optimized_path = output_dir / f"{stem}_optimized.onnx" - quantized_path = output_dir / f"{stem}_quantized.onnx" - compiled_path = output_dir / f"{stem}_compiled.onnx" - final_path = output_dir / "model.onnx" - config_path = output_dir / "winml_build_config.json" - manifest_path = output_dir / "build_manifest.json" - analyze_result_path = output_dir / "analyze_result.json" + optimized_path = output_dir / _name("optimized.onnx") + quantized_path = output_dir / _name("quantized.onnx") + compiled_path = output_dir / _name("compiled.onnx") + final_path = output_dir / _name("model.onnx") + config_path = output_dir / _name("winml_build_config.json") + manifest_path = output_dir / _name("build_manifest.json") + analyze_result_path = output_dir / _name("analyze_result.json") # Check for existing artifact (skip build if present and not rebuilding) if final_path.exists() and not rebuild: @@ -124,10 +132,12 @@ def build_onnx_model( # Rebuild: clean old ONNX artifacts to prevent stale files if rebuild: - for old in output_dir.glob("*.onnx"): + pattern = f"{cache_key}_*.onnx" if cache_key else "*.onnx" + for old in output_dir.glob(pattern): old.unlink() logger.debug("Removed old artifact: %s", old.name) - for old in output_dir.glob("*.onnx.data"): + data_pattern = f"{cache_key}_*.onnx.data" if cache_key else "*.onnx.data" + for old in output_dir.glob(data_pattern): old.unlink() logger.debug("Removed old external data sidecar: %s", old.name) diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index a14d5b162..0b7e6c129 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -185,7 +185,7 @@ def from_onnx( # Skip build for compiled models or explicit skip. # Check is_compiled_onnx directly — don't rely on config shape alone # because auto+auto also produces quant=None, compile=None for raw models. - from ..onnx import is_compiled_onnx + from ..onnx import get_onnx_model_hash, is_compiled_onnx if skip_build or is_compiled_onnx(onnx_path): logger.info("Skipping build (compiled model or explicit skip). Using original ONNX.") @@ -199,10 +199,15 @@ def from_onnx( ep=ep, ) - # Resolve output directory + # Resolve output directory and cache key + task_abbrev = get_task_abbrev(resolved_task) if resolved_task else "onnx" + cache_key = get_cache_key(task_abbrev, config.generate_cache_key()) if use_cache: cache_dir_path = get_cache_dir(override=cache_dir) - output_dir = get_model_dir(onnx_path.stem, cache_dir=cache_dir_path) + output_dir = get_model_dir( + f"onnx-{get_onnx_model_hash(onnx_path)}", + cache_dir=cache_dir_path, + ) else: import tempfile @@ -221,6 +226,7 @@ def from_onnx( rebuild=force_rebuild, ep=ep, device=device, + cache_key=cache_key, allow_unsupported_nodes=allow_unsupported_nodes, **kwargs, ) diff --git a/src/winml/modelkit/onnx/__init__.py b/src/winml/modelkit/onnx/__init__.py index a3bc49d51..4f6255386 100644 --- a/src/winml/modelkit/onnx/__init__.py +++ b/src/winml/modelkit/onnx/__init__.py @@ -15,7 +15,7 @@ from .domains import ONNXDomain from .dtypes import SupportedONNXType, remove_optional_from_type_annotation -from .external_data import copy_onnx_model +from .external_data import copy_onnx_model, get_onnx_model_hash from .io import InputTensorSpec, OutputTensorSpec, generate_inputs_from_onnx, get_io_config from .metadata import capture_metadata, restore_metadata from .persistence import cleanup_onnx, load_onnx, save_onnx @@ -36,6 +36,7 @@ "generate_inputs_from_onnx", "get_io_config", "get_model_size", + "get_onnx_model_hash", "infer_onnx_shapes", "infer_shapes", "is_compiled_onnx", diff --git a/src/winml/modelkit/onnx/external_data.py b/src/winml/modelkit/onnx/external_data.py index 658a4847c..afe0c3805 100644 --- a/src/winml/modelkit/onnx/external_data.py +++ b/src/winml/modelkit/onnx/external_data.py @@ -16,6 +16,7 @@ from __future__ import annotations +import hashlib import logging import shutil from pathlib import Path @@ -154,6 +155,37 @@ def has_external_data(model_path: str | Path) -> bool: return len(get_external_data_files(model_path)) > 0 +def _update_hash_from_file(hash_obj: Any, path: Path) -> None: + """Stream *path* into an existing hash object.""" + with path.open("rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + hash_obj.update(chunk) + + +def get_onnx_model_hash(model_path: str | Path) -> str: + """Compute a content hash for an ONNX model and referenced external data.""" + model_path = Path(model_path).resolve() + hash_obj = hashlib.sha256() + _update_hash_from_file(hash_obj, model_path) + + try: + external_files = get_external_data_files(model_path) + except Exception: + logger.debug("Could not inspect ONNX external data for hashing: %s", model_path) + external_files = [] + + for location in external_files: + data_path = Path(location) + if not data_path.is_absolute(): + data_path = model_path.parent / data_path + hash_obj.update(b"\0external-data\0") + hash_obj.update(location.replace("\\", "/").encode("utf-8")) + hash_obj.update(b"\0") + _update_hash_from_file(hash_obj, data_path) + + return hash_obj.hexdigest()[:16] + + def copy_onnx_model( src: str | Path, dst: str | Path, diff --git a/tests/unit/build/test_onnx.py b/tests/unit/build/test_onnx.py index 1c1322907..79fb108cf 100644 --- a/tests/unit/build/test_onnx.py +++ b/tests/unit/build/test_onnx.py @@ -581,3 +581,86 @@ def test_no_output_path_for_prequantized( mock_onnx_pipeline["is_quantized_onnx"].return_value = True build_onnx_model(fake_onnx, config=sample_onnx_config, output_dir=tmp_path / "output") mock_onnx_pipeline["analyze"].assert_not_called() + + +# ============================================================================= +# CACHE KEY TESTS +# ============================================================================= + + +class TestBuildOnnxCacheKey: + """Test cache_key parameter for artifact naming.""" + + def test_no_cache_key_produces_model_onnx( + self, tmp_path: Path, fake_onnx: Path, sample_onnx_config_minimal, mock_onnx_pipeline + ) -> None: + """cache_key=None (default) produces model.onnx as the final artifact.""" + output_dir = tmp_path / "output" + result = build_onnx_model( + fake_onnx, + config=sample_onnx_config_minimal, + output_dir=output_dir, + ) + assert result.final_onnx_path == output_dir / "model.onnx" + + def test_cache_key_prefixes_final_artifact( + self, tmp_path: Path, fake_onnx: Path, sample_onnx_config_minimal, mock_onnx_pipeline + ) -> None: + """cache_key prefixes the final artifact filename.""" + output_dir = tmp_path / "output" + result = build_onnx_model( + fake_onnx, + config=sample_onnx_config_minimal, + output_dir=output_dir, + cache_key="imgcls_abc1234567890123", + ) + assert result.final_onnx_path == output_dir / "imgcls_abc1234567890123_model.onnx" + + def test_cache_key_prefixes_config_path( + self, tmp_path: Path, fake_onnx: Path, sample_onnx_config_minimal, mock_onnx_pipeline + ) -> None: + """cache_key prefixes the config JSON filename.""" + output_dir = tmp_path / "output" + result = build_onnx_model( + fake_onnx, + config=sample_onnx_config_minimal, + output_dir=output_dir, + cache_key="imgcls_abc1234567890123", + ) + assert result.config_path == output_dir / "imgcls_abc1234567890123_winml_build_config.json" + assert result.config_path.exists() + + def test_cache_key_reuse_checks_prefixed_path( + self, tmp_path: Path, fake_onnx: Path, sample_onnx_config_minimal, mock_onnx_pipeline + ) -> None: + """Existing prefixed model.onnx is reused when rebuild=False.""" + output_dir = tmp_path / "output" + output_dir.mkdir() + (output_dir / "imgcls_abc1234567890123_model.onnx").write_text("existing") + + result = build_onnx_model( + fake_onnx, + config=sample_onnx_config_minimal, + output_dir=output_dir, + cache_key="imgcls_abc1234567890123", + ) + assert result.reused is True + mock_onnx_pipeline["optimize"].assert_not_called() + + def test_cache_key_rebuild_does_not_remove_unrelated_artifacts( + self, tmp_path: Path, fake_onnx: Path, sample_onnx_config_minimal, mock_onnx_pipeline + ) -> None: + """rebuild=True with cache_key removes only matching prefixed files, not unrelated ones.""" + output_dir = tmp_path / "output" + output_dir.mkdir() + other = output_dir / "other_model.onnx" + other.write_text("other-model") + + build_onnx_model( + fake_onnx, + config=sample_onnx_config_minimal, + output_dir=output_dir, + cache_key="imgcls_abc1234567890123", + rebuild=True, + ) + assert other.exists(), "unrelated artifacts should not be removed" diff --git a/tests/unit/models/auto/test_auto_onnx.py b/tests/unit/models/auto/test_auto_onnx.py index 44f7c6202..03ef1b472 100644 --- a/tests/unit/models/auto/test_auto_onnx.py +++ b/tests/unit/models/auto/test_auto_onnx.py @@ -13,6 +13,7 @@ from __future__ import annotations +import hashlib from typing import TYPE_CHECKING, ClassVar from unittest.mock import MagicMock, patch @@ -235,10 +236,126 @@ def test_passes_ep_from_kwargs(self, fake_onnx: Path, tmp_path: Path): # ============================================================================= -# from_onnx dict dispatch → WinMLCompositeModel.from_onnx +# from_onnx cache dir and cache_key tests # ============================================================================= +class TestFromOnnxCacheDirAndKey: + """Verify from_onnx uses content-addressed model dirs and passes cache_key.""" + + def test_uses_content_hash_for_model_dir(self, fake_onnx: Path, tmp_path: Path): + """from_onnx uses the ONNX content hash as model_id for get_model_dir.""" + with ( + patch("winml.modelkit.onnx.is_compiled_onnx", return_value=False), + patch("winml.modelkit.onnx.is_quantized_onnx", return_value=False), + patch( + "winml.modelkit.sysinfo.resolve_device", + return_value=("cpu", ["cpu"]), + ), + patch( + "winml.modelkit.config.precision.resolve_eps", + return_value=["CPUExecutionProvider"], + ), + patch("winml.modelkit.build.build_onnx_model") as mock_build, + patch("winml.modelkit.models.auto.get_winml_class") as mock_get_class, + patch("winml.modelkit.models.auto.get_model_dir") as mock_get_model_dir, + ): + mock_build.return_value = _make_build_result(tmp_path) + mock_get_class.return_value = lambda **kw: MagicMock() + mock_get_model_dir.return_value = tmp_path / "model_dir" + + WinMLAutoModel.from_onnx( + fake_onnx, + task="image-classification", + device="cpu", + ) + + mock_get_model_dir.assert_called_once() + model_id_arg = mock_get_model_dir.call_args.args[0] + expected_hash = hashlib.sha256(fake_onnx.read_bytes()).hexdigest()[:16] + assert model_id_arg == f"onnx-{expected_hash}" + assert model_id_arg != str(fake_onnx.resolve()) + + def test_replacing_same_path_content_gets_different_model_dir(self, tmp_path: Path): + """Replacing an ONNX file at the same path changes its cache model dir.""" + from winml.modelkit.cache import get_cache_dir, get_model_dir + from winml.modelkit.onnx import get_onnx_model_hash + + onnx_path = tmp_path / "model.onnx" + cache = get_cache_dir() + + onnx_path.write_bytes(b"first-content") + model_dir_a = get_model_dir(f"onnx-{get_onnx_model_hash(onnx_path)}", cache_dir=cache) + + onnx_path.write_bytes(b"second-content") + model_dir_b = get_model_dir(f"onnx-{get_onnx_model_hash(onnx_path)}", cache_dir=cache) + + assert model_dir_a != model_dir_b + + def test_onnx_model_hash_includes_external_data(self, tmp_path: Path): + """Changing external data changes the ONNX model content hash.""" + import numpy as np + import onnx + + from winml.modelkit.onnx import get_onnx_model_hash + + onnx_path = tmp_path / "external.onnx" + data_path = tmp_path / "external.onnx.data" + tensor = onnx.helper.make_tensor( + "weight", + onnx.TensorProto.FLOAT, + [4], + np.arange(4, dtype=np.float32).tobytes(), + raw=True, + ) + graph = onnx.helper.make_graph([], "external-data-test", [], [], [tensor]) + model = onnx.helper.make_model(graph) + onnx.save_model( + model, + str(onnx_path), + save_as_external_data=True, + all_tensors_to_one_file=True, + location=data_path.name, + size_threshold=0, + ) + + original_hash = get_onnx_model_hash(onnx_path) + data_path.write_bytes(data_path.read_bytes() + b"changed") + + assert get_onnx_model_hash(onnx_path) != original_hash + + def test_passes_cache_key_to_build_onnx_model(self, fake_onnx: Path, tmp_path: Path): + """from_onnx computes and passes a cache_key to build_onnx_model.""" + with ( + patch("winml.modelkit.onnx.is_compiled_onnx", return_value=False), + patch("winml.modelkit.onnx.is_quantized_onnx", return_value=False), + patch( + "winml.modelkit.sysinfo.resolve_device", + return_value=("cpu", ["cpu"]), + ), + patch( + "winml.modelkit.config.precision.resolve_eps", + return_value=["CPUExecutionProvider"], + ), + patch("winml.modelkit.build.build_onnx_model") as mock_build, + patch("winml.modelkit.models.auto.get_winml_class") as mock_get_class, + ): + mock_build.return_value = _make_build_result(tmp_path) + mock_get_class.return_value = lambda **kw: MagicMock() + + WinMLAutoModel.from_onnx( + fake_onnx, + task="image-classification", + device="cpu", + ) + + call_kwargs = mock_build.call_args.kwargs + assert "cache_key" in call_kwargs + # cache_key must be non-empty and contain the task abbreviation + assert call_kwargs["cache_key"] + assert "imgcls" in call_kwargs["cache_key"] + + class TestFromOnnxDictDispatch: """from_onnx with dict onnx_path delegates to WinMLCompositeModel.from_onnx."""