Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions docs/openvino_qwen3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Converting Qwen3-Reranker to OpenVINO IR for OpenArc

Working recipe for converting a Qwen3-Reranker HuggingFace checkpoint to INT8 OpenVINO IR, validated on `Qwen/Qwen3-Reranker-4B` (2026-04-24).

## Why not `optimum-cli`

The obvious approach — `optimum-cli export openvino --weight-format int8` — **silently truncated** the output on every combination we tried (optimum-intel 1.27.0.dev + openvino 2026.1.0.dev, plus the notebook-pinned transformers 4.55.4 / torch 2.9.1 / openvino 2026.1.0 release). NNCF reported 100% weight compression, the process exited 0 with no traceback, but it wrote a **0-byte `openvino_model.xml`** and a ~13 MB stub `openvino_model.bin`. Bypassing the HF cache and using a local source copy did not help.

The same stack via the `optimum.intel` Python API produced a correct 4.03 GB bin + 3.16 MB xml on the first attempt. Use the Python API.

## Prerequisites

OpenArc's `.venv` already has everything needed — `optimum[openvino]`, `openvino`, `nncf`, `transformers`, `torch`. No extra install.

Source model on a writable local path. If you rely on the HF Hub cache, make sure `~/.cache/huggingface/hub/models--<org>--<name>/` is writable by the current user. A root-owned cache (e.g. populated by a container) produces `Permission denied` warnings during `.no_exist/` writes that are *technically* non-fatal for load, but correlate with other breakage — `chown -R $USER:$USER` the cache dirs if you see them.

## Conversion

```python
# convert_qwen3_reranker.py
import os, shutil
from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig
from openvino_tokenizers import convert_tokenizer
import openvino as ov
from transformers import AutoTokenizer

SRC = "Qwen/Qwen3-Reranker-4B" # HF id or local path
OUT = "/data/openvino-models/reranker/qwen3-4b-reranker"

os.makedirs(OUT, exist_ok=True)

# 1. Load + convert + INT8 weight-only compression
qc = OVWeightQuantizationConfig(bits=8, sym=False)
model = OVModelForCausalLM.from_pretrained(
SRC,
export=True,
use_cache=False, # reranker does a single scoring forward pass; no KV cache
quantization_config=qc,
compile=False,
)
model.save_pretrained(OUT) # writes openvino_model.xml/.bin + openvino_config.json

# 2. Copy HF tokenizer artifacts (save_pretrained does NOT)
tok = AutoTokenizer.from_pretrained(SRC, padding_side="left")
tok.save_pretrained(OUT)

# 3. Emit OpenVINO tokenizer + detokenizer (optional but standard)
ov_tok, ov_detok = convert_tokenizer(tok, with_detokenizer=True)
ov.save_model(ov_tok, os.path.join(OUT, "openvino_tokenizer.xml"))
ov.save_model(ov_detok, os.path.join(OUT, "openvino_detokenizer.xml"))
```

Run with `PYTHONUNBUFFERED=1` so progress bars flush. Expected peak RSS is roughly `model_fp_size + NNCF_overhead`, so plan on ≥16 GB RAM + swap headroom for the 4B. The compression phase is CPU-bound and took ~20 s on this host; the forward-pass trace and save together add another minute or two.

Quantization notes:
- `OVWeightQuantizationConfig(bits=8, sym=False)` is weight-only INT8 asymmetric, per-channel — the same thing `--weight-format int8` does with NNCF internally. No calibration dataset needed.
- Do **not** use `--quant-mode int8` / full activation quantization for a reranker without a domain-matched calibration set; accuracy drops quickly.
- For 4-bit, switch to `bits=4` and consider `sym=True`, `group_size=128`, and optionally `dataset="wikitext2"` with `awq=True, scale_estimation=True` for better recovery. INT8 is the right default for rerankers.

## Verification

Quick smoke test — load the exported model and run one forward pass:

