Skip to content
1 change: 1 addition & 0 deletions src/winml/modelkit/analyze/core/information_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def _check_model(self) -> list[Information]:
self._model,
op_runtime_results=self._op_runtime_results,
device=self._device,
ep=self._ep,
)
manager_init_ms = int((time.perf_counter() - manager_init_start) * 1000)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

from .base import ModelValidator
from .batched_const_matmul_validator import BatchedConstMatMulValidator
from .constant_folding_validator import ConstantFoldingValidator
from .dynamic_input_validator import DynamicInputValidator
from .model_validator_manager import ModelValidatorManager
Expand All @@ -21,6 +22,7 @@


__all__ = [
"BatchedConstMatMulValidator",
"ConstantFoldingValidator",
"DynamicInputValidator",
"ModelValidator",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Validator for batched MatMul with a constant operand on OpenVINO GPU.

OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched
(rank >= 3) MatMul where an operand is a compile-time constant. The identical
gemm with a dynamic operand, and 2D constant gemm, both compile fine. Models
whose batched MatMul weights fold to constants (e.g. transformer disentangled
attention position terms) therefore fail to compile on OpenVINO GPU with:

[GPU] Failed to select implementation for ... type: gemm

This validator detects that structural pattern and recommends the
``batchedconstmatmul-untied`` rewrite, which routes the constant operand through
``Add(const, runtime-zero)`` so it becomes runtime-valued and gemm
implementation selection succeeds. The rewrite (source/target patterns) lives in
``modelkit.pattern.batched_const_matmul_patterns``; this validator only supplies
the EP/device gating and the autoconf trigger that enables it.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from ...models.information import Action, ActionItem, ActionLevel, Information
from ...utils import infer_ihv_from_ep_name
from .base import ModelValidator


if TYPE_CHECKING:
from ....utils.constants import EPName
from ...models.onnx_model import ONNXModel
from ...models.runtime_checks import PatternRuntime

logger = logging.getLogger(__name__)

# Rewrite capability enabled when the pattern is detected. Kebab-case to match
# the capability registry / autoconf normalization; derived from the JSON rule
# (source flag "batchedconstmatmul" + target flag "untied") in
# pattern/rules/default.json.
_REWRITE_FLAG = "batchedconstmatmul-untied"


class BatchedConstMatMulValidator(ModelValidator):
"""Detect batched MatMul with a constant operand (OpenVINO GPU only)."""

def __init__(
self,
model: ONNXModel,
op_runtime_results: list[PatternRuntime] | None = None,
ep: EPName | None = None,
device: str | None = None,
) -> None:
super().__init__(model, op_runtime_results=op_runtime_results)
self.ep = ep
self.device = device

@property
def validator_name(self) -> str:
"""Name of this validator for logging/reporting."""
return "BatchedConstMatMulValidator"

@property
def pattern_id(self) -> str:
"""Pattern ID for Information objects."""
return "MODEL/BatchedConstantMatMul"

def _is_enabled(self) -> bool:
"""Only relevant for OpenVINO (Intel IHV) on GPU."""
if (self.device or "").upper() != "GPU":
return False
if not self.ep:
return False
try:
from ...models.ihv_type import IHVType

return infer_ihv_from_ep_name(self.ep) == IHVType.INTEL
except Exception: # pragma: no cover - defensive
return False

def validate(self) -> Information | None:
"""Detect batched MatMul with a single constant rank>=3 operand."""
if not self._is_enabled():
return None

# Known gap: constants expressed as `Constant` op nodes (rather than
# graph initializers) are not detected here. The batchedconstmatmul-untied
# rewrite's source pattern matches on initializer operands too, so
# detection and rewrite stay consistent. Most exporters and ORT
# preprocessing emit weights as initializers, so this covers the
# disentangled-attention case
# in practice; `Constant`-node weights would need handling on both sides.
initializers = {init.name for init in self.graph.initializer}
rank_by_init = {init.name: len(init.dims) for init in self.graph.initializer}

offenders: list[str] = []
for node in self.graph.node:
if node.op_type != "MatMul" or len(node.input) != 2:
continue
const_inputs = [name for name in node.input if name in initializers]
# Exactly one constant operand (two-constant MatMuls fold away and
# never reach gemm impl selection).
if len(const_inputs) != 1:
continue
if rank_by_init.get(const_inputs[0], 0) >= 3:
offenders.append(node.name or const_inputs[0])

if not offenders:
return None

examples = ", ".join(offenders[:3])
action = Action(
pattern_from_id="",
pattern_to_id="",
level=ActionLevel.REQUIRED,
status=None,
action_items=[
ActionItem(type="GraphOptimization", optimization_options={_REWRITE_FLAG: True})
],
details=(
"Enable the batchedconstmatmul-untied rewrite so the constant "
"operand becomes runtime-valued and OpenVINO GPU can select a "
"gemm implementation."
),
)
# https://github.com/openvinotoolkit/openvino/issues/36272
explanation = (
f"Model contains {len(offenders)} batched MatMul(s) with a constant "
f"operand (examples: {examples}). OpenVINO GPU's oneDNN gemm cannot "
f"select an implementation for a batched MatMul with a constant "
f"operand, causing a '[GPU] Failed to select implementation ... gemm' "
f"compile failure. The batchedconstmatmul-untied rewrite makes "
f"the operand runtime-valued without changing numerics. "
f"It is fixed in openvino==2026.2.0, so no need to apply the rewrite "
f"if using that version or later."
)
return Information(
explanation=explanation,
actions=[action],
pattern_id=self.pattern_id,
status=None,
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, ClassVar

from ...utils.timing_utils import make_timing_logger
from .batched_const_matmul_validator import BatchedConstMatMulValidator
from .constant_folding_validator import ConstantFoldingValidator
from .dynamic_input_validator import DynamicInputValidator
from .pattern_matching_validator import PatternMatchingValidator
Expand Down Expand Up @@ -64,6 +65,11 @@ class ModelValidatorManager:
"class": PatternMatchingValidator,
"enabled_devices": None, # All devices
},
"batched_const_matmul": {
"class": BatchedConstMatMulValidator,
"enabled_devices": ["GPU"], # OpenVINO GPU gemm impl-selection issue
"needs_context": True, # validator self-gates on EP (Intel IHV)
},
}

