Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 195 additions & 7 deletions src/winml/modelkit/analyze/core/runtime_checker_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2491,22 +2491,210 @@ def get_pattern_id(is_qdq: bool) -> str:
parquet_rules_ms = _elapsed_ms(parquet_rules_start)
return _finish(final_result, outcome="parquet_rules")

def _load_parquet_pattern_rule_table(
self,
pattern_name: str,
op_domain: ONNXDomain,
opset_version: int,
) -> tuple[pd.DataFrame | None, Path, _ParquetConditionTree | None]:
"""Load per-pattern parquet rule table with cache.

Returns:
tuple[pd.DataFrame | None, Path, _ParquetConditionTree | None]:
Loaded dataframe when available, otherwise None,
the resolved parquet path used for lookup,
and optional pre-built condition tree.
"""
parquet_name = (
f"{pattern_name}_{self.ep_name}_{self.device_type.upper()}_{op_domain.name}"
f"_opset{opset_version}.parquet"
)
parquet_path = resolve_rule_parquet_path(parquet_name)

cache_key = (pattern_name, op_domain.value, opset_version, False)
if cache_key in self._parquet_rule_table_cache:
_log_parquet_cache_hit(parquet_path, scope="instance")
return (
self._parquet_rule_table_cache[cache_key],
parquet_path,
self._parquet_condition_tree_cache.get(cache_key),
)

table_df = _get_or_load_parquet_table_global(parquet_path)
condition_tree = _build_condition_tree(table_df)
self._parquet_rule_table_cache[cache_key] = table_df
self._parquet_condition_tree_cache[cache_key] = condition_tree
return table_df, parquet_path, condition_tree

