Skip to content
1 change: 1 addition & 0 deletions src/winml/modelkit/analyze/core/information_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def _check_model(self) -> list[Information]:
self._model,
op_runtime_results=self._op_runtime_results,
device=self._device,
ep=self._ep,
)
manager_init_ms = int((time.perf_counter() - manager_init_start) * 1000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

from .base import ModelValidator
from .batched_const_matmul_validator import BatchedConstMatMulValidator
from .constant_folding_validator import ConstantFoldingValidator
from .dynamic_input_validator import DynamicInputValidator
from .model_validator_manager import ModelValidatorManager
Expand All @@ -21,6 +22,7 @@


__all__ = [
"BatchedConstMatMulValidator",
"ConstantFoldingValidator",
"DynamicInputValidator",
"ModelValidator",
Expand Down
13 changes: 12 additions & 1 deletion src/winml/modelkit/analyze/core/model_validators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


if TYPE_CHECKING:
from ....utils.constants import EPName
from ...models.information import Information
from ...models.onnx_model import ONNXModel
from ...models.runtime_checks import PatternRuntime
Expand All @@ -34,19 +35,26 @@ class ModelValidator(ABC):
model_proto: ONNX ModelProto extracted from model
graph: Shorthand for model_proto.graph
op_runtime_results: List of PatternRuntime results from runtime checker (optional)
ep: Execution provider name (optional)
device: Device type, e.g. "NPU", "GPU", "CPU" (optional)
"""

def __init__(
self,
model: ONNXModel,
op_runtime_results: list[PatternRuntime] | None = None,
ep: EPName | None = None,
device: str | None = None,
) -> None:
"""Initialize validator with ONNX model and optional runtime results.
"""Initialize validator with ONNX model and optional context.

Args:
model: ONNXModel wrapper to validate
op_runtime_results: List of PatternRuntime results from runtime checker.
Used to enrich validators with OP-level information.
ep: Execution provider name. Validators that gate on EP read this.
device: Device type (e.g., "NPU", "GPU", "CPU"). Validators that gate
on device read this.

Raises:
ValueError: If model is invalid
Expand All @@ -55,6 +63,9 @@ def __init__(
self.model_proto = model.get_model()
self.graph = self.model_proto.graph
self.op_runtime_results = op_runtime_results or []
# Annotate explicitly so the EPName Literal is not widened to ``str``.
self.ep: EPName | None = ep
self.device = device

logger.debug(
f"Initialized {self.validator_name} for model with {len(self.graph.node)} nodes"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Validator for batched MatMul with a constant operand on OpenVINO GPU.

OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched
(rank >= 3) MatMul where an operand is a compile-time constant. The identical
gemm with a dynamic operand, and 2D constant gemm, both compile fine. Models
whose batched MatMul weights fold to constants (e.g. transformer disentangled
attention position terms) therefore fail to compile on OpenVINO GPU with:

[GPU] Failed to select implementation for ... type: gemm

This validator detects that structural pattern and recommends the
``untie-constant-batched-matmul`` surgery, which makes the constant operand
runtime-valued so gemm implementation selection succeeds.
"""

from __future__ import annotations

import logging

from ...models.information import Action, ActionItem, ActionLevel, Information
from ...utils import infer_ihv_from_ep_name
from .base import ModelValidator


logger = logging.getLogger(__name__)

# Surgery capability enabled when the pattern is detected (kebab-case to match
# the capability registry / autoconf normalization).
_SURGERY_FLAG = "untie-constant-batched-matmul"


class BatchedConstMatMulValidator(ModelValidator):
"""Detect batched MatMul with a constant operand (OpenVINO GPU only)."""

@property
def validator_name(self) -> str:
"""Name of this validator for logging/reporting."""
return "BatchedConstMatMulValidator"

@property
def pattern_id(self) -> str:
"""Pattern ID for Information objects."""
return "MODEL/BatchedConstantMatMul"

def _is_enabled(self) -> bool:
"""Only relevant for OpenVINO (Intel IHV) on GPU."""
if (self.device or "").upper() != "GPU":
return False
ep = self.ep
if not ep:
return False
try:
from ...models.ihv_type import IHVType

return infer_ihv_from_ep_name(ep) == IHVType.INTEL
except Exception: # pragma: no cover - defensive
return False

def validate(self) -> Information | None:
"""Detect batched MatMul with a single constant rank>=3 operand."""
if not self._is_enabled():
return None

# Known gap: constants expressed as `Constant` op nodes (rather than
# graph initializers) are not detected here. The `untie-constant-batched
# -matmul` surgery in surgery.py has the same limitation, so detection
# and surgery stay consistent. Most exporters and ORT preprocessing emit
# weights as initializers, so this covers the disentangled-attention case
# in practice; `Constant`-node weights would need handling on both sides.
initializers = {init.name for init in self.graph.initializer}
Comment thread
xieofxie marked this conversation as resolved.
rank_by_init = {init.name: len(init.dims) for init in self.graph.initializer}

offenders: list[str] = []
for node in self.graph.node:
if node.op_type != "MatMul" or len(node.input) != 2:
continue
const_inputs = [name for name in node.input if name in initializers]
# Exactly one constant operand (two-constant MatMuls fold away and
# never reach gemm impl selection).
if len(const_inputs) != 1:
continue
if rank_by_init.get(const_inputs[0], 0) >= 3:
offenders.append(node.name or const_inputs[0])

if not offenders:
return None

examples = ", ".join(offenders[:3])
action = Action(
pattern_from_id="",
pattern_to_id="",
level=ActionLevel.REQUIRED,
status=None,
action_items=[
ActionItem(type="GraphOptimization", optimization_options={_SURGERY_FLAG: True})
],
details=(
"Enable untie-constant-batched-matmul surgery so the constant "
"operand becomes runtime-valued and OpenVINO GPU can select a "
"gemm implementation."
),
)
# https://github.com/openvinotoolkit/openvino/issues/36272
explanation = (
f"Model contains {len(offenders)} batched MatMul(s) with a constant "
f"operand (examples: {examples}). OpenVINO GPU's oneDNN gemm cannot "
f"select an implementation for a batched MatMul with a constant "
f"operand, causing a '[GPU] Failed to select implementation ... gemm' "
f"compile failure. The untie-constant-batched-matmul surgery makes "
f"the operand runtime-valued without changing numerics. "
f"It is fixed in openvino==2026.2.0, so no need to apply the surgery "
f"if using that version or later."
)
return Information(
explanation=explanation,
actions=[action],
pattern_id=self.pattern_id,
status=None,
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, ClassVar

from ...utils.timing_utils import make_timing_logger
from .batched_const_matmul_validator import BatchedConstMatMulValidator
from .constant_folding_validator import ConstantFoldingValidator
from .dynamic_input_validator import DynamicInputValidator
from .pattern_matching_validator import PatternMatchingValidator
Expand All @@ -23,6 +24,7 @@


if TYPE_CHECKING:
from ....utils.constants import EPName
from ...models.information import Information
from ...models.onnx_model import ONNXModel
from ...models.runtime_checks import PatternRuntime
Expand Down Expand Up @@ -64,14 +66,20 @@ class ModelValidatorManager:
"class": PatternMatchingValidator,
"enabled_devices": None, # All devices
},
"batched_const_matmul": {
Comment thread
xieofxie marked this conversation as resolved.
"class": BatchedConstMatMulValidator,
"enabled_devices": ["GPU"], # OpenVINO GPU gemm impl-selection issue
},
}

def __init__(
self,
model: ONNXModel,
enabled_validators: list[str] | None = None,
op_runtime_results: list[PatternRuntime] | None = None,
device: str | None = None,
*,
device: str,
ep: EPName,
) -> None:
"""Initialize validator manager.

Expand All @@ -83,6 +91,7 @@ def __init__(
Used to enrich validators with OP-level information.
device: Device type (e.g., "NPU", "GPU", "CPU").
Used to filter validators based on device constraints.
ep: Execution provider name. Forwarded to validators that gate on EP.

Raises:
ValueError: If model is not valid ONNXModel instance
Expand All @@ -91,7 +100,8 @@ def __init__(
self.model = model
self.model_proto = model.get_model()
self.op_runtime_results = op_runtime_results or []
self.device = device or "NPU"
self.device = device
self.ep = ep
self.enabled_validators = enabled_validators or list(self.VALIDATORS.keys())

# Instantiate enabled validators
Expand All @@ -102,18 +112,25 @@ def __init__(
validator_class = validator_config["class"]
enabled_devices = validator_config.get("enabled_devices")

# Check device constraint
if enabled_devices is not None and self.device not in enabled_devices:
# Check device constraint (case-insensitive: callers may pass
# "gpu" or "GPU" depending on the build/analyze entry point).
if enabled_devices is not None and (self.device or "").upper() not in {
d.upper() for d in enabled_devices
}:
logger.info(
f"Validator '{name}' is not enabled for device '{self.device}'. "
f"Only enabled for: {enabled_devices}"
)
continue

ctor_kwargs: dict = {
"op_runtime_results": self.op_runtime_results,
"ep": self.ep,
"device": self.device,
}

try:
self.validators.append(
validator_class(self.model, op_runtime_results=self.op_runtime_results)
)
self.validators.append(validator_class(self.model, **ctor_kwargs))
logger.debug(f"Initialized validator: {name}")
except Exception:
logger.exception(f"Failed to initialize validator {name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,13 @@
import json
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar
from typing import ClassVar

from ...models import ModelTag
from ...models.information import Action, ActionLevel, Information
from .base import ModelValidator


if TYPE_CHECKING:
from ...models.onnx_model import ONNXModel
from ...models.runtime_checks import PatternRuntime

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -95,19 +91,6 @@ class PatternMatchingValidator(ModelValidator):
),
]

def __init__(
self,
model: ONNXModel,
op_runtime_results: list[PatternRuntime] | None = None,
) -> None:
"""Initialize validator.

Args:
model: ONNXModel wrapper to validate
op_runtime_results: List of PatternRuntime results from runtime checker
"""
super().__init__(model, op_runtime_results)

@property
def validator_name(self) -> str:
"""Return validator name."""
Expand Down
17 changes: 17 additions & 0 deletions src/winml/modelkit/optim/capabilities/surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,20 @@
category=CapabilityCategory.SURGERY,
default=False,
)

# Route a constant operand of a batched (rank >= 3) MatMul through a runtime
# no-op so it is no longer a compile-time constant. OpenVINO GPU's oneDNN gemm
# cannot select an implementation for a batched MatMul with a constant operand
# (e.g. transformer disentangled-attention position terms that fold to 3D
# constants); making the operand runtime-valued lets gemm impl selection
# succeed without changing numerics or splitting the batched op.
UNTIE_CONSTANT_BATCHED_MATMUL = BoolCapability(
name="untie-constant-batched-matmul",
ort_name=None, # Custom implementation, not ORT optimizer
description=(
"Make a batched MatMul's constant operand runtime-valued so OpenVINO "
"GPU can select a gemm implementation"
),
category=CapabilityCategory.SURGERY,
default=False,
)
Loading
Loading