diff --git a/blockrun_llm/__init__.py b/blockrun_llm/__init__.py index ebf37be..2e63490 100644 --- a/blockrun_llm/__init__.py +++ b/blockrun_llm/__init__.py @@ -82,6 +82,8 @@ ChatCompletionChunk, ChatChunkChoice, ChatChunkDelta, + ChatChunkToolCall, + ChatChunkFunctionCall, Model, APIError, PaymentError, @@ -203,6 +205,8 @@ "ChatCompletionChunk", "ChatChunkChoice", "ChatChunkDelta", + "ChatChunkToolCall", + "ChatChunkFunctionCall", "Model", "APIError", "PaymentError", diff --git a/blockrun_llm/client.py b/blockrun_llm/client.py index dee5957..1392e5a 100644 --- a/blockrun_llm/client.py +++ b/blockrun_llm/client.py @@ -55,6 +55,10 @@ SmartChatResponse, RoutingProfile, SearchResult, + stream_choice_content, + stream_choice_finish_reason, + chunk_meta, + chunk_usage_dict, ) from .router import route as route_request from .tx_log import TransactionLogger, decode_settlement_header, _resolve_log_dir @@ -881,16 +885,21 @@ def _iter_and_archive( for chunk in self._iter_sse_chunks(response): if chunk.choices: choice = chunk.choices[0] - if choice.delta.content: - content_parts.append(choice.delta.content) - if choice.finish_reason: - finish_reason = choice.finish_reason - if assembled_id is None and chunk.id: - assembled_id = chunk.id - assembled_model = chunk.model - assembled_created = chunk.created - if chunk.usage is not None: - usage_dict = chunk.usage.model_dump(exclude_none=True) + content = stream_choice_content(choice) + if content: + content_parts.append(content) + fr = stream_choice_finish_reason(choice) + if fr: + finish_reason = fr + if assembled_id is None: + _id, _model, _created = chunk_meta(chunk) + if _id: + assembled_id = _id + assembled_model = _model + assembled_created = _created + _usage = chunk_usage_dict(chunk) + if _usage is not None: + usage_dict = _usage yield chunk # Stream complete (saw [DONE]). Free models have cost_usd == 0; only @@ -2544,16 +2553,21 @@ async def _aiter_and_archive( async for chunk in self._aiter_sse_chunks(response): if chunk.choices: choice = chunk.choices[0] - if choice.delta.content: - content_parts.append(choice.delta.content) - if choice.finish_reason: - finish_reason = choice.finish_reason - if assembled_id is None and chunk.id: - assembled_id = chunk.id - assembled_model = chunk.model - assembled_created = chunk.created - if chunk.usage is not None: - usage_dict = chunk.usage.model_dump(exclude_none=True) + content = stream_choice_content(choice) + if content: + content_parts.append(content) + fr = stream_choice_finish_reason(choice) + if fr: + finish_reason = fr + if assembled_id is None: + _id, _model, _created = chunk_meta(chunk) + if _id: + assembled_id = _id + assembled_model = _model + assembled_created = _created + _usage = chunk_usage_dict(chunk) + if _usage is not None: + usage_dict = _usage yield chunk if cost_usd > 0: diff --git a/blockrun_llm/solana_client.py b/blockrun_llm/solana_client.py index b6153cf..cd53c13 100644 --- a/blockrun_llm/solana_client.py +++ b/blockrun_llm/solana_client.py @@ -33,6 +33,10 @@ APIError, PaymentError, SearchResult, + stream_choice_content, + stream_choice_finish_reason, + chunk_meta, + chunk_usage_dict, ) from .solana_wallet import get_solana_public_key from .tx_log import TransactionLogger, decode_settlement_header, _resolve_log_dir @@ -813,16 +817,21 @@ def _iter_and_archive( for chunk in self._iter_sse_chunks(response): if chunk.choices: choice = chunk.choices[0] - if choice.delta.content: - content_parts.append(choice.delta.content) - if choice.finish_reason: - finish_reason = choice.finish_reason - if assembled_id is None and chunk.id: - assembled_id = chunk.id - assembled_model = chunk.model - assembled_created = chunk.created - if chunk.usage is not None: - usage_dict = chunk.usage.model_dump(exclude_none=True) + content = stream_choice_content(choice) + if content: + content_parts.append(content) + fr = stream_choice_finish_reason(choice) + if fr: + finish_reason = fr + if assembled_id is None: + _id, _model, _created = chunk_meta(chunk) + if _id: + assembled_id = _id + assembled_model = _model + assembled_created = _created + _usage = chunk_usage_dict(chunk) + if _usage is not None: + usage_dict = _usage yield chunk if cost_usd > 0: @@ -2255,16 +2264,21 @@ async def _aiter_and_archive( async for chunk in self._aiter_sse_chunks(response): if chunk.choices: choice = chunk.choices[0] - if choice.delta.content: - content_parts.append(choice.delta.content) - if choice.finish_reason: - finish_reason = choice.finish_reason - if assembled_id is None and chunk.id: - assembled_id = chunk.id - assembled_model = chunk.model - assembled_created = chunk.created - if chunk.usage is not None: - usage_dict = chunk.usage.model_dump(exclude_none=True) + content = stream_choice_content(choice) + if content: + content_parts.append(content) + fr = stream_choice_finish_reason(choice) + if fr: + finish_reason = fr + if assembled_id is None: + _id, _model, _created = chunk_meta(chunk) + if _id: + assembled_id = _id + assembled_model = _model + assembled_created = _created + _usage = chunk_usage_dict(chunk) + if _usage is not None: + usage_dict = _usage yield chunk if cost_usd > 0: diff --git a/blockrun_llm/types.py b/blockrun_llm/types.py index 5c7ebaf..b8eaa7e 100644 --- a/blockrun_llm/types.py +++ b/blockrun_llm/types.py @@ -121,6 +121,44 @@ class Config: # --------------------------------------------------------------------------- +class ChatChunkFunctionCall(BaseModel): + """Streaming function-call delta. The model sends ``name`` on the first + frame and ``arguments`` in fragments afterwards, so both are optional here — + unlike the non-stream :class:`FunctionCall` where both are required.""" + + name: Optional[str] = None + arguments: Optional[str] = None + + class Config: + extra = "allow" + + +class ChatChunkToolCall(BaseModel): + """One streaming tool-call delta. + + OpenAI streams tool calls incrementally: the first frame carries + ``index`` + ``id`` + ``function.name`` (+ empty args), later frames carry + only ``index`` + ``function.arguments`` fragments. Every field is therefore + optional. The strict non-stream :class:`ToolCall` (``id`` / ``function.name`` + / ``arguments`` all required) rejected the argument-fragment frames, which + made ``ChatCompletionChunk(**chunk)`` raise and fall back to + ``model_construct`` — leaving ``choices`` as raw dicts and crashing the + archive loop with ``'dict' object has no attribute 'delta'``. Using this + lenient type keeps streamed tool calls parsing into real objects. + """ + + index: Optional[int] = None + id: Optional[str] = None + # Kept as a free-form ``str`` (not ``Literal["function"]``) so an upstream + # that streams a non-"function" tool type can't fail validation and re-trigger + # the very ``model_construct`` fallback this lenient type exists to avoid. + type: Optional[str] = None + function: Optional[ChatChunkFunctionCall] = None + + class Config: + extra = "allow" + + class ChatChunkDelta(BaseModel): """Incremental ``message`` delta sent over SSE. @@ -132,7 +170,7 @@ class ChatChunkDelta(BaseModel): role: Optional[Literal["system", "user", "assistant", "tool"]] = None content: Optional[str] = None - tool_calls: Optional[List[ToolCall]] = None + tool_calls: Optional[List[ChatChunkToolCall]] = None reasoning_content: Optional[str] = None thinking: Optional[str] = None @@ -168,6 +206,57 @@ class Config: extra = "allow" +def stream_choice_content(choice: Any) -> Optional[str]: + """Text delta from a streaming choice, tolerant of a raw ``dict`` choice. + + A chunk that fails strict validation falls back to ``model_construct``, + which leaves nested ``choices`` as plain dicts. Defensive accessors keep the + stream-archiving loop from crashing on those (``'dict' object has no + attribute 'delta'``); a tool-call frame simply has no content and yields + ``None``. + """ + if isinstance(choice, dict): + delta = choice.get("delta") + return delta.get("content") if isinstance(delta, dict) else None + delta = getattr(choice, "delta", None) + return getattr(delta, "content", None) if delta is not None else None + + +def stream_choice_finish_reason(choice: Any) -> Optional[str]: + """``finish_reason`` from a streaming choice, tolerant of a raw dict choice.""" + if isinstance(choice, dict): + return choice.get("finish_reason") + return getattr(choice, "finish_reason", None) + + +def chunk_meta(chunk: Any) -> "tuple[Optional[str], Optional[str], Optional[int]]": + """``(id, model, created)`` of a chunk, tolerant of a ``model_construct``'d + chunk that omits required fields. + + ``model_construct`` does not populate missing required fields, so a drifted + frame that lost its top-level ``id`` yields a chunk object with no ``id`` + attribute. Reading ``chunk.id`` directly would then raise ``AttributeError`` + and crash the stream-archiving loop — the same failure class the other + accessors here guard against. ``getattr`` keeps those reads safe. + """ + return ( + getattr(chunk, "id", None), + getattr(chunk, "model", None), + getattr(chunk, "created", None), + ) + + +def chunk_usage_dict(chunk: Any) -> Optional[Dict[str, Any]]: + """``usage`` of a chunk as a dict, tolerant of a model_construct'd chunk + whose ``usage`` is a raw dict (no ``.model_dump``).""" + usage = getattr(chunk, "usage", None) + if usage is None: + return None + if isinstance(usage, dict): + return {k: v for k, v in usage.items() if v is not None} + return usage.model_dump(exclude_none=True) + + class Model(BaseModel): """Available model information.""" diff --git a/blockrun_llm/validation.py b/blockrun_llm/validation.py index 43fd909..4693b86 100644 --- a/blockrun_llm/validation.py +++ b/blockrun_llm/validation.py @@ -10,9 +10,12 @@ """ import re -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, TYPE_CHECKING from urllib.parse import urlparse +if TYPE_CHECKING: + from .types import PaymentError + # Localhost domains that are allowed to use HTTP LOCALHOST_DOMAINS = {"localhost", "127.0.0.1"} diff --git a/examples/benchmark_claude.py b/examples/benchmark_claude.py index c25c828..5b0bcba 100644 --- a/examples/benchmark_claude.py +++ b/examples/benchmark_claude.py @@ -85,8 +85,8 @@ def _count_tokens(text: str, model_hint: str = "") -> int: @dataclass class ReqResult: ok: bool - ttft: Optional[float] = None # seconds to first content token - latency: Optional[float] = None # seconds request → last token + ttft: Optional[float] = None # seconds to first content token + latency: Optional[float] = None # seconds request → last token out_tokens: int = 0 error: str = "" @@ -180,7 +180,9 @@ def cache_probe(self) -> float: usage = getattr(resp, "usage", None) if usage is None: return 0.0 - u: Dict[str, Any] = usage.model_dump(exclude_none=True) if hasattr(usage, "model_dump") else dict(usage) + u: Dict[str, Any] = ( + usage.model_dump(exclude_none=True) if hasattr(usage, "model_dump") else dict(usage) + ) prompt_tokens = u.get("prompt_tokens") or 0 cache_read = u.get("cache_read_input_tokens") or 0 cache_creation = u.get("cache_creation_input_tokens") or 0 @@ -212,8 +214,10 @@ def fmt(x: float) -> str: print("\n" + "=" * 56) print(f" Claude E2E benchmark — {self.model} ({self.chain})") print(f" {self.api_url}") - print(f" requests={self.requests} concurrency={self.concurrency} " - f"max_tokens={self.max_tokens}") + print( + f" requests={self.requests} concurrency={self.concurrency} " + f"max_tokens={self.max_tokens}" + ) print("=" * 56) rows = [ ("单个请求吞吐 (token/s)", fmt(statistics.mean(per_req_tps)) if per_req_tps else "nan"), @@ -232,7 +236,9 @@ def fmt(x: float) -> str: for name, val in rows: print(f" {name:<34} {val}") print("-" * 56) - print(f" 样本: 成功 {len(ok)}/{self.requests} 总输出≈{total_out} tokens wall={wall:.2f}s") + print( + f" 样本: 成功 {len(ok)}/{self.requests} 总输出≈{total_out} tokens wall={wall:.2f}s" + ) fails = [r for r in self.results if not r.ok] if fails: print(f" 失败 {len(fails)} 例,示例: {fails[0].error}") @@ -249,15 +255,23 @@ def main() -> None: p.add_argument("--max-tokens", type=int, default=256) p.add_argument("--prompt", default=DEFAULT_PROMPT) p.add_argument("--private-key", default=None, help="wallet key (else env / ~/.blockrun)") - p.add_argument("--cache-probe", action="store_true", - help="add 2 non-streaming calls to measure cache hit rate (extra spend)") + p.add_argument( + "--cache-probe", + action="store_true", + help="add 2 non-streaming calls to measure cache hit rate (extra spend)", + ) args = p.parse_args() api_url = args.api_url or (SOLANA_API_URL if args.chain == "solana" else BASE_API_URL) bench = Bench( - chain=args.chain, model=args.model, api_url=api_url, - requests=args.requests, concurrency=args.concurrency, - prompt=args.prompt, max_tokens=args.max_tokens, private_key=args.private_key, + chain=args.chain, + model=args.model, + api_url=api_url, + requests=args.requests, + concurrency=args.concurrency, + prompt=args.prompt, + max_tokens=args.max_tokens, + private_key=args.private_key, ) print(f"[benchmark] {args.requests} paid streaming requests → {api_url} ({args.model}) …") wall = bench.run_throughput_phase() diff --git a/tests/unit/test_image_poll.py b/tests/unit/test_image_poll.py index 503fb88..5d2c74f 100644 --- a/tests/unit/test_image_poll.py +++ b/tests/unit/test_image_poll.py @@ -13,7 +13,6 @@ from __future__ import annotations -import json from typing import List import httpx @@ -240,7 +239,6 @@ def test_image_poll_surfaces_upstream_failure(monkeypatch: pytest.MonkeyPatch) - monkeypatch.setattr(ImageClient, "IMAGE_POLL_INTERVAL_SECONDS", 0.0) def handler(request: httpx.Request) -> httpx.Response: - path = request.url.path if request.method == "POST": if "PAYMENT-SIGNATURE" not in request.headers: return _payment_required_402(request) diff --git a/tests/unit/test_payment_error_helper.py b/tests/unit/test_payment_error_helper.py index 68f23c8..db1f75d 100644 --- a/tests/unit/test_payment_error_helper.py +++ b/tests/unit/test_payment_error_helper.py @@ -10,7 +10,6 @@ from typing import Any, Dict -import pytest from blockrun_llm.types import PaymentError from blockrun_llm.validation import build_payment_rejected_error @@ -34,7 +33,10 @@ def test_payment_error_carries_status_and_response(self) -> None: exc = PaymentError( "Payment rejected by gateway: transaction_simulation_failed", status_code=402, - response={"message": "Payment settlement failed", "details": "transaction_simulation_failed"}, + response={ + "message": "Payment settlement failed", + "details": "transaction_simulation_failed", + }, ) assert exc.status_code == 402 assert exc.response is not None diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py index 345387e..c722111 100644 --- a/tests/unit/test_streaming.py +++ b/tests/unit/test_streaming.py @@ -18,7 +18,7 @@ from __future__ import annotations import json -from typing import Iterator, List +from typing import List import httpx import pytest @@ -33,39 +33,49 @@ # Synthetic SSE bodies # --------------------------------------------------------------------------- + def _sse_events(deltas: List[str], finish: str = "stop", model: str = "test/model") -> bytes: """Render a list of content deltas as raw SSE bytes ending with [DONE].""" lines: List[str] = [] # First chunk — role only. lines.append( - "data: " + json.dumps({ - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1700000000, - "model": model, - "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], - }) + "data: " + + json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion.chunk", + "created": 1700000000, + "model": model, + "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], + } + ) ) # Content chunks. for i, d in enumerate(deltas): lines.append( - "data: " + json.dumps({ + "data: " + + json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion.chunk", + "created": 1700000000, + "model": model, + "choices": [{"index": 0, "delta": {"content": d}, "finish_reason": None}], + } + ) + ) + # Final chunk with finish_reason. + lines.append( + "data: " + + json.dumps( + { "id": "chatcmpl-test", "object": "chat.completion.chunk", "created": 1700000000, "model": model, - "choices": [{"index": 0, "delta": {"content": d}, "finish_reason": None}], - }) + "choices": [{"index": 0, "delta": {}, "finish_reason": finish}], + } ) - # Final chunk with finish_reason. - lines.append( - "data: " + json.dumps({ - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1700000000, - "model": model, - "choices": [{"index": 0, "delta": {}, "finish_reason": finish}], - }) ) lines.append("data: [DONE]") body = "\n\n".join(lines) + "\n\n" @@ -87,6 +97,7 @@ def _sse_with_garbage(deltas: List[str]) -> bytes: # Mock transports # --------------------------------------------------------------------------- + def _make_free_model_transport(sse_body: bytes, calls: List[httpx.Request]) -> httpx.MockTransport: def handler(request: httpx.Request) -> httpx.Response: calls.append(request) @@ -99,9 +110,7 @@ def handler(request: httpx.Request) -> httpx.Response: return httpx.MockTransport(handler) -def _make_paid_model_transport( - sse_body: bytes, calls: List[httpx.Request] -) -> httpx.MockTransport: +def _make_paid_model_transport(sse_body: bytes, calls: List[httpx.Request]) -> httpx.MockTransport: """First call → 402 with valid payment-required header; second → 200 SSE.""" def handler(request: httpx.Request) -> httpx.Response: @@ -128,6 +137,7 @@ def handler(request: httpx.Request) -> httpx.Response: # Sync tests # --------------------------------------------------------------------------- + class TestSyncStreaming: def test_free_model_streams_without_payment(self): calls: List[httpx.Request] = [] @@ -180,9 +190,10 @@ def test_paid_model_signs_and_retries(self): assert client._session_calls == 1 assert client._session_total_usd > 0 # Streamed content arrives. - assert "".join( - c.choices[0].delta.content for c in chunks if c.choices[0].delta.content - ) == "Paid" + assert ( + "".join(c.choices[0].delta.content for c in chunks if c.choices[0].delta.content) + == "Paid" + ) def test_malformed_chunks_dont_abort_stream(self): calls: List[httpx.Request] = [] @@ -198,9 +209,7 @@ def test_malformed_chunks_dont_abort_stream(self): ) ) # We should have gotten both deltas through, despite the garbage chunk. - joined = "".join( - c.choices[0].delta.content for c in chunks if c.choices[0].delta.content - ) + joined = "".join(c.choices[0].delta.content for c in chunks if c.choices[0].delta.content) assert joined == "AB" def test_paid_path_propagates_payment_rejected(self): @@ -236,6 +245,7 @@ def handler(request: httpx.Request) -> httpx.Response: # Async tests # --------------------------------------------------------------------------- + class TestAsyncStreaming: @pytest.mark.asyncio async def test_async_free_model(self): @@ -255,9 +265,10 @@ async def test_async_free_model(self): chunks.append(chunk) assert len(calls) == 1 - assert "".join( - c.choices[0].delta.content for c in chunks if c.choices[0].delta.content - ) == "Hi!" + assert ( + "".join(c.choices[0].delta.content for c in chunks if c.choices[0].delta.content) + == "Hi!" + ) await client.close() @pytest.mark.asyncio @@ -286,6 +297,7 @@ async def test_async_paid_model_signs_and_retries(self): # 5xx retry tests # --------------------------------------------------------------------------- + def _make_flaky_free_transport( sse_body: bytes, fail_count: int, @@ -301,9 +313,7 @@ def handler(request: httpx.Request) -> httpx.Response: return httpx.Response( status, headers={"content-type": "application/json"}, json={"error": "transient"} ) - return httpx.Response( - 200, headers={"content-type": "text/event-stream"}, content=sse_body - ) + return httpx.Response(200, headers={"content-type": "text/event-stream"}, content=sse_body) return httpx.MockTransport(handler) @@ -350,10 +360,12 @@ def handler(request: httpx.Request) -> httpx.Response: from blockrun_llm.types import APIError with pytest.raises(APIError): - list(client.chat_completion_stream( - "nvidia/deepseek-v4-flash", - [{"role": "user", "content": "hi"}], - )) + list( + client.chat_completion_stream( + "nvidia/deepseek-v4-flash", + [{"role": "user", "content": "hi"}], + ) + ) # 1 + 3 backoffs == 4 probe attempts before raising. assert len(calls) == 1 + len(LLMClient._STREAM_5XX_BACKOFFS) @@ -382,17 +394,17 @@ def handler(request: httpx.Request) -> httpx.Response: paid_calls = sum(1 for c in calls if c.headers.get("PAYMENT-SIGNATURE")) if paid_calls <= 2: return httpx.Response(503, json={"error": "transient"}) - return httpx.Response( - 200, headers={"content-type": "text/event-stream"}, content=body - ) + return httpx.Response(200, headers={"content-type": "text/event-stream"}, content=body) client = LLMClient(private_key=TEST_PRIVATE_KEY) client._client = httpx.Client(transport=httpx.MockTransport(handler)) - chunks = list(client.chat_completion_stream( - "openai/gpt-5.5", - [{"role": "user", "content": "hi"}], - )) + chunks = list( + client.chat_completion_stream( + "openai/gpt-5.5", + [{"role": "user", "content": "hi"}], + ) + ) # 1 probe (402) + 2 paid-503 + 1 paid-200 == 4 total assert len(calls) == 4 assert any(c.choices[0].delta.content == "paid-OK" for c in chunks) @@ -402,6 +414,7 @@ def handler(request: httpx.Request) -> httpx.Response: # Fallback chain tests # --------------------------------------------------------------------------- + class TestStreamingFallback: """``fallback_models`` walks the chain only on retriable pre-stream errors. Once a chunk is yielded, the upstream is committed.""" @@ -415,6 +428,7 @@ def handler(request: httpx.Request) -> httpx.Response: calls.append(request) body = request.read() import json as _json + payload = _json.loads(body) if payload["model"] == "primary/bad": return httpx.Response(503, json={"error": "down"}) @@ -428,11 +442,13 @@ def handler(request: httpx.Request) -> httpx.Response: client = LLMClient(private_key=TEST_PRIVATE_KEY) client._client = httpx.Client(transport=httpx.MockTransport(handler)) - chunks = list(client.chat_completion_stream( - "primary/bad", - [{"role": "user", "content": "hi"}], - fallback_models=["fallback/good"], - )) + chunks = list( + client.chat_completion_stream( + "primary/bad", + [{"role": "user", "content": "hi"}], + fallback_models=["fallback/good"], + ) + ) # 4 hits on primary (1 + 3 retries) all 503 → swap to fallback → 1 success assert len(calls) >= 5 @@ -471,11 +487,13 @@ def handler(request: httpx.Request) -> httpx.Response: # naturally — no exception, no fallback. The fallback handler should # NEVER be invoked because we got valid chunks before the stream # ended. - chunks = list(client.chat_completion_stream( - "primary/bad", - [{"role": "user", "content": "hi"}], - fallback_models=["fallback/good"], - )) + chunks = list( + client.chat_completion_stream( + "primary/bad", + [{"role": "user", "content": "hi"}], + fallback_models=["fallback/good"], + ) + ) # Exactly one upstream call: no fallback because partial chunks were # already yielded. assert len(calls) == 1 @@ -499,10 +517,236 @@ def handler(request: httpx.Request) -> httpx.Response: from blockrun_llm.types import APIError with pytest.raises(APIError): - list(client.chat_completion_stream( - "primary/bad", - [{"role": "user", "content": "hi"}], - fallback_models=["fallback/good"], - )) + list( + client.chat_completion_stream( + "primary/bad", + [{"role": "user", "content": "hi"}], + fallback_models=["fallback/good"], + ) + ) # Single attempt; no retries (400 isn't 5xx), no fallback (400 isn't retriable). assert len(calls) == 1 + + +# --------------------------------------------------------------------------- +# Streamed tool calls — regression for the archive-loop crash +# --------------------------------------------------------------------------- + + +def _sse_with_tool_call(model: str = "anthropic/claude-haiku-4-5") -> bytes: + """SSE for a streamed tool call: role frame, a name frame, then argument- + fragment frames (id/name absent — these used to fail the strict ToolCall + schema), and a final finish=tool_calls frame with usage.""" + frames = [ + {"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}]}, + { + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "call_1", + "type": "function", + "function": {"name": "get_weather", "arguments": ""}, + } + ] + }, + "finish_reason": None, + } + ] + }, + { + "choices": [ + { + "index": 0, + "delta": {"tool_calls": [{"index": 0, "function": {"arguments": '{"city":'}}]}, + "finish_reason": None, + } + ] + }, + { + "choices": [ + { + "index": 0, + "delta": {"tool_calls": [{"index": 0, "function": {"arguments": '"Paris"}'}}]}, + "finish_reason": None, + } + ] + }, + { + "choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + }, + ] + lines = [] + for f in frames: + f = { + "id": "chatcmpl-tc", + "object": "chat.completion.chunk", + "created": 1700000000, + "model": model, + **f, + } + lines.append("data: " + json.dumps(f)) + lines.append("data: [DONE]") + return ("\n\n".join(lines) + "\n\n").encode("utf-8") + + +def _collect_tool_args(chunks: List[ChatCompletionChunk]) -> str: + out: List[str] = [] + for c in chunks: + if not c.choices: + continue + for tc in c.choices[0].delta.tool_calls or []: + if tc.function and tc.function.arguments: + out.append(tc.function.arguments) + return "".join(out) + + +class TestStreamedToolCalls: + """Streamed tool calls must parse + archive without crashing. + + The argument-fragment frames (id/name absent) used to fail the strict + ToolCall schema, fall back to model_construct (leaving choices as dicts), + then crash the archive loop with "'dict' object has no attribute 'delta'". + The PAID path is used so cost_usd > 0 and the archive loop actually runs. + """ + + def test_sync_streamed_tool_call(self): + calls: List[httpx.Request] = [] + client = LLMClient(private_key=TEST_PRIVATE_KEY) + client._client = httpx.Client( + transport=_make_paid_model_transport(_sse_with_tool_call(), calls) + ) + chunks = list( + client.chat_completion_stream( + "openai/gpt-5.5", + [{"role": "user", "content": "weather?"}], + max_tokens=64, + ) + ) + tool_frames = [c for c in chunks if c.choices and c.choices[0].delta.tool_calls] + assert tool_frames, "expected streamed tool_call deltas" + for c in tool_frames: + assert hasattr(c.choices[0], "delta") # parsed object, not a raw dict + assert _collect_tool_args(chunks) == '{"city":"Paris"}' + finishes = [ + c.choices[0].finish_reason for c in chunks if c.choices and c.choices[0].finish_reason + ] + assert finishes == ["tool_calls"] + + @pytest.mark.asyncio + async def test_async_streamed_tool_call(self): + calls: List[httpx.Request] = [] + client = AsyncLLMClient(private_key=TEST_PRIVATE_KEY) + await client._client.aclose() + client._client = httpx.AsyncClient( + transport=_make_paid_model_transport(_sse_with_tool_call(), calls) + ) + chunks: List[ChatCompletionChunk] = [] + async for chunk in client.chat_completion_stream( + "openai/gpt-5.5", + [{"role": "user", "content": "weather?"}], + ): + chunks.append(chunk) + assert _collect_tool_args(chunks) == '{"city":"Paris"}' + await client.close() + + def test_sync_streamed_tool_call_non_function_type(self): + """A non-"function" tool ``type`` must still parse into a real object + rather than re-trigger the strict-validation -> model_construct fallback + (which would leave choices as raw dicts and crash consumers).""" + frames = [ + { + "id": "chatcmpl-tc", + "object": "chat.completion.chunk", + "created": 1700000000, + "model": "anthropic/claude-haiku-4-5", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "call_1", + "type": "custom", # non-"function" type + "function": { + "name": "get_weather", + "arguments": '{"city":"Paris"}', + }, + } + ] + }, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-tc", + "object": "chat.completion.chunk", + "created": 1700000000, + "model": "anthropic/claude-haiku-4-5", + "choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + }, + ] + sse = ( + "\n\n".join("data: " + json.dumps(f) for f in frames) + "\n\ndata: [DONE]\n\n" + ).encode() + calls: List[httpx.Request] = [] + client = LLMClient(private_key=TEST_PRIVATE_KEY) + client._client = httpx.Client(transport=_make_paid_model_transport(sse, calls)) + chunks = list( + client.chat_completion_stream( + "openai/gpt-5.5", + [{"role": "user", "content": "weather?"}], + max_tokens=64, + ) + ) + tool_frames = [c for c in chunks if c.choices and c.choices[0].delta.tool_calls] + assert tool_frames, "expected the non-'function' tool_call delta to parse" + assert tool_frames[0].choices[0].delta.tool_calls[0].type == "custom" + assert _collect_tool_args(chunks) == '{"city":"Paris"}' + + def test_sync_stream_archive_survives_model_construct_chunk_missing_id(self): + """Archive-loop hardening: a frame that omits the required top-level + ``id`` fails strict validation and is yielded via ``model_construct`` + (no ``id`` attribute). The stream-archiving loop must not crash reading + ``chunk.id`` (old behaviour: ``AttributeError``); draining the generator + runs the paid archive path end to end.""" + frames = [ + # Missing "id" -> model_construct, no .id attribute on the chunk. + { + "object": "chat.completion.chunk", + "created": 1700000000, + "model": "anthropic/claude-haiku-4-5", + "choices": [{"index": 0, "delta": {"content": "hi"}, "finish_reason": None}], + }, + { + "id": "chatcmpl-tc", + "object": "chat.completion.chunk", + "created": 1700000000, + "model": "anthropic/claude-haiku-4-5", + "choices": [{"index": 0, "delta": {"content": " there"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + }, + ] + sse = ( + "\n\n".join("data: " + json.dumps(f) for f in frames) + "\n\ndata: [DONE]\n\n" + ).encode() + calls: List[httpx.Request] = [] + client = LLMClient(private_key=TEST_PRIVATE_KEY) + client._client = httpx.Client(transport=_make_paid_model_transport(sse, calls)) + # Must not raise: the archive loop reads chunk.id via the dict/attr-tolerant + # accessor, so a model_construct'd chunk missing id is skipped, not fatal. + chunks = list( + client.chat_completion_stream( + "openai/gpt-5.5", + [{"role": "user", "content": "hi"}], + max_tokens=64, + ) + ) + assert len(chunks) == 2 # both frames yielded, stream drained cleanly diff --git a/tests/unit/test_streaming_solana.py b/tests/unit/test_streaming_solana.py index afafc73..dafc22f 100644 --- a/tests/unit/test_streaming_solana.py +++ b/tests/unit/test_streaming_solana.py @@ -20,7 +20,7 @@ pytest.importorskip("x402") pytest.importorskip("solders") -from blockrun_llm import ChatCompletionChunk, SolanaLLMClient +from blockrun_llm import SolanaLLMClient from blockrun_llm.types import APIError, PaymentError @@ -28,35 +28,45 @@ # Helpers — synthetic SSE bodies (same shape Base tests use) # --------------------------------------------------------------------------- + def _sse_events(deltas: List[str], finish: str = "stop", model: str = "test/model") -> bytes: lines: List[str] = [] lines.append( - "data: " + json.dumps({ - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1700000000, - "model": model, - "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], - }) + "data: " + + json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion.chunk", + "created": 1700000000, + "model": model, + "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], + } + ) ) for d in deltas: lines.append( - "data: " + json.dumps({ + "data: " + + json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion.chunk", + "created": 1700000000, + "model": model, + "choices": [{"index": 0, "delta": {"content": d}, "finish_reason": None}], + } + ) + ) + lines.append( + "data: " + + json.dumps( + { "id": "chatcmpl-test", "object": "chat.completion.chunk", "created": 1700000000, "model": model, - "choices": [{"index": 0, "delta": {"content": d}, "finish_reason": None}], - }) + "choices": [{"index": 0, "delta": {}, "finish_reason": finish}], + } ) - lines.append( - "data: " + json.dumps({ - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "created": 1700000000, - "model": model, - "choices": [{"index": 0, "delta": {}, "finish_reason": finish}], - }) ) lines.append("data: [DONE]") return ("\n\n".join(lines) + "\n\n").encode("utf-8") @@ -74,8 +84,10 @@ def solana_client(): after construction by replacing the x402_client with a fake.""" import unittest.mock as mock - with mock.patch("blockrun_llm.solana_client.register_exact_svm_client"), \ - mock.patch("blockrun_llm.solana_client._create_signer"): + with ( + mock.patch("blockrun_llm.solana_client.register_exact_svm_client"), + mock.patch("blockrun_llm.solana_client._create_signer"), + ): client = SolanaLLMClient( private_key="bogus_not_used_because_signer_is_patched", api_url="https://sol.blockrun.ai/api", @@ -110,17 +122,18 @@ def _patch_sse_helpers(monkeypatch): # Transport builders # --------------------------------------------------------------------------- + def _free_transport(sse_body: bytes, calls: List[httpx.Request]) -> httpx.MockTransport: def handler(request: httpx.Request) -> httpx.Response: calls.append(request) - return httpx.Response( - 200, headers={"content-type": "text/event-stream"}, content=sse_body - ) + return httpx.Response(200, headers={"content-type": "text/event-stream"}, content=sse_body) + return httpx.MockTransport(handler) def _paid_transport(sse_body: bytes, calls: List[httpx.Request]) -> httpx.MockTransport: """First call → 402; second call (with PAYMENT-SIGNATURE) → 200 SSE.""" + def handler(request: httpx.Request) -> httpx.Response: calls.append(request) if "PAYMENT-SIGNATURE" not in request.headers: @@ -132,9 +145,8 @@ def handler(request: httpx.Request) -> httpx.Response: }, json={"error": "Payment Required"}, ) - return httpx.Response( - 200, headers={"content-type": "text/event-stream"}, content=sse_body - ) + return httpx.Response(200, headers={"content-type": "text/event-stream"}, content=sse_body) + return httpx.MockTransport(handler) @@ -145,9 +157,8 @@ def handler(request: httpx.Request) -> httpx.Response: calls.append(request) if len(calls) <= fail_count: return httpx.Response(status, json={"error": "transient"}) - return httpx.Response( - 200, headers={"content-type": "text/event-stream"}, content=sse_body - ) + return httpx.Response(200, headers={"content-type": "text/event-stream"}, content=sse_body) + return httpx.MockTransport(handler) @@ -163,10 +174,12 @@ def test_free_model_streams_directly(self, solana_client, monkeypatch): transport=_free_transport(_sse_events(["Hello", " world"]), calls) ) - chunks = list(solana_client.chat_completion_stream( - "nvidia/deepseek-v4-flash", - [{"role": "user", "content": "hi"}], - )) + chunks = list( + solana_client.chat_completion_stream( + "nvidia/deepseek-v4-flash", + [{"role": "user", "content": "hi"}], + ) + ) assert len(calls) == 1 assert "PAYMENT-SIGNATURE" not in calls[0].headers @@ -180,10 +193,12 @@ def test_paid_model_signs_and_retries(self, solana_client, monkeypatch): transport=_paid_transport(_sse_events(["Paid"]), calls) ) - chunks = list(solana_client.chat_completion_stream( - "openai/gpt-5.5", - [{"role": "user", "content": "hi"}], - )) + chunks = list( + solana_client.chat_completion_stream( + "openai/gpt-5.5", + [{"role": "user", "content": "hi"}], + ) + ) # 1 probe (402) + 1 paid (200) == 2 total assert len(calls) == 2 assert "PAYMENT-SIGNATURE" not in calls[0].headers @@ -200,10 +215,12 @@ def test_retries_5xx_with_backoff(self, solana_client, monkeypatch): transport=_flaky_transport(_sse_events(["OK"]), fail_count=2, calls=calls) ) - chunks = list(solana_client.chat_completion_stream( - "nvidia/deepseek-v4-flash", - [{"role": "user", "content": "hi"}], - )) + chunks = list( + solana_client.chat_completion_stream( + "nvidia/deepseek-v4-flash", + [{"role": "user", "content": "hi"}], + ) + ) # 2 failed + 1 success assert len(calls) == 3 assert any(c.choices[0].delta.content == "OK" for c in chunks) @@ -219,10 +236,12 @@ def handler(request: httpx.Request) -> httpx.Response: solana_client._client = httpx.Client(transport=httpx.MockTransport(handler)) with pytest.raises(APIError): - list(solana_client.chat_completion_stream( - "nvidia/deepseek-v4-flash", - [{"role": "user", "content": "hi"}], - )) + list( + solana_client.chat_completion_stream( + "nvidia/deepseek-v4-flash", + [{"role": "user", "content": "hi"}], + ) + ) # 1 + 3 backoffs == 4 attempts assert len(calls) == 1 + len(SolanaLLMClient._STREAM_5XX_BACKOFFS) @@ -244,11 +263,13 @@ def handler(request: httpx.Request) -> httpx.Response: solana_client._client = httpx.Client(transport=httpx.MockTransport(handler)) - chunks = list(solana_client.chat_completion_stream( - "primary/bad", - [{"role": "user", "content": "hi"}], - fallback_models=["fallback/good"], - )) + chunks = list( + solana_client.chat_completion_stream( + "primary/bad", + [{"role": "user", "content": "hi"}], + fallback_models=["fallback/good"], + ) + ) # 4 calls to primary/bad all 503, then 1 to fallback/good assert len(calls) >= 5 assert any(c.choices[0].delta.content == "FALLBACK" for c in chunks) @@ -273,7 +294,9 @@ def handler(request: httpx.Request) -> httpx.Response: solana_client._client = httpx.Client(transport=httpx.MockTransport(handler)) with pytest.raises(PaymentError): - list(solana_client.chat_completion_stream( - "openai/gpt-5.5", - [{"role": "user", "content": "hi"}], - )) + list( + solana_client.chat_completion_stream( + "openai/gpt-5.5", + [{"role": "user", "content": "hi"}], + ) + )