From 4418a5a6f71cbdbfe4da55d7d2c4e8fef9fe8ada Mon Sep 17 00:00:00 2001 From: Kevin Read Date: Fri, 24 Apr 2026 21:40:37 +0200 Subject: [PATCH 1/2] =?UTF-8?q?docs:=20add=20Qwen3-Reranker=20=E2=86=92=20?= =?UTF-8?q?OpenVINO=20IR=20conversion=20recipe?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Covers the working Python-API path (optimum.intel.OVModelForCausalLM) after hitting a silent-truncate bug in optimum-cli export openvino for the Qwen3-Reranker-4B + int8 path. Documents prerequisites, step-by-step conversion, verification, and how to wire the resulting IR into openarc_config.json as a rerank model under the optimum engine. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/openvino_qwen3.md | 95 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 docs/openvino_qwen3.md diff --git a/docs/openvino_qwen3.md b/docs/openvino_qwen3.md new file mode 100644 index 0000000..850c353 --- /dev/null +++ b/docs/openvino_qwen3.md @@ -0,0 +1,95 @@ +# Converting Qwen3-Reranker to OpenVINO IR for OpenArc + +Working recipe for converting a Qwen3-Reranker HuggingFace checkpoint to INT8 OpenVINO IR, validated on `Qwen/Qwen3-Reranker-4B` (2026-04-24). + +## Why not `optimum-cli` + +The obvious approach — `optimum-cli export openvino --weight-format int8` — **silently truncated** the output on every combination we tried (optimum-intel 1.27.0.dev + openvino 2026.1.0.dev, plus the notebook-pinned transformers 4.55.4 / torch 2.9.1 / openvino 2026.1.0 release). NNCF reported 100% weight compression, the process exited 0 with no traceback, but it wrote a **0-byte `openvino_model.xml`** and a ~13 MB stub `openvino_model.bin`. Bypassing the HF cache and using a local source copy did not help. + +The same stack via the `optimum.intel` Python API produced a correct 4.03 GB bin + 3.16 MB xml on the first attempt. Use the Python API. + +## Prerequisites + +OpenArc's `.venv` already has everything needed — `optimum[openvino]`, `openvino`, `nncf`, `transformers`, `torch`. No extra install. + +Source model on a writable local path. If you rely on the HF Hub cache, make sure `~/.cache/huggingface/hub/models----/` is writable by the current user. A root-owned cache (e.g. populated by a container) produces `Permission denied` warnings during `.no_exist/` writes that are *technically* non-fatal for load, but correlate with other breakage — `chown -R $USER:$USER` the cache dirs if you see them. + +## Conversion + +```python +# convert_qwen3_reranker.py +import os, shutil +from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig +from openvino_tokenizers import convert_tokenizer +import openvino as ov +from transformers import AutoTokenizer + +SRC = "Qwen/Qwen3-Reranker-4B" # HF id or local path +OUT = "/data/openvino-models/reranker/qwen3-4b-reranker" + +os.makedirs(OUT, exist_ok=True) + +# 1. Load + convert + INT8 weight-only compression +qc = OVWeightQuantizationConfig(bits=8, sym=False) +model = OVModelForCausalLM.from_pretrained( + SRC, + export=True, + use_cache=False, # reranker does a single scoring forward pass; no KV cache + quantization_config=qc, + compile=False, +) +model.save_pretrained(OUT) # writes openvino_model.xml/.bin + openvino_config.json + +# 2. Copy HF tokenizer artifacts (save_pretrained does NOT) +tok = AutoTokenizer.from_pretrained(SRC, padding_side="left") +tok.save_pretrained(OUT) + +# 3. Emit OpenVINO tokenizer + detokenizer (optional but standard) +ov_tok, ov_detok = convert_tokenizer(tok, with_detokenizer=True) +ov.save_model(ov_tok, os.path.join(OUT, "openvino_tokenizer.xml")) +ov.save_model(ov_detok, os.path.join(OUT, "openvino_detokenizer.xml")) +``` + +Run with `PYTHONUNBUFFERED=1` so progress bars flush. Expected peak RSS is roughly `model_fp_size + NNCF_overhead`, so plan on ≥16 GB RAM + swap headroom for the 4B. The compression phase is CPU-bound and took ~20 s on this host; the forward-pass trace and save together add another minute or two. + +Quantization notes: +- `OVWeightQuantizationConfig(bits=8, sym=False)` is weight-only INT8 asymmetric, per-channel — the same thing `--weight-format int8` does with NNCF internally. No calibration dataset needed. +- Do **not** use `--quant-mode int8` / full activation quantization for a reranker without a domain-matched calibration set; accuracy drops quickly. +- For 4-bit, switch to `bits=4` and consider `sym=True`, `group_size=128`, and optionally `dataset="wikitext2"` with `awq=True, scale_estimation=True` for better recovery. INT8 is the right default for rerankers. + +## Verification + +Quick smoke test — load the exported model and run one forward pass: + +```python +import torch +from optimum.intel import OVModelForCausalLM +from transformers import AutoTokenizer + +m = OVModelForCausalLM.from_pretrained(OUT, device="CPU", use_cache=False, export=False) +tok = AutoTokenizer.from_pretrained(OUT, padding_side="left") +inp = tok("Query: Paris\nDocument: Paris is the capital of France.\nRelevant:", return_tensors="pt") +with torch.no_grad(): + out = m(**inp) +print(out.logits.shape) # expect (1, seq_len, 151669) +``` + +A vocab dimension of 151669 and a non-empty bin/xml on disk are the two signals that the export is real — not a stub. + +## Wiring into OpenArc + +Add the model to `openarc_config.json` alongside any existing reranker entry: + +```json +"qwen3-4b-reranker": { + "model_name": "qwen3-4b-reranker", + "model_path": "/data/openvino-models/reranker/qwen3-4b-reranker", + "model_type": "rerank", + "engine": "optimum", + "device": "GPU", + "runtime_config": {}, + "vlm_type": null +} +``` + +`engine: optimum` wires it into `src/engine/optimum/optimum_rr.py`, which uses `AutoTokenizer` + an `OVModelForCausalLM` forward pass per (query, document) pair — so the HF tokenizer files in the output dir are what gets loaded at runtime. The `openvino_tokenizer.xml` produced above is unused by this engine today but is standard for OV IR packages. From 408146781c56e8635cfb8dc2c94ef51e1722217b Mon Sep 17 00:00:00 2001 From: Kevin Read Date: Fri, 24 Apr 2026 21:41:17 +0200 Subject: [PATCH 2/2] feat(embed): dispatch pooling mode from sentence-transformers config Optimum_EMB previously always applied last_token_pool to the encoder output, which is correct for Qwen3-Embedding but wrong for encoder-style sentence-transformers models (CLS pooling) or mean-pooled ones. Load now inspects 1_Pooling/config.json and picks cls / mean / last accordingly, defaulting to last when the file is absent so existing Qwen3-Embedding deployments keep their behavior. runtime_config may set "pool_mode" to pin the choice explicitly and protect against upgrade regressions on models whose shipped ST config would otherwise change pooling: "runtime_config": {"pool_mode": "last"} Unknown values raise ValueError on load rather than silently falling through to last-token. Tests: 14 unit tests cover each pool fn, the auto-detect ladder, the runtime-config override and its validation. Two integration tests (added to the existing bge-m3-local-path pattern) load a real bge-m3 IR and verify the CLS auto-detect + override behavior end-to-end. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/engine/optimum/optimum_emb.py | 66 +++++++- src/tests/test_optimum_emb_integration.py | 74 +++++++++ src/tests/test_optimum_emb_unit.py | 185 ++++++++++++++++++++-- 3 files changed, 306 insertions(+), 19 deletions(-) diff --git a/src/engine/optimum/optimum_emb.py b/src/engine/optimum/optimum_emb.py index 2d0cbd3..c3324c8 100755 --- a/src/engine/optimum/optimum_emb.py +++ b/src/engine/optimum/optimum_emb.py @@ -2,7 +2,9 @@ import asyncio import gc +import json import logging +from pathlib import Path from typing import Any, AsyncIterator, Dict, List, Union import torch @@ -22,13 +24,44 @@ +_VALID_POOL_MODES = ("cls", "mean", "last") + + class Optimum_EMB: - + def __init__(self, load_config: ModelLoadConfig): self.model_path = None self.encoder_tokenizer = None self.load_config = load_config - + self.pool_mode = "last" + + @staticmethod + def _detect_pool_mode(model_path: str) -> str: + pool_cfg = Path(model_path) / "1_Pooling" / "config.json" + if not pool_cfg.is_file(): + return "last" + try: + cfg = json.loads(pool_cfg.read_text()) + except (OSError, json.JSONDecodeError): + return "last" + if cfg.get("pooling_mode_cls_token"): + return "cls" + if cfg.get("pooling_mode_mean_tokens"): + return "mean" + return "last" + + @staticmethod + def _resolve_pool_mode(loader: ModelLoadConfig) -> str: + override = (loader.runtime_config or {}).get("pool_mode") + if override is not None: + if override not in _VALID_POOL_MODES: + raise ValueError( + f"Unknown pool_mode {override!r} for model {loader.model_name!r}; " + f"expected one of {_VALID_POOL_MODES}" + ) + return override + return Optimum_EMB._detect_pool_mode(loader.model_path) + def last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] if left_padding: @@ -38,8 +71,26 @@ def last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] + @staticmethod + def cls_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: + return last_hidden_states[:, 0] + + @staticmethod + def mean_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: + mask = attention_mask.unsqueeze(-1).to(last_hidden_states.dtype) + summed = (last_hidden_states * mask).sum(dim=1) + counts = mask.sum(dim=1).clamp(min=1e-9) + return summed / counts + + def pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: + if self.pool_mode == "cls": + return self.cls_pool(last_hidden_states, attention_mask) + if self.pool_mode == "mean": + return self.mean_pool(last_hidden_states, attention_mask) + return self.last_token_pool(last_hidden_states, attention_mask) + async def generate_embeddings(self, tok_config: PreTrainedTokenizerConfig) -> AsyncIterator[Union[Dict[str, Any], str]]: - + # Tokenize the input texts batch_dict = self.tokenizer( text=tok_config.text, @@ -65,7 +116,7 @@ async def generate_embeddings(self, tok_config: PreTrainedTokenizerConfig) -> As ) batch_dict.to(self.model.device) outputs = self.model(**batch_dict) - embeddings = self.last_token_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) + embeddings = self.pool(outputs.last_hidden_state, batch_dict["attention_mask"]) # normalize embeddings if tok_config.return_tensors=="pt": embeddings = F.normalize(embeddings, p=2, dim=1) @@ -81,12 +132,13 @@ def load_model(self, loader: ModelLoadConfig): loader: ModelLoadConfig containing model_path, device, engine, and runtime_config. """ - self.model = OVModelForFeatureExtraction.from_pretrained(loader.model_path, - device=loader.device, + self.model = OVModelForFeatureExtraction.from_pretrained(loader.model_path, + device=loader.device, export=False) self.tokenizer = AutoTokenizer.from_pretrained(loader.model_path) - logging.info(f"Model loaded successfully: {loader.model_name}") + self.pool_mode = self._resolve_pool_mode(loader) + logging.info(f"Model loaded successfully: {loader.model_name} (pool_mode={self.pool_mode})") async def unload_model(self, registry: ModelRegistry, model_name: str) -> bool: """Unregister model from registry and free memory resources. diff --git a/src/tests/test_optimum_emb_integration.py b/src/tests/test_optimum_emb_integration.py index 6602d3d..aa16a82 100644 --- a/src/tests/test_optimum_emb_integration.py +++ b/src/tests/test_optimum_emb_integration.py @@ -11,6 +11,7 @@ MODEL_PATH = Path("/mnt/Ironwolf-4TB/Models/Pytorch/Qwen/Qwen3-Embed-0.6B-INT8-ASYM-ov") +BGE_M3_PATH = Path("/data/openvino-models/embeddings/bge-m3") UNIT_TEST_PATH = Path(__file__).with_name("test_optimum_emb_unit.py") _UNIT_TESTS_PASSED: bool | None = None @@ -82,3 +83,76 @@ async def _run(): finally: asyncio.run(emb.unload_model(_DummyRegistry(), load_config.model_name)) + +def test_bge_m3_cls_pool_cpu_integration() -> None: + """Load the converted bge-m3 OV IR and verify CLS pooling is auto-selected. + + The sentence-transformers metadata shipped with bge-m3 declares CLS pooling, + so loading without a runtime_config override should pick it up automatically + and emit a unit-normed 1024-dim vector. + """ + _ensure_unit_tests_pass() + if not BGE_M3_PATH.exists(): + pytest.skip(f"bge-m3 OV IR not found at {BGE_M3_PATH}") + + load_config = ModelLoadConfig( + model_path=str(BGE_M3_PATH), + model_name="integration-bge-m3", + model_type=ModelType.EMB, + engine=EngineType.OV_OPTIMUM, + device="CPU", + runtime_config={}, + ) + + emb = Optimum_EMB(load_config) + emb.load_model(load_config) + try: + assert emb.pool_mode == "cls", f"expected cls auto-detect, got {emb.pool_mode!r}" + + tok_config = PreTrainedTokenizerConfig( + text=["What is the capital of France?"], + padding="longest", + truncation=True, + max_length=64, + return_tensors="pt", + ) + + async def _run(): + vectors = [] + async for item in emb.generate_embeddings(tok_config): + vectors.append(item) + return vectors + + outputs = asyncio.run(_run()) + assert len(outputs) == 1 + vec = outputs[0] + assert isinstance(vec, list) and len(vec) == 1 + assert len(vec[0]) == 1024 + import math + norm = math.sqrt(sum(x * x for x in vec[0])) + assert abs(norm - 1.0) < 1e-3, f"expected unit-normed, got ||v||={norm}" + finally: + asyncio.run(emb.unload_model(_DummyRegistry(), load_config.model_name)) + + +def test_bge_m3_runtime_config_override_forces_last_pool() -> None: + """runtime_config `pool_mode` override beats the shipped sentence-transformers metadata.""" + _ensure_unit_tests_pass() + if not BGE_M3_PATH.exists(): + pytest.skip(f"bge-m3 OV IR not found at {BGE_M3_PATH}") + + load_config = ModelLoadConfig( + model_path=str(BGE_M3_PATH), + model_name="integration-bge-m3-override", + model_type=ModelType.EMB, + engine=EngineType.OV_OPTIMUM, + device="CPU", + runtime_config={"pool_mode": "last"}, + ) + + emb = Optimum_EMB(load_config) + emb.load_model(load_config) + try: + assert emb.pool_mode == "last" + finally: + asyncio.run(emb.unload_model(_DummyRegistry(), load_config.model_name)) diff --git a/src/tests/test_optimum_emb_unit.py b/src/tests/test_optimum_emb_unit.py index 2823002..d7cd52d 100644 --- a/src/tests/test_optimum_emb_unit.py +++ b/src/tests/test_optimum_emb_unit.py @@ -1,4 +1,6 @@ import asyncio +import json +from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock @@ -40,31 +42,87 @@ def test_last_token_pool_uses_sequence_length(load_config: ModelLoadConfig) -> N assert torch.equal(pooled, expected) -def test_generate_embeddings_returns_normalized_vectors(load_config: ModelLoadConfig) -> None: +def test_cls_pool_returns_first_token() -> None: + states = torch.tensor([ + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], + ]) + attention = torch.tensor([[1, 1, 0], [1, 1, 1]]) + + pooled = Optimum_EMB.cls_pool(states, attention) + + expected = torch.tensor([[1.0, 2.0], [7.0, 8.0]]) + assert torch.equal(pooled, expected) + + +def test_mean_pool_ignores_padding() -> None: + states = torch.tensor([ + [[2.0, 4.0], [4.0, 8.0], [100.0, 200.0]], # last token is padding + ]) + attention = torch.tensor([[1, 1, 0]]) + + pooled = Optimum_EMB.mean_pool(states, attention) + + expected = torch.tensor([[3.0, 6.0]]) + assert torch.allclose(pooled, expected) + + +def test_detect_pool_mode_defaults_to_last(tmp_path: Path) -> None: + # No 1_Pooling/ dir => last-token (Qwen3-Embedding behavior preserved) + assert Optimum_EMB._detect_pool_mode(str(tmp_path)) == "last" + + +@pytest.mark.parametrize( + "cfg, expected", + [ + ({"pooling_mode_cls_token": True}, "cls"), + ({"pooling_mode_mean_tokens": True}, "mean"), + ({"pooling_mode_cls_token": False, "pooling_mode_mean_tokens": False}, "last"), + ], +) +def test_detect_pool_mode_reads_sentence_transformers_config( + tmp_path: Path, cfg: dict, expected: str +) -> None: + pool_dir = tmp_path / "1_Pooling" + pool_dir.mkdir() + (pool_dir / "config.json").write_text(json.dumps(cfg)) + + assert Optimum_EMB._detect_pool_mode(str(tmp_path)) == expected + + +def _make_emb_with_dummy_pipeline( + load_config: ModelLoadConfig, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + pool_mode: str = "last", +): emb = Optimum_EMB(load_config) + emb.pool_mode = pool_mode class DummyBatch(dict): def to(self, device): # noqa: D401 - mimic HuggingFace BatchEncoding self["moved_to"] = device return self - attention_mask = torch.tensor([[1, 1, 1]]) - batch = DummyBatch({"input_ids": torch.tensor([[1, 2, 3]]), "attention_mask": attention_mask}) - - tokenizer_mock = MagicMock(return_value=batch) - emb.tokenizer = tokenizer_mock - - outputs = torch.tensor([ - [[0.5, 0.5], [0.0, 1.0], [1.0, 0.0]], - ]) + batch = DummyBatch({"input_ids": torch.zeros_like(attention_mask), "attention_mask": attention_mask}) + emb.tokenizer = MagicMock(return_value=batch) class DummyModel: device = "cpu" def __call__(self, **kwargs): - return SimpleNamespace(last_hidden_state=outputs) + return SimpleNamespace(last_hidden_state=hidden_states) emb.model = DummyModel() + return emb + + +def test_generate_embeddings_returns_normalized_vectors(load_config: ModelLoadConfig) -> None: + outputs = torch.tensor([ + [[0.5, 0.5], [0.0, 1.0], [1.0, 0.0]], + ]) + attention_mask = torch.tensor([[1, 1, 1]]) + emb = _make_emb_with_dummy_pipeline(load_config, outputs, attention_mask, pool_mode="last") tok_config = PreTrainedTokenizerConfig( text=["hello world"], @@ -85,7 +143,33 @@ async def _run(): assert len(vectors) == 1 vec = torch.tensor(vectors[0]) assert torch.allclose(torch.norm(vec, dim=1), torch.ones(1), atol=1e-6) - tokenizer_mock.assert_called_once() + + +def test_generate_embeddings_uses_cls_pool_when_configured(load_config: ModelLoadConfig) -> None: + # First token is the CLS; make it obviously distinct from the others. + outputs = torch.tensor([ + [[3.0, 4.0], [100.0, 100.0], [-100.0, -100.0]], + ]) + attention_mask = torch.tensor([[1, 1, 1]]) + emb = _make_emb_with_dummy_pipeline(load_config, outputs, attention_mask, pool_mode="cls") + + tok_config = PreTrainedTokenizerConfig( + text=["hello world"], + padding="longest", + truncation=True, + max_length=16, + return_tensors="pt", + ) + + async def _run(): + collected = [] + async for item in emb.generate_embeddings(tok_config): + collected.append(item) + return collected + + vectors = asyncio.run(_run()) + # [3, 4] normalized = [0.6, 0.8] + assert torch.allclose(torch.tensor(vectors[0]), torch.tensor([[0.6, 0.8]]), atol=1e-6) def test_load_model_initializes_pipeline(monkeypatch: pytest.MonkeyPatch, load_config: ModelLoadConfig) -> None: @@ -110,6 +194,83 @@ def test_load_model_initializes_pipeline(monkeypatch: pytest.MonkeyPatch, load_c emb_module.AutoTokenizer.from_pretrained.assert_called_once_with(load_config.model_path) assert emb.model is model_instance assert emb.tokenizer is tokenizer_instance + # Nonexistent path => default pooling preserved. + assert emb.pool_mode == "last" + + +def test_load_model_picks_up_cls_pool_from_sentence_transformers_config( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + pool_dir = tmp_path / "1_Pooling" + pool_dir.mkdir() + (pool_dir / "config.json").write_text(json.dumps({"pooling_mode_cls_token": True})) + + cfg = ModelLoadConfig( + model_path=str(tmp_path), + model_name="bge-m3-like", + model_type=ModelType.EMB, + engine=EngineType.OV_OPTIMUM, + device="CPU", + runtime_config={}, + ) + emb = Optimum_EMB(cfg) + monkeypatch.setattr( + emb_module.OVModelForFeatureExtraction, "from_pretrained", MagicMock(return_value=MagicMock()) + ) + monkeypatch.setattr(emb_module.AutoTokenizer, "from_pretrained", MagicMock(return_value=MagicMock())) + + emb.load_model(cfg) + + assert emb.pool_mode == "cls" + + +def test_runtime_config_pool_mode_override_beats_autodetect( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + # Model ships sentence-transformers metadata saying CLS... + pool_dir = tmp_path / "1_Pooling" + pool_dir.mkdir() + (pool_dir / "config.json").write_text(json.dumps({"pooling_mode_cls_token": True})) + + # ...but the operator pins last-token pooling via runtime_config. + cfg = ModelLoadConfig( + model_path=str(tmp_path), + model_name="override-wins", + model_type=ModelType.EMB, + engine=EngineType.OV_OPTIMUM, + device="CPU", + runtime_config={"pool_mode": "last"}, + ) + emb = Optimum_EMB(cfg) + monkeypatch.setattr( + emb_module.OVModelForFeatureExtraction, "from_pretrained", MagicMock(return_value=MagicMock()) + ) + monkeypatch.setattr(emb_module.AutoTokenizer, "from_pretrained", MagicMock(return_value=MagicMock())) + + emb.load_model(cfg) + + assert emb.pool_mode == "last" + + +def test_load_model_rejects_unknown_pool_mode_override( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + cfg = ModelLoadConfig( + model_path=str(tmp_path), + model_name="typo", + model_type=ModelType.EMB, + engine=EngineType.OV_OPTIMUM, + device="CPU", + runtime_config={"pool_mode": "clas"}, # typo of "cls" + ) + emb = Optimum_EMB(cfg) + monkeypatch.setattr( + emb_module.OVModelForFeatureExtraction, "from_pretrained", MagicMock(return_value=MagicMock()) + ) + monkeypatch.setattr(emb_module.AutoTokenizer, "from_pretrained", MagicMock(return_value=MagicMock())) + + with pytest.raises(ValueError, match="pool_mode"): + emb.load_model(cfg) def test_unload_model_resets_state(monkeypatch: pytest.MonkeyPatch, load_config: ModelLoadConfig) -> None: