From 7fc0678153567badc889001f659a59af60a47cfa Mon Sep 17 00:00:00 2001 From: pragnyanramtha Date: Sat, 16 May 2026 21:14:57 +0000 Subject: [PATCH] fix: support legacy search query bodies --- pinecone/_internal/data_plane_helpers.py | 83 ++++++++++++++++++- pinecone/async_client/async_index.py | 64 ++++++-------- pinecone/grpc/__init__.py | 70 ++++++++-------- pinecone/index/__init__.py | 64 ++++++-------- tests/unit/test_async_search.py | 30 +++++++ tests/unit/test_async_search_records_alias.py | 2 + tests/unit/test_grpc_index.py | 15 ++++ tests/unit/test_index_search.py | 28 +++++++ tests/unit/test_search_records_alias.py | 2 + 9 files changed, 247 insertions(+), 111 deletions(-) diff --git a/pinecone/_internal/data_plane_helpers.py b/pinecone/_internal/data_plane_helpers.py index 427c5b9fb..39f12afad 100644 --- a/pinecone/_internal/data_plane_helpers.py +++ b/pinecone/_internal/data_plane_helpers.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any from pinecone._internal.config import normalize_host @@ -60,6 +60,87 @@ def _normalize_search_vector_dict(vector: Mapping[str, Any]) -> dict[str, Any]: return result +def _legacy_search_query_to_dict(query: Any) -> dict[str, Any]: + if hasattr(query, "to_dict") and callable(query.to_dict): + raw = query.to_dict() + elif hasattr(query, "as_dict") and callable(query.as_dict): + raw = query.as_dict() + else: + raw = dict(query) + return dict(raw) + + +def _build_search_records_body( + *, + top_k: int | None, + inputs: Mapping[str, Any] | None, + vector: Sequence[float] | Mapping[str, Any] | None, + id: str | None, + filter: Mapping[str, Any] | None, + fields: Sequence[str] | None, + rerank: Mapping[str, Any] | None, + match_terms: Mapping[str, Any] | None, + query: Any | None, + wrap_dense_vector: bool = True, +) -> dict[str, Any]: + if rerank is not None: + if "model" not in rerank: + raise ValidationError("rerank requires 'model' to be specified") + if "rank_fields" not in rerank: + raise ValidationError("rerank requires 'rank_fields' to be specified") + + if query is not None: + if any(value is not None for value in (top_k, inputs, vector, id, filter, match_terms)): + raise ValidationError( + "query cannot be combined with top_k, inputs, vector, id, filter, or match_terms" + ) + query_body = _legacy_search_query_to_dict(query) + if "vector" in query_body and query_body["vector"] is not None: + query_vector = query_body["vector"] + if isinstance(query_vector, Mapping): + query_body["vector"] = _normalize_search_vector_dict(query_vector) + else: + values = list(query_vector) + query_body["vector"] = {"values": values} if wrap_dense_vector else values + else: + if top_k is None: + raise ValidationError("top_k is required unless query is provided") + query_body = {"top_k": top_k} + if inputs is not None: + query_body["inputs"] = inputs + if vector is not None: + if isinstance(vector, Mapping): + query_body["vector"] = _normalize_search_vector_dict(vector) + else: + values = list(vector) + query_body["vector"] = {"values": values} if wrap_dense_vector else values + if id is not None: + query_body["id"] = id + if filter is not None: + query_body["filter"] = filter + if match_terms is not None: + query_body["match_terms"] = match_terms + + top_k_value = query_body.get("top_k") + if not isinstance(top_k_value, int) or top_k_value < 1: + raise ValidationError(f"top_k must be a positive integer, got {top_k_value}") + if ( + query_body.get("inputs") is None + and query_body.get("vector") is None + and query_body.get("id") is None + ): + raise ValidationError( + "At least one of inputs, vector, or id must be provided as a query source" + ) + + body: dict[str, Any] = {"query": query_body} + if fields is not None: + body["fields"] = fields + if rerank is not None: + body["rerank"] = rerank + return body + + def _vector_to_dict(v: Vector) -> dict[str, Any]: """Serialize a Vector to a dict matching the API wire format.""" id_ = v.id diff --git a/pinecone/async_client/async_index.py b/pinecone/async_client/async_index.py index a19db0d6c..f0e2bab08 100644 --- a/pinecone/async_client/async_index.py +++ b/pinecone/async_client/async_index.py @@ -18,7 +18,7 @@ from pinecone._internal.config import PineconeConfig from pinecone._internal.constants import DATA_PLANE_API_VERSION from pinecone._internal.data_plane_helpers import ( - _normalize_search_vector_dict, + _build_search_records_body, _validate_host, _vector_to_dict, ) @@ -40,7 +40,12 @@ UpsertRecordsResponse, UpsertResponse, ) -from pinecone.models.vectors.search import RerankConfig, SearchInputs, SearchRecordsResponse +from pinecone.models.vectors.search import ( + RerankConfig, + SearchInputs, + SearchQuery, + SearchRecordsResponse, +) from pinecone.models.vectors.sparse import SparseValues from pinecone.models.vectors.vector import Vector @@ -926,7 +931,7 @@ async def search( self, *, namespace: str, - top_k: int, + top_k: int | None = None, inputs: SearchInputs | Mapping[str, Any] | None = None, vector: Sequence[float] | Mapping[str, Any] | None = None, id: str | None = None, @@ -934,6 +939,7 @@ async def search( fields: Sequence[str] | None = None, rerank: RerankConfig | Mapping[str, Any] | None = None, match_terms: Mapping[str, Any] | None = None, + query: SearchQuery | Mapping[str, Any] | None = None, timeout: float | None = None, ) -> SearchRecordsResponse: """Search records by text, vector, or ID with optional reranking. @@ -967,6 +973,9 @@ async def search( ``"all"``) and ``"terms"`` (list of strings). Only supported for sparse indexes using ``pinecone-sparse-english-v0``. ``None`` disables term matching. + query (dict[str, Any] | None): Legacy query body containing + ``top_k`` plus one of ``inputs``, ``vector``, or ``id``. Prefer + passing these fields directly. Returns: :class:`SearchRecordsResponse` with hits and usage statistics. @@ -995,40 +1004,19 @@ async def search( raise ValidationError("namespace must be a string") if not namespace or not namespace.strip(): raise ValidationError("namespace must be a non-empty string") - if top_k < 1: - raise ValidationError(f"top_k must be a positive integer, got {top_k}") - if rerank is not None: - if "model" not in rerank: - raise ValidationError("rerank requires 'model' to be specified") - if "rank_fields" not in rerank: - raise ValidationError("rerank requires 'rank_fields' to be specified") - if inputs is None and vector is None and id is None: - raise ValidationError( - "At least one of inputs, vector, or id must be provided as a query source" - ) - - query_body: dict[str, Any] = {"top_k": top_k} - if inputs is not None: - query_body["inputs"] = inputs - if vector is not None: - if isinstance(vector, Mapping): - query_body["vector"] = _normalize_search_vector_dict(vector) - else: - query_body["vector"] = {"values": list(vector)} - if id is not None: - query_body["id"] = id - if filter is not None: - query_body["filter"] = filter - if match_terms is not None: - query_body["match_terms"] = match_terms - - body: dict[str, Any] = {"query": query_body} - if fields is not None: - body["fields"] = fields - if rerank is not None: - body["rerank"] = rerank + body = _build_search_records_body( + top_k=top_k, + inputs=inputs, + vector=vector, + id=id, + filter=filter, + fields=fields, + rerank=rerank, + match_terms=match_terms, + query=query, + ) - logger.info("Searching namespace %r with top_k=%d", namespace, top_k) + logger.info("Searching namespace %r with top_k=%d", namespace, body["query"]["top_k"]) response = await self._http.post( f"/records/namespaces/{namespace}/search", timeout=timeout, json=body ) @@ -1040,7 +1028,7 @@ async def search_records( self, *, namespace: str, - top_k: int, + top_k: int | None = None, inputs: SearchInputs | Mapping[str, Any] | None = None, vector: Sequence[float] | Mapping[str, Any] | None = None, id: str | None = None, @@ -1048,6 +1036,7 @@ async def search_records( fields: Sequence[str] | None = None, rerank: RerankConfig | Mapping[str, Any] | None = None, match_terms: Mapping[str, Any] | None = None, + query: SearchQuery | Mapping[str, Any] | None = None, timeout: float | None = None, ) -> SearchRecordsResponse: """Alias for :meth:`search`. @@ -1063,6 +1052,7 @@ async def search_records( fields=fields, rerank=rerank, match_terms=match_terms, + query=query, timeout=timeout, ) diff --git a/pinecone/grpc/__init__.py b/pinecone/grpc/__init__.py index b268f7b42..87ece2dc6 100644 --- a/pinecone/grpc/__init__.py +++ b/pinecone/grpc/__init__.py @@ -17,7 +17,7 @@ from pinecone._internal.batching import chunked, validate_batch_size, with_progress from pinecone._internal.config import PineconeConfig from pinecone._internal.constants import DATA_PLANE_API_VERSION -from pinecone._internal.data_plane_helpers import _validate_host +from pinecone._internal.data_plane_helpers import _build_search_records_body, _validate_host from pinecone._internal.validation import require_in_range from pinecone._internal.vector_factory import VectorFactory from pinecone.errors.exceptions import ( @@ -45,7 +45,12 @@ UpsertRecordsResponse, UpsertResponse, ) -from pinecone.models.vectors.search import RerankConfig, SearchInputs, SearchRecordsResponse +from pinecone.models.vectors.search import ( + RerankConfig, + SearchInputs, + SearchQuery, + SearchRecordsResponse, +) from pinecone.models.vectors.sparse import SparseValues from pinecone.models.vectors.usage import Usage from pinecone.models.vectors.vector import ScoredVector, Vector @@ -1196,14 +1201,15 @@ def search( self, *, namespace: str, - top_k: int, + top_k: int | None = None, inputs: SearchInputs | Mapping[str, Any] | None = None, - vector: Sequence[float] | None = None, + vector: Sequence[float] | Mapping[str, Any] | None = None, id: str | None = None, filter: Mapping[str, Any] | None = None, fields: Sequence[str] | None = None, rerank: RerankConfig | Mapping[str, Any] | None = None, match_terms: Mapping[str, Any] | None = None, + query: SearchQuery | Mapping[str, Any] | None = None, timeout: float | None = None, ) -> SearchRecordsResponse: """Search records by text, vector, or ID with optional reranking. @@ -1234,6 +1240,9 @@ def search( ``"all"``) and ``"terms"`` (list of strings). Only supported for sparse indexes using ``pinecone-sparse-english-v0``. ``None`` disables term matching. + query (dict[str, Any] | None): Legacy query body containing + ``top_k`` plus one of ``inputs``, ``vector``, or ``id``. Prefer + passing these fields directly. Returns: :class:`SearchRecordsResponse` with hits and usage statistics. @@ -1284,37 +1293,24 @@ def search( raise ValidationError("namespace must be a string") if not namespace or not namespace.strip(): raise ValidationError("namespace must be a non-empty string") - if top_k < 1: - raise ValidationError(f"top_k must be a positive integer, got {top_k}") - if rerank is not None: - if "model" not in rerank: - raise ValidationError("rerank requires 'model' to be specified") - if "rank_fields" not in rerank: - raise ValidationError("rerank requires 'rank_fields' to be specified") - if inputs is None and vector is None and id is None: - raise ValidationError( - "At least one of inputs, vector, or id must be provided as a query source" - ) + body = _build_search_records_body( + top_k=top_k, + inputs=inputs, + vector=vector, + id=id, + filter=filter, + fields=fields, + rerank=rerank, + match_terms=match_terms, + query=query, + wrap_dense_vector=False, + ) - query_body: dict[str, Any] = {"top_k": top_k} - if inputs is not None: - query_body["inputs"] = inputs - if vector is not None: - query_body["vector"] = vector - if id is not None: - query_body["id"] = id - if filter is not None: - query_body["filter"] = filter - if match_terms is not None: - query_body["match_terms"] = match_terms - - body: dict[str, Any] = {"query": query_body} - if fields is not None: - body["fields"] = fields - if rerank is not None: - body["rerank"] = rerank - - logger.info("Searching namespace %r with top_k=%d (via REST)", namespace, top_k) + logger.info( + "Searching namespace %r with top_k=%d (via REST)", + namespace, + body["query"]["top_k"], + ) response = self._http.post( f"/records/namespaces/{namespace}/search", timeout=timeout, json=body ) @@ -1326,14 +1322,15 @@ def search_records( self, *, namespace: str, - top_k: int, + top_k: int | None = None, inputs: SearchInputs | Mapping[str, Any] | None = None, - vector: Sequence[float] | None = None, + vector: Sequence[float] | Mapping[str, Any] | None = None, id: str | None = None, filter: Mapping[str, Any] | None = None, fields: Sequence[str] | None = None, rerank: RerankConfig | Mapping[str, Any] | None = None, match_terms: Mapping[str, Any] | None = None, + query: SearchQuery | Mapping[str, Any] | None = None, timeout: float | None = None, ) -> SearchRecordsResponse: """Alias for :meth:`search`. @@ -1350,6 +1347,7 @@ def search_records( fields=fields, rerank=rerank, match_terms=match_terms, + query=query, timeout=timeout, ) diff --git a/pinecone/index/__init__.py b/pinecone/index/__init__.py index 43abb09a6..f346dee81 100644 --- a/pinecone/index/__init__.py +++ b/pinecone/index/__init__.py @@ -18,7 +18,7 @@ from pinecone._internal.config import PineconeConfig from pinecone._internal.constants import DATA_PLANE_API_VERSION from pinecone._internal.data_plane_helpers import ( - _normalize_search_vector_dict, + _build_search_records_body, _validate_host, _vector_to_dict, ) @@ -40,7 +40,12 @@ UpsertRecordsResponse, UpsertResponse, ) -from pinecone.models.vectors.search import RerankConfig, SearchInputs, SearchRecordsResponse +from pinecone.models.vectors.search import ( + RerankConfig, + SearchInputs, + SearchQuery, + SearchRecordsResponse, +) from pinecone.models.vectors.sparse import SparseValues from pinecone.models.vectors.vector import Vector @@ -1092,7 +1097,7 @@ def search( self, *, namespace: str, - top_k: int, + top_k: int | None = None, inputs: SearchInputs | Mapping[str, Any] | None = None, vector: Sequence[float] | Mapping[str, Any] | None = None, id: str | None = None, @@ -1100,6 +1105,7 @@ def search( fields: Sequence[str] | None = None, rerank: RerankConfig | Mapping[str, Any] | None = None, match_terms: Mapping[str, Any] | None = None, + query: SearchQuery | Mapping[str, Any] | None = None, timeout: float | None = None, ) -> SearchRecordsResponse: """Search records by text, vector, or ID with optional reranking. @@ -1137,6 +1143,9 @@ def search( ``"all"``) and ``"terms"`` (list of strings). Only supported for sparse indexes using ``pinecone-sparse-english-v0``. ``None`` disables term matching. + query (dict[str, Any] | None): Legacy query body containing + ``top_k`` plus one of ``inputs``, ``vector``, or ``id``. Prefer + passing these fields directly. Returns: :class:`SearchRecordsResponse` with hits and usage statistics. @@ -1184,40 +1193,19 @@ def search( raise ValidationError("namespace must be a string") if not namespace or not namespace.strip(): raise ValidationError("namespace must be a non-empty string") - if top_k < 1: - raise ValidationError(f"top_k must be a positive integer, got {top_k}") - if rerank is not None: - if "model" not in rerank: - raise ValidationError("rerank requires 'model' to be specified") - if "rank_fields" not in rerank: - raise ValidationError("rerank requires 'rank_fields' to be specified") - if inputs is None and vector is None and id is None: - raise ValidationError( - "At least one of inputs, vector, or id must be provided as a query source" - ) - - query_body: dict[str, Any] = {"top_k": top_k} - if inputs is not None: - query_body["inputs"] = inputs - if vector is not None: - if isinstance(vector, Mapping): - query_body["vector"] = _normalize_search_vector_dict(vector) - else: - query_body["vector"] = {"values": list(vector)} - if id is not None: - query_body["id"] = id - if filter is not None: - query_body["filter"] = filter - if match_terms is not None: - query_body["match_terms"] = match_terms - - body: dict[str, Any] = {"query": query_body} - if fields is not None: - body["fields"] = fields - if rerank is not None: - body["rerank"] = rerank + body = _build_search_records_body( + top_k=top_k, + inputs=inputs, + vector=vector, + id=id, + filter=filter, + fields=fields, + rerank=rerank, + match_terms=match_terms, + query=query, + ) - logger.info("Searching namespace %r with top_k=%d", namespace, top_k) + logger.info("Searching namespace %r with top_k=%d", namespace, body["query"]["top_k"]) response = self._http.post( f"/records/namespaces/{namespace}/search", timeout=timeout, json=body ) @@ -1229,7 +1217,7 @@ def search_records( self, *, namespace: str, - top_k: int, + top_k: int | None = None, inputs: SearchInputs | Mapping[str, Any] | None = None, vector: Sequence[float] | Mapping[str, Any] | None = None, id: str | None = None, @@ -1237,6 +1225,7 @@ def search_records( fields: Sequence[str] | None = None, rerank: RerankConfig | Mapping[str, Any] | None = None, match_terms: Mapping[str, Any] | None = None, + query: SearchQuery | Mapping[str, Any] | None = None, timeout: float | None = None, ) -> SearchRecordsResponse: """Alias for :meth:`search`. @@ -1253,6 +1242,7 @@ def search_records( fields=fields, rerank=rerank, match_terms=match_terms, + query=query, timeout=timeout, ) diff --git a/tests/unit/test_async_search.py b/tests/unit/test_async_search.py index c4589b893..6010bcdd7 100644 --- a/tests/unit/test_async_search.py +++ b/tests/unit/test_async_search.py @@ -113,6 +113,36 @@ async def test_async_search_with_fields(self) -> None: body = orjson.loads(route.calls.last.request.content) assert body["fields"] == ["chunk_text", "title"] + @respx.mock + @pytest.mark.anyio + async def test_async_search_accepts_legacy_query_body(self) -> None: + route = respx.post(SEARCH_URL_NS).mock( + return_value=httpx.Response(200, json=SEARCH_RESPONSE), + ) + idx = _make_async_index() + query = { + "inputs": {"text": "hello"}, + "top_k": 10, + "filter": {"genre": {"$eq": "sci-fi"}}, + } + await idx.search(namespace="test-ns", query=query, fields=["chunk_text", "title"]) + + import orjson + + body = orjson.loads(route.calls.last.request.content) + assert body["query"] == query + assert body["fields"] == ["chunk_text", "title"] + + @pytest.mark.anyio + async def test_async_search_query_body_cannot_mix_direct_query_params(self) -> None: + idx = _make_async_index() + with pytest.raises(ValidationError, match="query cannot be combined"): + await idx.search( + namespace="test-ns", + query={"inputs": {"text": "hello"}, "top_k": 10}, + top_k=10, + ) + @respx.mock @pytest.mark.anyio async def test_async_search_with_rerank(self) -> None: diff --git a/tests/unit/test_async_search_records_alias.py b/tests/unit/test_async_search_records_alias.py index f316f94fc..0da5203ea 100644 --- a/tests/unit/test_async_search_records_alias.py +++ b/tests/unit/test_async_search_records_alias.py @@ -28,6 +28,7 @@ async def test_search_records_delegates_to_search() -> None: fields=None, rerank=None, match_terms=None, + query=None, timeout=None, ) @@ -60,6 +61,7 @@ async def test_search_records_passes_all_params() -> None: "fields": ["title", "year"], "rerank": {"model": "bge-reranker-v2-m3", "rank_fields": ["text"]}, "match_terms": {"strategy": "all", "terms": ["animal", "duck"]}, + "query": None, "timeout": None, } diff --git a/tests/unit/test_grpc_index.py b/tests/unit/test_grpc_index.py index b6d309388..eb7b121db 100644 --- a/tests/unit/test_grpc_index.py +++ b/tests/unit/test_grpc_index.py @@ -1211,6 +1211,21 @@ def test_search_with_fields_forwarded_at_body_root(self, mock_channel: MagicMock assert body["fields"] == ["text", "title"] assert "fields" not in body["query"] + @respx.mock + def test_search_accepts_legacy_query_body(self, mock_channel: MagicMock) -> None: + import orjson + + route = respx.post(_SEARCH_URL).mock( + return_value=httpx.Response(200, json=_SEARCH_RESPONSE) + ) + idx = _make_grpc_index(mock_channel, host=_INDEX_HOST) + query = {"inputs": {"text": "q"}, "top_k": 5, "filter": {"topic": {"$eq": "ai"}}} + idx.search(namespace="test-ns", query=query, fields=["text", "title"]) + + body = orjson.loads(route.calls.last.request.content) + assert body["query"] == query + assert body["fields"] == ["text", "title"] + @respx.mock def test_search_records_alias(self, mock_channel: MagicMock) -> None: """search_records() is an alias for search() and produces the same result.""" diff --git a/tests/unit/test_index_search.py b/tests/unit/test_index_search.py index f694b9a46..a0eadac4a 100644 --- a/tests/unit/test_index_search.py +++ b/tests/unit/test_index_search.py @@ -108,6 +108,34 @@ def test_search_with_fields(self) -> None: body = orjson.loads(route.calls.last.request.content) assert body["fields"] == ["chunk_text", "title"] + @respx.mock + def test_search_accepts_legacy_query_body(self) -> None: + route = respx.post(SEARCH_URL_NS).mock( + return_value=httpx.Response(200, json=SEARCH_RESPONSE), + ) + idx = _make_index() + query = { + "inputs": {"text": "hello"}, + "top_k": 10, + "filter": {"genre": {"$eq": "sci-fi"}}, + } + idx.search(namespace="test-ns", query=query, fields=["chunk_text", "title"]) + + import orjson + + body = orjson.loads(route.calls.last.request.content) + assert body["query"] == query + assert body["fields"] == ["chunk_text", "title"] + + def test_search_query_body_cannot_mix_direct_query_params(self) -> None: + idx = _make_index() + with pytest.raises(ValidationError, match="query cannot be combined"): + idx.search( + namespace="test-ns", + query={"inputs": {"text": "hello"}, "top_k": 10}, + top_k=10, + ) + @respx.mock def test_search_with_rerank(self) -> None: route = respx.post(SEARCH_URL_NS).mock( diff --git a/tests/unit/test_search_records_alias.py b/tests/unit/test_search_records_alias.py index d521368ca..03aa0bcd7 100644 --- a/tests/unit/test_search_records_alias.py +++ b/tests/unit/test_search_records_alias.py @@ -46,6 +46,7 @@ def test_search_records_delegates_to_search(self) -> None: fields=None, rerank=None, match_terms=None, + query=None, timeout=None, ) @@ -69,6 +70,7 @@ def test_search_records_passes_all_params(self) -> None: "fields": ["title", "genre"], "rerank": {"model": "bge-reranker", "rank_fields": ["text"]}, "match_terms": {"strategy": "all", "terms": ["animal", "duck"]}, + "query": None, "timeout": None, } with patch.object(idx, "search", return_value=expected) as mock_search: