Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/recover-invalid-tool-input.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@workflow/ai': patch
---

DurableAgent now recovers from invalid tool-call input by returning the validation error to the model instead of aborting the stream.
154 changes: 154 additions & 0 deletions packages/ai/src/agent/durable-agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2006,6 +2006,160 @@ describe('DurableAgent', () => {
});
});

it('should convert invalid tool input to error-text result instead of failing stream', async () => {
const tools: ToolSet = {
strictTool: {
description: 'A tool with a strict input schema',
inputSchema: z.object({ requiredField: z.string().min(1) }),
execute: async () => ({ ok: true }),
},
};

const mockModel = createMockModel();

const agent = new DurableAgent({
model: async () => mockModel,
tools,
});

const mockWritable = new WritableStream({
write: vi.fn(),
close: vi.fn(),
});

const mockMessages: LanguageModelV3Prompt = [
{ role: 'user', content: [{ type: 'text', text: 'test' }] },
];

const { streamTextIterator } = await import('./stream-text-iterator.js');
const mockIterator = {
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: {
toolCalls: [
{
toolCallId: 'test-call-id',
toolName: 'strictTool',
// Valid JSON, but violates the schema (empty string fails .min(1)).
input: '{"requiredField":""}',
} as LanguageModelV3ToolCall,
],
messages: mockMessages,
},
})
.mockResolvedValueOnce({ done: true, value: [] }),
};
vi.mocked(streamTextIterator).mockReturnValue(
mockIterator as unknown as MockIterator
);

// Invalid tool input should be handled gracefully, not reject the stream.
await expect(
agent.stream({
messages: [{ role: 'user', content: 'test' }],
writable: mockWritable,
})
).resolves.not.toThrow();

// Verify the validation error was sent back as an error-text tool result
// (so the model can correct its arguments and retry).
expect(mockIterator.next).toHaveBeenCalledTimes(2);
const toolResultsCall = mockIterator.next.mock.calls[1][0];
expect(toolResultsCall).toBeDefined();
expect(toolResultsCall).toHaveLength(1);
expect(toolResultsCall[0]).toMatchObject({
type: 'tool-result',
toolCallId: 'test-call-id',
toolName: 'strictTool',
output: {
type: 'error-text',
},
});
expect(toolResultsCall[0].output.value).toContain(
'Invalid input for tool "strictTool"'
);
});

it('should recover from invalid tool input and execute the corrected retry', async () => {
const execute = vi.fn(async () => ({ ok: true }));
const tools: ToolSet = {
strictTool: {
description: 'A tool with a strict input schema',
inputSchema: z.object({ requiredField: z.string().min(1) }),
execute,
},
};

const mockModel = createMockModel();
const agent = new DurableAgent({ model: async () => mockModel, tools });
const mockWritable = new WritableStream({
write: vi.fn(),
close: vi.fn(),
});
const mockMessages: LanguageModelV3Prompt = [
{ role: 'user', content: [{ type: 'text', text: 'test' }] },
];

const makeToolCall = (input: string): LanguageModelV3ToolCall => ({
toolCallId: 'test-call-id',
toolName: 'strictTool',
input,
});

const { streamTextIterator } = await import('./stream-text-iterator.js');
const mockIterator = {
next: vi
.fn()
// Step 1: model emits invalid args (empty string fails .min(1)).
.mockResolvedValueOnce({
done: false,
value: {
toolCalls: [makeToolCall('{"requiredField":""}')],
messages: mockMessages,
},
})
// Step 2: model corrects the args after seeing the error-text result.
.mockResolvedValueOnce({
done: false,
value: {
toolCalls: [makeToolCall('{"requiredField":"ok"}')],
messages: mockMessages,
},
})
.mockResolvedValueOnce({ done: true, value: [] }),
};
vi.mocked(streamTextIterator).mockReturnValue(
mockIterator as unknown as MockIterator
);

await expect(
agent.stream({
messages: [{ role: 'user', content: 'test' }],
writable: mockWritable,
})
).resolves.not.toThrow();

// The tool must NOT run on the invalid call, and MUST run exactly once with
// the corrected input — proving the agent productively recovers, not just
// that the error was fed back.
expect(execute).toHaveBeenCalledTimes(1);
expect(execute.mock.calls[0][0]).toEqual({ requiredField: 'ok' });

// First turn fed back an error-text result; second turn produced a success.
expect(mockIterator.next).toHaveBeenCalledTimes(3);
const firstResults = mockIterator.next.mock.calls[1][0];
expect(firstResults[0].output.type).toBe('error-text');
const secondResults = mockIterator.next.mock.calls[2][0];
expect(secondResults[0]).toMatchObject({
type: 'tool-result',
toolCallId: 'test-call-id',
toolName: 'strictTool',
output: { type: 'json', value: { ok: true } },
});
});

it('should call onFinish with steps and messages when streaming completes', async () => {
const mockModel = createMockModel();

Expand Down
40 changes: 39 additions & 1 deletion packages/ai/src/agent/durable-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,45 @@ async function executeTool(
);
}
}
throw parseError;
// Input that fails to parse or validate (even after repair) is recoverable,
// exactly like a tool execution error below: feed the error back to the model
// as an error-text result so the agent can correct the call and retry, instead
// of aborting the entire stream. This aligns with AI SDK's streamText behavior
// for tool failures. Reaches here both for malformed JSON and for the
// re-thrown "Invalid input for tool ..." schema-validation error above.
//
// This path intentionally does not reach `onError` (it no longer throws),
// matching the tool-execution-error path below. Emit an `ai.toolCall` span
// recording the failure so the recovered error stays observable in traces.
const parseErrorMessage = getErrorMessage(parseError);
return recordSpan({
name: 'ai.toolCall',
telemetry,
attributes: {
'ai.toolCall.name': toolCall.toolName,
'ai.toolCall.id': toolCall.toolCallId,
...(telemetry?.recordOutputs !== false && {
'ai.toolCall.args': toolCall.input,
}),
},
fn: (span) => {
if (span) {
// 2 === OTel SpanStatusCode.ERROR (inlined to avoid a hard dependency
// on the optional @opentelemetry/api package).
span.setStatus({ code: 2, message: parseErrorMessage });
span.setAttributes({ 'ai.toolCall.error': parseErrorMessage });
}
return {
type: 'tool-result' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
output: {
type: 'error-text' as const,
value: parseErrorMessage,
},
};
},
});
}

return recordSpan({
Expand Down
Loading