def run_for_subgraph(
self,
pattern_match: PatternMatchResult,
run_unknown_op: bool = False,
) -> PatternRuntime:
"""Run runtime check for subgraph pattern via per-node checks."""
"""Run runtime check for subgraph pattern via parquet rule lookup.

Strategy mirrors ``run_for_node``'s parquet-based path:
1. Extract conditions from pattern match.
2. Resolve the pattern parquet table for
``(pattern_name, ep, device, domain, opset)``.
3. Look up the matching row and return its compile/run result.
4. Fallback to per-node checking when the table or matching row is missing.

Args:
pattern_match: PatternMatchResult containing pattern information.
run_unknown_op: If True, attempt local EP check for unknown ops in
the per-node fallback path.

Returns:
PatternRuntime with check results.
"""
pattern_id = pattern_match.pattern.pattern_id
pattern_name = pattern_match.pattern.__class__.__name__
logger.info(
"Pattern-level aggregated rules are removed; checking individual operators for '%s'",
pattern_name,

# Step 1: Extract conditions from PatternMatchResult
try:
conditions, infinite_properties = get_query_conditions_for_pattern(
pattern_match,
pattern_name,
self.opset_versions,
dynamic_axis_strict_mode=self.dynamic_axis_strict_mode,
)
except OpOptionalInputSupportError as e:
logger.error("OpOptionalInputSupportError for pattern '%s': %s", pattern_name, e)
return PatternRuntime(
pattern_id=pattern_id,
result=RuntimeTestResult(
compile=False,
run=False,
no_data=True,
reason="optional_input_properties_not_found",
debug_details={
"pattern_name": pattern_name,
"error_message": str(e),
"table_path": "",
"table_file": "",
},
),
alternatives=self.alternatives,
pattern_match=pattern_match,
)
except Exception as e:
logger.error("Failed to extract conditions for pattern '%s': %s", pattern_name, e)
return PatternRuntime(
pattern_id=pattern_id,
result=RuntimeTestResult(
compile=False,
run=False,
no_data=True,
reason="pattern_conditions_extraction_failed",
debug_details={
"pattern_name": pattern_name,
"error_message": str(e),
},
),
alternatives=self.alternatives,
pattern_match=pattern_match,
)

# Step 2: Determine domain & opset for the parquet table file name.
# Prefer ai.onnx (com.microsoft opset is always 1, so we use ai.onnx
# for naming when available); otherwise fall back to the first domain.
if ONNXDomain.AI_ONNX in self.opset_versions:
table_domain = ONNXDomain.AI_ONNX
table_opset = self.opset_versions[ONNXDomain.AI_ONNX]
else:
table_domain, table_opset = next(iter(self.opset_versions.items()))

# Step 3: Load the parquet rule table for this pattern
table_df, parquet_path, condition_tree = self._load_parquet_pattern_rule_table(
pattern_name, table_domain, table_opset
)
return self._run_for_subgraph_per_node(
pattern_match,
parquet_file = parquet_path.name
parquet_path_norm = _normalize_table_path(parquet_path)

if table_df is None:
logger.info(
"No pattern parquet '%s' found for '%s', checking individual operators",
parquet_file,
pattern_name,
)
return self._run_for_subgraph_per_node(pattern_match, pattern_name, run_unknown_op)

# Step 4: Build filter conditions and look up the matching row
pattern_columns = condition_tree.condition_columns if condition_tree is not None else []
table_filter_conditions = _build_table_filter_conditions(
conditions,
pattern_columns,
infinite_properties,
f"pattern {pattern_name}",
)
parquet_filter_conditions = {
k: encode_rule_condition_value_for_parquet(v)
for k, v in table_filter_conditions.items()
}
query_signature = _build_query_signature(pattern_columns, parquet_filter_conditions)

cache_key = (
pattern_name,
run_unknown_op,
table_domain.value,
table_opset,
False,
query_signature,
)
if cache_key in self._node_result_cache:
cached = self._node_result_cache[cache_key]
return PatternRuntime(
pattern_id=pattern_id,
result=cached.result,
alternatives=self.alternatives,
pattern_match=pattern_match,
)

matched_row = None
row_position = _lookup_row_position_in_condition_tree(
condition_tree, parquet_filter_conditions
)
if row_position is not None:
matched_row = table_df.iloc[row_position]
else:
ret = query_table_exact_match(table_df, parquet_filter_conditions)
if not ret.empty:
matched_row = ret.iloc[0]

if matched_row is None:
logger.info(
"Pattern parquet '%s' loaded but properties not matched for '%s': %s",
parquet_file,
pattern_name,
table_filter_conditions,
)
return self._run_for_subgraph_per_node(pattern_match, pattern_name, run_unknown_op)

compile_run = matched_row.get("compile_run_success", (False, False))
compile_result = bool(compile_run[0])
run_result = bool(compile_run[1])

result = RuntimeTestResult(
compile=compile_result,
run=run_result,
reason="",
no_data=False,
debug_details={
"table_path": parquet_path_norm,
"table_file": parquet_file,
"opset_version": table_opset,
"lookup_columns": pattern_columns,
"query_signature": query_signature,
},
)
pattern_runtime = PatternRuntime(
pattern_id=pattern_id,
result=result,
alternatives=self.alternatives,
pattern_match=pattern_match,
)
self._node_result_cache[cache_key] = pattern_runtime
return pattern_runtime

def _run_for_subgraph_per_node(
self,
Expand Down
6 changes: 6 additions & 0 deletions src/winml/modelkit/pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@
ReshapeTransposeReshapeOverlyHighDimPattern,
ReshapeTransposeReshapeOverlyHighDimPatternInputGenerator,
)
from .unsqueeze_cast_patterns import (
UnsqueezeCastPattern,
UnsqueezeCastPatternInputGenerator,
)


__all__ = [
Expand Down Expand Up @@ -139,6 +143,8 @@
"TransposedSingleLayerNormalizationPatternInputGenerator",
"TransposedSingleRMSNormalizationPattern",
"TransposedSingleRMSNormalizationPatternInputGenerator",
"UnsqueezeCastPattern",
"UnsqueezeCastPatternInputGenerator",
"get_pattern_input_generator",
"get_registered_pattern_input_generators",
"make_single_op_pattern",
Expand Down
19 changes: 13 additions & 6 deletions src/winml/modelkit/pattern/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,14 +965,21 @@ def _infer_type_mapping(self, skeleton_match_result: "SkeletonMatchResult") -> d
Dictionary mapping type parameters (e.g., 'T') to actual types (e.g., 'tensor(float)').
"""
schema = self.get_schema()
type_param_to_type = {}
matcher = skeleton_match_result.matcher
type_param_to_type: dict[str, str] = {}

for idx, input_param in enumerate(schema.inputs):
if idx < len(skeleton_match_result.inputs):
tensor_name = skeleton_match_result.inputs[idx]
actual_type = skeleton_match_result.matcher.get_tensor_type_str(tensor_name)
if actual_type and input_param.type_str:
type_param_to_type[input_param.type_str] = actual_type
if idx < len(skeleton_match_result.inputs) and input_param.type_str:
actual_type = matcher.get_tensor_type_str(skeleton_match_result.inputs[idx])
if actual_type:
type_param_to_type.setdefault(input_param.type_str, actual_type)

if schema.outputs and skeleton_match_result.output:
output_param = schema.outputs[0]
if output_param.type_str:
actual_type = matcher.get_tensor_type_str(skeleton_match_result.output)
if actual_type:
type_param_to_type.setdefault(output_param.type_str, actual_type)

return type_param_to_type

Expand Down
8 changes: 8 additions & 0 deletions src/winml/modelkit/pattern/rules/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@
"reason": "Merged axes reduce Transpose dimensionality for better hardware compatibility"
}
]
},
{
"pattern_id": "SUBGRAPH/UnsqueezeCastPattern",
"pattern_class": "UnsqueezeCastPattern",
"module": "winml.modelkit.pattern.unsqueeze_cast_patterns",
"enabled": true,
"flag_name": "unsqueezecast",
"description": "Unsqueeze followed by Cast(to=FLOAT) on a constant-axes Unsqueeze"
}
]
}
Loading
Loading