diff --git a/.changeset/recover-invalid-tool-input.md b/.changeset/recover-invalid-tool-input.md new file mode 100644 index 0000000000..e48258a9b1 --- /dev/null +++ b/.changeset/recover-invalid-tool-input.md @@ -0,0 +1,5 @@ +--- +'@workflow/ai': patch +--- + +DurableAgent now recovers from invalid tool-call input by returning the validation error to the model instead of aborting the stream. diff --git a/packages/ai/src/agent/durable-agent.test.ts b/packages/ai/src/agent/durable-agent.test.ts index f78346ade2..5e7040c920 100644 --- a/packages/ai/src/agent/durable-agent.test.ts +++ b/packages/ai/src/agent/durable-agent.test.ts @@ -2006,6 +2006,160 @@ describe('DurableAgent', () => { }); }); + it('should convert invalid tool input to error-text result instead of failing stream', async () => { + const tools: ToolSet = { + strictTool: { + description: 'A tool with a strict input schema', + inputSchema: z.object({ requiredField: z.string().min(1) }), + execute: async () => ({ ok: true }), + }, + }; + + const mockModel = createMockModel(); + + const agent = new DurableAgent({ + model: async () => mockModel, + tools, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const mockMessages: LanguageModelV3Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'test' }] }, + ]; + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [ + { + toolCallId: 'test-call-id', + toolName: 'strictTool', + // Valid JSON, but violates the schema (empty string fails .min(1)). + input: '{"requiredField":""}', + } as LanguageModelV3ToolCall, + ], + messages: mockMessages, + }, + }) + .mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as MockIterator + ); + + // Invalid tool input should be handled gracefully, not reject the stream. + await expect( + agent.stream({ + messages: [{ role: 'user', content: 'test' }], + writable: mockWritable, + }) + ).resolves.not.toThrow(); + + // Verify the validation error was sent back as an error-text tool result + // (so the model can correct its arguments and retry). + expect(mockIterator.next).toHaveBeenCalledTimes(2); + const toolResultsCall = mockIterator.next.mock.calls[1][0]; + expect(toolResultsCall).toBeDefined(); + expect(toolResultsCall).toHaveLength(1); + expect(toolResultsCall[0]).toMatchObject({ + type: 'tool-result', + toolCallId: 'test-call-id', + toolName: 'strictTool', + output: { + type: 'error-text', + }, + }); + expect(toolResultsCall[0].output.value).toContain( + 'Invalid input for tool "strictTool"' + ); + }); + + it('should recover from invalid tool input and execute the corrected retry', async () => { + const execute = vi.fn(async () => ({ ok: true })); + const tools: ToolSet = { + strictTool: { + description: 'A tool with a strict input schema', + inputSchema: z.object({ requiredField: z.string().min(1) }), + execute, + }, + }; + + const mockModel = createMockModel(); + const agent = new DurableAgent({ model: async () => mockModel, tools }); + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + const mockMessages: LanguageModelV3Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'test' }] }, + ]; + + const makeToolCall = (input: string): LanguageModelV3ToolCall => ({ + toolCallId: 'test-call-id', + toolName: 'strictTool', + input, + }); + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi + .fn() + // Step 1: model emits invalid args (empty string fails .min(1)). + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [makeToolCall('{"requiredField":""}')], + messages: mockMessages, + }, + }) + // Step 2: model corrects the args after seeing the error-text result. + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [makeToolCall('{"requiredField":"ok"}')], + messages: mockMessages, + }, + }) + .mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as MockIterator + ); + + await expect( + agent.stream({ + messages: [{ role: 'user', content: 'test' }], + writable: mockWritable, + }) + ).resolves.not.toThrow(); + + // The tool must NOT run on the invalid call, and MUST run exactly once with + // the corrected input — proving the agent productively recovers, not just + // that the error was fed back. + expect(execute).toHaveBeenCalledTimes(1); + expect(execute.mock.calls[0][0]).toEqual({ requiredField: 'ok' }); + + // First turn fed back an error-text result; second turn produced a success. + expect(mockIterator.next).toHaveBeenCalledTimes(3); + const firstResults = mockIterator.next.mock.calls[1][0]; + expect(firstResults[0].output.type).toBe('error-text'); + const secondResults = mockIterator.next.mock.calls[2][0]; + expect(secondResults[0]).toMatchObject({ + type: 'tool-result', + toolCallId: 'test-call-id', + toolName: 'strictTool', + output: { type: 'json', value: { ok: true } }, + }); + }); + it('should call onFinish with steps and messages when streaming completes', async () => { const mockModel = createMockModel(); diff --git a/packages/ai/src/agent/durable-agent.ts b/packages/ai/src/agent/durable-agent.ts index a4f74ad08b..5b9dc33f69 100644 --- a/packages/ai/src/agent/durable-agent.ts +++ b/packages/ai/src/agent/durable-agent.ts @@ -1742,7 +1742,45 @@ async function executeTool( ); } } - throw parseError; + // Input that fails to parse or validate (even after repair) is recoverable, + // exactly like a tool execution error below: feed the error back to the model + // as an error-text result so the agent can correct the call and retry, instead + // of aborting the entire stream. This aligns with AI SDK's streamText behavior + // for tool failures. Reaches here both for malformed JSON and for the + // re-thrown "Invalid input for tool ..." schema-validation error above. + // + // This path intentionally does not reach `onError` (it no longer throws), + // matching the tool-execution-error path below. Emit an `ai.toolCall` span + // recording the failure so the recovered error stays observable in traces. + const parseErrorMessage = getErrorMessage(parseError); + return recordSpan({ + name: 'ai.toolCall', + telemetry, + attributes: { + 'ai.toolCall.name': toolCall.toolName, + 'ai.toolCall.id': toolCall.toolCallId, + ...(telemetry?.recordOutputs !== false && { + 'ai.toolCall.args': toolCall.input, + }), + }, + fn: (span) => { + if (span) { + // 2 === OTel SpanStatusCode.ERROR (inlined to avoid a hard dependency + // on the optional @opentelemetry/api package). + span.setStatus({ code: 2, message: parseErrorMessage }); + span.setAttributes({ 'ai.toolCall.error': parseErrorMessage }); + } + return { + type: 'tool-result' as const, + toolCallId: toolCall.toolCallId, + toolName: toolCall.toolName, + output: { + type: 'error-text' as const, + value: parseErrorMessage, + }, + }; + }, + }); } return recordSpan({