diff --git a/src/winml/modelkit/analyze/core/information_engine.py b/src/winml/modelkit/analyze/core/information_engine.py index d4c574b16..a3cfe1e1f 100644 --- a/src/winml/modelkit/analyze/core/information_engine.py +++ b/src/winml/modelkit/analyze/core/information_engine.py @@ -297,6 +297,7 @@ def _check_model(self) -> list[Information]: self._model, op_runtime_results=self._op_runtime_results, device=self._device, + ep=self._ep, ) manager_init_ms = int((time.perf_counter() - manager_init_start) * 1000) diff --git a/src/winml/modelkit/analyze/core/model_validators/__init__.py b/src/winml/modelkit/analyze/core/model_validators/__init__.py index 77cd7b924..4504269da 100644 --- a/src/winml/modelkit/analyze/core/model_validators/__init__.py +++ b/src/winml/modelkit/analyze/core/model_validators/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .base import ModelValidator +from .batched_const_matmul_validator import BatchedConstMatMulValidator from .constant_folding_validator import ConstantFoldingValidator from .dynamic_input_validator import DynamicInputValidator from .model_validator_manager import ModelValidatorManager @@ -21,6 +22,7 @@ __all__ = [ + "BatchedConstMatMulValidator", "ConstantFoldingValidator", "DynamicInputValidator", "ModelValidator", diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py new file mode 100644 index 000000000..9679bbe1b --- /dev/null +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -0,0 +1,145 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Validator for batched MatMul with a constant operand on OpenVINO GPU. + +OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched +(rank >= 3) MatMul where an operand is a compile-time constant. The identical +gemm with a dynamic operand, and 2D constant gemm, both compile fine. Models +whose batched MatMul weights fold to constants (e.g. transformer disentangled +attention position terms) therefore fail to compile on OpenVINO GPU with: + + [GPU] Failed to select implementation for ... type: gemm + +This validator detects that structural pattern and recommends the +``batchedconstmatmul-untied`` rewrite, which routes the constant operand through +``Add(const, runtime-zero)`` so it becomes runtime-valued and gemm +implementation selection succeeds. The rewrite (source/target patterns) lives in +``modelkit.pattern.batched_const_matmul_patterns``; this validator only supplies +the EP/device gating and the autoconf trigger that enables it. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from ...models.information import Action, ActionItem, ActionLevel, Information +from ...utils import infer_ihv_from_ep_name +from .base import ModelValidator + + +if TYPE_CHECKING: + from ....utils.constants import EPName + from ...models.onnx_model import ONNXModel + from ...models.runtime_checks import PatternRuntime + +logger = logging.getLogger(__name__) + +# Rewrite capability enabled when the pattern is detected. Kebab-case to match +# the capability registry / autoconf normalization; derived from the JSON rule +# (source flag "batchedconstmatmul" + target flag "untied") in +# pattern/rules/default.json. +_REWRITE_FLAG = "batchedconstmatmul-untied" + + +class BatchedConstMatMulValidator(ModelValidator): + """Detect batched MatMul with a constant operand (OpenVINO GPU only).""" + + def __init__( + self, + model: ONNXModel, + op_runtime_results: list[PatternRuntime] | None = None, + ep: EPName | None = None, + device: str | None = None, + ) -> None: + super().__init__(model, op_runtime_results=op_runtime_results) + self.ep = ep + self.device = device + + @property + def validator_name(self) -> str: + """Name of this validator for logging/reporting.""" + return "BatchedConstMatMulValidator" + + @property + def pattern_id(self) -> str: + """Pattern ID for Information objects.""" + return "MODEL/BatchedConstantMatMul" + + def _is_enabled(self) -> bool: + """Only relevant for OpenVINO (Intel IHV) on GPU.""" + if (self.device or "").upper() != "GPU": + return False + if not self.ep: + return False + try: + from ...models.ihv_type import IHVType + + return infer_ihv_from_ep_name(self.ep) == IHVType.INTEL + except Exception: # pragma: no cover - defensive + return False + + def validate(self) -> Information | None: + """Detect batched MatMul with a single constant rank>=3 operand.""" + if not self._is_enabled(): + return None + + # Known gap: constants expressed as `Constant` op nodes (rather than + # graph initializers) are not detected here. The batchedconstmatmul-untied + # rewrite's source pattern matches on initializer operands too, so + # detection and rewrite stay consistent. Most exporters and ORT + # preprocessing emit weights as initializers, so this covers the + # disentangled-attention case + # in practice; `Constant`-node weights would need handling on both sides. + initializers = {init.name for init in self.graph.initializer} + rank_by_init = {init.name: len(init.dims) for init in self.graph.initializer} + + offenders: list[str] = [] + for node in self.graph.node: + if node.op_type != "MatMul" or len(node.input) != 2: + continue + const_inputs = [name for name in node.input if name in initializers] + # Exactly one constant operand (two-constant MatMuls fold away and + # never reach gemm impl selection). + if len(const_inputs) != 1: + continue + if rank_by_init.get(const_inputs[0], 0) >= 3: + offenders.append(node.name or const_inputs[0]) + + if not offenders: + return None + + examples = ", ".join(offenders[:3]) + action = Action( + pattern_from_id="", + pattern_to_id="", + level=ActionLevel.REQUIRED, + status=None, + action_items=[ + ActionItem(type="GraphOptimization", optimization_options={_REWRITE_FLAG: True}) + ], + details=( + "Enable the batchedconstmatmul-untied rewrite so the constant " + "operand becomes runtime-valued and OpenVINO GPU can select a " + "gemm implementation." + ), + ) + # https://github.com/openvinotoolkit/openvino/issues/36272 + explanation = ( + f"Model contains {len(offenders)} batched MatMul(s) with a constant " + f"operand (examples: {examples}). OpenVINO GPU's oneDNN gemm cannot " + f"select an implementation for a batched MatMul with a constant " + f"operand, causing a '[GPU] Failed to select implementation ... gemm' " + f"compile failure. The batchedconstmatmul-untied rewrite makes " + f"the operand runtime-valued without changing numerics. " + f"It is fixed in openvino==2026.2.0, so no need to apply the rewrite " + f"if using that version or later." + ) + return Information( + explanation=explanation, + actions=[action], + pattern_id=self.pattern_id, + status=None, + ) diff --git a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py index 29dd0235c..d658dbe53 100644 --- a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py +++ b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, ClassVar from ...utils.timing_utils import make_timing_logger +from .batched_const_matmul_validator import BatchedConstMatMulValidator from .constant_folding_validator import ConstantFoldingValidator from .dynamic_input_validator import DynamicInputValidator from .pattern_matching_validator import PatternMatchingValidator @@ -64,6 +65,11 @@ class ModelValidatorManager: "class": PatternMatchingValidator, "enabled_devices": None, # All devices }, + "batched_const_matmul": { + "class": BatchedConstMatMulValidator, + "enabled_devices": ["GPU"], # OpenVINO GPU gemm impl-selection issue + "needs_context": True, # validator self-gates on EP (Intel IHV) + }, } def __init__( @@ -72,6 +78,7 @@ def __init__( enabled_validators: list[str] | None = None, op_runtime_results: list[PatternRuntime] | None = None, device: str | None = None, + ep: str | None = None, ) -> None: """Initialize validator manager. @@ -92,6 +99,7 @@ def __init__( self.model_proto = model.get_model() self.op_runtime_results = op_runtime_results or [] self.device = device or "NPU" + self.ep = ep self.enabled_validators = enabled_validators or list(self.VALIDATORS.keys()) # Instantiate enabled validators @@ -102,18 +110,24 @@ def __init__( validator_class = validator_config["class"] enabled_devices = validator_config.get("enabled_devices") - # Check device constraint - if enabled_devices is not None and self.device not in enabled_devices: + # Check device constraint (case-insensitive: callers may pass + # "gpu" or "GPU" depending on the build/analyze entry point). + if enabled_devices is not None and (self.device or "").upper() not in { + d.upper() for d in enabled_devices + }: logger.info( f"Validator '{name}' is not enabled for device '{self.device}'. " f"Only enabled for: {enabled_devices}" ) continue + ctor_kwargs: dict = {"op_runtime_results": self.op_runtime_results} + if validator_config.get("needs_context"): + ctor_kwargs["ep"] = self.ep + ctor_kwargs["device"] = self.device + try: - self.validators.append( - validator_class(self.model, op_runtime_results=self.op_runtime_results) - ) + self.validators.append(validator_class(self.model, **ctor_kwargs)) logger.debug(f"Initialized validator: {name}") except Exception: logger.exception(f"Failed to initialize validator {name}") diff --git a/src/winml/modelkit/pattern/__init__.py b/src/winml/modelkit/pattern/__init__.py index 38f0112c4..b6138887e 100644 --- a/src/winml/modelkit/pattern/__init__.py +++ b/src/winml/modelkit/pattern/__init__.py @@ -30,6 +30,10 @@ opschema_to_pattern_schema, register_pattern_input_generator, ) +from .batched_const_matmul_patterns import ( + BatchedConstMatMulPattern, + UntiedBatchedConstMatMulPattern, +) from .conv2d_inplace_linear_patterns import ( Conv2DInplaceLinear2DPattern, Conv2DInplaceLinear2DPatternInputGenerator, @@ -85,6 +89,7 @@ __all__ = [ "MATMUL_ADD_SCHEMA", + "BatchedConstMatMulPattern", "Conv2DInplaceLinear2DPattern", "Conv2DInplaceLinear2DPatternInputGenerator", "Conv2DInplaceLinear3DPattern", @@ -139,6 +144,7 @@ "TransposedSingleLayerNormalizationPatternInputGenerator", "TransposedSingleRMSNormalizationPattern", "TransposedSingleRMSNormalizationPatternInputGenerator", + "UntiedBatchedConstMatMulPattern", "get_pattern_input_generator", "get_registered_pattern_input_generators", "make_single_op_pattern", diff --git a/src/winml/modelkit/pattern/base.py b/src/winml/modelkit/pattern/base.py index 245e5d51b..611e13d41 100644 --- a/src/winml/modelkit/pattern/base.py +++ b/src/winml/modelkit/pattern/base.py @@ -2035,14 +2035,19 @@ def _allocate_graph_node_key(node: Any) -> str: is_constant_map[input_name] = info.is_constant if info.value is not None: inputs[input_name] = info.value - elif info.shape is not None: + elif info.shape is not None and all( + isinstance(dim, int) for dim in info.shape + ): # Create a dummy array with the shape for internal constant computation dtype_str = match_result.type_param_to_type[input_param.type_str] np_dtype = SupportedONNXType.from_onnx_type(dtype_str).np_type # For shape computation, create a zero array inputs[input_name] = np.zeros(info.shape, dtype=np_dtype) else: - # No shape info, skip this input + # No value, or a symbolic/dynamic shape that cannot be + # materialized into a dummy array: skip this input. A + # target whose get_onnx_model needs concrete operand + # shapes would already have been rejected at match time. is_constant_map[input_name] = False else: is_constant_map[input_name] = False diff --git a/src/winml/modelkit/pattern/batched_const_matmul_patterns.py b/src/winml/modelkit/pattern/batched_const_matmul_patterns.py new file mode 100644 index 000000000..c122b4a58 --- /dev/null +++ b/src/winml/modelkit/pattern/batched_const_matmul_patterns.py @@ -0,0 +1,322 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Pattern + rewrite for batched MatMul with a constant operand on OpenVINO GPU. + +OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched +(rank >= 3) MatMul where an operand is a compile-time constant. The identical +gemm with a dynamic operand, and 2D constant gemm, both compile fine. Models +whose batched MatMul weights fold to constants (e.g. transformer disentangled +attention position terms) therefore fail to compile on OpenVINO GPU with: + + [GPU] Failed to select implementation for ... type: gemm + +The rewrite makes the constant operand runtime-valued without changing numerics: +it routes the constant through ``Add(const, zero)`` where ``zero`` is a ``[1]`` +runtime tensor derived from the MatMul's *own dynamic operand* +(``Reshape([-1]) -> Slice([0:1]) -> Sub(elem, elem) == 0``). Because ``zero`` is +data-dependent, OpenVINO's constant folder cannot collapse the ``Add`` back into +a packed gemm weight, yet ``+ 0`` leaves the values unchanged and the single +batched MatMul is preserved (no per-head decomposition, no perf regression). + +Deriving ``zero`` from the dynamic operand (rather than a graph input) keeps the +replacement *local*: the rewriter only wires the target's nodes to the matched +subgraph's own boundary tensors. The two MatMul operands share a dtype (ONNX +MatMul requires it), so ``zero`` automatically matches the constant's dtype with +no Cast. + +The source pattern matches a bare ``MatMul`` and validates the constant-operand +structure in ``check_skeleton_result``; it deliberately does **not** call the +base implementation, which rejects matches whose non-constant input has +symbolic/dynamic dimensions — exactly the activation operand here. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +from onnx.defs import OpSchema + +from ..onnx import ONNXDomain +from .base import Pattern, PatternSchema, Skeleton +from .match import PatternMatchResult + + +if TYPE_CHECKING: + from onnx import ModelProto + + from .match import SkeletonMatchResult + + +# Minimum operand rank that triggers the OpenVINO GPU gemm impl-selection failure. +_MIN_BATCHED_RANK = 3 + + +# Source and target share this schema so PatternRewriter's schema-equality +# assertion holds (a MatMul: two same-typed operands -> one output). +_BATCHED_CONST_MATMUL_SCHEMA = PatternSchema( + name="BatchedConstMatMulPattern", + doc=( + "Batched (rank >= 3) MatMul with exactly one constant operand.\n" + "Computes Y = MatMul(A, B) where one of A/B is a compile-time constant " + "of rank >= 3 and the other is runtime-valued. Targeted by the untie " + "rewrite for OpenVINO GPU, whose oneDNN gemm cannot select an " + "implementation for this shape." + ), + type_constraints=[ + OpSchema.TypeConstraintParam( + type_param_str="T", + allowed_type_strs=[ + "tensor(float16)", + "tensor(float)", + "tensor(double)", + ], + description="Constrain operands and output to float tensors.", + ) + ], + inputs=[ + OpSchema.FormalParameter( + name="A", + type_str="T", + description="First MatMul operand (constant or runtime).", + param_option=OpSchema.FormalParameterOption.Single, + is_homogeneous=True, + min_arity=1, + ), + OpSchema.FormalParameter( + name="B", + type_str="T", + description="Second MatMul operand (constant or runtime).", + param_option=OpSchema.FormalParameterOption.Single, + is_homogeneous=True, + min_arity=1, + ), + ], + outputs=[ + OpSchema.FormalParameter( + name="Y", + type_str="T", + description="MatMul output.", + param_option=OpSchema.FormalParameterOption.Single, + is_homogeneous=True, + min_arity=1, + ) + ], +) + + +class BatchedConstMatMulPattern(Pattern): + """Source: a MatMul with exactly one rank->=3 constant operand.""" + + def get_skeleton(self) -> Skeleton: + """Return a single-MatMul skeleton with two virtual inputs.""" + return Skeleton( + node_op_types=["MatMul"], + node_domains=[ONNXDomain.AI_ONNX], + edges=[ + (-1, 0, 0, 0), # input A -> MatMul[0] + (-2, 0, 0, 1), # input B -> MatMul[1] + ], + exit_nodes=[0], + n_inputs=2, + ) + + def get_internal_constants_and_attributes( + self, + inputs: dict[str, np.ndarray], + attributes: dict[str, Any], + is_constant_map: dict[str, bool], + domain_versions: dict[ONNXDomain, int], + ) -> tuple[list[tuple[int, int, np.ndarray]], dict[tuple[int, str], Any]]: + """No internal constants or attributes for a bare MatMul.""" + return [], {} + + def get_schema(self) -> PatternSchema: + """Return the shared batched-const-MatMul schema.""" + return _BATCHED_CONST_MATMUL_SCHEMA + + def check_skeleton_result( + self, skeleton_match_result: SkeletonMatchResult + ) -> PatternMatchResult | None: + """Accept only a MatMul with exactly one rank->=3 constant operand. + + This does not call ``super().check_skeleton_result``: the base + implementation rejects matches whose non-constant input carries + symbolic/dynamic dimensions, which is precisely the activation operand + of a transformer MatMul. The structural predicate here depends only on + the constant operand's rank, so symbolic activation dims are fine. + """ + input_infos = self._build_input_infos(skeleton_match_result) + + const_names = [name for name, info in input_infos.items() if info.is_constant] + # Exactly one constant operand: all-constant MatMuls fold away entirely + # and never reach gemm impl selection; zero-constant MatMuls are unaffected. + if len(const_names) != 1: + return None + + const_info = input_infos[const_names[0]] + rank = _operand_rank(const_info) + if rank is None or rank < _MIN_BATCHED_RANK: + return None + + # Build the PatternMatchResult directly (mirrors the tail of the base + # implementation, minus the symbolic-dim rejection). + schema = self.get_schema() + type_param_to_type = self._infer_type_mapping(skeleton_match_result) + schema_input_to_value = { + param.name: skeleton_match_result.inputs[idx] + for idx, param in enumerate(schema.inputs) + if idx < len(skeleton_match_result.inputs) + } + schema_output_to_value = {} + if schema.outputs and skeleton_match_result.output: + schema_output_to_value[schema.outputs[0].name] = skeleton_match_result.output + + return PatternMatchResult( + skeleton_match_result=skeleton_match_result, + schema_input_to_value=schema_input_to_value, + schema_output_to_value=schema_output_to_value, + type_param_to_type=type_param_to_type, + attributes={}, + input_infos=input_infos, + ) + + +class UntiedBatchedConstMatMulPattern(Pattern): + """Target: MatMul with the constant operand routed through ``Add(const, zero)``. + + ``zero`` is a ``[1]`` runtime tensor derived from the dynamic operand, so the + rewrite stays local and OpenVINO's constant folder cannot repack the operand + into a gemm weight. ``get_onnx_model`` is overridden to emit the subgraph, + since which operand is constant is only known at rewrite time. + """ + + def get_skeleton(self) -> Skeleton: + """Return a representative skeleton (unused by the rewriter). + + The replacement is built in :meth:`get_onnx_model`; this skeleton exists + only to satisfy the abstract base and documents the canonical RHS-constant + topology ``Reshape -> Slice -> Sub -> Add -> MatMul``. + """ + return Skeleton( + node_op_types=["Reshape", "Slice", "Sub", "Add", "MatMul"], + node_domains=[ONNXDomain.AI_ONNX] * 5, + edges=[ + (-1, 0, 0, 0), # dynamic A -> Reshape[0] + (0, 0, 1, 0), # Reshape -> Slice[0] + (1, 0, 2, 0), # Slice -> Sub[0] + (1, 0, 2, 1), # Slice -> Sub[1] + (-2, 0, 3, 0), # constant B -> Add[0] + (2, 0, 3, 1), # Sub(zero) -> Add[1] + (-1, 0, 4, 0), # dynamic A -> MatMul[0] + (3, 0, 4, 1), # Add(untied B) -> MatMul[1] + ], + exit_nodes=[4], + n_inputs=2, + ) + + def get_internal_constants_and_attributes( + self, + inputs: dict[str, np.ndarray], + attributes: dict[str, Any], + is_constant_map: dict[str, bool], + domain_versions: dict[ONNXDomain, int], + ) -> tuple[list[tuple[int, int, np.ndarray]], dict[tuple[int, str], Any]]: + """No declarative constants; the subgraph is built in get_onnx_model.""" + return [], {} + + def get_schema(self) -> PatternSchema: + """Return the shared batched-const-MatMul schema.""" + return _BATCHED_CONST_MATMUL_SCHEMA + + def get_onnx_model( + self, + inputs: dict[str, np.ndarray], + attributes: dict[str, Any], + is_constant_map: dict[str, bool], + output_dtypes: list[str], + domain_versions: dict[ONNXDomain, int], + prefix: str = "", + input_names: list[str] | None = None, + output_names: list[str] | None = None, + ) -> ModelProto: + """Emit ``MatMul(dyn, Add(const, zero(dyn)))`` preserving operand order. + + The constant operand (per ``is_constant_map``) is routed through + ``Add(const, zero)``; ``zero`` is a ``[1]`` runtime tensor built from the + dynamic operand. Operand slots are preserved so the MatMul semantics are + unchanged. + """ + from onnx import helper, numpy_helper + + schema = self.get_schema() + if input_names is None: + input_names = [param.name for param in schema.inputs] + if output_names is None: + output_names = [param.name for param in schema.outputs] + + # Identify which operand is the constant (schema order: A=0, B=1). + param_names = [param.name for param in schema.inputs] + const_idx = 0 if is_constant_map.get(param_names[0]) else 1 + dyn_idx = 1 - const_idx + const_name = input_names[const_idx] + dyn_name = input_names[dyn_idx] + out_name = output_names[0] + + nodes = [] + initializers = [] + + def _init(arr: np.ndarray, name: str) -> str: + initializers.append(numpy_helper.from_array(arr, name)) + return name + + neg1 = _init(np.array([-1], dtype=np.int64), f"{prefix}neg1") + starts = _init(np.array([0], dtype=np.int64), f"{prefix}slice_starts") + ends = _init(np.array([1], dtype=np.int64), f"{prefix}slice_ends") + axis = _init(np.array([0], dtype=np.int64), f"{prefix}slice_axis") + + flat = f"{prefix}flat" + elem = f"{prefix}elem" + zero = f"{prefix}zero" + untied = f"{prefix}untied" + + # flat = Reshape(dyn, [-1]); elem = Slice(flat, [0:1]); zero = elem - elem. + # dyn and const share a dtype (ONNX MatMul), so zero needs no Cast. + nodes.append(helper.make_node("Reshape", [dyn_name, neg1], [flat], name=f"{prefix}Reshape")) + nodes.append( + helper.make_node("Slice", [flat, starts, ends, axis], [elem], name=f"{prefix}Slice") + ) + nodes.append(helper.make_node("Sub", [elem, elem], [zero], name=f"{prefix}Sub")) + nodes.append(helper.make_node("Add", [const_name, zero], [untied], name=f"{prefix}Add")) + + # Preserve original operand order in the rebuilt MatMul. + mm_inputs = [untied, dyn_name] if const_idx == 0 else [dyn_name, untied] + nodes.append(helper.make_node("MatMul", mm_inputs, [out_name], name=f"{prefix}MatMul")) + + opset_imports = [ + helper.make_opsetid(domain.schema_domain, version) + for domain, version in domain_versions.items() + ] + graph = helper.make_graph( + nodes=nodes, + name=f"{prefix}untied_batched_const_matmul", + inputs=[], + outputs=[], + initializer=initializers, + ) + model = helper.make_model( + graph, producer_name="winmlcli-pattern-generator", opset_imports=opset_imports + ) + model.ir_version = 11 + return model + + +def _operand_rank(info: Any) -> int | None: + """Return an operand's rank from its InputInfo (shape, else constant value).""" + if info.shape is not None: + return len(info.shape) + if info.value is not None: + return int(info.value.ndim) + return None diff --git a/src/winml/modelkit/pattern/rules/default.json b/src/winml/modelkit/pattern/rules/default.json index bb9794e2e..8dddf6a70 100644 --- a/src/winml/modelkit/pattern/rules/default.json +++ b/src/winml/modelkit/pattern/rules/default.json @@ -180,6 +180,24 @@ "reason": "Merged axes reduce Transpose dimensionality for better hardware compatibility" } ] + }, + { + "pattern_id": "SUBGRAPH/BatchedConstMatMulPattern", + "pattern_class": "BatchedConstMatMulPattern", + "module": "winml.modelkit.pattern.batched_const_matmul_patterns", + "enabled": false, + "flag_name": "batchedconstmatmul", + "description": "Batched (rank >= 3) MatMul with one constant operand (OpenVINO GPU gemm impl-selection workaround)", + "alternatives": [ + { + "pattern_to_id": "SUBGRAPH/UntiedBatchedConstMatMulPattern", + "pattern_class": "UntiedBatchedConstMatMulPattern", + "module": "winml.modelkit.pattern.batched_const_matmul_patterns", + "priority": 1, + "flag_name": "untied", + "reason": "Route the constant operand through Add(const, runtime-zero) so OpenVINO GPU can select a gemm implementation" + } + ] } ] } diff --git a/tests/unit/analyze/core/model_validators/test_validators.py b/tests/unit/analyze/core/model_validators/test_validators.py index ddc83c7be..437e166a2 100644 --- a/tests/unit/analyze/core/model_validators/test_validators.py +++ b/tests/unit/analyze/core/model_validators/test_validators.py @@ -379,3 +379,63 @@ def test_unknown_validator_logs_warning(self, caplog): assert len(manager.validators) == 1 assert manager.validators[0].validator_name == "ConstantFoldingValidator" assert "Unknown validator" in caplog.text + + +def _make_batched_const_matmul_proto(const_rank: int = 3): + """Model: data [2,3,4] @ W(const) [2,4,5] -> out [2,3,5].""" + import numpy as np + from onnx import numpy_helper + + w_shape = [2, 4, 5] if const_rank == 3 else [4, 5] + w = numpy_helper.from_array(np.zeros(w_shape, dtype=np.float32), "W") + matmul = helper.make_node("MatMul", ["data", "W"], ["out"], name="batched_matmul") + graph = helper.make_graph( + [matmul], + "batched_const_matmul", + [helper.make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4])], + [helper.make_tensor_value_info("out", TensorProto.FLOAT, [2, 3, 5])], + initializer=[w], + ) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +class TestBatchedConstMatMulValidator: + """OpenVINO-GPU batched constant MatMul detector.""" + + def _validate(self, proto, ep, device): + from winml.modelkit.analyze.core.model_validators import BatchedConstMatMulValidator + + model = create_onnx_model_wrapper(proto) + return BatchedConstMatMulValidator(model, ep=ep, device=device).validate() + + def test_detects_for_openvino_gpu(self): + """Emits a GraphOptimization action enabling the untie rewrite for OV GPU.""" + info = self._validate(_make_batched_const_matmul_proto(), "openvino", "GPU") + assert info is not None + assert info.pattern_id == "MODEL/BatchedConstantMatMul" + items = info.actions[0].action_items + assert items[0].type == "GraphOptimization" + assert items[0].optimization_options == {"batchedconstmatmul-untied": True} + + def test_skipped_for_openvino_npu(self): + """Device-gated: NPU is unaffected.""" + assert self._validate(_make_batched_const_matmul_proto(), "openvino", "NPU") is None + + def test_skipped_for_non_intel_gpu(self): + """IHV-gated: a non-Intel GPU EP is unaffected.""" + info = self._validate(_make_batched_const_matmul_proto(), "DmlExecutionProvider", "GPU") + assert info is None + + def test_skipped_for_two_dim_constant(self): + """Rank-2 constant gemm compiles on OV GPU; not flagged.""" + info = self._validate(_make_batched_const_matmul_proto(const_rank=2), "openvino", "GPU") + assert info is None + + def test_manager_wires_validator_for_openvino_gpu(self): + """Manager constructs the validator and surfaces the action for OV GPU.""" + model = create_onnx_model_wrapper(_make_batched_const_matmul_proto()) + manager = ModelValidatorManager(model, device="GPU", ep="openvino") + names = [v.validator_name for v in manager.validators] + assert "BatchedConstMatMulValidator" in names + infos = manager.run_all_validators() + assert any(i.pattern_id == "MODEL/BatchedConstantMatMul" for i in infos) diff --git a/tests/unit/optim/pipes/test_pipe_rewrite_batched_const_matmul.py b/tests/unit/optim/pipes/test_pipe_rewrite_batched_const_matmul.py new file mode 100644 index 000000000..711e79aca --- /dev/null +++ b/tests/unit/optim/pipes/test_pipe_rewrite_batched_const_matmul.py @@ -0,0 +1,145 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""RewritePipe tests for the batchedconstmatmul-untied rewrite. + +Exercises the source/target patterns in +``winml.modelkit.pattern.batched_const_matmul_patterns`` end-to-end through +``RewritePipe``: a batched (rank >= 3) MatMul with one constant operand is +rewritten so the constant flows through ``Add(const, runtime-zero)``, leaving +numerics unchanged while making the operand runtime-valued. + +Cardinal Rules followed: +- #1: No hardcoded model architectures or node names. +- #2: All expected results generated by code at runtime (ORT reference run). +- #3: Tests run and pass under pytest. +""" + +from __future__ import annotations + +import numpy as np +import onnx +import onnxruntime as ort + +from winml.modelkit.optim.pipes.rewrite import RewritePipe + + +_CAP_KWARG = "batchedconstmatmul_untied" + + +def _build_model( + *, + const_on_rhs: bool, + const_rank: int = 3, + dynamic_batch: bool = True, + both_constant: bool = False, +) -> onnx.ModelProto: + """Build a single-MatMul model with one rank-``const_rank`` constant operand. + + The dynamic operand carries a symbolic batch dim when ``dynamic_batch`` is + True, which is the case the base pattern matcher would normally reject. + """ + rng = np.random.RandomState(0) + batch: int | str = "batch" if dynamic_batch else 2 + + if const_on_rhs: + # dyn [batch,3,4] @ const W -> [batch,3,5] + dyn_shape = [batch, 3, 4] + w_arr = rng.randn(1, 4, 5).astype(np.float32) + if const_rank == 2: + w_arr = w_arr[0] # [4,5] + out_shape = [batch, 3, 5] + mm_inputs = ["dyn", "W"] + else: + # const W @ dyn [batch,4,5] -> [batch,3,5] + dyn_shape = [batch, 4, 5] + w_arr = rng.randn(1, 3, 4).astype(np.float32) + if const_rank == 2: + w_arr = w_arr[0] # [3,4] + out_shape = [batch, 3, 5] + mm_inputs = ["W", "dyn"] + + initializers = [onnx.numpy_helper.from_array(w_arr, "W")] + graph_inputs = [ + onnx.helper.make_tensor_value_info("dyn", onnx.TensorProto.FLOAT, dyn_shape) + ] + + if both_constant: + # Replace the dynamic operand with a second constant of matching shape. + dyn_concrete = [d if isinstance(d, int) else 2 for d in dyn_shape] + initializers.append( + onnx.numpy_helper.from_array(rng.randn(*dyn_concrete).astype(np.float32), "dyn") + ) + graph_inputs = [] + + matmul = onnx.helper.make_node("MatMul", mm_inputs, ["out"], name="batched_matmul") + graph = onnx.helper.make_graph( + [matmul], + "batched_const_matmul", + graph_inputs, + [onnx.helper.make_tensor_value_info("out", onnx.TensorProto.FLOAT, out_shape)], + initializer=initializers, + ) + return onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 17)]) + + +def _run(model: onnx.ModelProto, feed: dict[str, np.ndarray]) -> np.ndarray: + sess = ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) + return sess.run(None, feed)[0] + + +def _rewrite(model: onnx.ModelProto) -> onnx.ModelProto: + pipe = RewritePipe() + return pipe.process(model, pipe.build_config(**{_CAP_KWARG: True})) + + +class TestBatchedConstMatMulUntieRewrite: + """End-to-end rewrite behavior via RewritePipe.""" + + def test_constant_operand_becomes_runtime_valued(self) -> None: + """The MatMul no longer consumes the initializer directly; an Add feeds it.""" + model = _build_model(const_on_rhs=True) + result = _rewrite(model) + + matmul = next(n for n in result.graph.node if n.op_type == "MatMul") + initializer_names = {init.name for init in result.graph.initializer} + assert not (set(matmul.input) & initializer_names) + add_nodes = [n for n in result.graph.node if n.op_type == "Add"] + assert len(add_nodes) == 1 + assert add_nodes[0].output[0] in matmul.input + onnx.checker.check_model(result) + + def test_numerics_unchanged_rhs_constant(self) -> None: + """+0 untie leaves outputs identical (constant on RHS, dynamic batch).""" + model = _build_model(const_on_rhs=True) + result = _rewrite(model) + feed = {"dyn": np.random.RandomState(7).randn(2, 3, 4).astype(np.float32)} + np.testing.assert_array_equal(_run(model, feed), _run(result, feed)) + + def test_numerics_unchanged_lhs_constant(self) -> None: + """Constant on the LHS is untied and stays numerically identical.""" + model = _build_model(const_on_rhs=False) + result = _rewrite(model) + assert any(n.op_type == "Add" for n in result.graph.node) + feed = {"dyn": np.random.RandomState(7).randn(2, 4, 5).astype(np.float32)} + np.testing.assert_array_equal(_run(model, feed), _run(result, feed)) + + def test_symbolic_dynamic_dims_are_matched(self) -> None: + """The dynamic operand may carry symbolic dims (the real transformer case).""" + model = _build_model(const_on_rhs=True, dynamic_batch=True) + result = _rewrite(model) + # Match succeeded despite the symbolic batch dim on the activation operand. + assert any(n.op_type == "Add" for n in result.graph.node) + + def test_rank2_constant_is_left_untouched(self) -> None: + """A rank-2 constant gemm compiles on OV GPU, so it must not be rewritten.""" + model = _build_model(const_on_rhs=True, const_rank=2) + result = _rewrite(model) + assert not any(n.op_type == "Add" for n in result.graph.node) + + def test_all_constant_matmul_is_left_untouched(self) -> None: + """A MatMul whose operands are both constant folds away; do not rewrite it.""" + model = _build_model(const_on_rhs=True, both_constant=True) + result = _rewrite(model) + assert not any(n.op_type == "Add" for n in result.graph.node)