From 9b42f68fcf886426150eea051e707e4ad8c5d6ba Mon Sep 17 00:00:00 2001 From: hualxie Date: Fri, 5 Jun 2026 14:58:30 +0800 Subject: [PATCH 1/8] fix(optim): untie batched constant MatMul for OpenVINO GPU OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched (rank >= 3) MatMul where an operand is a compile-time constant; the same gemm with a dynamic operand, and 2D constant gemm, both compile fine. Transformer disentangled-attention position terms (e.g. DeBERTa) fold to 3D constants and fail to compile with: [GPU] Failed to select implementation for ... type: gemm (compile_graph.cpp:59 selected_impl == nullptr) Add an EP-gated `untie-constant-batched-matmul` surgery that routes the constant operand through Add(const, zero), where zero is a data-dependent runtime [1] tensor (Cast -> Reshape(-1) -> Slice[0:1] -> Sub). This makes the operand runtime-valued so OV's constant folder cannot repack it into a gemm weight, while keeping the single batched MatMul (no perf regression) and leaving numerics unchanged (+0). Wired via autoconf: BatchedConstMatMulValidator detects the pattern and, gated to Intel IHV + GPU, emits a GraphOptimization opportunity the existing autoconf loop auto-applies. Pattern-based, architecture-agnostic. Also makes the model-validator device filter case-insensitive so builds that pass lowercase "gpu" are matched. --- .../analyze/core/information_engine.py | 1 + .../analyze/core/model_validators/__init__.py | 2 + .../batched_const_matmul_validator.py | 129 +++++++++++++++ .../model_validator_manager.py | 24 ++- .../modelkit/optim/capabilities/surgery.py | 17 ++ src/winml/modelkit/optim/pipes/surgery.py | 148 +++++++++++++++++- .../core/model_validators/test_validators.py | 60 +++++++ tests/unit/optim/pipes/test_pipe_surgery.py | 117 ++++++++++++++ 8 files changed, 492 insertions(+), 6 deletions(-) create mode 100644 src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py diff --git a/src/winml/modelkit/analyze/core/information_engine.py b/src/winml/modelkit/analyze/core/information_engine.py index d4c574b16..a3cfe1e1f 100644 --- a/src/winml/modelkit/analyze/core/information_engine.py +++ b/src/winml/modelkit/analyze/core/information_engine.py @@ -297,6 +297,7 @@ def _check_model(self) -> list[Information]: self._model, op_runtime_results=self._op_runtime_results, device=self._device, + ep=self._ep, ) manager_init_ms = int((time.perf_counter() - manager_init_start) * 1000) diff --git a/src/winml/modelkit/analyze/core/model_validators/__init__.py b/src/winml/modelkit/analyze/core/model_validators/__init__.py index 77cd7b924..4504269da 100644 --- a/src/winml/modelkit/analyze/core/model_validators/__init__.py +++ b/src/winml/modelkit/analyze/core/model_validators/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .base import ModelValidator +from .batched_const_matmul_validator import BatchedConstMatMulValidator from .constant_folding_validator import ConstantFoldingValidator from .dynamic_input_validator import DynamicInputValidator from .model_validator_manager import ModelValidatorManager @@ -21,6 +22,7 @@ __all__ = [ + "BatchedConstMatMulValidator", "ConstantFoldingValidator", "DynamicInputValidator", "ModelValidator", diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py new file mode 100644 index 000000000..9c9200546 --- /dev/null +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -0,0 +1,129 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Validator for batched MatMul with a constant operand on OpenVINO GPU. + +OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched +(rank >= 3) MatMul where an operand is a compile-time constant. The identical +gemm with a dynamic operand, and 2D constant gemm, both compile fine. Models +whose batched MatMul weights fold to constants (e.g. transformer disentangled +attention position terms) therefore fail to compile on OpenVINO GPU with: + + [GPU] Failed to select implementation for ... type: gemm + +This validator detects that structural pattern and recommends the +``untie-constant-batched-matmul`` surgery, which makes the constant operand +runtime-valued so gemm implementation selection succeeds. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from ...models.information import Action, ActionItem, ActionLevel, Information +from ...utils import infer_ihv_from_ep_name +from .base import ModelValidator + + +if TYPE_CHECKING: + from ...models.onnx_model import ONNXModel + from ...models.runtime_checks import PatternRuntime + +logger = logging.getLogger(__name__) + +# Surgery capability enabled when the pattern is detected (kebab-case to match +# the capability registry / autoconf normalization). +_SURGERY_FLAG = "untie-constant-batched-matmul" + + +class BatchedConstMatMulValidator(ModelValidator): + """Detect batched MatMul with a constant operand (OpenVINO GPU only).""" + + def __init__( + self, + model: ONNXModel, + op_runtime_results: list[PatternRuntime] | None = None, + ep: str | None = None, + device: str | None = None, + ) -> None: + super().__init__(model, op_runtime_results=op_runtime_results) + self.ep = ep + self.device = device + + @property + def validator_name(self) -> str: + """Name of this validator for logging/reporting.""" + return "BatchedConstMatMulValidator" + + @property + def pattern_id(self) -> str: + """Pattern ID for Information objects.""" + return "MODEL/BatchedConstantMatMul" + + def _is_enabled(self) -> bool: + """Only relevant for OpenVINO (Intel IHV) on GPU.""" + if (self.device or "").upper() != "GPU": + return False + if not self.ep: + return False + try: + from ...models.ihv_type import IHVType + + return infer_ihv_from_ep_name(self.ep) == IHVType.INTEL + except Exception: # pragma: no cover - defensive + return False + + def validate(self) -> Information | None: + """Detect batched MatMul with a single constant rank>=3 operand.""" + if not self._is_enabled(): + return None + + initializers = {init.name for init in self.graph.initializer} + rank_by_init = {init.name: len(init.dims) for init in self.graph.initializer} + + offenders: list[str] = [] + for node in self.graph.node: + if node.op_type != "MatMul" or len(node.input) != 2: + continue + const_inputs = [name for name in node.input if name in initializers] + # Exactly one constant operand (two-constant MatMuls fold away and + # never reach gemm impl selection). + if len(const_inputs) != 1: + continue + if rank_by_init.get(const_inputs[0], 0) >= 3: + offenders.append(node.name or const_inputs[0]) + + if not offenders: + return None + + examples = ", ".join(offenders[:3]) + action = Action( + pattern_from_id="", + pattern_to_id="", + level=ActionLevel.REQUIRED, + status=None, + action_items=[ + ActionItem(type="GraphOptimization", optimization_options={_SURGERY_FLAG: True}) + ], + details=( + "Enable untie-constant-batched-matmul surgery so the constant " + "operand becomes runtime-valued and OpenVINO GPU can select a " + "gemm implementation." + ), + ) + explanation = ( + f"Model contains {len(offenders)} batched MatMul(s) with a constant " + f"operand (examples: {examples}). OpenVINO GPU's oneDNN gemm cannot " + f"select an implementation for a batched MatMul with a constant " + f"operand, causing a '[GPU] Failed to select implementation ... gemm' " + f"compile failure. The untie-constant-batched-matmul surgery makes " + f"the operand runtime-valued without changing numerics." + ) + return Information( + explanation=explanation, + actions=[action], + pattern_id=self.pattern_id, + status=None, + ) diff --git a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py index 29dd0235c..d658dbe53 100644 --- a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py +++ b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, ClassVar from ...utils.timing_utils import make_timing_logger +from .batched_const_matmul_validator import BatchedConstMatMulValidator from .constant_folding_validator import ConstantFoldingValidator from .dynamic_input_validator import DynamicInputValidator from .pattern_matching_validator import PatternMatchingValidator @@ -64,6 +65,11 @@ class ModelValidatorManager: "class": PatternMatchingValidator, "enabled_devices": None, # All devices }, + "batched_const_matmul": { + "class": BatchedConstMatMulValidator, + "enabled_devices": ["GPU"], # OpenVINO GPU gemm impl-selection issue + "needs_context": True, # validator self-gates on EP (Intel IHV) + }, } def __init__( @@ -72,6 +78,7 @@ def __init__( enabled_validators: list[str] | None = None, op_runtime_results: list[PatternRuntime] | None = None, device: str | None = None, + ep: str | None = None, ) -> None: """Initialize validator manager. @@ -92,6 +99,7 @@ def __init__( self.model_proto = model.get_model() self.op_runtime_results = op_runtime_results or [] self.device = device or "NPU" + self.ep = ep self.enabled_validators = enabled_validators or list(self.VALIDATORS.keys()) # Instantiate enabled validators @@ -102,18 +110,24 @@ def __init__( validator_class = validator_config["class"] enabled_devices = validator_config.get("enabled_devices") - # Check device constraint - if enabled_devices is not None and self.device not in enabled_devices: + # Check device constraint (case-insensitive: callers may pass + # "gpu" or "GPU" depending on the build/analyze entry point). + if enabled_devices is not None and (self.device or "").upper() not in { + d.upper() for d in enabled_devices + }: logger.info( f"Validator '{name}' is not enabled for device '{self.device}'. " f"Only enabled for: {enabled_devices}" ) continue + ctor_kwargs: dict = {"op_runtime_results": self.op_runtime_results} + if validator_config.get("needs_context"): + ctor_kwargs["ep"] = self.ep + ctor_kwargs["device"] = self.device + try: - self.validators.append( - validator_class(self.model, op_runtime_results=self.op_runtime_results) - ) + self.validators.append(validator_class(self.model, **ctor_kwargs)) logger.debug(f"Initialized validator: {name}") except Exception: logger.exception(f"Failed to initialize validator {name}") diff --git a/src/winml/modelkit/optim/capabilities/surgery.py b/src/winml/modelkit/optim/capabilities/surgery.py index 8b2048f00..0b6ec0768 100644 --- a/src/winml/modelkit/optim/capabilities/surgery.py +++ b/src/winml/modelkit/optim/capabilities/surgery.py @@ -37,3 +37,20 @@ category=CapabilityCategory.SURGERY, default=False, ) + +# Route a constant operand of a batched (rank >= 3) MatMul through a runtime +# no-op so it is no longer a compile-time constant. OpenVINO GPU's oneDNN gemm +# cannot select an implementation for a batched MatMul with a constant operand +# (e.g. transformer disentangled-attention position terms that fold to 3D +# constants); making the operand runtime-valued lets gemm impl selection +# succeed without changing numerics or splitting the batched op. +UNTIE_CONSTANT_BATCHED_MATMUL = BoolCapability( + name="untie-constant-batched-matmul", + ort_name=None, # Custom implementation, not ORT optimizer + description=( + "Make a batched MatMul's constant operand runtime-valued so OpenVINO " + "GPU can select a gemm implementation" + ), + category=CapabilityCategory.SURGERY, + default=False, +) diff --git a/src/winml/modelkit/optim/pipes/surgery.py b/src/winml/modelkit/optim/pipes/surgery.py index fa4fa6bcf..fed9b6ce6 100644 --- a/src/winml/modelkit/optim/pipes/surgery.py +++ b/src/winml/modelkit/optim/pipes/surgery.py @@ -38,6 +38,7 @@ SURGERY_CAPABILITIES: dict[str, Any] = caps_dict( surgery.CLAMP_CONSTANT_VALUES, surgery.REMOVE_ISNAN_IN_ATTENTION_MASK, + surgery.UNTIE_CONSTANT_BATCHED_MATMUL, ) @@ -57,6 +58,8 @@ class SurgeryPipeConfig(PipeConfig): fix_nan_attention_mask: Replace -inf attention mask with finite value and remove Softmax->IsNaN->Where NaN guard patterns mask_value: Replacement value for -inf (default: -1e3) + untie_constant_batched_matmul: Make a batched MatMul's constant operand + runtime-valued so OpenVINO GPU can select a gemm implementation verbose: Enable verbose logging """ @@ -64,6 +67,7 @@ class SurgeryPipeConfig(PipeConfig): clamp_min: float = -1e3 clamp_max: float = 1e3 remove_isnan_in_attention_mask: bool = False + untie_constant_batched_matmul: bool = False verbose: bool = False @@ -106,6 +110,7 @@ def build_config(cls, **kwargs: Any) -> SurgeryPipeConfig: clamp_min=kwargs.get("clamp_min", -1e3), clamp_max=kwargs.get("clamp_max", 1e3), remove_isnan_in_attention_mask=kwargs.get("remove_isnan_in_attention_mask", False), + untie_constant_batched_matmul=kwargs.get("untie_constant_batched_matmul", False), verbose=kwargs.get("verbose", False), ) @@ -119,7 +124,11 @@ def should_process(cls, config: SurgeryPipeConfig) -> bool: Returns: True if any surgery operation is enabled """ - return config.clamp_constant_values or config.remove_isnan_in_attention_mask + return ( + config.clamp_constant_values + or config.remove_isnan_in_attention_mask + or config.untie_constant_batched_matmul + ) def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.ModelProto: """Apply surgery operations to the model. @@ -149,6 +158,9 @@ def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.Mod if config.remove_isnan_in_attention_mask: model_copy = self._remove_isnan_in_attention_mask(model_copy, config.verbose) + if config.untie_constant_batched_matmul: + model_copy = self._untie_constant_batched_matmul(model_copy, config.verbose) + return model_copy def _clamp_constant_values( @@ -319,3 +331,137 @@ def _remove_isnan_in_attention_mask( ) return model + + # ----------------------------------------------------------------- + # untie-constant-batched-matmul + # ----------------------------------------------------------------- + + def _untie_constant_batched_matmul( + self, + model: onnx.ModelProto, + verbose: bool = False, + ) -> onnx.ModelProto: + """Make a batched MatMul's constant operand runtime-valued. + + OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched + (rank >= 3) MatMul where an operand is a compile-time constant: the same + gemm with a dynamic operand, and 2D constant gemm, both compile fine. + Transformer disentangled-attention position terms depend only on weights, + so they fold into 3D constants and hit this case. + + Fix: route each such constant operand through ``Add(const, zero)`` where + ``zero = Sub(s, s)`` and ``s = ReduceMin(Cast(first_input -> float))``. + ``s`` is data-dependent, so OpenVINO's constant folder cannot collapse + the Add back into a packed gemm weight, yet ``+ 0`` leaves the values + unchanged and the single batched MatMul is preserved (no perf cost). + """ + from onnx import TensorProto, helper, numpy_helper + + graph = model.graph + initializers = {init.name: init for init in graph.initializer} + + # Collect (matmul_node, operand_index) where the operand is a constant + # initializer of rank >= 3. Skip MatMuls whose operands are all constant + # (those fold away entirely and never reach gemm impl selection). + targets: list[tuple[onnx.NodeProto, int]] = [] + for node in graph.node: + if node.op_type != "MatMul" or len(node.input) != 2: + continue + const_idx = [i for i, name in enumerate(node.input) if name in initializers] + if len(const_idx) != 1: + continue + idx = const_idx[0] + if len(initializers[node.input[idx]].dims) >= 3: + targets.append((node, idx)) + + if not targets: + return model + + if not graph.input: + logger.warning( + "SurgeryPipe: untie-constant-batched-matmul: no graph input to " + "derive a runtime value from; skipping %d MatMul(s)", + len(targets), + ) + return model + + prefix = "winml_ovgpu_untie" + first_input = graph.input[0].name + new_nodes: list[onnx.NodeProto] = [] + new_inits: list[onnx.TensorProto] = [] + + # Build a shape-[1] runtime zero from input *data* (not shape — input + # shapes are static and would be folded). Only ubiquitous ops are used + # so the static analyzer handles them: a single input element is sliced + # out and subtracted from itself. A [1] tensor broadcasts against any + # constant operand, regardless of its rank. + xf = f"{prefix}_xf" + new_nodes.append( + helper.make_node("Cast", [first_input], [xf], to=TensorProto.FLOAT, name=xf) + ) + flat = f"{prefix}_flat" + new_inits.append(numpy_helper.from_array(np.array([-1], dtype=np.int64), f"{prefix}_m1")) + new_nodes.append(helper.make_node("Reshape", [xf, f"{prefix}_m1"], [flat], name=flat)) + elem = f"{prefix}_elem" + new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), f"{prefix}_0")) + new_inits.append(numpy_helper.from_array(np.array([1], dtype=np.int64), f"{prefix}_1")) + new_nodes.append( + helper.make_node( + "Slice", [flat, f"{prefix}_0", f"{prefix}_1", f"{prefix}_0"], [elem], name=elem + ) + ) + # zero = elem - elem == 0.0 (data-dependent, so it is not folded away). + zero_f32 = f"{prefix}_zero_f32" + new_nodes.append(helper.make_node("Sub", [elem, elem], [zero_f32], name=zero_f32)) + + # A zero must match each operand's dtype (ONNX has no implicit promotion). + zero_by_dtype: dict[int, str] = {int(TensorProto.FLOAT): zero_f32} + + def zero_for(dtype: int) -> str: + name = zero_by_dtype.get(dtype) + if name is None: + name = f"{prefix}_zero_{dtype}" + new_nodes.append( + helper.make_node("Cast", [zero_f32], [name], to=dtype, name=name) + ) + zero_by_dtype[dtype] = name + return name + + untied = 0 + for node, idx in targets: + const_name = node.input[idx] + dtype = initializers[const_name].data_type + if dtype not in (TensorProto.FLOAT, TensorProto.FLOAT16, TensorProto.DOUBLE): + continue + dyn = f"{prefix}_{node.name}_in{idx}".replace("/", "_") + new_nodes.append( + helper.make_node("Add", [const_name, zero_for(dtype)], [dyn], name=dyn) + ) + node.input[idx] = dyn + untied += 1 + if verbose: + logger.info( + " untie-constant-batched-matmul: %s input[%d] %s -> %s", + node.name, + idx, + const_name, + dyn, + ) + + if untied == 0: + return model + + graph.initializer.extend(new_inits) + # Prepend new nodes: their inputs are only graph inputs / initializers, + # so placing them first keeps the graph topologically sorted. + existing = list(graph.node) + del graph.node[:] + graph.node.extend(new_nodes + existing) + + logger.info( + "SurgeryPipe: untie-constant-batched-matmul: untied %d batched " + "MatMul constant operand(s)", + untied, + ) + + return model diff --git a/tests/unit/analyze/core/model_validators/test_validators.py b/tests/unit/analyze/core/model_validators/test_validators.py index ddc83c7be..787224aa6 100644 --- a/tests/unit/analyze/core/model_validators/test_validators.py +++ b/tests/unit/analyze/core/model_validators/test_validators.py @@ -379,3 +379,63 @@ def test_unknown_validator_logs_warning(self, caplog): assert len(manager.validators) == 1 assert manager.validators[0].validator_name == "ConstantFoldingValidator" assert "Unknown validator" in caplog.text + + +def _make_batched_const_matmul_proto(const_rank: int = 3): + """Model: data [2,3,4] @ W(const) [2,4,5] -> out [2,3,5].""" + import numpy as np + from onnx import numpy_helper + + w_shape = [2, 4, 5] if const_rank == 3 else [4, 5] + w = numpy_helper.from_array(np.zeros(w_shape, dtype=np.float32), "W") + matmul = helper.make_node("MatMul", ["data", "W"], ["out"], name="batched_matmul") + graph = helper.make_graph( + [matmul], + "batched_const_matmul", + [helper.make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4])], + [helper.make_tensor_value_info("out", TensorProto.FLOAT, [2, 3, 5])], + initializer=[w], + ) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +class TestBatchedConstMatMulValidator: + """OpenVINO-GPU batched constant MatMul detector.""" + + def _validate(self, proto, ep, device): + from winml.modelkit.analyze.core.model_validators import BatchedConstMatMulValidator + + model = create_onnx_model_wrapper(proto) + return BatchedConstMatMulValidator(model, ep=ep, device=device).validate() + + def test_detects_for_openvino_gpu(self): + """Emits a GraphOptimization action enabling the surgery for OV GPU.""" + info = self._validate(_make_batched_const_matmul_proto(), "openvino", "GPU") + assert info is not None + assert info.pattern_id == "MODEL/BatchedConstantMatMul" + items = info.actions[0].action_items + assert items[0].type == "GraphOptimization" + assert items[0].optimization_options == {"untie-constant-batched-matmul": True} + + def test_skipped_for_openvino_npu(self): + """Device-gated: NPU is unaffected.""" + assert self._validate(_make_batched_const_matmul_proto(), "openvino", "NPU") is None + + def test_skipped_for_non_intel_gpu(self): + """IHV-gated: a non-Intel GPU EP is unaffected.""" + info = self._validate(_make_batched_const_matmul_proto(), "DmlExecutionProvider", "GPU") + assert info is None + + def test_skipped_for_two_dim_constant(self): + """Rank-2 constant gemm compiles on OV GPU; not flagged.""" + info = self._validate(_make_batched_const_matmul_proto(const_rank=2), "openvino", "GPU") + assert info is None + + def test_manager_wires_validator_for_openvino_gpu(self): + """Manager constructs the validator and surfaces the action for OV GPU.""" + model = create_onnx_model_wrapper(_make_batched_const_matmul_proto()) + manager = ModelValidatorManager(model, device="GPU", ep="openvino") + names = [v.validator_name for v in manager.validators] + assert "BatchedConstMatMulValidator" in names + infos = manager.run_all_validators() + assert any(i.pattern_id == "MODEL/BatchedConstantMatMul" for i in infos) diff --git a/tests/unit/optim/pipes/test_pipe_surgery.py b/tests/unit/optim/pipes/test_pipe_surgery.py index 6ab7e0034..fb98763d3 100644 --- a/tests/unit/optim/pipes/test_pipe_surgery.py +++ b/tests/unit/optim/pipes/test_pipe_surgery.py @@ -342,3 +342,120 @@ def test_surgery_pipe_runs_last(self) -> None: # SurgeryPipe should be last in the list assert PIPES[-1].name == "surgery" + + +# ============================================================================= +# UNTIE-CONSTANT-BATCHED-MATMUL TESTS +# ============================================================================= + + +def _make_batched_const_matmul_model( + *, + const_rank: int = 3, + const_on_rhs: bool = True, +) -> onnx.ModelProto: + """Build a model with a batched MatMul that has one constant operand. + + data [2,3,4] @ W(const) [2,4,5] -> out [2,3,5] (const on rhs), or the + transposed arrangement when ``const_on_rhs`` is False. + """ + from onnx import TensorProto, helper + + rng = np.random.RandomState(0) + if const_on_rhs: + data_shape, w_shape, out_shape = [2, 3, 4], [2, 4, 5], [2, 3, 5] + mm_inputs = ["data", "W"] + else: + data_shape, w_shape, out_shape = [2, 4, 5], [2, 3, 4], [2, 3, 5] + mm_inputs = ["W", "data"] + + if const_rank == 2: + w_shape = w_shape[1:] + + w = numpy_helper.from_array(rng.randn(*w_shape).astype(np.float32), "W") + matmul = helper.make_node("MatMul", mm_inputs, ["out"], name="batched_matmul") + graph = helper.make_graph( + [matmul], + "test_batched_const_matmul", + [helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape)], + [helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)], + initializer=[w], + ) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +class TestUntieConstantBatchedMatmulCapability: + """Capability/config plumbing for untie-constant-batched-matmul.""" + + def test_capability_exists(self) -> None: + """Capability is registered with a None ort_name (custom impl).""" + assert "untie-constant-batched-matmul" in SURGERY_CAPABILITIES + assert SURGERY_CAPABILITIES["untie-constant-batched-matmul"].ort_name is None + + def test_build_config_enable_via_kwarg(self) -> None: + """Flag can be toggled through build_config.""" + config = SurgeryPipe.build_config(untie_constant_batched_matmul=True) + assert config.untie_constant_batched_matmul is True + + def test_should_process_true_when_enabled(self) -> None: + """should_process is True when only this surgery is enabled.""" + config = SurgeryPipeConfig(untie_constant_batched_matmul=True) + assert SurgeryPipe.should_process(config) is True + + +class TestUntieConstantBatchedMatmulProcess: + """Graph transform behavior.""" + + def test_constant_operand_becomes_runtime_valued(self) -> None: + """The MatMul no longer consumes the initializer directly.""" + model = _make_batched_const_matmul_model() + result = SurgeryPipe().process( + model, SurgeryPipeConfig(untie_constant_batched_matmul=True) + ) + + matmul = next(n for n in result.graph.node if n.op_type == "MatMul") + initializer_names = {init.name for init in result.graph.initializer} + # No MatMul input is a direct initializer anymore. + assert not (set(matmul.input) & initializer_names) + # An Add node now produces the (formerly constant) operand. + add_nodes = [n for n in result.graph.node if n.op_type == "Add"] + assert len(add_nodes) == 1 + assert add_nodes[0].output[0] in matmul.input + # Graph remains structurally valid. + onnx.checker.check_model(result) + + def test_numerics_unchanged(self) -> None: + """+0 tie leaves outputs bit-for-bit identical on ORT CPU.""" + import onnxruntime as ort + + model = _make_batched_const_matmul_model() + transformed = SurgeryPipe().process( + model, SurgeryPipeConfig(untie_constant_batched_matmul=True) + ) + + rng = np.random.RandomState(7) + feed = {"data": rng.randn(2, 3, 4).astype(np.float32)} + + ref = ort.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + got = ort.InferenceSession( + transformed.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + np.testing.assert_array_equal(ref, got) + + def test_two_dim_constant_is_left_untouched(self) -> None: + """Rank-2 constant gemm compiles on OV GPU, so it must not be rewritten.""" + model = _make_batched_const_matmul_model(const_rank=2) + result = SurgeryPipe().process( + model, SurgeryPipeConfig(untie_constant_batched_matmul=True) + ) + assert not any(n.op_type == "Add" for n in result.graph.node) + + def test_constant_on_lhs_is_handled(self) -> None: + """A constant rank-3 operand on the LHS is untied too.""" + model = _make_batched_const_matmul_model(const_on_rhs=False) + result = SurgeryPipe().process( + model, SurgeryPipeConfig(untie_constant_batched_matmul=True) + ) + assert any(n.op_type == "Add" for n in result.graph.node) From 5d04ce0f2e0c0c57a508e8e72a606f8245c49ec8 Mon Sep 17 00:00:00 2001 From: hualxie Date: Mon, 8 Jun 2026 16:46:03 +0800 Subject: [PATCH 2/8] update --- .../core/model_validators/batched_const_matmul_validator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py index 9c9200546..188a4bcac 100644 --- a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -113,13 +113,15 @@ def validate(self) -> Information | None: "gemm implementation." ), ) + # https://github.com/openvinotoolkit/openvino/issues/36272 explanation = ( f"Model contains {len(offenders)} batched MatMul(s) with a constant " f"operand (examples: {examples}). OpenVINO GPU's oneDNN gemm cannot " f"select an implementation for a batched MatMul with a constant " f"operand, causing a '[GPU] Failed to select implementation ... gemm' " f"compile failure. The untie-constant-batched-matmul surgery makes " - f"the operand runtime-valued without changing numerics." + f"the operand runtime-valued without changing numerics. " + f"It is fixed in openvino==2026.2.0, so no need to apply the surgery if using that version or later." ) return Information( explanation=explanation, From d9f5ca760195107a041fdc2bcd79f72a5ce756dc Mon Sep 17 00:00:00 2001 From: hualxie Date: Mon, 8 Jun 2026 17:04:37 +0800 Subject: [PATCH 3/8] use EPName --- .../core/model_validators/batched_const_matmul_validator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py index 188a4bcac..62825dfc2 100644 --- a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from ...models.onnx_model import ONNXModel from ...models.runtime_checks import PatternRuntime + from ....utils.constants import EPName logger = logging.getLogger(__name__) @@ -45,7 +46,7 @@ def __init__( self, model: ONNXModel, op_runtime_results: list[PatternRuntime] | None = None, - ep: str | None = None, + ep: EPName | None = None, device: str | None = None, ) -> None: super().__init__(model, op_runtime_results=op_runtime_results) @@ -121,7 +122,8 @@ def validate(self) -> Information | None: f"operand, causing a '[GPU] Failed to select implementation ... gemm' " f"compile failure. The untie-constant-batched-matmul surgery makes " f"the operand runtime-valued without changing numerics. " - f"It is fixed in openvino==2026.2.0, so no need to apply the surgery if using that version or later." + f"It is fixed in openvino==2026.2.0, so no need to apply the surgery " + f"if using that version or later." ) return Information( explanation=explanation, From 9dbfb6fad609c1ac9a80f10f3d1a9d586c66f7ef Mon Sep 17 00:00:00 2001 From: hualxie Date: Tue, 9 Jun 2026 09:42:13 +0800 Subject: [PATCH 4/8] sort --- .../core/model_validators/batched_const_matmul_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py index 62825dfc2..65ea953cd 100644 --- a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -28,9 +28,9 @@ if TYPE_CHECKING: + from ....utils.constants import EPName from ...models.onnx_model import ONNXModel from ...models.runtime_checks import PatternRuntime - from ....utils.constants import EPName logger = logging.getLogger(__name__) From 79bcc3f5ec6473145cd249ab11b4430975888fef Mon Sep 17 00:00:00 2001 From: hualxie Date: Wed, 10 Jun 2026 10:39:27 +0800 Subject: [PATCH 5/8] fix(optim): address review comments for untie batched constant MatMul - Use loop index for the untied operand name instead of node.name, which is optional in ONNX and can be blank/duplicated (would collide and yield an invalid graph). - Update docstring to describe the actual Cast/Reshape/Slice/Sub construction (was stale ReduceMin wording) and document the non-empty-first-input assumption. - Split the Slice starts/ends/axes initializers into distinct named tensors. - Note the Constant-node detection gap in the validator (shared with surgery). - Add a test for two unnamed batched-const MatMuls (name-collision regression). --- .../batched_const_matmul_validator.py | 6 ++ src/winml/modelkit/optim/pipes/surgery.py | 39 ++++++++----- tests/unit/optim/pipes/test_pipe_surgery.py | 56 ++++++++++++++++--- 3 files changed, 78 insertions(+), 23 deletions(-) diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py index 65ea953cd..d13b7f5ce 100644 --- a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -81,6 +81,12 @@ def validate(self) -> Information | None: if not self._is_enabled(): return None + # Known gap: constants expressed as `Constant` op nodes (rather than + # graph initializers) are not detected here. The `untie-constant-batched + # -matmul` surgery in surgery.py has the same limitation, so detection + # and surgery stay consistent. Most exporters and ORT preprocessing emit + # weights as initializers, so this covers the disentangled-attention case + # in practice; `Constant`-node weights would need handling on both sides. initializers = {init.name for init in self.graph.initializer} rank_by_init = {init.name: len(init.dims) for init in self.graph.initializer} diff --git a/src/winml/modelkit/optim/pipes/surgery.py b/src/winml/modelkit/optim/pipes/surgery.py index fed9b6ce6..8250a19ad 100644 --- a/src/winml/modelkit/optim/pipes/surgery.py +++ b/src/winml/modelkit/optim/pipes/surgery.py @@ -350,10 +350,17 @@ def _untie_constant_batched_matmul( so they fold into 3D constants and hit this case. Fix: route each such constant operand through ``Add(const, zero)`` where - ``zero = Sub(s, s)`` and ``s = ReduceMin(Cast(first_input -> float))``. - ``s`` is data-dependent, so OpenVINO's constant folder cannot collapse + ``zero`` is a runtime ``[1]`` tensor built from the first graph input's + *data*: ``Cast(first_input -> float) -> Reshape([-1]) -> Slice([0:1])`` + yields a single element ``elem``, and ``zero = Sub(elem, elem) == 0.0``. + ``zero`` is data-dependent, so OpenVINO's constant folder cannot collapse the Add back into a packed gemm weight, yet ``+ 0`` leaves the values unchanged and the single batched MatMul is preserved (no perf cost). + + Assumption: the first graph input has at least one element at runtime. + The ``Slice([0:1])`` is out of bounds for a zero-sized input (e.g. a + dynamic batch dimension fed an empty batch), which would raise at + inference time rather than produce a zero. """ from onnx import TensorProto, helper, numpy_helper @@ -403,13 +410,16 @@ def _untie_constant_batched_matmul( new_inits.append(numpy_helper.from_array(np.array([-1], dtype=np.int64), f"{prefix}_m1")) new_nodes.append(helper.make_node("Reshape", [xf, f"{prefix}_m1"], [flat], name=flat)) elem = f"{prefix}_elem" - new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), f"{prefix}_0")) - new_inits.append(numpy_helper.from_array(np.array([1], dtype=np.int64), f"{prefix}_1")) - new_nodes.append( - helper.make_node( - "Slice", [flat, f"{prefix}_0", f"{prefix}_1", f"{prefix}_0"], [elem], name=elem - ) - ) + # Slice(flat, starts=[0], ends=[1], axes=[0]) -> the first element. + # starts and axes are distinct tensors even though both hold [0], so a + # future edit to one role cannot silently corrupt the other. + starts = f"{prefix}_slice_starts" + ends = f"{prefix}_slice_ends" + axis = f"{prefix}_slice_axis" + new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), starts)) + new_inits.append(numpy_helper.from_array(np.array([1], dtype=np.int64), ends)) + new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), axis)) + new_nodes.append(helper.make_node("Slice", [flat, starts, ends, axis], [elem], name=elem)) # zero = elem - elem == 0.0 (data-dependent, so it is not folded away). zero_f32 = f"{prefix}_zero_f32" new_nodes.append(helper.make_node("Sub", [elem, elem], [zero_f32], name=zero_f32)) @@ -421,19 +431,20 @@ def zero_for(dtype: int) -> str: name = zero_by_dtype.get(dtype) if name is None: name = f"{prefix}_zero_{dtype}" - new_nodes.append( - helper.make_node("Cast", [zero_f32], [name], to=dtype, name=name) - ) + new_nodes.append(helper.make_node("Cast", [zero_f32], [name], to=dtype, name=name)) zero_by_dtype[dtype] = name return name untied = 0 - for node, idx in targets: + # Index the loop rather than node.name: node names are optional in ONNX + # and exporters routinely leave them blank or duplicated, so deriving + # `dyn` from the name would collide and produce an invalid graph. + for untie_idx, (node, idx) in enumerate(targets): const_name = node.input[idx] dtype = initializers[const_name].data_type if dtype not in (TensorProto.FLOAT, TensorProto.FLOAT16, TensorProto.DOUBLE): continue - dyn = f"{prefix}_{node.name}_in{idx}".replace("/", "_") + dyn = f"{prefix}_untied{untie_idx}_in{idx}" new_nodes.append( helper.make_node("Add", [const_name, zero_for(dtype)], [dyn], name=dyn) ) diff --git a/tests/unit/optim/pipes/test_pipe_surgery.py b/tests/unit/optim/pipes/test_pipe_surgery.py index fb98763d3..16df08d79 100644 --- a/tests/unit/optim/pipes/test_pipe_surgery.py +++ b/tests/unit/optim/pipes/test_pipe_surgery.py @@ -409,9 +409,7 @@ class TestUntieConstantBatchedMatmulProcess: def test_constant_operand_becomes_runtime_valued(self) -> None: """The MatMul no longer consumes the initializer directly.""" model = _make_batched_const_matmul_model() - result = SurgeryPipe().process( - model, SurgeryPipeConfig(untie_constant_batched_matmul=True) - ) + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) matmul = next(n for n in result.graph.node if n.op_type == "MatMul") initializer_names = {init.name for init in result.graph.initializer} @@ -447,15 +445,55 @@ def test_numerics_unchanged(self) -> None: def test_two_dim_constant_is_left_untouched(self) -> None: """Rank-2 constant gemm compiles on OV GPU, so it must not be rewritten.""" model = _make_batched_const_matmul_model(const_rank=2) - result = SurgeryPipe().process( - model, SurgeryPipeConfig(untie_constant_batched_matmul=True) - ) + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) assert not any(n.op_type == "Add" for n in result.graph.node) def test_constant_on_lhs_is_handled(self) -> None: """A constant rank-3 operand on the LHS is untied too.""" model = _make_batched_const_matmul_model(const_on_rhs=False) - result = SurgeryPipe().process( - model, SurgeryPipeConfig(untie_constant_batched_matmul=True) - ) + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) assert any(n.op_type == "Add" for n in result.graph.node) + + def test_duplicate_node_names_do_not_collide(self) -> None: + """Two target MatMuls with empty names produce a valid graph. + + Node names are optional in ONNX; exporters routinely leave them blank. + The generated dynamic-operand names must be unique regardless, or the + transformed graph would have colliding tensor names and fail validation. + """ + from onnx import TensorProto, helper + + rng = np.random.RandomState(0) + w1 = numpy_helper.from_array(rng.randn(2, 4, 5).astype(np.float32), "W1") + w2 = numpy_helper.from_array(rng.randn(2, 5, 6).astype(np.float32), "W2") + # Both MatMuls deliberately left unnamed (name=""). + mm1 = helper.make_node("MatMul", ["data", "W1"], ["mid"], name="") + mm2 = helper.make_node("MatMul", ["mid", "W2"], ["out"], name="") + graph = helper.make_graph( + [mm1, mm2], + "test_dup_names", + [helper.make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4])], + [helper.make_tensor_value_info("out", TensorProto.FLOAT, [2, 3, 6])], + initializer=[w1, w2], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) + + # Both constants are untied and the graph stays structurally valid. + add_nodes = [n for n in result.graph.node if n.op_type == "Add"] + assert len(add_nodes) == 2 + assert len({n.output[0] for n in add_nodes}) == 2 + onnx.checker.check_model(result) + + # Numerics are unchanged versus the original model. + import onnxruntime as ort + + feed = {"data": np.random.RandomState(7).randn(2, 3, 4).astype(np.float32)} + ref = ort.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + got = ort.InferenceSession( + result.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + np.testing.assert_array_equal(ref, got) From c2897acc28ad0278801a38f01c1fe61340a1b43b Mon Sep 17 00:00:00 2001 From: hualxie Date: Wed, 10 Jun 2026 11:43:55 +0800 Subject: [PATCH 6/8] feat(pattern): untie batched constant MatMul via rewrite instead of surgery Re-implements the OpenVINO-GPU batched-const-MatMul workaround as a pattern rewrite (match -> replace) instead of a SurgeryPipe transform, so node generation goes through the existing PatternMatcher/PatternRewriter framework. - New patterns (pattern/batched_const_matmul_patterns.py): - BatchedConstMatMulPattern (source): matches a bare MatMul with exactly one rank->=3 constant operand. Overrides check_skeleton_result to skip the base symbolic-dim rejection, since the dynamic activation operand legitimately carries symbolic dims. - UntiedBatchedConstMatMulPattern (target): emits MatMul(dyn, Add(const, zero)) where zero is a [1] runtime tensor derived from the MatMul's own dynamic operand (Reshape([-1]) -> Slice([0:1]) -> Sub). Deriving zero from the dynamic operand keeps the replacement local (the rewriter only wires to the matched subgraph's boundary tensors) and removes the surgery's dependency on graph.input[0] and its empty-first-input edge case. Operands share a dtype, so no Cast is needed. - Wires the rule into pattern/rules/default.json as capability "batchedconstmatmul-untied" (enabled:false so it stays out of general matching; applied only when the capability is turned on). - Repoints BatchedConstMatMulValidator to emit the rewrite capability flag; the validator still supplies the Intel-IHV + GPU gating and the autoconf trigger. - PatternRewriter: tolerate symbolic/dynamic operand dims when building the dummy input array for a target's get_onnx_model (previously crashed). The SurgeryPipe untie implementation is left in place but is no longer driven by autoconf; it can be removed in a follow-up. --- .../batched_const_matmul_validator.py | 21 +- src/winml/modelkit/pattern/__init__.py | 6 + src/winml/modelkit/pattern/base.py | 9 +- .../pattern/batched_const_matmul_patterns.py | 322 ++++++++++++++++++ src/winml/modelkit/pattern/rules/default.json | 18 + .../core/model_validators/test_validators.py | 4 +- .../test_pipe_rewrite_batched_const_matmul.py | 144 ++++++++ 7 files changed, 512 insertions(+), 12 deletions(-) create mode 100644 src/winml/modelkit/pattern/batched_const_matmul_patterns.py create mode 100644 tests/unit/optim/pipes/test_pipe_rewrite_batched_const_matmul.py diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py index d13b7f5ce..9bcc28a61 100644 --- a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -13,8 +13,11 @@ [GPU] Failed to select implementation for ... type: gemm This validator detects that structural pattern and recommends the -``untie-constant-batched-matmul`` surgery, which makes the constant operand -runtime-valued so gemm implementation selection succeeds. +``batchedconstmatmul-untied`` rewrite, which routes the constant operand through +``Add(const, runtime-zero)`` so it becomes runtime-valued and gemm +implementation selection succeeds. The rewrite (source/target patterns) lives in +``modelkit.pattern.batched_const_matmul_patterns``; this validator only supplies +the EP/device gating and the autoconf trigger that enables it. """ from __future__ import annotations @@ -34,9 +37,11 @@ logger = logging.getLogger(__name__) -# Surgery capability enabled when the pattern is detected (kebab-case to match -# the capability registry / autoconf normalization). -_SURGERY_FLAG = "untie-constant-batched-matmul" +# Rewrite capability enabled when the pattern is detected. Kebab-case to match +# the capability registry / autoconf normalization; derived from the JSON rule +# (source flag "batchedconstmatmul" + target flag "untied") in +# pattern/rules/default.json. +_REWRITE_FLAG = "batchedconstmatmul-untied" class BatchedConstMatMulValidator(ModelValidator): @@ -112,10 +117,10 @@ def validate(self) -> Information | None: level=ActionLevel.REQUIRED, status=None, action_items=[ - ActionItem(type="GraphOptimization", optimization_options={_SURGERY_FLAG: True}) + ActionItem(type="GraphOptimization", optimization_options={_REWRITE_FLAG: True}) ], details=( - "Enable untie-constant-batched-matmul surgery so the constant " + "Enable the batchedconstmatmul-untied rewrite so the constant " "operand becomes runtime-valued and OpenVINO GPU can select a " "gemm implementation." ), @@ -126,7 +131,7 @@ def validate(self) -> Information | None: f"operand (examples: {examples}). OpenVINO GPU's oneDNN gemm cannot " f"select an implementation for a batched MatMul with a constant " f"operand, causing a '[GPU] Failed to select implementation ... gemm' " - f"compile failure. The untie-constant-batched-matmul surgery makes " + f"compile failure. The batchedconstmatmul-untied rewrite makes " f"the operand runtime-valued without changing numerics. " f"It is fixed in openvino==2026.2.0, so no need to apply the surgery " f"if using that version or later." diff --git a/src/winml/modelkit/pattern/__init__.py b/src/winml/modelkit/pattern/__init__.py index 38f0112c4..b6138887e 100644 --- a/src/winml/modelkit/pattern/__init__.py +++ b/src/winml/modelkit/pattern/__init__.py @@ -30,6 +30,10 @@ opschema_to_pattern_schema, register_pattern_input_generator, ) +from .batched_const_matmul_patterns import ( + BatchedConstMatMulPattern, + UntiedBatchedConstMatMulPattern, +) from .conv2d_inplace_linear_patterns import ( Conv2DInplaceLinear2DPattern, Conv2DInplaceLinear2DPatternInputGenerator, @@ -85,6 +89,7 @@ __all__ = [ "MATMUL_ADD_SCHEMA", + "BatchedConstMatMulPattern", "Conv2DInplaceLinear2DPattern", "Conv2DInplaceLinear2DPatternInputGenerator", "Conv2DInplaceLinear3DPattern", @@ -139,6 +144,7 @@ "TransposedSingleLayerNormalizationPatternInputGenerator", "TransposedSingleRMSNormalizationPattern", "TransposedSingleRMSNormalizationPatternInputGenerator", + "UntiedBatchedConstMatMulPattern", "get_pattern_input_generator", "get_registered_pattern_input_generators", "make_single_op_pattern", diff --git a/src/winml/modelkit/pattern/base.py b/src/winml/modelkit/pattern/base.py index 245e5d51b..611e13d41 100644 --- a/src/winml/modelkit/pattern/base.py +++ b/src/winml/modelkit/pattern/base.py @@ -2035,14 +2035,19 @@ def _allocate_graph_node_key(node: Any) -> str: is_constant_map[input_name] = info.is_constant if info.value is not None: inputs[input_name] = info.value - elif info.shape is not None: + elif info.shape is not None and all( + isinstance(dim, int) for dim in info.shape + ): # Create a dummy array with the shape for internal constant computation dtype_str = match_result.type_param_to_type[input_param.type_str] np_dtype = SupportedONNXType.from_onnx_type(dtype_str).np_type # For shape computation, create a zero array inputs[input_name] = np.zeros(info.shape, dtype=np_dtype) else: - # No shape info, skip this input + # No value, or a symbolic/dynamic shape that cannot be + # materialized into a dummy array: skip this input. A + # target whose get_onnx_model needs concrete operand + # shapes would already have been rejected at match time. is_constant_map[input_name] = False else: is_constant_map[input_name] = False diff --git a/src/winml/modelkit/pattern/batched_const_matmul_patterns.py b/src/winml/modelkit/pattern/batched_const_matmul_patterns.py new file mode 100644 index 000000000..c122b4a58 --- /dev/null +++ b/src/winml/modelkit/pattern/batched_const_matmul_patterns.py @@ -0,0 +1,322 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Pattern + rewrite for batched MatMul with a constant operand on OpenVINO GPU. + +OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched +(rank >= 3) MatMul where an operand is a compile-time constant. The identical +gemm with a dynamic operand, and 2D constant gemm, both compile fine. Models +whose batched MatMul weights fold to constants (e.g. transformer disentangled +attention position terms) therefore fail to compile on OpenVINO GPU with: + + [GPU] Failed to select implementation for ... type: gemm + +The rewrite makes the constant operand runtime-valued without changing numerics: +it routes the constant through ``Add(const, zero)`` where ``zero`` is a ``[1]`` +runtime tensor derived from the MatMul's *own dynamic operand* +(``Reshape([-1]) -> Slice([0:1]) -> Sub(elem, elem) == 0``). Because ``zero`` is +data-dependent, OpenVINO's constant folder cannot collapse the ``Add`` back into +a packed gemm weight, yet ``+ 0`` leaves the values unchanged and the single +batched MatMul is preserved (no per-head decomposition, no perf regression). + +Deriving ``zero`` from the dynamic operand (rather than a graph input) keeps the +replacement *local*: the rewriter only wires the target's nodes to the matched +subgraph's own boundary tensors. The two MatMul operands share a dtype (ONNX +MatMul requires it), so ``zero`` automatically matches the constant's dtype with +no Cast. + +The source pattern matches a bare ``MatMul`` and validates the constant-operand +structure in ``check_skeleton_result``; it deliberately does **not** call the +base implementation, which rejects matches whose non-constant input has +symbolic/dynamic dimensions — exactly the activation operand here. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +from onnx.defs import OpSchema + +from ..onnx import ONNXDomain +from .base import Pattern, PatternSchema, Skeleton +from .match import PatternMatchResult + + +if TYPE_CHECKING: + from onnx import ModelProto + + from .match import SkeletonMatchResult + + +# Minimum operand rank that triggers the OpenVINO GPU gemm impl-selection failure. +_MIN_BATCHED_RANK = 3 + + +# Source and target share this schema so PatternRewriter's schema-equality +# assertion holds (a MatMul: two same-typed operands -> one output). +_BATCHED_CONST_MATMUL_SCHEMA = PatternSchema( + name="BatchedConstMatMulPattern", + doc=( + "Batched (rank >= 3) MatMul with exactly one constant operand.\n" + "Computes Y = MatMul(A, B) where one of A/B is a compile-time constant " + "of rank >= 3 and the other is runtime-valued. Targeted by the untie " + "rewrite for OpenVINO GPU, whose oneDNN gemm cannot select an " + "implementation for this shape." + ), + type_constraints=[ + OpSchema.TypeConstraintParam( + type_param_str="T", + allowed_type_strs=[ + "tensor(float16)", + "tensor(float)", + "tensor(double)", + ], + description="Constrain operands and output to float tensors.", + ) + ], + inputs=[ + OpSchema.FormalParameter( + name="A", + type_str="T", + description="First MatMul operand (constant or runtime).", + param_option=OpSchema.FormalParameterOption.Single, + is_homogeneous=True, + min_arity=1, + ), + OpSchema.FormalParameter( + name="B", + type_str="T", + description="Second MatMul operand (constant or runtime).", + param_option=OpSchema.FormalParameterOption.Single, + is_homogeneous=True, + min_arity=1, + ), + ], + outputs=[ + OpSchema.FormalParameter( + name="Y", + type_str="T", + description="MatMul output.", + param_option=OpSchema.FormalParameterOption.Single, + is_homogeneous=True, + min_arity=1, + ) + ], +) + + +class BatchedConstMatMulPattern(Pattern): + """Source: a MatMul with exactly one rank->=3 constant operand.""" + + def get_skeleton(self) -> Skeleton: + """Return a single-MatMul skeleton with two virtual inputs.""" + return Skeleton( + node_op_types=["MatMul"], + node_domains=[ONNXDomain.AI_ONNX], + edges=[ + (-1, 0, 0, 0), # input A -> MatMul[0] + (-2, 0, 0, 1), # input B -> MatMul[1] + ], + exit_nodes=[0], + n_inputs=2, + ) + + def get_internal_constants_and_attributes( + self, + inputs: dict[str, np.ndarray], + attributes: dict[str, Any], + is_constant_map: dict[str, bool], + domain_versions: dict[ONNXDomain, int], + ) -> tuple[list[tuple[int, int, np.ndarray]], dict[tuple[int, str], Any]]: + """No internal constants or attributes for a bare MatMul.""" + return [], {} + + def get_schema(self) -> PatternSchema: + """Return the shared batched-const-MatMul schema.""" + return _BATCHED_CONST_MATMUL_SCHEMA + + def check_skeleton_result( + self, skeleton_match_result: SkeletonMatchResult + ) -> PatternMatchResult | None: + """Accept only a MatMul with exactly one rank->=3 constant operand. + + This does not call ``super().check_skeleton_result``: the base + implementation rejects matches whose non-constant input carries + symbolic/dynamic dimensions, which is precisely the activation operand + of a transformer MatMul. The structural predicate here depends only on + the constant operand's rank, so symbolic activation dims are fine. + """ + input_infos = self._build_input_infos(skeleton_match_result) + + const_names = [name for name, info in input_infos.items() if info.is_constant] + # Exactly one constant operand: all-constant MatMuls fold away entirely + # and never reach gemm impl selection; zero-constant MatMuls are unaffected. + if len(const_names) != 1: + return None + + const_info = input_infos[const_names[0]] + rank = _operand_rank(const_info) + if rank is None or rank < _MIN_BATCHED_RANK: + return None + + # Build the PatternMatchResult directly (mirrors the tail of the base + # implementation, minus the symbolic-dim rejection). + schema = self.get_schema() + type_param_to_type = self._infer_type_mapping(skeleton_match_result) + schema_input_to_value = { + param.name: skeleton_match_result.inputs[idx] + for idx, param in enumerate(schema.inputs) + if idx < len(skeleton_match_result.inputs) + } + schema_output_to_value = {} + if schema.outputs and skeleton_match_result.output: + schema_output_to_value[schema.outputs[0].name] = skeleton_match_result.output + + return PatternMatchResult( + skeleton_match_result=skeleton_match_result, + schema_input_to_value=schema_input_to_value, + schema_output_to_value=schema_output_to_value, + type_param_to_type=type_param_to_type, + attributes={}, + input_infos=input_infos, + ) + + +class UntiedBatchedConstMatMulPattern(Pattern): + """Target: MatMul with the constant operand routed through ``Add(const, zero)``. + + ``zero`` is a ``[1]`` runtime tensor derived from the dynamic operand, so the + rewrite stays local and OpenVINO's constant folder cannot repack the operand + into a gemm weight. ``get_onnx_model`` is overridden to emit the subgraph, + since which operand is constant is only known at rewrite time. + """ + + def get_skeleton(self) -> Skeleton: + """Return a representative skeleton (unused by the rewriter). + + The replacement is built in :meth:`get_onnx_model`; this skeleton exists + only to satisfy the abstract base and documents the canonical RHS-constant + topology ``Reshape -> Slice -> Sub -> Add -> MatMul``. + """ + return Skeleton( + node_op_types=["Reshape", "Slice", "Sub", "Add", "MatMul"], + node_domains=[ONNXDomain.AI_ONNX] * 5, + edges=[ + (-1, 0, 0, 0), # dynamic A -> Reshape[0] + (0, 0, 1, 0), # Reshape -> Slice[0] + (1, 0, 2, 0), # Slice -> Sub[0] + (1, 0, 2, 1), # Slice -> Sub[1] + (-2, 0, 3, 0), # constant B -> Add[0] + (2, 0, 3, 1), # Sub(zero) -> Add[1] + (-1, 0, 4, 0), # dynamic A -> MatMul[0] + (3, 0, 4, 1), # Add(untied B) -> MatMul[1] + ], + exit_nodes=[4], + n_inputs=2, + ) + + def get_internal_constants_and_attributes( + self, + inputs: dict[str, np.ndarray], + attributes: dict[str, Any], + is_constant_map: dict[str, bool], + domain_versions: dict[ONNXDomain, int], + ) -> tuple[list[tuple[int, int, np.ndarray]], dict[tuple[int, str], Any]]: + """No declarative constants; the subgraph is built in get_onnx_model.""" + return [], {} + + def get_schema(self) -> PatternSchema: + """Return the shared batched-const-MatMul schema.""" + return _BATCHED_CONST_MATMUL_SCHEMA + + def get_onnx_model( + self, + inputs: dict[str, np.ndarray], + attributes: dict[str, Any], + is_constant_map: dict[str, bool], + output_dtypes: list[str], + domain_versions: dict[ONNXDomain, int], + prefix: str = "", + input_names: list[str] | None = None, + output_names: list[str] | None = None, + ) -> ModelProto: + """Emit ``MatMul(dyn, Add(const, zero(dyn)))`` preserving operand order. + + The constant operand (per ``is_constant_map``) is routed through + ``Add(const, zero)``; ``zero`` is a ``[1]`` runtime tensor built from the + dynamic operand. Operand slots are preserved so the MatMul semantics are + unchanged. + """ + from onnx import helper, numpy_helper + + schema = self.get_schema() + if input_names is None: + input_names = [param.name for param in schema.inputs] + if output_names is None: + output_names = [param.name for param in schema.outputs] + + # Identify which operand is the constant (schema order: A=0, B=1). + param_names = [param.name for param in schema.inputs] + const_idx = 0 if is_constant_map.get(param_names[0]) else 1 + dyn_idx = 1 - const_idx + const_name = input_names[const_idx] + dyn_name = input_names[dyn_idx] + out_name = output_names[0] + + nodes = [] + initializers = [] + + def _init(arr: np.ndarray, name: str) -> str: + initializers.append(numpy_helper.from_array(arr, name)) + return name + + neg1 = _init(np.array([-1], dtype=np.int64), f"{prefix}neg1") + starts = _init(np.array([0], dtype=np.int64), f"{prefix}slice_starts") + ends = _init(np.array([1], dtype=np.int64), f"{prefix}slice_ends") + axis = _init(np.array([0], dtype=np.int64), f"{prefix}slice_axis") + + flat = f"{prefix}flat" + elem = f"{prefix}elem" + zero = f"{prefix}zero" + untied = f"{prefix}untied" + + # flat = Reshape(dyn, [-1]); elem = Slice(flat, [0:1]); zero = elem - elem. + # dyn and const share a dtype (ONNX MatMul), so zero needs no Cast. + nodes.append(helper.make_node("Reshape", [dyn_name, neg1], [flat], name=f"{prefix}Reshape")) + nodes.append( + helper.make_node("Slice", [flat, starts, ends, axis], [elem], name=f"{prefix}Slice") + ) + nodes.append(helper.make_node("Sub", [elem, elem], [zero], name=f"{prefix}Sub")) + nodes.append(helper.make_node("Add", [const_name, zero], [untied], name=f"{prefix}Add")) + + # Preserve original operand order in the rebuilt MatMul. + mm_inputs = [untied, dyn_name] if const_idx == 0 else [dyn_name, untied] + nodes.append(helper.make_node("MatMul", mm_inputs, [out_name], name=f"{prefix}MatMul")) + + opset_imports = [ + helper.make_opsetid(domain.schema_domain, version) + for domain, version in domain_versions.items() + ] + graph = helper.make_graph( + nodes=nodes, + name=f"{prefix}untied_batched_const_matmul", + inputs=[], + outputs=[], + initializer=initializers, + ) + model = helper.make_model( + graph, producer_name="winmlcli-pattern-generator", opset_imports=opset_imports + ) + model.ir_version = 11 + return model + + +def _operand_rank(info: Any) -> int | None: + """Return an operand's rank from its InputInfo (shape, else constant value).""" + if info.shape is not None: + return len(info.shape) + if info.value is not None: + return int(info.value.ndim) + return None diff --git a/src/winml/modelkit/pattern/rules/default.json b/src/winml/modelkit/pattern/rules/default.json index bb9794e2e..8dddf6a70 100644 --- a/src/winml/modelkit/pattern/rules/default.json +++ b/src/winml/modelkit/pattern/rules/default.json @@ -180,6 +180,24 @@ "reason": "Merged axes reduce Transpose dimensionality for better hardware compatibility" } ] + }, + { + "pattern_id": "SUBGRAPH/BatchedConstMatMulPattern", + "pattern_class": "BatchedConstMatMulPattern", + "module": "winml.modelkit.pattern.batched_const_matmul_patterns", + "enabled": false, + "flag_name": "batchedconstmatmul", + "description": "Batched (rank >= 3) MatMul with one constant operand (OpenVINO GPU gemm impl-selection workaround)", + "alternatives": [ + { + "pattern_to_id": "SUBGRAPH/UntiedBatchedConstMatMulPattern", + "pattern_class": "UntiedBatchedConstMatMulPattern", + "module": "winml.modelkit.pattern.batched_const_matmul_patterns", + "priority": 1, + "flag_name": "untied", + "reason": "Route the constant operand through Add(const, runtime-zero) so OpenVINO GPU can select a gemm implementation" + } + ] } ] } diff --git a/tests/unit/analyze/core/model_validators/test_validators.py b/tests/unit/analyze/core/model_validators/test_validators.py index 787224aa6..437e166a2 100644 --- a/tests/unit/analyze/core/model_validators/test_validators.py +++ b/tests/unit/analyze/core/model_validators/test_validators.py @@ -409,13 +409,13 @@ def _validate(self, proto, ep, device): return BatchedConstMatMulValidator(model, ep=ep, device=device).validate() def test_detects_for_openvino_gpu(self): - """Emits a GraphOptimization action enabling the surgery for OV GPU.""" + """Emits a GraphOptimization action enabling the untie rewrite for OV GPU.""" info = self._validate(_make_batched_const_matmul_proto(), "openvino", "GPU") assert info is not None assert info.pattern_id == "MODEL/BatchedConstantMatMul" items = info.actions[0].action_items assert items[0].type == "GraphOptimization" - assert items[0].optimization_options == {"untie-constant-batched-matmul": True} + assert items[0].optimization_options == {"batchedconstmatmul-untied": True} def test_skipped_for_openvino_npu(self): """Device-gated: NPU is unaffected.""" diff --git a/tests/unit/optim/pipes/test_pipe_rewrite_batched_const_matmul.py b/tests/unit/optim/pipes/test_pipe_rewrite_batched_const_matmul.py new file mode 100644 index 000000000..86b90a1c0 --- /dev/null +++ b/tests/unit/optim/pipes/test_pipe_rewrite_batched_const_matmul.py @@ -0,0 +1,144 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""RewritePipe tests for the batchedconstmatmul-untied rewrite. + +Exercises the source/target patterns in +``winml.modelkit.pattern.batched_const_matmul_patterns`` end-to-end through +``RewritePipe``: a batched (rank >= 3) MatMul with one constant operand is +rewritten so the constant flows through ``Add(const, runtime-zero)``, leaving +numerics unchanged while making the operand runtime-valued. + +Cardinal Rules followed: +- #1: No hardcoded model architectures or node names. +- #2: All expected results generated by code at runtime (ORT reference run). +- #3: Tests run and pass under pytest. +""" + +from __future__ import annotations + +import numpy as np +import onnx +import onnxruntime as ort +from onnx import TensorProto, helper, numpy_helper + +from winml.modelkit.optim.pipes.rewrite import RewritePipe + + +_CAP_KWARG = "batchedconstmatmul_untied" + + +def _build_model( + *, + const_on_rhs: bool, + const_rank: int = 3, + dynamic_batch: bool = True, + both_constant: bool = False, +) -> onnx.ModelProto: + """Build a single-MatMul model with one rank-``const_rank`` constant operand. + + The dynamic operand carries a symbolic batch dim when ``dynamic_batch`` is + True, which is the case the base pattern matcher would normally reject. + """ + rng = np.random.RandomState(0) + batch: int | str = "batch" if dynamic_batch else 2 + + if const_on_rhs: + # dyn [batch,3,4] @ const W -> [batch,3,5] + dyn_shape = [batch, 3, 4] + w_arr = rng.randn(1, 4, 5).astype(np.float32) + if const_rank == 2: + w_arr = w_arr[0] # [4,5] + out_shape = [batch, 3, 5] + mm_inputs = ["dyn", "W"] + else: + # const W @ dyn [batch,4,5] -> [batch,3,5] + dyn_shape = [batch, 4, 5] + w_arr = rng.randn(1, 3, 4).astype(np.float32) + if const_rank == 2: + w_arr = w_arr[0] # [3,4] + out_shape = [batch, 3, 5] + mm_inputs = ["W", "dyn"] + + initializers = [numpy_helper.from_array(w_arr, "W")] + graph_inputs = [helper.make_tensor_value_info("dyn", TensorProto.FLOAT, dyn_shape)] + + if both_constant: + # Replace the dynamic operand with a second constant of matching shape. + dyn_concrete = [d if isinstance(d, int) else 2 for d in dyn_shape] + initializers.append( + numpy_helper.from_array(rng.randn(*dyn_concrete).astype(np.float32), "dyn") + ) + graph_inputs = [] + + matmul = helper.make_node("MatMul", mm_inputs, ["out"], name="batched_matmul") + graph = helper.make_graph( + [matmul], + "batched_const_matmul", + graph_inputs, + [helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)], + initializer=initializers, + ) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +def _run(model: onnx.ModelProto, feed: dict[str, np.ndarray]) -> np.ndarray: + sess = ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) + return sess.run(None, feed)[0] + + +def _rewrite(model: onnx.ModelProto) -> onnx.ModelProto: + pipe = RewritePipe() + return pipe.process(model, pipe.build_config(**{_CAP_KWARG: True})) + + +class TestBatchedConstMatMulUntieRewrite: + """End-to-end rewrite behavior via RewritePipe.""" + + def test_constant_operand_becomes_runtime_valued(self) -> None: + """The MatMul no longer consumes the initializer directly; an Add feeds it.""" + model = _build_model(const_on_rhs=True) + result = _rewrite(model) + + matmul = next(n for n in result.graph.node if n.op_type == "MatMul") + initializer_names = {init.name for init in result.graph.initializer} + assert not (set(matmul.input) & initializer_names) + add_nodes = [n for n in result.graph.node if n.op_type == "Add"] + assert len(add_nodes) == 1 + assert add_nodes[0].output[0] in matmul.input + onnx.checker.check_model(result) + + def test_numerics_unchanged_rhs_constant(self) -> None: + """+0 untie leaves outputs identical (constant on RHS, dynamic batch).""" + model = _build_model(const_on_rhs=True) + result = _rewrite(model) + feed = {"dyn": np.random.RandomState(7).randn(2, 3, 4).astype(np.float32)} + np.testing.assert_array_equal(_run(model, feed), _run(result, feed)) + + def test_numerics_unchanged_lhs_constant(self) -> None: + """Constant on the LHS is untied and stays numerically identical.""" + model = _build_model(const_on_rhs=False) + result = _rewrite(model) + assert any(n.op_type == "Add" for n in result.graph.node) + feed = {"dyn": np.random.RandomState(7).randn(2, 4, 5).astype(np.float32)} + np.testing.assert_array_equal(_run(model, feed), _run(result, feed)) + + def test_symbolic_dynamic_dims_are_matched(self) -> None: + """The dynamic operand may carry symbolic dims (the real transformer case).""" + model = _build_model(const_on_rhs=True, dynamic_batch=True) + result = _rewrite(model) + # Match succeeded despite the symbolic batch dim on the activation operand. + assert any(n.op_type == "Add" for n in result.graph.node) + + def test_rank2_constant_is_left_untouched(self) -> None: + """A rank-2 constant gemm compiles on OV GPU, so it must not be rewritten.""" + model = _build_model(const_on_rhs=True, const_rank=2) + result = _rewrite(model) + assert not any(n.op_type == "Add" for n in result.graph.node) + + def test_all_constant_matmul_is_left_untouched(self) -> None: + """A MatMul whose operands are both constant folds away; do not rewrite it.""" + model = _build_model(const_on_rhs=True, both_constant=True) + result = _rewrite(model) + assert not any(n.op_type == "Add" for n in result.graph.node) From 689d6a1ff4a46ad757d98bd06dc8a6f746416036 Mon Sep 17 00:00:00 2001 From: hualxie Date: Wed, 10 Jun 2026 11:56:02 +0800 Subject: [PATCH 7/8] refactor(optim): remove orphaned untie-constant-batched-matmul surgery The batched-const-MatMul workaround is now implemented as the batchedconstmatmul-untied rewrite, so the SurgeryPipe path is dead code. Removes: - UNTIE_CONSTANT_BATCHED_MATMUL capability. - SurgeryPipeConfig.untie_constant_batched_matmul and its build_config / should_process / process wiring. - SurgeryPipe._untie_constant_batched_matmul. - The corresponding SurgeryPipe unit tests. Updates the validator's known-gap comment to reference the rewrite source pattern instead of the removed surgery. --- .../batched_const_matmul_validator.py | 11 +- .../modelkit/optim/capabilities/surgery.py | 17 -- src/winml/modelkit/optim/pipes/surgery.py | 159 +----------------- tests/unit/optim/pipes/test_pipe_surgery.py | 155 ----------------- 4 files changed, 7 insertions(+), 335 deletions(-) diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py index 9bcc28a61..9679bbe1b 100644 --- a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -87,10 +87,11 @@ def validate(self) -> Information | None: return None # Known gap: constants expressed as `Constant` op nodes (rather than - # graph initializers) are not detected here. The `untie-constant-batched - # -matmul` surgery in surgery.py has the same limitation, so detection - # and surgery stay consistent. Most exporters and ORT preprocessing emit - # weights as initializers, so this covers the disentangled-attention case + # graph initializers) are not detected here. The batchedconstmatmul-untied + # rewrite's source pattern matches on initializer operands too, so + # detection and rewrite stay consistent. Most exporters and ORT + # preprocessing emit weights as initializers, so this covers the + # disentangled-attention case # in practice; `Constant`-node weights would need handling on both sides. initializers = {init.name for init in self.graph.initializer} rank_by_init = {init.name: len(init.dims) for init in self.graph.initializer} @@ -133,7 +134,7 @@ def validate(self) -> Information | None: f"operand, causing a '[GPU] Failed to select implementation ... gemm' " f"compile failure. The batchedconstmatmul-untied rewrite makes " f"the operand runtime-valued without changing numerics. " - f"It is fixed in openvino==2026.2.0, so no need to apply the surgery " + f"It is fixed in openvino==2026.2.0, so no need to apply the rewrite " f"if using that version or later." ) return Information( diff --git a/src/winml/modelkit/optim/capabilities/surgery.py b/src/winml/modelkit/optim/capabilities/surgery.py index 0b6ec0768..8b2048f00 100644 --- a/src/winml/modelkit/optim/capabilities/surgery.py +++ b/src/winml/modelkit/optim/capabilities/surgery.py @@ -37,20 +37,3 @@ category=CapabilityCategory.SURGERY, default=False, ) - -# Route a constant operand of a batched (rank >= 3) MatMul through a runtime -# no-op so it is no longer a compile-time constant. OpenVINO GPU's oneDNN gemm -# cannot select an implementation for a batched MatMul with a constant operand -# (e.g. transformer disentangled-attention position terms that fold to 3D -# constants); making the operand runtime-valued lets gemm impl selection -# succeed without changing numerics or splitting the batched op. -UNTIE_CONSTANT_BATCHED_MATMUL = BoolCapability( - name="untie-constant-batched-matmul", - ort_name=None, # Custom implementation, not ORT optimizer - description=( - "Make a batched MatMul's constant operand runtime-valued so OpenVINO " - "GPU can select a gemm implementation" - ), - category=CapabilityCategory.SURGERY, - default=False, -) diff --git a/src/winml/modelkit/optim/pipes/surgery.py b/src/winml/modelkit/optim/pipes/surgery.py index 8250a19ad..fa4fa6bcf 100644 --- a/src/winml/modelkit/optim/pipes/surgery.py +++ b/src/winml/modelkit/optim/pipes/surgery.py @@ -38,7 +38,6 @@ SURGERY_CAPABILITIES: dict[str, Any] = caps_dict( surgery.CLAMP_CONSTANT_VALUES, surgery.REMOVE_ISNAN_IN_ATTENTION_MASK, - surgery.UNTIE_CONSTANT_BATCHED_MATMUL, ) @@ -58,8 +57,6 @@ class SurgeryPipeConfig(PipeConfig): fix_nan_attention_mask: Replace -inf attention mask with finite value and remove Softmax->IsNaN->Where NaN guard patterns mask_value: Replacement value for -inf (default: -1e3) - untie_constant_batched_matmul: Make a batched MatMul's constant operand - runtime-valued so OpenVINO GPU can select a gemm implementation verbose: Enable verbose logging """ @@ -67,7 +64,6 @@ class SurgeryPipeConfig(PipeConfig): clamp_min: float = -1e3 clamp_max: float = 1e3 remove_isnan_in_attention_mask: bool = False - untie_constant_batched_matmul: bool = False verbose: bool = False @@ -110,7 +106,6 @@ def build_config(cls, **kwargs: Any) -> SurgeryPipeConfig: clamp_min=kwargs.get("clamp_min", -1e3), clamp_max=kwargs.get("clamp_max", 1e3), remove_isnan_in_attention_mask=kwargs.get("remove_isnan_in_attention_mask", False), - untie_constant_batched_matmul=kwargs.get("untie_constant_batched_matmul", False), verbose=kwargs.get("verbose", False), ) @@ -124,11 +119,7 @@ def should_process(cls, config: SurgeryPipeConfig) -> bool: Returns: True if any surgery operation is enabled """ - return ( - config.clamp_constant_values - or config.remove_isnan_in_attention_mask - or config.untie_constant_batched_matmul - ) + return config.clamp_constant_values or config.remove_isnan_in_attention_mask def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.ModelProto: """Apply surgery operations to the model. @@ -158,9 +149,6 @@ def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.Mod if config.remove_isnan_in_attention_mask: model_copy = self._remove_isnan_in_attention_mask(model_copy, config.verbose) - if config.untie_constant_batched_matmul: - model_copy = self._untie_constant_batched_matmul(model_copy, config.verbose) - return model_copy def _clamp_constant_values( @@ -331,148 +319,3 @@ def _remove_isnan_in_attention_mask( ) return model - - # ----------------------------------------------------------------- - # untie-constant-batched-matmul - # ----------------------------------------------------------------- - - def _untie_constant_batched_matmul( - self, - model: onnx.ModelProto, - verbose: bool = False, - ) -> onnx.ModelProto: - """Make a batched MatMul's constant operand runtime-valued. - - OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched - (rank >= 3) MatMul where an operand is a compile-time constant: the same - gemm with a dynamic operand, and 2D constant gemm, both compile fine. - Transformer disentangled-attention position terms depend only on weights, - so they fold into 3D constants and hit this case. - - Fix: route each such constant operand through ``Add(const, zero)`` where - ``zero`` is a runtime ``[1]`` tensor built from the first graph input's - *data*: ``Cast(first_input -> float) -> Reshape([-1]) -> Slice([0:1])`` - yields a single element ``elem``, and ``zero = Sub(elem, elem) == 0.0``. - ``zero`` is data-dependent, so OpenVINO's constant folder cannot collapse - the Add back into a packed gemm weight, yet ``+ 0`` leaves the values - unchanged and the single batched MatMul is preserved (no perf cost). - - Assumption: the first graph input has at least one element at runtime. - The ``Slice([0:1])`` is out of bounds for a zero-sized input (e.g. a - dynamic batch dimension fed an empty batch), which would raise at - inference time rather than produce a zero. - """ - from onnx import TensorProto, helper, numpy_helper - - graph = model.graph - initializers = {init.name: init for init in graph.initializer} - - # Collect (matmul_node, operand_index) where the operand is a constant - # initializer of rank >= 3. Skip MatMuls whose operands are all constant - # (those fold away entirely and never reach gemm impl selection). - targets: list[tuple[onnx.NodeProto, int]] = [] - for node in graph.node: - if node.op_type != "MatMul" or len(node.input) != 2: - continue - const_idx = [i for i, name in enumerate(node.input) if name in initializers] - if len(const_idx) != 1: - continue - idx = const_idx[0] - if len(initializers[node.input[idx]].dims) >= 3: - targets.append((node, idx)) - - if not targets: - return model - - if not graph.input: - logger.warning( - "SurgeryPipe: untie-constant-batched-matmul: no graph input to " - "derive a runtime value from; skipping %d MatMul(s)", - len(targets), - ) - return model - - prefix = "winml_ovgpu_untie" - first_input = graph.input[0].name - new_nodes: list[onnx.NodeProto] = [] - new_inits: list[onnx.TensorProto] = [] - - # Build a shape-[1] runtime zero from input *data* (not shape — input - # shapes are static and would be folded). Only ubiquitous ops are used - # so the static analyzer handles them: a single input element is sliced - # out and subtracted from itself. A [1] tensor broadcasts against any - # constant operand, regardless of its rank. - xf = f"{prefix}_xf" - new_nodes.append( - helper.make_node("Cast", [first_input], [xf], to=TensorProto.FLOAT, name=xf) - ) - flat = f"{prefix}_flat" - new_inits.append(numpy_helper.from_array(np.array([-1], dtype=np.int64), f"{prefix}_m1")) - new_nodes.append(helper.make_node("Reshape", [xf, f"{prefix}_m1"], [flat], name=flat)) - elem = f"{prefix}_elem" - # Slice(flat, starts=[0], ends=[1], axes=[0]) -> the first element. - # starts and axes are distinct tensors even though both hold [0], so a - # future edit to one role cannot silently corrupt the other. - starts = f"{prefix}_slice_starts" - ends = f"{prefix}_slice_ends" - axis = f"{prefix}_slice_axis" - new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), starts)) - new_inits.append(numpy_helper.from_array(np.array([1], dtype=np.int64), ends)) - new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), axis)) - new_nodes.append(helper.make_node("Slice", [flat, starts, ends, axis], [elem], name=elem)) - # zero = elem - elem == 0.0 (data-dependent, so it is not folded away). - zero_f32 = f"{prefix}_zero_f32" - new_nodes.append(helper.make_node("Sub", [elem, elem], [zero_f32], name=zero_f32)) - - # A zero must match each operand's dtype (ONNX has no implicit promotion). - zero_by_dtype: dict[int, str] = {int(TensorProto.FLOAT): zero_f32} - - def zero_for(dtype: int) -> str: - name = zero_by_dtype.get(dtype) - if name is None: - name = f"{prefix}_zero_{dtype}" - new_nodes.append(helper.make_node("Cast", [zero_f32], [name], to=dtype, name=name)) - zero_by_dtype[dtype] = name - return name - - untied = 0 - # Index the loop rather than node.name: node names are optional in ONNX - # and exporters routinely leave them blank or duplicated, so deriving - # `dyn` from the name would collide and produce an invalid graph. - for untie_idx, (node, idx) in enumerate(targets): - const_name = node.input[idx] - dtype = initializers[const_name].data_type - if dtype not in (TensorProto.FLOAT, TensorProto.FLOAT16, TensorProto.DOUBLE): - continue - dyn = f"{prefix}_untied{untie_idx}_in{idx}" - new_nodes.append( - helper.make_node("Add", [const_name, zero_for(dtype)], [dyn], name=dyn) - ) - node.input[idx] = dyn - untied += 1 - if verbose: - logger.info( - " untie-constant-batched-matmul: %s input[%d] %s -> %s", - node.name, - idx, - const_name, - dyn, - ) - - if untied == 0: - return model - - graph.initializer.extend(new_inits) - # Prepend new nodes: their inputs are only graph inputs / initializers, - # so placing them first keeps the graph topologically sorted. - existing = list(graph.node) - del graph.node[:] - graph.node.extend(new_nodes + existing) - - logger.info( - "SurgeryPipe: untie-constant-batched-matmul: untied %d batched " - "MatMul constant operand(s)", - untied, - ) - - return model diff --git a/tests/unit/optim/pipes/test_pipe_surgery.py b/tests/unit/optim/pipes/test_pipe_surgery.py index 16df08d79..6ab7e0034 100644 --- a/tests/unit/optim/pipes/test_pipe_surgery.py +++ b/tests/unit/optim/pipes/test_pipe_surgery.py @@ -342,158 +342,3 @@ def test_surgery_pipe_runs_last(self) -> None: # SurgeryPipe should be last in the list assert PIPES[-1].name == "surgery" - - -# ============================================================================= -# UNTIE-CONSTANT-BATCHED-MATMUL TESTS -# ============================================================================= - - -def _make_batched_const_matmul_model( - *, - const_rank: int = 3, - const_on_rhs: bool = True, -) -> onnx.ModelProto: - """Build a model with a batched MatMul that has one constant operand. - - data [2,3,4] @ W(const) [2,4,5] -> out [2,3,5] (const on rhs), or the - transposed arrangement when ``const_on_rhs`` is False. - """ - from onnx import TensorProto, helper - - rng = np.random.RandomState(0) - if const_on_rhs: - data_shape, w_shape, out_shape = [2, 3, 4], [2, 4, 5], [2, 3, 5] - mm_inputs = ["data", "W"] - else: - data_shape, w_shape, out_shape = [2, 4, 5], [2, 3, 4], [2, 3, 5] - mm_inputs = ["W", "data"] - - if const_rank == 2: - w_shape = w_shape[1:] - - w = numpy_helper.from_array(rng.randn(*w_shape).astype(np.float32), "W") - matmul = helper.make_node("MatMul", mm_inputs, ["out"], name="batched_matmul") - graph = helper.make_graph( - [matmul], - "test_batched_const_matmul", - [helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape)], - [helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)], - initializer=[w], - ) - return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) - - -class TestUntieConstantBatchedMatmulCapability: - """Capability/config plumbing for untie-constant-batched-matmul.""" - - def test_capability_exists(self) -> None: - """Capability is registered with a None ort_name (custom impl).""" - assert "untie-constant-batched-matmul" in SURGERY_CAPABILITIES - assert SURGERY_CAPABILITIES["untie-constant-batched-matmul"].ort_name is None - - def test_build_config_enable_via_kwarg(self) -> None: - """Flag can be toggled through build_config.""" - config = SurgeryPipe.build_config(untie_constant_batched_matmul=True) - assert config.untie_constant_batched_matmul is True - - def test_should_process_true_when_enabled(self) -> None: - """should_process is True when only this surgery is enabled.""" - config = SurgeryPipeConfig(untie_constant_batched_matmul=True) - assert SurgeryPipe.should_process(config) is True - - -class TestUntieConstantBatchedMatmulProcess: - """Graph transform behavior.""" - - def test_constant_operand_becomes_runtime_valued(self) -> None: - """The MatMul no longer consumes the initializer directly.""" - model = _make_batched_const_matmul_model() - result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) - - matmul = next(n for n in result.graph.node if n.op_type == "MatMul") - initializer_names = {init.name for init in result.graph.initializer} - # No MatMul input is a direct initializer anymore. - assert not (set(matmul.input) & initializer_names) - # An Add node now produces the (formerly constant) operand. - add_nodes = [n for n in result.graph.node if n.op_type == "Add"] - assert len(add_nodes) == 1 - assert add_nodes[0].output[0] in matmul.input - # Graph remains structurally valid. - onnx.checker.check_model(result) - - def test_numerics_unchanged(self) -> None: - """+0 tie leaves outputs bit-for-bit identical on ORT CPU.""" - import onnxruntime as ort - - model = _make_batched_const_matmul_model() - transformed = SurgeryPipe().process( - model, SurgeryPipeConfig(untie_constant_batched_matmul=True) - ) - - rng = np.random.RandomState(7) - feed = {"data": rng.randn(2, 3, 4).astype(np.float32)} - - ref = ort.InferenceSession( - model.SerializeToString(), providers=["CPUExecutionProvider"] - ).run(None, feed)[0] - got = ort.InferenceSession( - transformed.SerializeToString(), providers=["CPUExecutionProvider"] - ).run(None, feed)[0] - np.testing.assert_array_equal(ref, got) - - def test_two_dim_constant_is_left_untouched(self) -> None: - """Rank-2 constant gemm compiles on OV GPU, so it must not be rewritten.""" - model = _make_batched_const_matmul_model(const_rank=2) - result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) - assert not any(n.op_type == "Add" for n in result.graph.node) - - def test_constant_on_lhs_is_handled(self) -> None: - """A constant rank-3 operand on the LHS is untied too.""" - model = _make_batched_const_matmul_model(const_on_rhs=False) - result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) - assert any(n.op_type == "Add" for n in result.graph.node) - - def test_duplicate_node_names_do_not_collide(self) -> None: - """Two target MatMuls with empty names produce a valid graph. - - Node names are optional in ONNX; exporters routinely leave them blank. - The generated dynamic-operand names must be unique regardless, or the - transformed graph would have colliding tensor names and fail validation. - """ - from onnx import TensorProto, helper - - rng = np.random.RandomState(0) - w1 = numpy_helper.from_array(rng.randn(2, 4, 5).astype(np.float32), "W1") - w2 = numpy_helper.from_array(rng.randn(2, 5, 6).astype(np.float32), "W2") - # Both MatMuls deliberately left unnamed (name=""). - mm1 = helper.make_node("MatMul", ["data", "W1"], ["mid"], name="") - mm2 = helper.make_node("MatMul", ["mid", "W2"], ["out"], name="") - graph = helper.make_graph( - [mm1, mm2], - "test_dup_names", - [helper.make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4])], - [helper.make_tensor_value_info("out", TensorProto.FLOAT, [2, 3, 6])], - initializer=[w1, w2], - ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) - - result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) - - # Both constants are untied and the graph stays structurally valid. - add_nodes = [n for n in result.graph.node if n.op_type == "Add"] - assert len(add_nodes) == 2 - assert len({n.output[0] for n in add_nodes}) == 2 - onnx.checker.check_model(result) - - # Numerics are unchanged versus the original model. - import onnxruntime as ort - - feed = {"data": np.random.RandomState(7).randn(2, 3, 4).astype(np.float32)} - ref = ort.InferenceSession( - model.SerializeToString(), providers=["CPUExecutionProvider"] - ).run(None, feed)[0] - got = ort.InferenceSession( - result.SerializeToString(), providers=["CPUExecutionProvider"] - ).run(None, feed)[0] - np.testing.assert_array_equal(ref, got) From c829d29e95a02ff67e81071fd5012feb404df66c Mon Sep 17 00:00:00 2001 From: xieofxie Date: Wed, 10 Jun 2026 14:29:17 +0800 Subject: [PATCH 8/8] Potential fix for pull request finding 'CodeQL / Module is imported with 'import' and 'import from'' Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- .../test_pipe_rewrite_batched_const_matmul.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/unit/optim/pipes/test_pipe_rewrite_batched_const_matmul.py b/tests/unit/optim/pipes/test_pipe_rewrite_batched_const_matmul.py index 86b90a1c0..711e79aca 100644 --- a/tests/unit/optim/pipes/test_pipe_rewrite_batched_const_matmul.py +++ b/tests/unit/optim/pipes/test_pipe_rewrite_batched_const_matmul.py @@ -21,7 +21,6 @@ import numpy as np import onnx import onnxruntime as ort -from onnx import TensorProto, helper, numpy_helper from winml.modelkit.optim.pipes.rewrite import RewritePipe @@ -61,26 +60,28 @@ def _build_model( out_shape = [batch, 3, 5] mm_inputs = ["W", "dyn"] - initializers = [numpy_helper.from_array(w_arr, "W")] - graph_inputs = [helper.make_tensor_value_info("dyn", TensorProto.FLOAT, dyn_shape)] + initializers = [onnx.numpy_helper.from_array(w_arr, "W")] + graph_inputs = [ + onnx.helper.make_tensor_value_info("dyn", onnx.TensorProto.FLOAT, dyn_shape) + ] if both_constant: # Replace the dynamic operand with a second constant of matching shape. dyn_concrete = [d if isinstance(d, int) else 2 for d in dyn_shape] initializers.append( - numpy_helper.from_array(rng.randn(*dyn_concrete).astype(np.float32), "dyn") + onnx.numpy_helper.from_array(rng.randn(*dyn_concrete).astype(np.float32), "dyn") ) graph_inputs = [] - matmul = helper.make_node("MatMul", mm_inputs, ["out"], name="batched_matmul") - graph = helper.make_graph( + matmul = onnx.helper.make_node("MatMul", mm_inputs, ["out"], name="batched_matmul") + graph = onnx.helper.make_graph( [matmul], "batched_const_matmul", graph_inputs, - [helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)], + [onnx.helper.make_tensor_value_info("out", onnx.TensorProto.FLOAT, out_shape)], initializer=initializers, ) - return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + return onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 17)]) def _run(model: onnx.ModelProto, feed: dict[str, np.ndarray]) -> np.ndarray: