Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 82 additions & 41 deletions src/winml/modelkit/commands/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
from ..onnx import is_compiled_onnx
from ..sysinfo import resolve_device, resolve_eps
from ..utils import cli as cli_utils
from ..utils.constants import normalize_ep_name
from ..utils.constants import COMPILER_NAMES, ORT_SESSION_COMPILER, normalize_ep_name


if TYPE_CHECKING:
from ..utils.constants import EPName, EPNameOrAlias
from ..utils.constants import CompilerName, EPName, EPNameOrAlias
from ..utils.logging import configure_logging


Expand All @@ -45,8 +45,10 @@
"--model",
"-m",
required=False,
multiple=True,
type=click.Path(exists=True, path_type=Path),
help="Input ONNX model file (required unless --list)",
help="Input ONNX model file. Repeat -m to compile multiple models with a shared "
"EP context (weight sharing). Required unless --list.",
)
@cli_utils.output_option("Output file path (e.g., model_compiled.onnx)")
@click.option(
Expand Down Expand Up @@ -74,9 +76,10 @@
)
@click.option(
"--compiler",
type=click.Choice(["ort", "qairt"]),
type=click.Choice(list(COMPILER_NAMES)),
default="ort",
help="Compiler backend (default: ort)",
help="Compiler backend (default: ort). 'ort_session' compiles via "
"ort.InferenceSession (ep.context_enable) — required for shared-context multi-model.",
)
@click.option(
"--qnn-sdk-root",
Expand All @@ -102,15 +105,15 @@
@click.pass_context
def compile(
ctx: click.Context,
model: Path | None,
model: tuple[Path, ...],
output: Path | None,
output_dir: Path | None,
device: str,
ep: EPNameOrAlias | None,
validate: bool,
verbose: int,
quiet: bool,
compiler: str,
compiler: CompilerName,
qnn_sdk_root: Path | None,
embed: bool,
list_compilers_flag: bool,
Expand Down Expand Up @@ -140,9 +143,13 @@ def compile(

# Apply build config defaults (CLI explicit options take precedence).
# Read raw JSON so missing keys are distinguishable from dataclass defaults.
config_provider_options: dict[str, str] = {}
if config_file is not None:
_, raw_cfg = cli_utils.load_build_config(config_file)
cc = raw_cfg.get("compile") or {}
# EP provider options (e.g. QNN htp_arch/soc_model/vtcm_mb) for the compile session.
if "provider_options" in cc:
config_provider_options = dict(cc["provider_options"])
if not cli_utils.is_cli_provided(ctx, "ep") and "execution_provider" in cc:
ep = cc["execution_provider"]
if not cli_utils.is_cli_provided(ctx, "compiler") and "compiler" in cc:
Expand Down Expand Up @@ -176,18 +183,29 @@ def compile(
click.echo(list_compilers(provider))
return

# Validate model is provided when not listing
if model is None:
# Validate model(s) provided when not listing
if not model:
raise click.UsageError("Missing option '--model' / '-m'.")

if is_compiled_onnx(model):
raise click.ClickException(
f"{model} is already a compiled EPContext model and cannot be re-compiled. "
"Run 'winml compile' on the original ONNX model."
models = list(model)

for m in models:
if is_compiled_onnx(m):
raise click.ClickException(
f"{m} is already a compiled EPContext model and cannot be re-compiled. "
"Run 'winml compile' on the original ONNX model."
)

# Multiple models share one EP context and are written by filename into a
# directory, so a single -o/--output file path is ambiguous: require --output-dir
# (and forbid -o/--output).
if len(models) > 1 and (output is not None or output_dir is None):
raise click.UsageError(
"Multiple --model inputs are written by filename into a directory; "
"pass --output-dir (and not -o/--output)."
)

# Import compiler (late import to speed up CLI)
from ..compiler import WinMLCompileConfig, compile_onnx
from ..compiler import WinMLCompileConfig, compile_multiple_onnx, compile_onnx

# Resolve EP from device + ep flags
provider = _resolve_compile_provider(resolved_device, ep)
Expand All @@ -203,18 +221,24 @@ def compile(
config.validate = validate
config.verbose = bool(verbose)

# Set compiler options
# Set compiler options. The compiler choice selects the backend:
# "ort_session" -> ort.InferenceSession, else ort.ModelCompiler / qairt.
config.ep_config.compiler = compiler
config.ep_config.qnn_sdk_root = qnn_sdk_root
config.ep_config.embed_context = embed
# EP provider options supplied via --config (compile.provider_options).
if config_provider_options:
config.ep_config.provider_options.update(config_provider_options)

# Show info
console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Input:[/bold blue] {', '.join(str(m) for m in models)}")
console.print(f"[bold blue]Device:[/bold blue] {resolved_device}")
if ep:
console.print(f"[bold blue]EP:[/bold blue] {ep}")
console.print(f"[bold blue]Provider:[/bold blue] {provider}")
console.print(f"[bold blue]Compiler:[/bold blue] {compiler}")
if len(models) > 1:
console.print(f"[bold blue]Shared EP context:[/bold blue] yes ({len(models)} models)")
if qnn_sdk_root:
console.print(f"[bold blue]SDK root:[/bold blue] {qnn_sdk_root}")
# Resolve output path: -o (file) takes precedence over --output-dir
Expand All @@ -225,31 +249,48 @@ def compile(
console.print(f"[bold blue]Output dir:[/bold blue] {output_dir}")

try:
console.print("\n[bold]Compiling model...[/bold]")
result = compile_onnx(model, output_path=resolved_output, config=config)

if result.success:
if config.ep_config.enable_ep_context and not result.output_path:
console.print(
"\n[bold yellow]Warning:[/bold yellow] Compilation finished "
"but no output file was written to the output directory."
)
raise click.ClickException(
"No output file produced. Check EP context support for "
f"provider '{config.ep_config.provider}'."
)
console.print("\n[bold green]Success![/bold green] Model compiled")
if result.output_path:
console.print(f"[dim]Output: {result.output_path}[/dim]")
if result.compile_time:
console.print(f"[dim]Compile time: {result.compile_time:.2f}s[/dim]")
if result.total_time:
console.print(f"[dim]Total time: {result.total_time:.2f}s[/dim]")
console.print("\n[bold]Compiling model(s)...[/bold]")
if len(models) == 1 and compiler != ORT_SESSION_COMPILER:
# Default path: single model via ort.ModelCompiler (staged pipeline).
results = [compile_onnx(models[0], output_path=resolved_output, config=config)]
else:
console.print("\n[bold red]Compilation failed:[/bold red]")
for error in result.errors:
console.print(f" {error}")
raise click.ClickException("Compilation failed")
# Multi-model (shared EP context) and/or inference-session backend.
# Multiple models require --output-dir (a directory, enforced above); a
# single inference_session model may use -o (a file) or --output-dir.
results = compile_multiple_onnx(models, resolved_output, config)

# Report every model's result (not just the first failure).
multi = len(results) > 1
failures = 0
for model_path, result in zip(models, results, strict=True):
label = f" — {model_path.name}" if multi else ""
if result.success:
if config.ep_config.enable_ep_context and not result.output_path:
# Compiled but no artifact landed: a warning, not a failure.
console.print(
"\n[bold yellow]Warning:[/bold yellow] Compilation finished but "
f"no output file was written to the output directory.{label}"
)
continue
console.print(f"\n[bold green]Success![/bold green] Model compiled{label}")
if result.output_path:
console.print(f"[dim]Output: {result.output_path}[/dim]")
if result.compile_time:
console.print(f"[dim]Compile time: {result.compile_time:.2f}s[/dim]")
if result.total_time:
console.print(f"[dim]Total time: {result.total_time:.2f}s[/dim]")
else:
failures += 1
console.print(f"\n[bold red]Compilation failed:[/bold red]{label}")
for error in result.errors:
console.print(f" {error}")

if failures:
raise click.ClickException(
f"Compilation failed for {failures} of {len(results)} model(s)."
if multi
else "Compilation failed"
)

except click.ClickException:
raise
Expand Down
17 changes: 13 additions & 4 deletions src/winml/modelkit/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,27 @@
# (mypy, CodeQL) visibility into what ``__all__`` actually exports without
# triggering the heavy imports at runtime.
if TYPE_CHECKING:
from .compiler import Compiler, compile_onnx, list_compilers
from .compiler import Compiler, compile_multiple_onnx, compile_onnx, list_compilers
from .stages.compile import CompileStage
from .stages.optimize import OptimizeStage
from .stages.qformat import QFormatConvertStage


def __getattr__(name: str) -> Any:
"""Lazy-load heavy symbols that pull in session/torch to speed up import."""
if name in {"Compiler", "compile_onnx", "list_compilers"}:
from .compiler import Compiler, compile_onnx, list_compilers
if name in {"Compiler", "compile_multiple_onnx", "compile_onnx", "list_compilers"}:
from .compiler import (
Compiler,
compile_multiple_onnx,
compile_onnx,
list_compilers,
)

globals().update(
Compiler=Compiler, compile_onnx=compile_onnx, list_compilers=list_compilers
Compiler=Compiler,
compile_multiple_onnx=compile_multiple_onnx,
compile_onnx=compile_onnx,
list_compilers=list_compilers,
)
return globals()[name]

Expand Down Expand Up @@ -84,6 +92,7 @@ def __getattr__(name: str) -> Any:
"QFormatConvertStage",
"WinMLCompileConfig",
"clear_transforms",
"compile_multiple_onnx",
"compile_onnx",
"get_transforms_for_ep",
"list_compilers",
Expand Down
Loading
Loading