diff --git a/src/winml/modelkit/commands/compile.py b/src/winml/modelkit/commands/compile.py index 1387d7907..895d6a4d8 100644 --- a/src/winml/modelkit/commands/compile.py +++ b/src/winml/modelkit/commands/compile.py @@ -28,11 +28,11 @@ from ..onnx import is_compiled_onnx from ..sysinfo import resolve_device, resolve_eps from ..utils import cli as cli_utils -from ..utils.constants import normalize_ep_name +from ..utils.constants import COMPILER_NAMES, ORT_SESSION_COMPILER, normalize_ep_name if TYPE_CHECKING: - from ..utils.constants import EPName, EPNameOrAlias + from ..utils.constants import CompilerName, EPName, EPNameOrAlias from ..utils.logging import configure_logging @@ -45,8 +45,10 @@ "--model", "-m", required=False, + multiple=True, type=click.Path(exists=True, path_type=Path), - help="Input ONNX model file (required unless --list)", + help="Input ONNX model file. Repeat -m to compile multiple models with a shared " + "EP context (weight sharing). Required unless --list.", ) @cli_utils.output_option("Output file path (e.g., model_compiled.onnx)") @click.option( @@ -74,9 +76,10 @@ ) @click.option( "--compiler", - type=click.Choice(["ort", "qairt"]), + type=click.Choice(list(COMPILER_NAMES)), default="ort", - help="Compiler backend (default: ort)", + help="Compiler backend (default: ort). 'ort_session' compiles via " + "ort.InferenceSession (ep.context_enable) — required for shared-context multi-model.", ) @click.option( "--qnn-sdk-root", @@ -102,7 +105,7 @@ @click.pass_context def compile( ctx: click.Context, - model: Path | None, + model: tuple[Path, ...], output: Path | None, output_dir: Path | None, device: str, @@ -110,7 +113,7 @@ def compile( validate: bool, verbose: int, quiet: bool, - compiler: str, + compiler: CompilerName, qnn_sdk_root: Path | None, embed: bool, list_compilers_flag: bool, @@ -140,9 +143,13 @@ def compile( # Apply build config defaults (CLI explicit options take precedence). # Read raw JSON so missing keys are distinguishable from dataclass defaults. + config_provider_options: dict[str, str] = {} if config_file is not None: _, raw_cfg = cli_utils.load_build_config(config_file) cc = raw_cfg.get("compile") or {} + # EP provider options (e.g. QNN htp_arch/soc_model/vtcm_mb) for the compile session. + if "provider_options" in cc: + config_provider_options = dict(cc["provider_options"]) if not cli_utils.is_cli_provided(ctx, "ep") and "execution_provider" in cc: ep = cc["execution_provider"] if not cli_utils.is_cli_provided(ctx, "compiler") and "compiler" in cc: @@ -176,18 +183,29 @@ def compile( click.echo(list_compilers(provider)) return - # Validate model is provided when not listing - if model is None: + # Validate model(s) provided when not listing + if not model: raise click.UsageError("Missing option '--model' / '-m'.") - - if is_compiled_onnx(model): - raise click.ClickException( - f"{model} is already a compiled EPContext model and cannot be re-compiled. " - "Run 'winml compile' on the original ONNX model." + models = list(model) + + for m in models: + if is_compiled_onnx(m): + raise click.ClickException( + f"{m} is already a compiled EPContext model and cannot be re-compiled. " + "Run 'winml compile' on the original ONNX model." + ) + + # Multiple models share one EP context and are written by filename into a + # directory, so a single -o/--output file path is ambiguous: require --output-dir + # (and forbid -o/--output). + if len(models) > 1 and (output is not None or output_dir is None): + raise click.UsageError( + "Multiple --model inputs are written by filename into a directory; " + "pass --output-dir (and not -o/--output)." ) # Import compiler (late import to speed up CLI) - from ..compiler import WinMLCompileConfig, compile_onnx + from ..compiler import WinMLCompileConfig, compile_multiple_onnx, compile_onnx # Resolve EP from device + ep flags provider = _resolve_compile_provider(resolved_device, ep) @@ -203,18 +221,24 @@ def compile( config.validate = validate config.verbose = bool(verbose) - # Set compiler options + # Set compiler options. The compiler choice selects the backend: + # "ort_session" -> ort.InferenceSession, else ort.ModelCompiler / qairt. config.ep_config.compiler = compiler config.ep_config.qnn_sdk_root = qnn_sdk_root config.ep_config.embed_context = embed + # EP provider options supplied via --config (compile.provider_options). + if config_provider_options: + config.ep_config.provider_options.update(config_provider_options) # Show info - console.print(f"[bold blue]Input:[/bold blue] {model}") + console.print(f"[bold blue]Input:[/bold blue] {', '.join(str(m) for m in models)}") console.print(f"[bold blue]Device:[/bold blue] {resolved_device}") if ep: console.print(f"[bold blue]EP:[/bold blue] {ep}") console.print(f"[bold blue]Provider:[/bold blue] {provider}") console.print(f"[bold blue]Compiler:[/bold blue] {compiler}") + if len(models) > 1: + console.print(f"[bold blue]Shared EP context:[/bold blue] yes ({len(models)} models)") if qnn_sdk_root: console.print(f"[bold blue]SDK root:[/bold blue] {qnn_sdk_root}") # Resolve output path: -o (file) takes precedence over --output-dir @@ -225,31 +249,48 @@ def compile( console.print(f"[bold blue]Output dir:[/bold blue] {output_dir}") try: - console.print("\n[bold]Compiling model...[/bold]") - result = compile_onnx(model, output_path=resolved_output, config=config) - - if result.success: - if config.ep_config.enable_ep_context and not result.output_path: - console.print( - "\n[bold yellow]Warning:[/bold yellow] Compilation finished " - "but no output file was written to the output directory." - ) - raise click.ClickException( - "No output file produced. Check EP context support for " - f"provider '{config.ep_config.provider}'." - ) - console.print("\n[bold green]Success![/bold green] Model compiled") - if result.output_path: - console.print(f"[dim]Output: {result.output_path}[/dim]") - if result.compile_time: - console.print(f"[dim]Compile time: {result.compile_time:.2f}s[/dim]") - if result.total_time: - console.print(f"[dim]Total time: {result.total_time:.2f}s[/dim]") + console.print("\n[bold]Compiling model(s)...[/bold]") + if len(models) == 1 and compiler != ORT_SESSION_COMPILER: + # Default path: single model via ort.ModelCompiler (staged pipeline). + results = [compile_onnx(models[0], output_path=resolved_output, config=config)] else: - console.print("\n[bold red]Compilation failed:[/bold red]") - for error in result.errors: - console.print(f" {error}") - raise click.ClickException("Compilation failed") + # Multi-model (shared EP context) and/or inference-session backend. + # Multiple models require --output-dir (a directory, enforced above); a + # single inference_session model may use -o (a file) or --output-dir. + results = compile_multiple_onnx(models, resolved_output, config) + + # Report every model's result (not just the first failure). + multi = len(results) > 1 + failures = 0 + for model_path, result in zip(models, results, strict=True): + label = f" — {model_path.name}" if multi else "" + if result.success: + if config.ep_config.enable_ep_context and not result.output_path: + # Compiled but no artifact landed: a warning, not a failure. + console.print( + "\n[bold yellow]Warning:[/bold yellow] Compilation finished but " + f"no output file was written to the output directory.{label}" + ) + continue + console.print(f"\n[bold green]Success![/bold green] Model compiled{label}") + if result.output_path: + console.print(f"[dim]Output: {result.output_path}[/dim]") + if result.compile_time: + console.print(f"[dim]Compile time: {result.compile_time:.2f}s[/dim]") + if result.total_time: + console.print(f"[dim]Total time: {result.total_time:.2f}s[/dim]") + else: + failures += 1 + console.print(f"\n[bold red]Compilation failed:[/bold red]{label}") + for error in result.errors: + console.print(f" {error}") + + if failures: + raise click.ClickException( + f"Compilation failed for {failures} of {len(results)} model(s)." + if multi + else "Compilation failed" + ) except click.ClickException: raise diff --git a/src/winml/modelkit/compiler/__init__.py b/src/winml/modelkit/compiler/__init__.py index 99c0eb42d..bd49317e5 100644 --- a/src/winml/modelkit/compiler/__init__.py +++ b/src/winml/modelkit/compiler/__init__.py @@ -41,7 +41,7 @@ # (mypy, CodeQL) visibility into what ``__all__`` actually exports without # triggering the heavy imports at runtime. if TYPE_CHECKING: - from .compiler import Compiler, compile_onnx, list_compilers + from .compiler import Compiler, compile_multiple_onnx, compile_onnx, list_compilers from .stages.compile import CompileStage from .stages.optimize import OptimizeStage from .stages.qformat import QFormatConvertStage @@ -49,11 +49,19 @@ def __getattr__(name: str) -> Any: """Lazy-load heavy symbols that pull in session/torch to speed up import.""" - if name in {"Compiler", "compile_onnx", "list_compilers"}: - from .compiler import Compiler, compile_onnx, list_compilers + if name in {"Compiler", "compile_multiple_onnx", "compile_onnx", "list_compilers"}: + from .compiler import ( + Compiler, + compile_multiple_onnx, + compile_onnx, + list_compilers, + ) globals().update( - Compiler=Compiler, compile_onnx=compile_onnx, list_compilers=list_compilers + Compiler=Compiler, + compile_multiple_onnx=compile_multiple_onnx, + compile_onnx=compile_onnx, + list_compilers=list_compilers, ) return globals()[name] @@ -84,6 +92,7 @@ def __getattr__(name: str) -> Any: "QFormatConvertStage", "WinMLCompileConfig", "clear_transforms", + "compile_multiple_onnx", "compile_onnx", "get_transforms_for_ep", "list_compilers", diff --git a/src/winml/modelkit/compiler/compiler.py b/src/winml/modelkit/compiler/compiler.py index 6267daca6..97c924328 100644 --- a/src/winml/modelkit/compiler/compiler.py +++ b/src/winml/modelkit/compiler/compiler.py @@ -6,6 +6,7 @@ from __future__ import annotations +import logging import tempfile import time from pathlib import Path @@ -15,16 +16,23 @@ from .result import CompileResult +logger = logging.getLogger(__name__) + + if TYPE_CHECKING: - from ..utils.constants import EPName + from collections.abc import Sequence + + import onnxruntime as ort + + from ..utils.constants import CompilerName, EPName from .configs import WinMLCompileConfig from .stages.base import BaseStage # EP → available compilers. Keys are canonical EPName (or None for the default). -EP_COMPILER_MAPPING: dict[EPName | None, list[str]] = { - "QNNExecutionProvider": ["ort", "qairt"], - None: ["ort"], +EP_COMPILER_MAPPING: dict[EPName | None, list[CompilerName]] = { + "QNNExecutionProvider": ["ort", "ort_session", "qairt"], + None: ["ort", "ort_session"], } @@ -53,6 +61,24 @@ class Compiler: # Registered stages (in execution order) _stages: list[type[BaseStage]] | None = None + def __init__(self, n_total_models: int = 1) -> None: + """Create a compiler. + + Args: + n_total_models: Total number of models compiled by this instance. When + >1, the models share a single EP context (weight sharing) and the + same shared ``SessionOptions`` is reused across every ``compile``. + + The compile backend (ort.ModelCompiler vs ort.InferenceSession) is taken from + the config's ``compiler`` setting ("ort_session" selects the + InferenceSession backend), surfaced via ``CompileContext.use_inference_session``. + """ + self.n_total_models = n_total_models + # The shared SessionOptions: created by CompileStage on the first model and + # reused for the rest (kept here so it survives between compile() calls). + self.shared_session_options: ort.SessionOptions | None = None + self.n_compiled_models = 0 + @classmethod def _get_stages(cls) -> list[type[BaseStage]]: """Lazy initialization of stages.""" @@ -103,12 +129,17 @@ def compile( work_dir = Path(temp_dir.name) try: - # Create context from config + # Create context from config. Multi-model / weight-sharing state is + # threaded through so CompileStage can pick the backend, reuse the shared + # SessionOptions, and detect the last (stop_share) model. context = CompileContext( model_path=model_path, config=config.to_dict(), work_dir=work_dir, verbose=config.verbose, + n_compiled_models=self.n_compiled_models, + n_total_models=self.n_total_models, + shared_session_options=self.shared_session_options, ) if output_path is not None: @@ -129,6 +160,11 @@ def compile( else: context.log(f"Skipping stage: {stage_cls.name}") + # Carry the shared SessionOptions (created/reused by CompileStage) forward + # so the next model in a shared-context run reuses the same EP + group. + self.shared_session_options = context.shared_session_options + self.n_compiled_models += 1 + # Build result total_time = time.time() - start_time result = self._build_result(context, total_time) @@ -199,3 +235,73 @@ def compile_onnx( """ compiler = Compiler() return compiler.compile(model_path=model_path, output_path=output_path, config=config) + + +def compile_multiple_onnx( + model_paths: Sequence[str | Path], + output_path: str | Path | None = None, + config: WinMLCompileConfig | None = None, +) -> list[CompileResult]: + """Compile one or more ONNX models, sharing a single EP context when >1. + + A single :class:`Compiler` (``n_total_models=len(model_paths)``) compiles every + model in sequence, reusing one shared ``SessionOptions`` so the weights are shared + across the compiled EPContext models. The backend is taken from + ``config.ep_config.compiler``: ``ort.ModelCompiler`` (default) or + ``ort.InferenceSession`` when it is ``"ort_session"``. + + Args: + model_paths: Input ONNX model paths. + output_path: Where to write the compiled model(s). + + * With a **single** model it may be a **file** path (the exact + ``*_ctx.onnx``) or a **directory** (``_ctx.onnx`` is written into + it); ``None`` writes next to the input. + * With **multiple** models it **must be a directory** — each model is + written as ``_ctx.onnx`` there, with same-named inputs disambiguated + by an integer suffix on the later one(s) (with a warning), e.g. + ``model_ctx.onnx`` then ``model_1_ctx.onnx``. + config: Compilation configuration. ``None`` skips compilation (passthrough). + + Returns: + One :class:`CompileResult` per input model, in order. + """ + paths = [Path(mp) for mp in model_paths] + out = Path(output_path) if output_path is not None else None + # A path with a suffix (e.g. ".onnx") is a file; otherwise it's a directory. + out_is_file = out is not None and bool(out.suffix) + + if len(paths) > 1 and (out is None or out_is_file): + raise ValueError( + "output_path must be a directory when compiling multiple models " + f"(shared EP context), got {output_path!r}" + ) + + # Backend is taken from config.ep_config.compiler ("ort_session" selects + # the InferenceSession backend), surfaced via CompileContext.use_inference_session. + compiler = Compiler(n_total_models=len(paths)) + # Compiled in order so the shared context accumulates and the last model flushes it. + # When writing into a directory, outputs are keyed by filename stem, so disambiguate + # same-named inputs by suffixing the later one(s) instead of overwriting. + results: list[CompileResult] = [] + seen_stems: dict[str, int] = {} + for p in paths: + count = seen_stems.get(p.stem, 0) + seen_stems[p.stem] = count + 1 + out_stem = p.stem if count == 0 else f"{p.stem}_{count}" + if count > 0: + logger.warning( + "Input model name %r repeats; writing its compiled output as " + "'%s_ctx.onnx' to avoid overwriting the earlier one.", + p.name, + out_stem, + ) + if out is None: + resolved = None + elif out_is_file: + # Single-model file path: write exactly there. + resolved = out + else: + resolved = out / f"{out_stem}_ctx.onnx" + results.append(compiler.compile(model_path=p, output_path=resolved, config=config)) + return results diff --git a/src/winml/modelkit/compiler/configs.py b/src/winml/modelkit/compiler/configs.py index 2059c9528..50f2c71a5 100644 --- a/src/winml/modelkit/compiler/configs.py +++ b/src/winml/modelkit/compiler/configs.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from ..utils.constants import EPAlias, EPName +from ..utils.constants import CompilerName, EPAlias, EPName if TYPE_CHECKING: @@ -37,7 +37,8 @@ class EPConfig: provider_options: EP-specific options as key=value dict enable_ep_context: Generate EPContext model with pre-compiled graph embed_context: Embed context in ONNX (True) or external .bin file (False) - compiler: Compiler backend ("ort" or "qairt") + compiler: Compiler backend ("ort", "ort_session", or "qairt"). + "ort_session" selects the ort.InferenceSession backend. qnn_sdk_root: Path to QAIRT SDK root (required when compiler is "qairt") device: Target device ("npu", "gpu", "cpu", "auto") """ @@ -46,7 +47,7 @@ class EPConfig: provider_options: dict[str, str] = field(default_factory=dict) enable_ep_context: bool = True embed_context: bool = False - compiler: str = "ort" + compiler: CompilerName = "ort" qnn_sdk_root: Path | None = None device: str = "auto" diff --git a/src/winml/modelkit/compiler/context.py b/src/winml/modelkit/compiler/context.py index 18fb22679..fda351c50 100644 --- a/src/winml/modelkit/compiler/context.py +++ b/src/winml/modelkit/compiler/context.py @@ -14,6 +14,8 @@ import onnx import onnxruntime as ort +from ..utils.constants import ORT_SESSION_COMPILER + if TYPE_CHECKING: from ..utils.constants import EPAlias @@ -44,6 +46,16 @@ class CompileContext: # Session (set during compile) session: ort.InferenceSession | None = None + # Multi-model / shared-EP-context compilation state (driven by Compiler). + # n_compiled_models: how many models the Compiler has already compiled (0-based + # index of the current model). + # n_total_models: total models in this compile run (>1 enables weight sharing). + # shared_session_options: the shared ort.SessionOptions created on the first model + # and reused for the rest (the EP is added once and the share group lives on it). + n_compiled_models: int = 0 + n_total_models: int = 1 + shared_session_options: ort.SessionOptions | None = None + # Output paths output_path: Path | None = None context_binary_path: Path | None = None @@ -91,6 +103,14 @@ def execution_provider(self) -> EPAlias: """Get target execution provider.""" return cast("EPAlias", self.config.get("execution_provider", "qnn")) + @property + def use_inference_session(self) -> bool: + """Whether to use the ort.InferenceSession backend (vs ort.ModelCompiler). + + True iff the configured compiler is ``"ort_session"``. + """ + return self.config.get("compiler") == ORT_SESSION_COMPILER + @property def enable_ep_context(self) -> bool: """Whether to generate EPContext model.""" diff --git a/src/winml/modelkit/compiler/stages/compile.py b/src/winml/modelkit/compiler/stages/compile.py index 94b476956..4bc1c28c4 100644 --- a/src/winml/modelkit/compiler/stages/compile.py +++ b/src/winml/modelkit/compiler/stages/compile.py @@ -16,7 +16,7 @@ from ...onnx import load_onnx, save_onnx from ...session import WinMLQairtSession, WinMLSession -from ...utils.constants import normalize_ep_name +from ...utils.constants import ORT_SESSION_COMPILER, normalize_ep_name from ..configs import WinMLCompileConfig from .base import BaseStage @@ -45,62 +45,25 @@ def should_run(cls, context: CompileContext) -> bool: return True def process(self, context: CompileContext) -> CompileContext: - """Execute compilation.""" + """Execute compilation. + + Two compile paths, selected from the multi-model state on the context: + + * ``_compile_single_model_compiler`` (single model, default): the existing + ``WinMLSession`` / ``ort.ModelCompiler`` path — unchanged. + * ``_compile_multiple`` (``use_inference_session`` and/or + ``n_total_models > 1``): reuses one shared ``SessionOptions`` so multiple + models share a single EP context (weight sharing); the backend is + ``ort.InferenceSession`` when requested, else ``ort.ModelCompiler``. + """ context.log("Starting compile stage") start_time = time.time() try: - # Resolve session class from compiler config - compiler = context.config.get("compiler", "ort") - session_cls = COMPILER_SESSION_MAPPING[compiler] - - # Determine final output directory (default: same as input model) - output_dir = self._get_output_dir(context) - context.log(f"Output directory: {output_dir}") - - # Ensure model is saved to disk (may be in work_dir if modified) - model_path = self._ensure_model_file(context) - context.log(f"Model path: {model_path}") - - ep_config = WinMLCompileConfig.from_dict(context.config).ep_config - # Derive the target device from the runtime session so the compile - # stage stays aligned with the actual EPContext filename produced by - # WinMLSession instead of carrying device metadata in provider_options. - device = context.config.get("device", "auto") - explicit_ep = normalize_ep_name(ep_config.provider) - session_cls_name = getattr(session_cls, "__name__", session_cls.__class__.__name__) - context.log(f"Creating {session_cls_name} for device: {device}") - winml_session = session_cls( - onnx_path=model_path, - device=device, - ep_config=ep_config, - ep=explicit_ep, - ) - winml_session.compile() - - # Get the underlying session for validation and info collection - session = winml_session._session - context.session = session - - resolved_device = getattr(winml_session, "_device", device) - if isinstance(resolved_device, str) and resolved_device: - device = resolved_device.lower() - - # Log actual providers used - if session is not None: - actual_providers = session.get_providers() - context.log(f"Actual providers: {actual_providers}") - - # Validate if requested - if context.validate: - self._validate_model(session, context) - - # Collect model info - self._collect_model_info(session, context) - - # Find and relocate EPContext files to output directory - if ep_config.enable_ep_context: - self._finalize_output(context, model_path, output_dir, device=device) + if context.use_inference_session or context.n_total_models > 1: + self._compile_multiple(context) + else: + self._compile_single_model_compiler(context) except Exception as e: context.add_error(f"Compilation failed: {e}") @@ -113,6 +76,162 @@ def process(self, context: CompileContext) -> CompileContext: return context + def _compile_single_model_compiler(self, context: CompileContext) -> None: + """Single-model compile via ``WinMLSession`` (``ort.ModelCompiler``).""" + # Resolve session class from compiler config. "ort_session" must not + # reach here — it routes to _compile_multiple via context.use_inference_session. + compiler = context.config.get("compiler", "ort") + if compiler == ORT_SESSION_COMPILER: + raise ValueError( + f"{ORT_SESSION_COMPILER!r} is handled by the inference-session path, " + "not the single-model ModelCompiler path." + ) + session_cls = COMPILER_SESSION_MAPPING[compiler] + + # Determine final output directory (default: same as input model) + output_dir = self._get_output_dir(context) + context.log(f"Output directory: {output_dir}") + + # Ensure model is saved to disk (may be in work_dir if modified) + model_path = self._ensure_model_file(context) + context.log(f"Model path: {model_path}") + + ep_config = WinMLCompileConfig.from_dict(context.config).ep_config + # Derive the target device from the runtime session so the compile + # stage stays aligned with the actual EPContext filename produced by + # WinMLSession instead of carrying device metadata in provider_options. + device = context.config.get("device", "auto") + explicit_ep = normalize_ep_name(ep_config.provider) + session_cls_name = getattr(session_cls, "__name__", session_cls.__class__.__name__) + context.log(f"Creating {session_cls_name} for device: {device}") + winml_session = session_cls( + onnx_path=model_path, + device=device, + ep_config=ep_config, + ep=explicit_ep, + ) + winml_session.compile() + + # Get the underlying session for validation and info collection + session = winml_session._session + context.session = session + + resolved_device = getattr(winml_session, "_device", device) + if isinstance(resolved_device, str) and resolved_device: + device = resolved_device.lower() + + # Log actual providers used + if session is not None: + actual_providers = session.get_providers() + context.log(f"Actual providers: {actual_providers}") + + # Validate if requested + if context.validate: + self._validate_model(session, context) + + # Collect model info + self._collect_model_info(session, context) + + # Find and relocate EPContext files to output directory + if ep_config.enable_ep_context: + self._finalize_output(context, model_path, output_dir, device=device) + + def _compile_multiple(self, context: CompileContext) -> None: + """Multi-model / inference-session compile with a shared EP context. + + The shared ``SessionOptions`` (``context.shared_session_options``) is created on + the first model — the EP is added once and, for a multi-model run, the + ``ep.share_ep_contexts`` group is opened on it — then reused for every model. + ``ep.stop_share_ep_contexts`` is added before the final model so the shared + weights binary is flushed. + """ + import onnxruntime as ort + + from ...sysinfo.device import resolve_device, resolve_eps + from ...utils.constants import DEVICE_TO_DEVICE_TYPE + from ...winml import add_ep_for_device, register_execution_providers + + ep_config = WinMLCompileConfig.from_dict(context.config).ep_config + multi = context.n_total_models > 1 + is_last = context.n_compiled_models >= context.n_total_models - 1 + use_is = context.use_inference_session + + output_dir = self._get_output_dir(context) + output_dir.mkdir(parents=True, exist_ok=True) + model_path = self._ensure_model_file(context) + # Honor an explicit output filename (e.g. the de-duplicated _ctx.onnx + # that compile_multiple_onnx assigns); otherwise derive it from the model stem. + user_output = context.config.get("output_path") + if user_output and Path(user_output).suffix == ".onnx": + ctx_path = Path(user_output) + else: + ctx_path = output_dir / f"{context.model_path.stem}_ctx.onnx" + backend = "inference_session" if use_is else "model_compiler" + context.log( + f"[{backend}] compiling {model_path.name} " + f"({context.n_compiled_models + 1}/{context.n_total_models}) -> {ctx_path.name}" + ) + + # Build the shared SessionOptions once; reuse it for subsequent models. + sess_options = context.shared_session_options + if sess_options is None: + register_execution_providers(ort=True) + resolved_device, _ = resolve_device(context.config.get("device", "auto")) + ep = normalize_ep_name(ep_config.provider) or resolve_eps(resolved_device)[0] + device_type = DEVICE_TO_DEVICE_TYPE.get(resolved_device.upper()) + + sess_options = ort.SessionOptions() + if use_is: + sess_options.add_session_config_entry("ep.context_enable", "1") + sess_options.add_session_config_entry( + "ep.context_embed_mode", "1" if ep_config.embed_context else "0" + ) + if multi: + sess_options.add_session_config_entry("ep.share_ep_contexts", "1") + if not add_ep_for_device( + sess_options, ep, device_type, dict(ep_config.provider_options) + ): + raise RuntimeError(f"Could not add {ep} for device type {device_type}") + context.shared_session_options = sess_options # captured by Compiler for reuse + + # Last model in a shared run flushes the shared context. + if multi and is_last: + sess_options.add_session_config_entry("ep.stop_share_ep_contexts", "1") + + if use_is: + # InferenceSession backend: ep.context_file_path writes the EPContext + # wrapper; constructing the session performs the compile. + sess_options.add_session_config_entry("ep.context_file_path", str(ctx_path)) + session = ort.InferenceSession(str(model_path), sess_options=sess_options) + context.session = session + if session.get_providers(): + context.log(f"Actual providers: {session.get_providers()}") + # Models compiled this way are loadable; validate (run) when requested. + if context.validate: + self._validate_model(session, context) + # Collect I/O info regardless of validation. + self._collect_model_info(session, context) + else: + # ModelCompiler backend: compile straight to the EPContext file. No + # session is created here (smoke path — outputs are checked, not loaded). + ort.ModelCompiler( + sess_options, + str(model_path), + embed_compiled_data_into_model=ep_config.embed_context, + ).compile_to_file(str(ctx_path)) + + if ctx_path.exists(): + context.output_path = ctx_path + bins = [ + f + for f in output_dir.glob(f"{ctx_path.stem}*.bin") + if not f.name.endswith("_schematic.bin") + ] + if bins: + context.context_binary_path = bins[0] + else: + context.add_warning(f"No EPContext produced for {model_path.name}") + def _get_output_dir(self, context: CompileContext) -> Path: """Determine the output directory for compiled model. diff --git a/src/winml/modelkit/utils/constants.py b/src/winml/modelkit/utils/constants.py index b62e9e4c1..7b45d153a 100644 --- a/src/winml/modelkit/utils/constants.py +++ b/src/winml/modelkit/utils/constants.py @@ -49,6 +49,21 @@ EPNameOrAlias: TypeAlias = EPName | EPAlias +# Compile backends selectable via ``--compiler`` (see commands/compile.py): +# "ort" -> ort.ModelCompiler (default) +# "ort_session" -> ort.InferenceSession (ep.context_enable) +# "qairt" -> QAIRT SDK compiler +CompilerName = Literal["ort", "ort_session", "qairt"] + +# The ``--compiler`` choice that selects the ort.InferenceSession backend (the others +# go through ort.ModelCompiler / the QAIRT SDK). Referenced wherever the backend is +# branched on, so the magic string lives in exactly one place. +ORT_SESSION_COMPILER: CompilerName = "ort_session" + +# Runtime-iterable form of ``CompilerName`` (e.g. for the CLI choice list). +COMPILER_NAMES: tuple[CompilerName, ...] = get_args(CompilerName) + + # Supported execution providers — derived from the ``EPName`` Literal above so # that ``utils.constants`` stays leaf-level (no import dependency on sysinfo). # Membership parity with ``sysinfo.device._EP_DEVICE_MAP`` is enforced by diff --git a/tests/e2e/test_compile_e2e.py b/tests/e2e/test_compile_e2e.py index 58157518d..d26615ff4 100644 --- a/tests/e2e/test_compile_e2e.py +++ b/tests/e2e/test_compile_e2e.py @@ -35,15 +35,17 @@ from pathlib import Path from typing import TYPE_CHECKING +import numpy as np import onnx import pytest from click.testing import CliRunner +from onnx import TensorProto, helper from tests.e2e.require_ep import require_ep, require_not_ep from winml.modelkit.commands.compile import compile as compile_cmd from winml.modelkit.onnx import is_compiled_onnx from winml.modelkit.utils import normalize_ep_name -from winml.modelkit.utils.constants import EP_SUPPORTED_DEVICES +from winml.modelkit.utils.constants import EP_SUPPORTED_DEVICES, ORT_SESSION_COMPILER if TYPE_CHECKING: @@ -164,6 +166,9 @@ def assert_by_run_inference( device: str, ep: str, sample_input: dict, + reference_model: Path | None = None, + rtol: float = 1e-2, + atol: float = 1e-2, ) -> None: """Bind ``ep`` + ``device`` and run one inference call on the compiled artifact. @@ -171,13 +176,34 @@ def assert_by_run_inference( EP), this asserts the artifact specifically loads and runs on the requested ``(device, ep)`` pair. Catches the case where the compile succeeded against a different EP/device than the user asked for. + + When ``reference_model`` is given, the original (pre-compile) model is run on + the CPU EP with the same input and the compiled output is checked against it + with :func:`numpy.allclose` — a correctness check that the compiled graph still + computes the same result, not just that it runs. """ + import onnxruntime as ort + from winml.modelkit.session import WinMLSession session = WinMLSession(out_path, device=device, ep=ep) outputs = session.run(sample_input) assert outputs, "Inference produced no outputs" + if reference_model is not None: + ref_sess = ort.InferenceSession(str(reference_model), providers=["CPUExecutionProvider"]) + ref_names = [o.name for o in ref_sess.get_outputs()] + ref_outputs = ref_sess.run(None, sample_input) + for name, ref in zip(ref_names, ref_outputs, strict=True): + assert name in outputs, f"Compiled model missing output {name!r}" + got = np.asarray(outputs[name], dtype=np.float32) + ref = np.asarray(ref, dtype=np.float32) + assert np.allclose(got, ref, rtol=rtol, atol=atol), ( + f"Compiled output {name!r} differs from CPU reference " + f"(max abs diff {np.max(np.abs(got - ref)):.4g}, " + f"rtol={rtol}, atol={atol})" + ) + def _find_qairt_sdk_root() -> Path | None: """Locate an installed QAIRT SDK on this host, or None.""" @@ -860,3 +886,123 @@ def test_bad_input_no_ep_covers_device(simple_matmul_onnx: Path) -> None: src_hash, simple_matmul_onnx, ) + + +# =========================================================================== +# Compile backend (ort.ModelCompiler vs ort.InferenceSession) + multi-model +# shared EP context (qnn-only) +# =========================================================================== + + +@pytest.fixture +def shared_weight_models(tmp_path: Path) -> tuple[Path, Path]: + """Two MatMul models sharing the SAME weight but with different input shapes. + + Mirrors the prefill/decode (ctx/iter) pattern that QNN weight sharing targets: + one ``[K, K]`` weight ``B`` reused across both graphs while the leading sequence + dimension differs (4 vs 1). Returns ``(seq4_model, seq1_model)``. + """ + np.random.seed(7) + k = 4 + b_values = np.random.randn(k, k).astype(np.float32) + + def _build(seq: int, name: str) -> Path: + a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, seq, k]) + c = helper.make_tensor_value_info("C", TensorProto.FLOAT, [1, seq, k]) + b = helper.make_tensor("B", TensorProto.FLOAT, [k, k], b_values.flatten().tolist()) + node = helper.make_node("MatMul", ["A", "B"], ["C"], name="matmul") + graph = helper.make_graph([node], "shared_matmul", [a], [c], [b]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 + onnx.checker.check_model(model) + path = tmp_path / name + onnx.save(model, str(path)) + return path + + return _build(4, "shared_seq4.onnx"), _build(1, "shared_seq1.onnx") + + +def _sample_for(model_path: Path) -> dict[str, np.ndarray]: + """Random input matching a ``shared_weight_models`` graph's declared shape.""" + dims = onnx.load(str(model_path)).graph.input[0].type.tensor_type.shape.dim + shape = [d.dim_value for d in dims] + return {"A": np.random.randn(*shape).astype(np.float32)} + + +@pytest.mark.e2e +def test_default_backend_uses_model_compiler( + simple_matmul_onnx: Path, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """By default (``--compiler ort``), a single-model compile is driven by + ``ort.ModelCompiler``.""" + require_ep("qnn") + import onnxruntime as ort + + real = ort.ModelCompiler + calls: list[int] = [] + + def _spy(*args: object, **kwargs: object) -> object: + calls.append(1) + return real(*args, **kwargs) + + monkeypatch.setattr(ort, "ModelCompiler", _spy) + + out = tmp_path / "default.onnx" + result = _invoke("-m", str(simple_matmul_onnx), "--ep", "qnn", "-o", str(out)) + assert result.exit_code == 0, result.output + assert calls, "Default single-model compile should use ort.ModelCompiler" + assert is_compiled_onnx(out) + + +@pytest.mark.e2e +@pytest.mark.parametrize( + "use_inference_session", + [False, True], + ids=["model_compiler", "inference_session"], +) +def test_multi_model_shared_weights( + use_inference_session: bool, + shared_weight_models: tuple[Path, Path], + tmp_path: Path, +) -> None: + """Multiple models with shared weights compile to a single shared EP context in + BOTH backends (``ort.ModelCompiler`` default, ``ort.InferenceSession`` opt-in). + + Both backends are smoke-checked (files + one shared weights bin). The + inference_session output is additionally loaded and run on QNN; the + model_compiler output is smoke-only (not loaded). + """ + require_ep("qnn") + m_seq4, m_seq1 = shared_weight_models + out_dir = tmp_path / "out" + out_dir.mkdir() + + cmd = ["-m", str(m_seq4), "-m", str(m_seq1), "--ep", "qnn", "--output-dir", str(out_dir)] + if use_inference_session: + cmd += ["--compiler", ORT_SESSION_COMPILER] + result = _invoke(*cmd) + assert result.exit_code == 0, result.output + assert "Success! Model compiled" in result.output, result.output + # The InferenceSession backend is selected via --compiler ort_session. + if use_inference_session: + assert ORT_SESSION_COMPILER in result.output + else: + assert ORT_SESSION_COMPILER not in result.output + + # Both compiled wrappers exist + exactly one shared weights bin (weight sharing). + ctx4 = out_dir / f"{m_seq4.stem}_ctx.onnx" + ctx1 = out_dir / f"{m_seq1.stem}_ctx.onnx" + assert ctx4.is_file() and is_compiled_onnx(ctx4), f"missing/invalid {ctx4}" + assert ctx1.is_file() and is_compiled_onnx(ctx1), f"missing/invalid {ctx1}" + bins = [p for p in out_dir.glob("*.bin") if not p.name.endswith("_schematic.bin")] + assert len(bins) == 1, f"Expected one shared weights bin, got {[b.name for b in bins]}" + + # inference_session output is runnable; model_compiler output is smoke-only (no load). + # Run on QNN and np.allclose-check against the original model on CPU. + if use_inference_session: + assert_by_run_inference( + ctx4, device="npu", ep="qnn", sample_input=_sample_for(m_seq4), reference_model=m_seq4 + ) + assert_by_run_inference( + ctx1, device="npu", ep="qnn", sample_input=_sample_for(m_seq1), reference_model=m_seq1 + ) diff --git a/tests/unit/compiler/test_compile_command.py b/tests/unit/compiler/test_compile_command.py index f4d3a1de0..9aa1547d5 100644 --- a/tests/unit/compiler/test_compile_command.py +++ b/tests/unit/compiler/test_compile_command.py @@ -20,6 +20,7 @@ from click.testing import CliRunner from winml.modelkit.cli import main +from winml.modelkit.utils.constants import ORT_SESSION_COMPILER @pytest.fixture @@ -341,6 +342,132 @@ def test_compile_device_propagates_to_provider_options( assert config.ep_config.provider_options.get("device_type") == "NPU" assert config.ep_config.device == "npu" + def test_multiple_models_reject_output_file(self, runner: CliRunner, tmp_path: Path) -> None: + """Multiple -m inputs with -o/--output (a file) are rejected: use --output-dir. + + Several models share one EP context and are written by filename into a + directory, so a single output file path is ambiguous. + """ + m1 = tmp_path / "m1.onnx" + m2 = tmp_path / "m2.onnx" + self._create_simple_onnx(m1) + self._create_simple_onnx(m2) + out_file = tmp_path / "out.onnx" + + result = runner.invoke(main, ["compile", "-m", str(m1), "-m", str(m2), "-o", str(out_file)]) + + assert result.exit_code != 0 + assert "output-dir" in result.output.lower(), result.output + + def test_multiple_models_require_output_dir(self, runner: CliRunner, tmp_path: Path) -> None: + """Multiple -m inputs with neither -o nor --output-dir are rejected. + + --output-dir is mandatory for multi-model compiles (the compiled models are + written by filename into that directory). + """ + m1 = tmp_path / "m1.onnx" + m2 = tmp_path / "m2.onnx" + self._create_simple_onnx(m1) + self._create_simple_onnx(m2) + + result = runner.invoke(main, ["compile", "-m", str(m1), "-m", str(m2)]) + + assert result.exit_code != 0 + assert "output-dir" in result.output.lower(), result.output + + @patch("winml.modelkit.compiler.compile_multiple_onnx") + def test_multiple_models_with_output_dir_calls_compile_multiple( + self, + mock_compile_multiple: MagicMock, + runner: CliRunner, + tmp_path: Path, + ) -> None: + """Multiple -m inputs with --output-dir compile via compile_multiple_onnx.""" + m1 = tmp_path / "m1.onnx" + m2 = tmp_path / "m2.onnx" + self._create_simple_onnx(m1) + self._create_simple_onnx(m2) + out_dir = tmp_path / "out" + + mock_result = MagicMock() + mock_result.success = True + mock_result.output_path = out_dir / "m2_ctx.onnx" + mock_result.compile_time = 1.0 + mock_result.total_time = 1.5 + mock_compile_multiple.return_value = [mock_result, mock_result] + + result = runner.invoke( + main, + [ + "compile", + "-m", + str(m1), + "-m", + str(m2), + "--device", + "npu", + "--ep", + "qnn", + "--output-dir", + str(out_dir), + ], + ) + + assert result.exit_code == 0, result.output + assert mock_compile_multiple.called + call_args = mock_compile_multiple.call_args + # First positional arg is the ordered list of input models. + passed_models = call_args.args[0] + assert [str(m) for m in passed_models] == [str(m1), str(m2)] + # Second positional arg is the output target — the --output-dir directory. + assert call_args.args[1] == out_dir + # Backend is carried on the config's compiler; defaults to "ort" (ModelCompiler). + assert call_args.args[2].ep_config.compiler == "ort" + + @patch("winml.modelkit.compiler.compile_multiple_onnx") + def test_ort_session_compiler_sets_config( + self, + mock_compile_multiple: MagicMock, + runner: CliRunner, + tmp_path: Path, + ) -> None: + """--compiler ort_session is carried on the config used for compilation.""" + m1 = tmp_path / "m1.onnx" + m2 = tmp_path / "m2.onnx" + self._create_simple_onnx(m1) + self._create_simple_onnx(m2) + out_dir = tmp_path / "out" + + mock_result = MagicMock() + mock_result.success = True + mock_result.output_path = out_dir / "m2_ctx.onnx" + mock_result.compile_time = 1.0 + mock_result.total_time = 1.5 + mock_compile_multiple.return_value = [mock_result, mock_result] + + result = runner.invoke( + main, + [ + "compile", + "-m", + str(m1), + "-m", + str(m2), + "--device", + "npu", + "--ep", + "qnn", + "--output-dir", + str(out_dir), + "--compiler", + ORT_SESSION_COMPILER, + ], + ) + + assert result.exit_code == 0, result.output + # The compiler choice is applied onto the config that drives compilation. + assert mock_compile_multiple.call_args.args[2].ep_config.compiler == ORT_SESSION_COMPILER + def _create_simple_onnx(self, path: Path) -> None: """Create a simple ONNX model for testing.""" import onnx diff --git a/tests/unit/compiler/test_compile_multiple.py b/tests/unit/compiler/test_compile_multiple.py new file mode 100644 index 000000000..3b31829b7 --- /dev/null +++ b/tests/unit/compiler/test_compile_multiple.py @@ -0,0 +1,129 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for ``compile_multiple_onnx`` output-name handling. + +``Compiler`` is mocked so these exercise the per-model output naming / de-dup +logic only — no real compilation or EP runtime is needed. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from winml.modelkit.compiler import compile_multiple_onnx + + +def _output_names(mock_compiler_cls: MagicMock) -> list[str]: + """Filenames passed as ``output_path`` to each ``Compiler.compile`` call, in order.""" + calls = mock_compiler_cls.return_value.compile.call_args_list + return [Path(c.kwargs["output_path"]).name for c in calls] + + +class TestCompileMultipleNaming: + @patch("winml.modelkit.compiler.compiler.Compiler") + def test_duplicate_names_suffixed_with_warning( + self, mock_compiler_cls: MagicMock, tmp_path: Path, caplog: pytest.LogCaptureFixture + ) -> None: + """Two inputs with the same filename: the later one gets an ``_1`` suffix and warns.""" + m1 = tmp_path / "a" / "model.onnx" + m2 = tmp_path / "b" / "model.onnx" + out_dir = tmp_path / "out" + mock_compiler_cls.return_value.compile.return_value = MagicMock(success=True) + + with caplog.at_level(logging.WARNING): + results = compile_multiple_onnx([m1, m2], out_dir) + + assert len(results) == 2 + names = _output_names(mock_compiler_cls) + assert names == ["model_ctx.onnx", "model_1_ctx.onnx"] + # Both land in the requested output directory. + for c in mock_compiler_cls.return_value.compile.call_args_list: + assert Path(c.kwargs["output_path"]).parent == out_dir + assert "repeats" in caplog.text + + @patch("winml.modelkit.compiler.compiler.Compiler") + def test_triple_duplicate_names_increment( + self, mock_compiler_cls: MagicMock, tmp_path: Path, caplog: pytest.LogCaptureFixture + ) -> None: + """Three same-named inputs increment the suffix: _ , _1, _2.""" + models = [tmp_path / d / "m.onnx" for d in ("a", "b", "c")] + mock_compiler_cls.return_value.compile.return_value = MagicMock(success=True) + + with caplog.at_level(logging.WARNING): + compile_multiple_onnx(models, tmp_path / "out") + + assert _output_names(mock_compiler_cls) == [ + "m_ctx.onnx", + "m_1_ctx.onnx", + "m_2_ctx.onnx", + ] + assert caplog.text.count("repeats") == 2 + + @patch("winml.modelkit.compiler.compiler.Compiler") + def test_unique_names_no_suffix_no_warning( + self, mock_compiler_cls: MagicMock, tmp_path: Path, caplog: pytest.LogCaptureFixture + ) -> None: + """Distinct filenames keep their stems and emit no warning.""" + m1 = tmp_path / "a.onnx" + m2 = tmp_path / "b.onnx" + mock_compiler_cls.return_value.compile.return_value = MagicMock(success=True) + + with caplog.at_level(logging.WARNING): + compile_multiple_onnx([m1, m2], tmp_path / "out") + + assert _output_names(mock_compiler_cls) == ["a_ctx.onnx", "b_ctx.onnx"] + assert "repeats" not in caplog.text + + +def _single_output(mock_compiler_cls: MagicMock) -> Path | None: + """The ``output_path`` passed to the single ``Compiler.compile`` call.""" + out = mock_compiler_cls.return_value.compile.call_args.kwargs["output_path"] + return Path(out) if out is not None else None + + +class TestCompileMultipleOutputPath: + @patch("winml.modelkit.compiler.compiler.Compiler") + def test_multiple_models_require_directory( + self, mock_compiler_cls: MagicMock, tmp_path: Path + ) -> None: + """Multiple models with a file output_path (has a suffix) is rejected.""" + m1 = tmp_path / "a" / "m.onnx" + m2 = tmp_path / "b" / "m.onnx" + with pytest.raises(ValueError, match="must be a directory"): + compile_multiple_onnx([m1, m2], tmp_path / "out.onnx") + + @patch("winml.modelkit.compiler.compiler.Compiler") + def test_multiple_models_reject_none_output( + self, mock_compiler_cls: MagicMock, tmp_path: Path + ) -> None: + """Multiple models with no output_path is rejected (would break shared context).""" + m1 = tmp_path / "a" / "m.onnx" + m2 = tmp_path / "b" / "m.onnx" + with pytest.raises(ValueError, match="must be a directory"): + compile_multiple_onnx([m1, m2], None) + + @patch("winml.modelkit.compiler.compiler.Compiler") + def test_single_model_file_output_path( + self, mock_compiler_cls: MagicMock, tmp_path: Path + ) -> None: + """A single model accepts a file output_path and writes exactly there.""" + mock_compiler_cls.return_value.compile.return_value = MagicMock(success=True) + out_file = tmp_path / "custom_name.onnx" + compile_multiple_onnx([tmp_path / "model.onnx"], out_file) + assert _single_output(mock_compiler_cls) == out_file + + @patch("winml.modelkit.compiler.compiler.Compiler") + def test_single_model_dir_output_path( + self, mock_compiler_cls: MagicMock, tmp_path: Path + ) -> None: + """A single model with a directory output_path writes _ctx.onnx into it.""" + mock_compiler_cls.return_value.compile.return_value = MagicMock(success=True) + out_dir = tmp_path / "out" + compile_multiple_onnx([tmp_path / "model.onnx"], out_dir) + assert _single_output(mock_compiler_cls) == out_dir / "model_ctx.onnx"