diff --git a/qwen3_quantize.py b/qwen3_quantize.py new file mode 100644 index 000000000..655c65e6a --- /dev/null +++ b/qwen3_quantize.py @@ -0,0 +1,256 @@ +"""Qwen3 transformer-only quantization. + +Must be called after the composite Qwen3 model has been built (e.g. by +``test_qwen 2.py``) so that ``decoder_prefill`` / ``decoder_gen`` ONNX files +exist in the winml cache. + +Pipeline: + + 1. Apply ``make_transformer_only`` surgery to each sub-model, producing + ``*_transformer.onnx`` with ``inputs_embeds`` input and + ``output_hidden_states`` output — embeddings and lm_head are stripped + out (ignored, not quantized). + 2. Quantize those transformer-only files via winml-cli's ``quantize_onnx`` + using a calibration reader that runs ``embed_tokens`` in PyTorch on + real text samples. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Iterator + +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from winml.modelkit.models.winml.composite_model import WinMLCompositeModel +from winml.modelkit.onnx import make_transformer_only +from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx + + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL_ID = "Qwen/Qwen3-0.6B" +DEFAULT_MAX_CACHE = 256 +DEFAULT_PREFILL_SEQ = 64 +DEFAULT_GEN_SEQ = 1 +DEFAULT_NUM_SAMPLES = 16 +DEFAULT_PROMPTS = [ + "Solve: 8 * 7 = ?", + "Translate to French: The weather is nice today.", + "Write a short poem about the ocean.", + "Explain gradient descent in one paragraph.", + "What is the capital of Japan?", + "List three uses of magnesium.", + "Summarize the plot of Hamlet in two sentences.", + "Give a Python one-liner to reverse a string.", +] + + +# --------------------------------------------------------------------------- +# Calibration data reader +# --------------------------------------------------------------------------- + + +class Qwen3TransformerCalibReader: + """Yields calibration feeds for the transformer-only Qwen3 ONNX. + + Runs HF ``embed_tokens`` in PyTorch to produce ``inputs_embeds`` since the + embedding layer was stripped from the ONNX graph. All other inputs + (attention_mask, position_ids, past_{i}_key/value) follow the conventions + used by winml-cli's ``WinMLQwen3Model`` runtime. + """ + + def __init__( + self, + embed_tokens: torch.nn.Module, + config: Any, + token_ids_list: list[torch.Tensor], + *, + seq_len: int, + max_cache_len: int, + ) -> None: + self.embed = embed_tokens + self.cfg = config + self.seq_len = seq_len + self.max_cache_len = max_cache_len + self.num_layers = config.num_hidden_layers + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self._samples = list(self._build_samples(token_ids_list)) + self._iter: Iterator[dict[str, np.ndarray]] | None = None + self.rewind() + + def _build_samples( + self, token_ids_list: list[torch.Tensor] + ) -> Iterator[dict[str, np.ndarray]]: + for ids in token_ids_list: + # Right-truncate / pad to seq_len so we feed the static graph shape. + ids = ids[:, : self.seq_len] + real_len = ids.shape[1] + if real_len < self.seq_len: + pad = torch.zeros( + (1, self.seq_len - real_len), dtype=ids.dtype, device=ids.device + ) + ids = torch.cat([ids, pad], dim=1) + + with torch.no_grad(): + embeds = self.embed(ids).to(torch.float32).cpu().numpy() + + # attention_mask: ones for real prompt positions placed at the + # END of the max_cache buffer (sliding-window cache convention), + # zeros elsewhere. + attn_mask = np.zeros((1, self.max_cache_len), dtype=np.int64) + attn_mask[0, -real_len:] = 1 + + # position_ids: 0..seq_len-1 (clamped for padding). + position_ids = np.arange(self.seq_len, dtype=np.int64)[None, :] + + feed: dict[str, np.ndarray] = { + "inputs_embeds": embeds.astype(np.float32), + "attention_mask": attn_mask, + "position_ids": position_ids, + } + kv_shape = (1, self.num_kv_heads, self.max_cache_len, self.head_dim) + zeros = np.zeros(kv_shape, dtype=np.float32) + for i in range(self.num_layers): + feed[f"past_{i}_key"] = zeros + feed[f"past_{i}_value"] = zeros + yield feed + + def get_next(self) -> dict[str, np.ndarray] | None: + try: + return next(self._iter) if self._iter is not None else None + except StopIteration: + return None + + def rewind(self) -> None: + self._iter = iter(self._samples) + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +def _tokenize_prompts( + tokenizer: Any, prompts: list[str], num_samples: int +) -> list[torch.Tensor]: + # Cycle through prompts up to num_samples; apply chat template like the + # runtime so calibration distribution matches inference inputs. + out: list[torch.Tensor] = [] + for i in range(num_samples): + prompt = prompts[i % len(prompts)] + text = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ids = tokenizer([text], return_tensors="pt").input_ids + out.append(ids) + return out + + +def quantize_built_model( + model: WinMLCompositeModel, + *, + model_id: str = DEFAULT_MODEL_ID, + max_cache_len: int = DEFAULT_MAX_CACHE, + prefill_seq: int = DEFAULT_PREFILL_SEQ, + num_samples: int = DEFAULT_NUM_SAMPLES, + weight_type: str = "uint8", + activation_type: str = "uint16", +) -> dict[str, Path]: + """Run surgery + transformer-only quantization on an already-built composite. + + Reuses the ONNX files produced by ``WinMLCompositeModel.from_pretrained`` + so this can be called after a build step without re-exporting. + + Returns: mapping of sub-model name → quantized ONNX path. + """ + sub_paths: dict[str, Path] = {} + for name, sub in model.sub_models.items(): + final_path = Path(sub._onnx_path) + # ``_model.onnx`` is the *compiled* QNN EPContext blob — surgery needs + # the uncompiled fp16 graph. ``build.hf`` emits ``{cache_key}_optimized.onnx`` + # alongside it in the same artifacts directory. + if final_path.name.endswith("_model.onnx"): + stem = final_path.name[: -len("_model.onnx")] + optimized = final_path.with_name(f"{stem}_optimized.onnx") + if optimized.exists(): + sub_paths[name] = optimized + continue + print( + f"WARNING: {optimized.name} not found next to {final_path.name}; " + "falling back to the compiled model (surgery will likely fail)." + ) + sub_paths[name] = final_path + + for name, p in sub_paths.items(): + print(f" {name}: {p}") + + print("\n=== Loading HF embed_tokens for calibration ===") + hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) + hf_model.eval() + embed_tokens = hf_model.get_input_embeddings() + tokenizer = AutoTokenizer.from_pretrained(model_id) + token_ids_list = _tokenize_prompts(tokenizer, DEFAULT_PROMPTS, num_samples) + + seq_by_sub = { + "decoder_prefill": prefill_seq, + "decoder_gen": DEFAULT_GEN_SEQ, + } + + quant_paths: dict[str, Path] = {} + for sub_name, fused_path in sub_paths.items(): + if sub_name not in seq_by_sub: + print(f"\n--- Skipping unknown sub-model {sub_name!r} ---") + continue + + seq_len = seq_by_sub[sub_name] + transformer_path = fused_path.with_name(fused_path.stem + "_transformer.onnx") + quant_path = transformer_path.with_name( + transformer_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" + ) + + print(f"\n=== Surgery: {sub_name} (seq_len={seq_len}) ===") + print(f" in : {fused_path}") + print(f" out: {transformer_path}") + make_transformer_only(fused_path, transformer_path) + + print(f"\n=== Quantize (transformer only): {sub_name} ===") + print(f" out: {quant_path}") + reader = Qwen3TransformerCalibReader( + embed_tokens, + hf_model.config, + token_ids_list, + seq_len=seq_len, + max_cache_len=max_cache_len, + ) + cfg = WinMLQuantizationConfig( + samples=num_samples, + weight_type=weight_type, # type: ignore[arg-type] + activation_type=activation_type, # type: ignore[arg-type] + calibration_method="minmax", + calibration_data=reader, + ) + result = quantize_onnx(transformer_path, output_path=quant_path, config=cfg) + if not result.success: + print(" FAILED:") + for err in result.errors: + print(f" {err}") + raise SystemExit(1) + print( + f" ok — {result.nodes_quantized} QDQ nodes inserted in " + f"{result.total_time_seconds:.1f}s" + ) + quant_paths[sub_name] = quant_path + + print("\n=== Done ===") + return quant_paths + diff --git a/src/winml/modelkit/onnx/__init__.py b/src/winml/modelkit/onnx/__init__.py index a3bc49d51..0287a2ff7 100644 --- a/src/winml/modelkit/onnx/__init__.py +++ b/src/winml/modelkit/onnx/__init__.py @@ -19,6 +19,7 @@ from .io import InputTensorSpec, OutputTensorSpec, generate_inputs_from_onnx, get_io_config from .metadata import capture_metadata, restore_metadata from .persistence import cleanup_onnx, load_onnx, save_onnx +from .qwen_surgery import make_transformer_only from .shape import infer_onnx_shapes, infer_shapes from .utils import EXTERNAL_DATA_THRESHOLD, check_onnx_model, get_model_size @@ -41,6 +42,7 @@ "is_compiled_onnx", "is_quantized_onnx", "load_onnx", + "make_transformer_only", "remove_optional_from_type_annotation", "restore_metadata", "save_onnx", diff --git a/src/winml/modelkit/onnx/qwen_surgery.py b/src/winml/modelkit/onnx/qwen_surgery.py new file mode 100644 index 000000000..cd49ee5ec --- /dev/null +++ b/src/winml/modelkit/onnx/qwen_surgery.py @@ -0,0 +1,186 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Ad-hoc ONNX surgery to turn a Qwen3 decoder ONNX into a transformer-only graph. + +Applied as a post-export surgery on the fused decoder ONNX produced by +``WinMLQwen3Model`` (``decoder_prefill.onnx`` / ``decoder_gen.onnx``). + +The resulting transformer-only ONNX has: + - ``input_ids`` graph input replaced by ``inputs_embeds`` (FLOAT, + ``[batch, seq, hidden_size]``) — the upstream embedding Gather is + removed. + - ``logits`` graph output replaced by ``output_hidden_states`` + (FLOAT, ``[batch, seq, hidden_size]``) — the final ``lm_head`` MatMul + is removed. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import onnx +from onnx import TensorProto, helper + +from .persistence import load_onnx, save_onnx + + +logger = logging.getLogger(__name__) + + +def _dim(d: onnx.TensorShapeProto.Dimension) -> int | str: + if d.HasField("dim_value"): + return d.dim_value + return d.dim_param or "?" + + +def make_transformer_only( + model_path: str | Path, + output_path: str | Path, + *, + input_ids_name: str = "input_ids", + logits_name: str = "logits", + inputs_embeds_name: str = "inputs_embeds", + output_hidden_states_name: str = "output_hidden_states", +) -> Path: + """Strip the embedding Gather and the lm_head MatMul from a Qwen3 ONNX. + + Args: + model_path: Path to the fused decoder ONNX (logits output, input_ids input). + output_path: Destination for the transformer-only ONNX. + input_ids_name: Name of the input_ids graph input to drop. + logits_name: Name of the logits graph output to drop. + inputs_embeds_name: Display name for the new embeddings input + (used only for logging; the actual tensor keeps its existing + internal name so downstream nodes need no rewiring). + output_hidden_states_name: Display name for the new hidden-state output. + + Returns: + The output path. + """ + model_path = Path(model_path) + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + model = load_onnx(model_path, load_weights=True, validate=False) + graph = model.graph + init_by_name = {init.name: init for init in graph.initializer} + + # -------------------- Embedding removal -------------------- + embed_idx = next( + (i for i, n in enumerate(graph.node) if input_ids_name in n.input), + None, + ) + if embed_idx is None: + msg = f"No node consumes graph input {input_ids_name!r}" + raise RuntimeError(msg) + + embed_node = graph.node[embed_idx] + embed_out_name = embed_node.output[0] + + embed_weight = None + for ipt in embed_node.input: + init = init_by_name.get(ipt) + if init is not None and len(init.dims) == 2: + embed_weight = init + break + if embed_weight is None: + msg = f"Could not find 2-D embedding weight initializer on node {embed_node.name!r}" + raise RuntimeError(msg) + hidden_size = int(embed_weight.dims[1]) + + ids_input = next(i for i in graph.input if i.name == input_ids_name) + batch_dim = _dim(ids_input.type.tensor_type.shape.dim[0]) + seq_dim = _dim(ids_input.type.tensor_type.shape.dim[1]) + + logger.info( + "Removing embedding node %r (%s) — exposing %r as new input %r [%s, %s, %d]", + embed_node.name, + embed_node.op_type, + embed_out_name, + inputs_embeds_name, + batch_dim, + seq_dim, + hidden_size, + ) + + new_embed_input = helper.make_tensor_value_info( + inputs_embeds_name, + TensorProto.FLOAT, + [batch_dim, seq_dim, hidden_size], + ) + + del graph.node[embed_idx] + graph.input.remove(ids_input) + graph.input.append(new_embed_input) + graph.initializer.remove(embed_weight) + + # Rewire any consumer of the removed embedding output to the new input. + for n in graph.node: + for i, name in enumerate(n.input): + if name == embed_out_name: + n.input[i] = inputs_embeds_name + + # -------------------- lm_head removal -------------------- + lmh_idx = next( + (i for i, n in enumerate(graph.node) if logits_name in n.output), + None, + ) + if lmh_idx is None: + msg = f"No node produces graph output {logits_name!r}" + raise RuntimeError(msg) + + lmh_node = graph.node[lmh_idx] + init_names = {init.name for init in graph.initializer} + hidden_in: str | None = None + weight_in: str | None = None + for ipt in lmh_node.input: + if ipt in init_names: + weight_in = ipt + else: + hidden_in = ipt + if hidden_in is None: + msg = f"lm_head node {lmh_node.name!r} has no non-initializer input ({list(lmh_node.input)})" + raise RuntimeError(msg) + + logger.info( + "Removing lm_head node %r (%s) — exposing %r as new output %r", + lmh_node.name, + lmh_node.op_type, + hidden_in, + output_hidden_states_name, + ) + + logits_output = next(o for o in graph.output if o.name == logits_name) + new_hidden_output = helper.make_tensor_value_info( + output_hidden_states_name, + TensorProto.FLOAT, + [batch_dim, seq_dim, hidden_size], + ) + + del graph.node[lmh_idx] + graph.output.remove(logits_output) + # Put hidden states first so it mirrors the original logits position. + graph.output.insert(0, new_hidden_output) + + # Rename the producer of ``hidden_in`` to emit the new graph output name. + for n in graph.node: + for i, name in enumerate(n.output): + if name == hidden_in: + n.output[i] = output_hidden_states_name + for i, name in enumerate(n.input): + if name == hidden_in: + n.input[i] = output_hidden_states_name + + if weight_in is not None and not any(weight_in in n.input for n in graph.node): + wi = next(init for init in graph.initializer if init.name == weight_in) + graph.initializer.remove(wi) + + save_onnx(model, output_path) + logger.info("Wrote transformer-only ONNX → %s", output_path) + return output_path + + +__all__ = ["make_transformer_only"] diff --git a/test_qwen 2.py b/test_qwen 2.py new file mode 100644 index 000000000..6a52dee72 --- /dev/null +++ b/test_qwen 2.py @@ -0,0 +1,70 @@ +"""E2E test for Qwen3 decoder-only pipeline. + +Uses sub_model_kwargs to set per-component shape_config: + - decoder_prefill: max_cache_len=256, seq_len=64 + - decoder_gen: max_cache_len=256, seq_len=1 + +Set env var ``QUANTIZE=1`` to also run the MOPS-style Step 3: +transformer-only surgery + winml quantize on both sub-models +(embeddings and lm_head are stripped and not quantized). +""" + +import os + +from transformers import AutoTokenizer + +from winml.modelkit.config import WinMLBuildConfig +from winml.modelkit.models.winml.composite_model import WinMLCompositeModel + +model_id = "Qwen/Qwen3-0.6B" + +model = WinMLCompositeModel.from_pretrained( + model_id, + task="text-generation", + # config=WinMLBuildConfig(quant=None, compile=None), + config=WinMLBuildConfig(quant=None), + precision="fp16", + device="npu", + ep="qnn", + force_rebuild=False, + sub_model_kwargs={ + "decoder_prefill": {"shape_config": {"max_cache_len": 256, "seq_len": 64}}, + "decoder_gen": {"shape_config": {"max_cache_len": 256, "seq_len": 1}}, + }, +) + +# Verify ONNX I/O shapes +for name, sub in model.sub_models.items(): + io = sub.io_config + shapes = dict(zip(io["input_names"], io["input_shapes"])) + print(f"\n=== {name} ===") + for k, v in shapes.items(): + print(f" {k}: {v}") + +tokenizer = AutoTokenizer.from_pretrained(model_id) + +prompt = "8 * 7 = ?" +messages = [{"role": "user", "content": prompt}] +text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, +) +model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + +generated_ids = model.generate(**model_inputs) + +output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() +content = tokenizer.decode(output_ids, skip_special_tokens=True) +print("\nAnswer:", content) + +if os.environ.get("QUANTIZE") == "1": + # Reuse the already-built decoder_prefill/decoder_gen ONNX files: + # surgery (strip embed + lm_head) + transformer-only quantize. + print("\n=== QUANTIZE=1 — running transformer-only quantization ===") + from qwen3_quantize import quantize_built_model + + quantize_built_model( + model, + model_id=model_id, + max_cache_len=256, + prefill_seq=64, + )