Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/winml/modelkit/build/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def build_onnx_model(
rebuild: bool = False,
ep: EPNameOrAlias | None = None,
device: str | None = None,
cache_key: str | None = None,
**kwargs: Any,
) -> BuildResult:
"""Build from a pre-exported ONNX model.
Expand All @@ -58,6 +59,9 @@ def build_onnx_model(
rebuild: Force rebuild even if output exists.
ep: Target execution provider for the analyzer (e.g., ``"qnn"``).
device: Target device for the analyzer (e.g., ``"NPU"``).
cache_key: Optional prefix for artifact filenames, enabling multiple
task/config variants to coexist in one directory. When set, all
artifact files are prefixed (e.g., ``"{cache_key}_model.onnx"``).
**kwargs: Additional options:
- ``hack_max_optim_iterations`` (int, default 3): Max analyzer
iterations. 0 disables analyzer.
Expand Down Expand Up @@ -101,15 +105,19 @@ def build_onnx_model(
start_time = time.monotonic()
output_dir.mkdir(parents=True, exist_ok=True)

# Artifact naming — optionally prefixed when cache_key is set so that
# multiple task/config variants can coexist in one directory.
def _name(base: str) -> str:
return f"{cache_key}_{base}" if cache_key else base

# Define output paths
stem = onnx_path.stem
optimized_path = output_dir / f"{stem}_optimized.onnx"
quantized_path = output_dir / f"{stem}_quantized.onnx"
compiled_path = output_dir / f"{stem}_compiled.onnx"
final_path = output_dir / "model.onnx"
config_path = output_dir / "winml_build_config.json"
manifest_path = output_dir / "build_manifest.json"
analyze_result_path = output_dir / "analyze_result.json"
optimized_path = output_dir / _name("optimized.onnx")
Comment thread
chinazhangchao marked this conversation as resolved.
quantized_path = output_dir / _name("quantized.onnx")
compiled_path = output_dir / _name("compiled.onnx")
final_path = output_dir / _name("model.onnx")
config_path = output_dir / _name("winml_build_config.json")
manifest_path = output_dir / _name("build_manifest.json")
analyze_result_path = output_dir / _name("analyze_result.json")

# Check for existing artifact (skip build if present and not rebuilding)
if final_path.exists() and not rebuild:
Expand All @@ -124,10 +132,12 @@ def build_onnx_model(

# Rebuild: clean old ONNX artifacts to prevent stale files
if rebuild:
for old in output_dir.glob("*.onnx"):
pattern = f"{cache_key}_*.onnx" if cache_key else "*.onnx"
for old in output_dir.glob(pattern):
old.unlink()
logger.debug("Removed old artifact: %s", old.name)
for old in output_dir.glob("*.onnx.data"):
data_pattern = f"{cache_key}_*.onnx.data" if cache_key else "*.onnx.data"
for old in output_dir.glob(data_pattern):
old.unlink()
logger.debug("Removed old external data sidecar: %s", old.name)

Expand Down
12 changes: 9 additions & 3 deletions src/winml/modelkit/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def from_onnx(
# Skip build for compiled models or explicit skip.
# Check is_compiled_onnx directly — don't rely on config shape alone
# because auto+auto also produces quant=None, compile=None for raw models.
from ..onnx import is_compiled_onnx
from ..onnx import get_onnx_model_hash, is_compiled_onnx

if skip_build or is_compiled_onnx(onnx_path):
logger.info("Skipping build (compiled model or explicit skip). Using original ONNX.")
Expand All @@ -199,10 +199,15 @@ def from_onnx(
ep=ep,
)

# Resolve output directory
# Resolve output directory and cache key
task_abbrev = get_task_abbrev(resolved_task) if resolved_task else "onnx"
cache_key = get_cache_key(task_abbrev, config.generate_cache_key())
if use_cache:
cache_dir_path = get_cache_dir(override=cache_dir)
output_dir = get_model_dir(onnx_path.stem, cache_dir=cache_dir_path)
output_dir = get_model_dir(
f"onnx-{get_onnx_model_hash(onnx_path)}",
cache_dir=cache_dir_path,
)
else:
import tempfile

Expand All @@ -221,6 +226,7 @@ def from_onnx(
rebuild=force_rebuild,
ep=ep,
device=device,
cache_key=cache_key,
allow_unsupported_nodes=allow_unsupported_nodes,
**kwargs,
)
Expand Down
3 changes: 2 additions & 1 deletion src/winml/modelkit/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .domains import ONNXDomain
from .dtypes import SupportedONNXType, remove_optional_from_type_annotation
from .external_data import copy_onnx_model
from .external_data import copy_onnx_model, get_onnx_model_hash
from .io import InputTensorSpec, OutputTensorSpec, generate_inputs_from_onnx, get_io_config
from .metadata import capture_metadata, restore_metadata
from .persistence import cleanup_onnx, load_onnx, save_onnx
Expand All @@ -36,6 +36,7 @@
"generate_inputs_from_onnx",
"get_io_config",
"get_model_size",
"get_onnx_model_hash",
"infer_onnx_shapes",
"infer_shapes",
"is_compiled_onnx",
Expand Down
32 changes: 32 additions & 0 deletions src/winml/modelkit/onnx/external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import hashlib
import logging
import shutil
from pathlib import Path
Expand Down Expand Up @@ -154,6 +155,37 @@ def has_external_data(model_path: str | Path) -> bool:
return len(get_external_data_files(model_path)) > 0


def _update_hash_from_file(hash_obj: Any, path: Path) -> None:
"""Stream *path* into an existing hash object."""
with path.open("rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
hash_obj.update(chunk)


def get_onnx_model_hash(model_path: str | Path) -> str:
"""Compute a content hash for an ONNX model and referenced external data."""
model_path = Path(model_path).resolve()
hash_obj = hashlib.sha256()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about use hash path + size + mtime? weight could be big

_update_hash_from_file(hash_obj, model_path)

try:
external_files = get_external_data_files(model_path)
except Exception:
logger.debug("Could not inspect ONNX external data for hashing: %s", model_path)
external_files = []

for location in external_files:
data_path = Path(location)
if not data_path.is_absolute():
data_path = model_path.parent / data_path
hash_obj.update(b"\0external-data\0")
hash_obj.update(location.replace("\\", "/").encode("utf-8"))
hash_obj.update(b"\0")
_update_hash_from_file(hash_obj, data_path)

return hash_obj.hexdigest()[:16]


def copy_onnx_model(
src: str | Path,
dst: str | Path,
Expand Down
83 changes: 83 additions & 0 deletions tests/unit/build/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,86 @@ def test_no_output_path_for_prequantized(
mock_onnx_pipeline["is_quantized_onnx"].return_value = True
build_onnx_model(fake_onnx, config=sample_onnx_config, output_dir=tmp_path / "output")
mock_onnx_pipeline["analyze"].assert_not_called()


# =============================================================================
# CACHE KEY TESTS
# =============================================================================


class TestBuildOnnxCacheKey:
"""Test cache_key parameter for artifact naming."""

def test_no_cache_key_produces_model_onnx(
self, tmp_path: Path, fake_onnx: Path, sample_onnx_config_minimal, mock_onnx_pipeline
) -> None:
"""cache_key=None (default) produces model.onnx as the final artifact."""
output_dir = tmp_path / "output"
result = build_onnx_model(
fake_onnx,
config=sample_onnx_config_minimal,
output_dir=output_dir,
)
assert result.final_onnx_path == output_dir / "model.onnx"

def test_cache_key_prefixes_final_artifact(
self, tmp_path: Path, fake_onnx: Path, sample_onnx_config_minimal, mock_onnx_pipeline
) -> None:
"""cache_key prefixes the final artifact filename."""
output_dir = tmp_path / "output"
result = build_onnx_model(
fake_onnx,
config=sample_onnx_config_minimal,
output_dir=output_dir,
cache_key="imgcls_abc1234567890123",
)
assert result.final_onnx_path == output_dir / "imgcls_abc1234567890123_model.onnx"

def test_cache_key_prefixes_config_path(
self, tmp_path: Path, fake_onnx: Path, sample_onnx_config_minimal, mock_onnx_pipeline
) -> None:
"""cache_key prefixes the config JSON filename."""
output_dir = tmp_path / "output"
result = build_onnx_model(
fake_onnx,
config=sample_onnx_config_minimal,
output_dir=output_dir,
cache_key="imgcls_abc1234567890123",
)
assert result.config_path == output_dir / "imgcls_abc1234567890123_winml_build_config.json"
assert result.config_path.exists()

def test_cache_key_reuse_checks_prefixed_path(
self, tmp_path: Path, fake_onnx: Path, sample_onnx_config_minimal, mock_onnx_pipeline
) -> None:
"""Existing prefixed model.onnx is reused when rebuild=False."""
output_dir = tmp_path / "output"
output_dir.mkdir()
(output_dir / "imgcls_abc1234567890123_model.onnx").write_text("existing")

result = build_onnx_model(
fake_onnx,
config=sample_onnx_config_minimal,
output_dir=output_dir,
cache_key="imgcls_abc1234567890123",
)
assert result.reused is True
mock_onnx_pipeline["optimize"].assert_not_called()

def test_cache_key_rebuild_does_not_remove_unrelated_artifacts(
self, tmp_path: Path, fake_onnx: Path, sample_onnx_config_minimal, mock_onnx_pipeline
) -> None:
"""rebuild=True with cache_key removes only matching prefixed files, not unrelated ones."""
output_dir = tmp_path / "output"
output_dir.mkdir()
other = output_dir / "other_model.onnx"
other.write_text("other-model")

build_onnx_model(
fake_onnx,
config=sample_onnx_config_minimal,
output_dir=output_dir,
cache_key="imgcls_abc1234567890123",
rebuild=True,
)
assert other.exists(), "unrelated artifacts should not be removed"
119 changes: 118 additions & 1 deletion tests/unit/models/auto/test_auto_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from __future__ import annotations

import hashlib
from typing import TYPE_CHECKING, ClassVar
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -235,10 +236,126 @@ def test_passes_ep_from_kwargs(self, fake_onnx: Path, tmp_path: Path):


# =============================================================================
# from_onnx dict dispatch → WinMLCompositeModel.from_onnx
# from_onnx cache dir and cache_key tests
# =============================================================================


class TestFromOnnxCacheDirAndKey:
"""Verify from_onnx uses content-addressed model dirs and passes cache_key."""

def test_uses_content_hash_for_model_dir(self, fake_onnx: Path, tmp_path: Path):
"""from_onnx uses the ONNX content hash as model_id for get_model_dir."""
with (
patch("winml.modelkit.onnx.is_compiled_onnx", return_value=False),
patch("winml.modelkit.onnx.is_quantized_onnx", return_value=False),
patch(
"winml.modelkit.sysinfo.resolve_device",
return_value=("cpu", ["cpu"]),
),
patch(
"winml.modelkit.config.precision.resolve_eps",
return_value=["CPUExecutionProvider"],
),
patch("winml.modelkit.build.build_onnx_model") as mock_build,
patch("winml.modelkit.models.auto.get_winml_class") as mock_get_class,
patch("winml.modelkit.models.auto.get_model_dir") as mock_get_model_dir,
):
mock_build.return_value = _make_build_result(tmp_path)
mock_get_class.return_value = lambda **kw: MagicMock()
mock_get_model_dir.return_value = tmp_path / "model_dir"

WinMLAutoModel.from_onnx(
fake_onnx,
task="image-classification",
device="cpu",
)

mock_get_model_dir.assert_called_once()
model_id_arg = mock_get_model_dir.call_args.args[0]
expected_hash = hashlib.sha256(fake_onnx.read_bytes()).hexdigest()[:16]
assert model_id_arg == f"onnx-{expected_hash}"
assert model_id_arg != str(fake_onnx.resolve())

def test_replacing_same_path_content_gets_different_model_dir(self, tmp_path: Path):
"""Replacing an ONNX file at the same path changes its cache model dir."""
from winml.modelkit.cache import get_cache_dir, get_model_dir
from winml.modelkit.onnx import get_onnx_model_hash

onnx_path = tmp_path / "model.onnx"
cache = get_cache_dir()

onnx_path.write_bytes(b"first-content")
model_dir_a = get_model_dir(f"onnx-{get_onnx_model_hash(onnx_path)}", cache_dir=cache)

onnx_path.write_bytes(b"second-content")
model_dir_b = get_model_dir(f"onnx-{get_onnx_model_hash(onnx_path)}", cache_dir=cache)

assert model_dir_a != model_dir_b

def test_onnx_model_hash_includes_external_data(self, tmp_path: Path):
"""Changing external data changes the ONNX model content hash."""
import numpy as np
import onnx
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed

from winml.modelkit.onnx import get_onnx_model_hash

onnx_path = tmp_path / "external.onnx"
data_path = tmp_path / "external.onnx.data"
tensor = onnx.helper.make_tensor(
"weight",
onnx.TensorProto.FLOAT,
[4],
np.arange(4, dtype=np.float32).tobytes(),
raw=True,
)
graph = onnx.helper.make_graph([], "external-data-test", [], [], [tensor])
model = onnx.helper.make_model(graph)
onnx.save_model(
model,
str(onnx_path),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=data_path.name,
size_threshold=0,
)

original_hash = get_onnx_model_hash(onnx_path)
data_path.write_bytes(data_path.read_bytes() + b"changed")

assert get_onnx_model_hash(onnx_path) != original_hash

def test_passes_cache_key_to_build_onnx_model(self, fake_onnx: Path, tmp_path: Path):
"""from_onnx computes and passes a cache_key to build_onnx_model."""
with (
patch("winml.modelkit.onnx.is_compiled_onnx", return_value=False),
patch("winml.modelkit.onnx.is_quantized_onnx", return_value=False),
patch(
"winml.modelkit.sysinfo.resolve_device",
return_value=("cpu", ["cpu"]),
),
patch(
"winml.modelkit.config.precision.resolve_eps",
return_value=["CPUExecutionProvider"],
),
patch("winml.modelkit.build.build_onnx_model") as mock_build,
patch("winml.modelkit.models.auto.get_winml_class") as mock_get_class,
):
mock_build.return_value = _make_build_result(tmp_path)
mock_get_class.return_value = lambda **kw: MagicMock()

WinMLAutoModel.from_onnx(
fake_onnx,
task="image-classification",
device="cpu",
)

call_kwargs = mock_build.call_args.kwargs
assert "cache_key" in call_kwargs
# cache_key must be non-empty and contain the task abbreviation
assert call_kwargs["cache_key"]
assert "imgcls" in call_kwargs["cache_key"]


class TestFromOnnxDictDispatch:
"""from_onnx with dict onnx_path delegates to WinMLCompositeModel.from_onnx."""

Expand Down
Loading