```python
import torch
from optimum.intel import OVModelForCausalLM
from transformers import AutoTokenizer

m = OVModelForCausalLM.from_pretrained(OUT, device="CPU", use_cache=False, export=False)
tok = AutoTokenizer.from_pretrained(OUT, padding_side="left")
inp = tok("Query: Paris\nDocument: Paris is the capital of France.\nRelevant:", return_tensors="pt")
with torch.no_grad():
out = m(**inp)
print(out.logits.shape) # expect (1, seq_len, 151669)
```

A vocab dimension of 151669 and a non-empty bin/xml on disk are the two signals that the export is real — not a stub.

## Wiring into OpenArc

Add the model to `openarc_config.json` alongside any existing reranker entry:

```json
"qwen3-4b-reranker": {
"model_name": "qwen3-4b-reranker",
"model_path": "/data/openvino-models/reranker/qwen3-4b-reranker",
"model_type": "rerank",
"engine": "optimum",
"device": "GPU",
"runtime_config": {},
"vlm_type": null
}
```

`engine: optimum` wires it into `src/engine/optimum/optimum_rr.py`, which uses `AutoTokenizer` + an `OVModelForCausalLM` forward pass per (query, document) pair — so the HF tokenizer files in the output dir are what gets loaded at runtime. The `openvino_tokenizer.xml` produced above is unused by this engine today but is standard for OV IR packages.
66 changes: 59 additions & 7 deletions src/engine/optimum/optimum_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import asyncio
import gc
import json
import logging
from pathlib import Path
from typing import Any, AsyncIterator, Dict, List, Union

import torch
Expand All @@ -22,13 +24,44 @@



_VALID_POOL_MODES = ("cls", "mean", "last")


class Optimum_EMB:

def __init__(self, load_config: ModelLoadConfig):
self.model_path = None
self.encoder_tokenizer = None
self.load_config = load_config

self.pool_mode = "last"

@staticmethod
def _detect_pool_mode(model_path: str) -> str:
pool_cfg = Path(model_path) / "1_Pooling" / "config.json"
if not pool_cfg.is_file():
return "last"
try:
cfg = json.loads(pool_cfg.read_text())
except (OSError, json.JSONDecodeError):
return "last"
if cfg.get("pooling_mode_cls_token"):
return "cls"
if cfg.get("pooling_mode_mean_tokens"):
return "mean"
return "last"

@staticmethod
def _resolve_pool_mode(loader: ModelLoadConfig) -> str:
override = (loader.runtime_config or {}).get("pool_mode")
if override is not None:
if override not in _VALID_POOL_MODES:
raise ValueError(
f"Unknown pool_mode {override!r} for model {loader.model_name!r}; "
f"expected one of {_VALID_POOL_MODES}"
)
return override
return Optimum_EMB._detect_pool_mode(loader.model_path)

def last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
if left_padding:
Expand All @@ -38,8 +71,26 @@ def last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) ->
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

@staticmethod
def cls_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
return last_hidden_states[:, 0]

