From 9bc2360d31a195e3af27f279316a1840bafe48cc Mon Sep 17 00:00:00 2001 From: Sisyphus Date: Fri, 26 Dec 2025 03:36:27 +0900 Subject: [PATCH] feat: add two-layer tool call validation system (proactive + reactive) (#249) Co-authored-by: sisyphus-dev-ai --- assets/oh-my-opencode.schema.json | 3 +- src/config/schema.ts | 1 + src/hooks/index.ts | 1 + src/hooks/session-recovery/index.ts | 61 ++++++++++++++ src/hooks/tool-call-validator/constants.ts | 5 ++ src/hooks/tool-call-validator/index.ts | 98 ++++++++++++++++++++++ src/hooks/tool-call-validator/types.ts | 28 +++++++ src/index.ts | 30 +++++-- src/shared/tool-registry.ts | 63 ++++++++++++++ 9 files changed, 282 insertions(+), 8 deletions(-) create mode 100644 src/hooks/tool-call-validator/constants.ts create mode 100644 src/hooks/tool-call-validator/index.ts create mode 100644 src/hooks/tool-call-validator/types.ts create mode 100644 src/shared/tool-registry.ts diff --git a/assets/oh-my-opencode.schema.json b/assets/oh-my-opencode.schema.json index a1948fa..ceb2096 100644 --- a/assets/oh-my-opencode.schema.json +++ b/assets/oh-my-opencode.schema.json @@ -59,7 +59,8 @@ "agent-usage-reminder", "non-interactive-env", "interactive-bash-session", - "empty-message-sanitizer" + "empty-message-sanitizer", + "tool-call-validator" ] } }, diff --git a/src/config/schema.ts b/src/config/schema.ts index a5d428e..6e7f3c5 100644 --- a/src/config/schema.ts +++ b/src/config/schema.ts @@ -64,6 +64,7 @@ export const HookNameSchema = z.enum([ "non-interactive-env", "interactive-bash-session", "empty-message-sanitizer", + "tool-call-validator", ]) export const AgentOverrideConfigSchema = z.object({ diff --git a/src/hooks/index.ts b/src/hooks/index.ts index 9b7bf86..95e819f 100644 --- a/src/hooks/index.ts +++ b/src/hooks/index.ts @@ -21,3 +21,4 @@ export { createKeywordDetectorHook } from "./keyword-detector"; export { createNonInteractiveEnvHook } from "./non-interactive-env"; export { createInteractiveBashSessionHook } from "./interactive-bash-session"; export { createEmptyMessageSanitizerHook } from "./empty-message-sanitizer"; +export { createToolCallValidatorHook } from "./tool-call-validator"; diff --git a/src/hooks/session-recovery/index.ts b/src/hooks/session-recovery/index.ts index 89fb460..ba4e90e 100644 --- a/src/hooks/session-recovery/index.ts +++ b/src/hooks/session-recovery/index.ts @@ -27,6 +27,7 @@ type RecoveryErrorType = | "tool_result_missing" | "thinking_block_order" | "thinking_disabled_violation" + | "tool_not_found" | null interface MessageInfo { @@ -143,6 +144,15 @@ function detectErrorType(error: unknown): RecoveryErrorType { return "thinking_disabled_violation" } + if ( + message.includes("tool") && + (message.includes("not found") || + message.includes("unknown tool") || + message.includes("invalid tool")) + ) { + return "tool_not_found" + } + return null } @@ -243,6 +253,53 @@ async function recoverThinkingDisabledViolation( return anySuccess } +async function recoverToolNotFound( + client: Client, + sessionID: string, + failedAssistantMsg: MessageData, + error: unknown +): Promise { + const errorMsg = getErrorMessage(error) + const toolNameMatch = errorMsg.match(/tool[:\s]+["']?([a-z0-9_-]+)["']?/i) + const toolName = toolNameMatch?.[1] ?? "unknown" + + let parts = failedAssistantMsg.parts || [] + if (parts.length === 0 && failedAssistantMsg.info?.id) { + const storedParts = readParts(failedAssistantMsg.info.id) + parts = storedParts.map((p) => ({ + type: p.type === "tool" ? "tool_use" : p.type, + id: "callID" in p ? (p as { callID?: string }).callID : p.id, + name: "tool" in p ? (p as { tool?: string }).tool : undefined, + })) + } + + const invalidToolUse = parts.find( + (p) => p.type === "tool_use" && "name" in p && p.name === toolName + ) + + if (!invalidToolUse || !("id" in invalidToolUse)) { + return false + } + + const toolResultPart = { + type: "tool_result" as const, + tool_use_id: invalidToolUse.id, + content: `Error: Tool '${toolName}' does not exist. The model attempted to use a tool that is not available. This may indicate the model hallucinated the tool name or the tool was recently removed.`, + } + + try { + await client.session.prompt({ + path: { id: sessionID }, + // @ts-expect-error - SDK types may not include tool_result parts + body: { parts: [toolResultPart] }, + }) + + return true + } catch { + return false + } +} + const PLACEHOLDER_TEXT = "[user interrupted]" async function recoverEmptyContentMessage( @@ -369,11 +426,13 @@ export function createSessionRecoveryHook(ctx: PluginInput, options?: SessionRec tool_result_missing: "Tool Crash Recovery", thinking_block_order: "Thinking Block Recovery", thinking_disabled_violation: "Thinking Strip Recovery", + tool_not_found: "Invalid Tool Recovery", } const toastMessages: Record = { tool_result_missing: "Injecting cancelled tool results...", thinking_block_order: "Fixing message structure...", thinking_disabled_violation: "Stripping thinking blocks...", + tool_not_found: "Handling invalid tool call...", } await ctx.client.tui @@ -405,6 +464,8 @@ export function createSessionRecoveryHook(ctx: PluginInput, options?: SessionRec const resumeConfig = extractResumeConfig(lastUser, sessionID) await resumeSession(ctx.client, resumeConfig) } + } else if (errorType === "tool_not_found") { + success = await recoverToolNotFound(ctx.client, sessionID, failedMsg, info.error) } return success diff --git a/src/hooks/tool-call-validator/constants.ts b/src/hooks/tool-call-validator/constants.ts new file mode 100644 index 0000000..8adc8ea --- /dev/null +++ b/src/hooks/tool-call-validator/constants.ts @@ -0,0 +1,5 @@ +export const HOOK_NAME = "tool-call-validator" +export const INVALID_TOOL_PLACEHOLDER_PREFIX = "[Invalid tool call: " +export const INVALID_TOOL_PLACEHOLDER_SUFFIX = " - tool does not exist]" +export const PAIRED_RESULT_PLACEHOLDER = + "[Tool result removed - corresponding tool call was invalid]" diff --git a/src/hooks/tool-call-validator/index.ts b/src/hooks/tool-call-validator/index.ts new file mode 100644 index 0000000..596d924 --- /dev/null +++ b/src/hooks/tool-call-validator/index.ts @@ -0,0 +1,98 @@ +import type { Part } from "@opencode-ai/sdk" +import type { ToolRegistry } from "../../shared/tool-registry" +import type { MessageWithParts, ToolUsePart, ToolResultPart, TextPart } from "./types" +import { + INVALID_TOOL_PLACEHOLDER_PREFIX, + INVALID_TOOL_PLACEHOLDER_SUFFIX, + PAIRED_RESULT_PLACEHOLDER, +} from "./constants" + +type MessagesTransformHook = { + "experimental.chat.messages.transform"?: ( + input: Record, + output: { messages: MessageWithParts[] } + ) => Promise +} + +function isToolUsePart(part: Part): part is ToolUsePart { + const type = part.type as string + return type === "tool_use" || type === "tool" +} + +function isToolResultPart(part: Part): part is ToolResultPart { + const type = part.type as string + return type === "tool_result" +} + +function getToolName(part: ToolUsePart): string | undefined { + return part.name || part.tool +} + +function getToolUseIdFromResult(part: Part): string | undefined { + return (part as { tool_use_id?: string }).tool_use_id +} + +function createInvalidToolPlaceholder(toolName: string): TextPart { + return { + type: "text", + text: `${INVALID_TOOL_PLACEHOLDER_PREFIX}${toolName}${INVALID_TOOL_PLACEHOLDER_SUFFIX}`, + synthetic: true, + } as TextPart +} + +function createPairedResultPlaceholder(): TextPart { + return { + type: "text", + text: PAIRED_RESULT_PLACEHOLDER, + synthetic: true, + } as TextPart +} + +export function createToolCallValidatorHook( + registry: ToolRegistry +): MessagesTransformHook { + return { + "experimental.chat.messages.transform": async (_input, output) => { + const { messages } = output + + if (!messages || messages.length === 0) { + return + } + + const invalidatedToolUseIds = new Set() + + for (const message of messages) { + if (message.info.role === "user") { + continue + } + + const parts = message.parts + if (!parts || parts.length === 0) { + continue + } + + for (let i = 0; i < parts.length; i++) { + const part = parts[i] + + if (isToolUsePart(part)) { + const toolName = getToolName(part) + + if (toolName && !registry.isValidTool(toolName)) { + if (part.id) { + invalidatedToolUseIds.add(part.id) + } + + parts[i] = createInvalidToolPlaceholder(toolName) + } + } else if (isToolResultPart(part)) { + const toolUseId = getToolUseIdFromResult(part) + + if (toolUseId && invalidatedToolUseIds.has(toolUseId)) { + parts[i] = createPairedResultPlaceholder() + } + } + } + } + }, + } +} diff --git a/src/hooks/tool-call-validator/types.ts b/src/hooks/tool-call-validator/types.ts new file mode 100644 index 0000000..682958b --- /dev/null +++ b/src/hooks/tool-call-validator/types.ts @@ -0,0 +1,28 @@ +import type { Message, Part } from "@opencode-ai/sdk" + +export interface MessageWithParts { + info: Message + parts: Part[] +} + +export type ToolUsePart = Part & { + type: "tool_use" | "tool" + id: string + name?: string + tool?: string + input?: Record +} + +export type ToolResultPart = Part & { + type: "tool_result" + id?: string + tool_use_id?: string + content?: unknown +} + +export type TextPart = Part & { + type: "text" + id?: string + text: string + synthetic?: boolean +} diff --git a/src/index.ts b/src/index.ts index d31a498..3dab1b1 100644 --- a/src/index.ts +++ b/src/index.ts @@ -23,6 +23,7 @@ import { createNonInteractiveEnvHook, createInteractiveBashSessionHook, createEmptyMessageSanitizerHook, + createToolCallValidatorHook, } from "./hooks"; import { createGoogleAntigravityAuthPlugin } from "./auth/antigravity"; import { @@ -49,6 +50,7 @@ import { BackgroundManager } from "./features/background-agent"; import { createBuiltinMcps } from "./mcp"; import { OhMyOpenCodeConfigSchema, type OhMyOpenCodeConfig, type HookName } from "./config"; import { log, deepMerge, getUserConfigDir, addConfigLoadError } from "./shared"; +import { createToolRegistry } from "./shared/tool-registry"; import { PLAN_SYSTEM_PROMPT, PLAN_PERMISSION } from "./agents/plan-prompt"; import * as fs from "fs"; import * as path from "path"; @@ -345,16 +347,28 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => { const tmuxAvailable = await getTmuxPath(); + const allTools = { + ...builtinTools, + ...backgroundTools, + call_omo_agent: callOmoAgent, + look_at: lookAt, + ...(tmuxAvailable ? { interactive_bash } : {}), + }; + + const toolRegistry = createToolRegistry( + allTools, + {}, + {} + ); + + const toolCallValidator = isHookEnabled("tool-call-validator") + ? createToolCallValidatorHook(toolRegistry) + : null; + return { ...(googleAuthHooks ? { auth: googleAuthHooks.auth } : {}), - tool: { - ...builtinTools, - ...backgroundTools, - call_omo_agent: callOmoAgent, - look_at: lookAt, - ...(tmuxAvailable ? { interactive_bash } : {}), - }, + tool: allTools, "chat.message": async (input, output) => { await claudeCodeHooks["chat.message"]?.(input, output); @@ -367,6 +381,8 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => { ) => { // eslint-disable-next-line @typescript-eslint/no-explicit-any await emptyMessageSanitizer?.["experimental.chat.messages.transform"]?.(input, output as any); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await toolCallValidator?.["experimental.chat.messages.transform"]?.(input, output as any); }, config: async (config) => { diff --git a/src/shared/tool-registry.ts b/src/shared/tool-registry.ts new file mode 100644 index 0000000..87f27b8 --- /dev/null +++ b/src/shared/tool-registry.ts @@ -0,0 +1,63 @@ +export interface ToolRegistry { + isValidTool(name: string): boolean + getAllToolNames(): string[] +} + +interface MCPServerConfig { + tools?: Array<{ name: string }> + [key: string]: unknown +} + +/** + * Create tool registry for validation + * MCP tools use prefix matching: "serverName_methodName" + */ +export function createToolRegistry( + builtinTools: Record, + dynamicTools: Record, + mcpServers: Record +): ToolRegistry { + const toolNames = new Set([ + ...Object.keys(builtinTools), + ...Object.keys(dynamicTools), + ]) + + const mcpPrefixes = new Set() + + for (const serverName of Object.keys(mcpServers)) { + mcpPrefixes.add(`${serverName}_`) + + const mcpTools = mcpServers[serverName]?.tools + if (mcpTools && Array.isArray(mcpTools)) { + for (const tool of mcpTools) { + if (tool.name) { + toolNames.add(tool.name) + } + } + } + } + + return { + isValidTool(name: string): boolean { + if (!name || typeof name !== "string" || name.trim() === "") { + return false + } + + if (toolNames.has(name)) { + return true + } + + for (const prefix of mcpPrefixes) { + if (name.startsWith(prefix)) { + return true + } + } + + return false + }, + + getAllToolNames(): string[] { + return Array.from(toolNames).sort() + }, + } +}