def __init__(
Expand All @@ -72,6 +78,7 @@ def __init__(
enabled_validators: list[str] | None = None,
op_runtime_results: list[PatternRuntime] | None = None,
device: str | None = None,
ep: str | None = None,
) -> None:
"""Initialize validator manager.

Expand All @@ -92,6 +99,7 @@ def __init__(
self.model_proto = model.get_model()
self.op_runtime_results = op_runtime_results or []
self.device = device or "NPU"
self.ep = ep
self.enabled_validators = enabled_validators or list(self.VALIDATORS.keys())

# Instantiate enabled validators
Expand All @@ -102,18 +110,24 @@ def __init__(
validator_class = validator_config["class"]
enabled_devices = validator_config.get("enabled_devices")

# Check device constraint
if enabled_devices is not None and self.device not in enabled_devices:
# Check device constraint (case-insensitive: callers may pass
# "gpu" or "GPU" depending on the build/analyze entry point).
if enabled_devices is not None and (self.device or "").upper() not in {
d.upper() for d in enabled_devices
}:
logger.info(
f"Validator '{name}' is not enabled for device '{self.device}'. "
f"Only enabled for: {enabled_devices}"
)
continue

ctor_kwargs: dict = {"op_runtime_results": self.op_runtime_results}
if validator_config.get("needs_context"):
ctor_kwargs["ep"] = self.ep
ctor_kwargs["device"] = self.device

try:
self.validators.append(
validator_class(self.model, op_runtime_results=self.op_runtime_results)
)
self.validators.append(validator_class(self.model, **ctor_kwargs))
logger.debug(f"Initialized validator: {name}")
except Exception:
logger.exception(f"Failed to initialize validator {name}")
Expand Down
6 changes: 6 additions & 0 deletions src/winml/modelkit/pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
opschema_to_pattern_schema,
register_pattern_input_generator,
)
from .batched_const_matmul_patterns import (
BatchedConstMatMulPattern,
UntiedBatchedConstMatMulPattern,
)
from .conv2d_inplace_linear_patterns import (
Conv2DInplaceLinear2DPattern,
Conv2DInplaceLinear2DPatternInputGenerator,
Expand Down Expand Up @@ -85,6 +89,7 @@

__all__ = [
"MATMUL_ADD_SCHEMA",
"BatchedConstMatMulPattern",
"Conv2DInplaceLinear2DPattern",
"Conv2DInplaceLinear2DPatternInputGenerator",
"Conv2DInplaceLinear3DPattern",
Expand Down Expand Up @@ -139,6 +144,7 @@
"TransposedSingleLayerNormalizationPatternInputGenerator",
"TransposedSingleRMSNormalizationPattern",
"TransposedSingleRMSNormalizationPatternInputGenerator",
"UntiedBatchedConstMatMulPattern",
"get_pattern_input_generator",
"get_registered_pattern_input_generators",
"make_single_op_pattern",
Expand Down
9 changes: 7 additions & 2 deletions src/winml/modelkit/pattern/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,14 +2035,19 @@ def _allocate_graph_node_key(node: Any) -> str:
is_constant_map[input_name] = info.is_constant
if info.value is not None:
inputs[input_name] = info.value
elif info.shape is not None:
elif info.shape is not None and all(
isinstance(dim, int) for dim in info.shape
):
# Create a dummy array with the shape for internal constant computation
dtype_str = match_result.type_param_to_type[input_param.type_str]
np_dtype = SupportedONNXType.from_onnx_type(dtype_str).np_type
# For shape computation, create a zero array
inputs[input_name] = np.zeros(info.shape, dtype=np_dtype)
else:
# No shape info, skip this input
# No value, or a symbolic/dynamic shape that cannot be
# materialized into a dummy array: skip this input. A
# target whose get_onnx_model needs concrete operand
# shapes would already have been rejected at match time.
is_constant_map[input_name] = False
else:
is_constant_map[input_name] = False
Expand Down
Loading
Loading