@staticmethod
def mean_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
mask = attention_mask.unsqueeze(-1).to(last_hidden_states.dtype)
summed = (last_hidden_states * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
return summed / counts

def pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
if self.pool_mode == "cls":
return self.cls_pool(last_hidden_states, attention_mask)
if self.pool_mode == "mean":
return self.mean_pool(last_hidden_states, attention_mask)
return self.last_token_pool(last_hidden_states, attention_mask)

async def generate_embeddings(self, tok_config: PreTrainedTokenizerConfig) -> AsyncIterator[Union[Dict[str, Any], str]]:

# Tokenize the input texts
batch_dict = self.tokenizer(
text=tok_config.text,
Expand All @@ -65,7 +116,7 @@ async def generate_embeddings(self, tok_config: PreTrainedTokenizerConfig) -> As
)
batch_dict.to(self.model.device)
outputs = self.model(**batch_dict)
embeddings = self.last_token_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
embeddings = self.pool(outputs.last_hidden_state, batch_dict["attention_mask"])
# normalize embeddings
if tok_config.return_tensors=="pt":
embeddings = F.normalize(embeddings, p=2, dim=1)
Expand All @@ -81,12 +132,13 @@ def load_model(self, loader: ModelLoadConfig):
loader: ModelLoadConfig containing model_path, device, engine, and runtime_config.
"""

self.model = OVModelForFeatureExtraction.from_pretrained(loader.model_path,
device=loader.device,
self.model = OVModelForFeatureExtraction.from_pretrained(loader.model_path,
device=loader.device,
export=False)

self.tokenizer = AutoTokenizer.from_pretrained(loader.model_path)
logging.info(f"Model loaded successfully: {loader.model_name}")
self.pool_mode = self._resolve_pool_mode(loader)
logging.info(f"Model loaded successfully: {loader.model_name} (pool_mode={self.pool_mode})")

async def unload_model(self, registry: ModelRegistry, model_name: str) -> bool:
"""Unregister model from registry and free memory resources.
Expand Down
74 changes: 74 additions & 0 deletions src/tests/test_optimum_emb_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


MODEL_PATH = Path("/mnt/Ironwolf-4TB/Models/Pytorch/Qwen/Qwen3-Embed-0.6B-INT8-ASYM-ov")
BGE_M3_PATH = Path("/data/openvino-models/embeddings/bge-m3")
UNIT_TEST_PATH = Path(__file__).with_name("test_optimum_emb_unit.py")

_UNIT_TESTS_PASSED: bool | None = None
Expand Down Expand Up @@ -82,3 +83,76 @@ async def _run():
finally:
asyncio.run(emb.unload_model(_DummyRegistry(), load_config.model_name))


def test_bge_m3_cls_pool_cpu_integration() -> None:
"""Load the converted bge-m3 OV IR and verify CLS pooling is auto-selected.

The sentence-transformers metadata shipped with bge-m3 declares CLS pooling,
so loading without a runtime_config override should pick it up automatically
and emit a unit-normed 1024-dim vector.
"""
_ensure_unit_tests_pass()
if not BGE_M3_PATH.exists():
pytest.skip(f"bge-m3 OV IR not found at {BGE_M3_PATH}")

load_config = ModelLoadConfig(
model_path=str(BGE_M3_PATH),
model_name="integration-bge-m3",
model_type=ModelType.EMB,
engine=EngineType.OV_OPTIMUM,
device="CPU",
runtime_config={},
)

emb = Optimum_EMB(load_config)
emb.load_model(load_config)
try:
assert emb.pool_mode == "cls", f"expected cls auto-detect, got {emb.pool_mode!r}"

tok_config = PreTrainedTokenizerConfig(
text=["What is the capital of France?"],
padding="longest",
truncation=True,
max_length=64,
return_tensors="pt",
)

async def _run():
vectors = []
async for item in emb.generate_embeddings(tok_config):
vectors.append(item)
return vectors

outputs = asyncio.run(_run())
assert len(outputs) == 1
vec = outputs[0]
assert isinstance(vec, list) and len(vec) == 1
assert len(vec[0]) == 1024
import math
norm = math.sqrt(sum(x * x for x in vec[0]))
assert abs(norm - 1.0) < 1e-3, f"expected unit-normed, got ||v||={norm}"
finally:
asyncio.run(emb.unload_model(_DummyRegistry(), load_config.model_name))


def test_bge_m3_runtime_config_override_forces_last_pool() -> None:
"""runtime_config `pool_mode` override beats the shipped sentence-transformers metadata."""
_ensure_unit_tests_pass()
if not BGE_M3_PATH.exists():
pytest.skip(f"bge-m3 OV IR not found at {BGE_M3_PATH}")

load_config = ModelLoadConfig(
model_path=str(BGE_M3_PATH),
model_name="integration-bge-m3-override",
model_type=ModelType.EMB,
engine=EngineType.OV_OPTIMUM,
device="CPU",
runtime_config={"pool_mode": "last"},
)

emb = Optimum_EMB(load_config)
emb.load_model(load_config)
try:
assert emb.pool_mode == "last"
finally:
asyncio.run(emb.unload_model(_DummyRegistry(), load_config.model_name))
Loading