feat: add two-layer tool call validation system (proactive + reactive) (#249)

Co-authored-by: sisyphus-dev-ai <sisyphus-dev-ai@users.noreply.github.com>
This commit is contained in:
Sisyphus
2025-12-26 03:36:27 +09:00
committed by GitHub
parent ad2bd673c4
commit 9bc2360d31
9 changed files with 282 additions and 8 deletions

View File

@@ -59,7 +59,8 @@
"agent-usage-reminder",
"non-interactive-env",
"interactive-bash-session",
"empty-message-sanitizer"
"empty-message-sanitizer",
"tool-call-validator"
]
}
},

View File

@@ -64,6 +64,7 @@ export const HookNameSchema = z.enum([
"non-interactive-env",
"interactive-bash-session",
"empty-message-sanitizer",
"tool-call-validator",
])
export const AgentOverrideConfigSchema = z.object({

View File

@@ -21,3 +21,4 @@ export { createKeywordDetectorHook } from "./keyword-detector";
export { createNonInteractiveEnvHook } from "./non-interactive-env";
export { createInteractiveBashSessionHook } from "./interactive-bash-session";
export { createEmptyMessageSanitizerHook } from "./empty-message-sanitizer";
export { createToolCallValidatorHook } from "./tool-call-validator";

View File

@@ -27,6 +27,7 @@ type RecoveryErrorType =
| "tool_result_missing"
| "thinking_block_order"
| "thinking_disabled_violation"
| "tool_not_found"
| null
interface MessageInfo {
@@ -143,6 +144,15 @@ function detectErrorType(error: unknown): RecoveryErrorType {
return "thinking_disabled_violation"
}
if (
message.includes("tool") &&
(message.includes("not found") ||
message.includes("unknown tool") ||
message.includes("invalid tool"))
) {
return "tool_not_found"
}
return null
}
@@ -243,6 +253,53 @@ async function recoverThinkingDisabledViolation(
return anySuccess
}
async function recoverToolNotFound(
client: Client,
sessionID: string,
failedAssistantMsg: MessageData,
error: unknown
): Promise<boolean> {
const errorMsg = getErrorMessage(error)
const toolNameMatch = errorMsg.match(/tool[:\s]+["']?([a-z0-9_-]+)["']?/i)
const toolName = toolNameMatch?.[1] ?? "unknown"
let parts = failedAssistantMsg.parts || []
if (parts.length === 0 && failedAssistantMsg.info?.id) {
const storedParts = readParts(failedAssistantMsg.info.id)
parts = storedParts.map((p) => ({
type: p.type === "tool" ? "tool_use" : p.type,
id: "callID" in p ? (p as { callID?: string }).callID : p.id,
name: "tool" in p ? (p as { tool?: string }).tool : undefined,
}))
}
const invalidToolUse = parts.find(
(p) => p.type === "tool_use" && "name" in p && p.name === toolName
)
if (!invalidToolUse || !("id" in invalidToolUse)) {
return false
}
const toolResultPart = {
type: "tool_result" as const,
tool_use_id: invalidToolUse.id,
content: `Error: Tool '${toolName}' does not exist. The model attempted to use a tool that is not available. This may indicate the model hallucinated the tool name or the tool was recently removed.`,
}
try {
await client.session.prompt({
path: { id: sessionID },
// @ts-expect-error - SDK types may not include tool_result parts
body: { parts: [toolResultPart] },
})
return true
} catch {
return false
}
}
const PLACEHOLDER_TEXT = "[user interrupted]"
async function recoverEmptyContentMessage(
@@ -369,11 +426,13 @@ export function createSessionRecoveryHook(ctx: PluginInput, options?: SessionRec
tool_result_missing: "Tool Crash Recovery",
thinking_block_order: "Thinking Block Recovery",
thinking_disabled_violation: "Thinking Strip Recovery",
tool_not_found: "Invalid Tool Recovery",
}
const toastMessages: Record<RecoveryErrorType & string, string> = {
tool_result_missing: "Injecting cancelled tool results...",
thinking_block_order: "Fixing message structure...",
thinking_disabled_violation: "Stripping thinking blocks...",
tool_not_found: "Handling invalid tool call...",
}
await ctx.client.tui
@@ -405,6 +464,8 @@ export function createSessionRecoveryHook(ctx: PluginInput, options?: SessionRec
const resumeConfig = extractResumeConfig(lastUser, sessionID)
await resumeSession(ctx.client, resumeConfig)
}
} else if (errorType === "tool_not_found") {
success = await recoverToolNotFound(ctx.client, sessionID, failedMsg, info.error)
}
return success

View File

@@ -0,0 +1,5 @@
export const HOOK_NAME = "tool-call-validator"
export const INVALID_TOOL_PLACEHOLDER_PREFIX = "[Invalid tool call: "
export const INVALID_TOOL_PLACEHOLDER_SUFFIX = " - tool does not exist]"
export const PAIRED_RESULT_PLACEHOLDER =
"[Tool result removed - corresponding tool call was invalid]"

View File

@@ -0,0 +1,98 @@
import type { Part } from "@opencode-ai/sdk"
import type { ToolRegistry } from "../../shared/tool-registry"
import type { MessageWithParts, ToolUsePart, ToolResultPart, TextPart } from "./types"
import {
INVALID_TOOL_PLACEHOLDER_PREFIX,
INVALID_TOOL_PLACEHOLDER_SUFFIX,
PAIRED_RESULT_PLACEHOLDER,
} from "./constants"
type MessagesTransformHook = {
"experimental.chat.messages.transform"?: (
input: Record<string, never>,
output: { messages: MessageWithParts[] }
) => Promise<void>
}
function isToolUsePart(part: Part): part is ToolUsePart {
const type = part.type as string
return type === "tool_use" || type === "tool"
}
function isToolResultPart(part: Part): part is ToolResultPart {
const type = part.type as string
return type === "tool_result"
}
function getToolName(part: ToolUsePart): string | undefined {
return part.name || part.tool
}
function getToolUseIdFromResult(part: Part): string | undefined {
return (part as { tool_use_id?: string }).tool_use_id
}
function createInvalidToolPlaceholder(toolName: string): TextPart {
return {
type: "text",
text: `${INVALID_TOOL_PLACEHOLDER_PREFIX}${toolName}${INVALID_TOOL_PLACEHOLDER_SUFFIX}`,
synthetic: true,
} as TextPart
}
function createPairedResultPlaceholder(): TextPart {
return {
type: "text",
text: PAIRED_RESULT_PLACEHOLDER,
synthetic: true,
} as TextPart
}
export function createToolCallValidatorHook(
registry: ToolRegistry
): MessagesTransformHook {
return {
"experimental.chat.messages.transform": async (_input, output) => {
const { messages } = output
if (!messages || messages.length === 0) {
return
}
const invalidatedToolUseIds = new Set<string>()
for (const message of messages) {
if (message.info.role === "user") {
continue
}
const parts = message.parts
if (!parts || parts.length === 0) {
continue
}
for (let i = 0; i < parts.length; i++) {
const part = parts[i]
if (isToolUsePart(part)) {
const toolName = getToolName(part)
if (toolName && !registry.isValidTool(toolName)) {
if (part.id) {
invalidatedToolUseIds.add(part.id)
}
parts[i] = createInvalidToolPlaceholder(toolName)
}
} else if (isToolResultPart(part)) {
const toolUseId = getToolUseIdFromResult(part)
if (toolUseId && invalidatedToolUseIds.has(toolUseId)) {
parts[i] = createPairedResultPlaceholder()
}
}
}
}
},
}
}

View File

@@ -0,0 +1,28 @@
import type { Message, Part } from "@opencode-ai/sdk"
export interface MessageWithParts {
info: Message
parts: Part[]
}
export type ToolUsePart = Part & {
type: "tool_use" | "tool"
id: string
name?: string
tool?: string
input?: Record<string, unknown>
}
export type ToolResultPart = Part & {
type: "tool_result"
id?: string
tool_use_id?: string
content?: unknown
}
export type TextPart = Part & {
type: "text"
id?: string
text: string
synthetic?: boolean
}

View File

@@ -23,6 +23,7 @@ import {
createNonInteractiveEnvHook,
createInteractiveBashSessionHook,
createEmptyMessageSanitizerHook,
createToolCallValidatorHook,
} from "./hooks";
import { createGoogleAntigravityAuthPlugin } from "./auth/antigravity";
import {
@@ -49,6 +50,7 @@ import { BackgroundManager } from "./features/background-agent";
import { createBuiltinMcps } from "./mcp";
import { OhMyOpenCodeConfigSchema, type OhMyOpenCodeConfig, type HookName } from "./config";
import { log, deepMerge, getUserConfigDir, addConfigLoadError } from "./shared";
import { createToolRegistry } from "./shared/tool-registry";
import { PLAN_SYSTEM_PROMPT, PLAN_PERMISSION } from "./agents/plan-prompt";
import * as fs from "fs";
import * as path from "path";
@@ -345,16 +347,28 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
const tmuxAvailable = await getTmuxPath();
const allTools = {
...builtinTools,
...backgroundTools,
call_omo_agent: callOmoAgent,
look_at: lookAt,
...(tmuxAvailable ? { interactive_bash } : {}),
};
const toolRegistry = createToolRegistry(
allTools,
{},
{}
);
const toolCallValidator = isHookEnabled("tool-call-validator")
? createToolCallValidatorHook(toolRegistry)
: null;
return {
...(googleAuthHooks ? { auth: googleAuthHooks.auth } : {}),
tool: {
...builtinTools,
...backgroundTools,
call_omo_agent: callOmoAgent,
look_at: lookAt,
...(tmuxAvailable ? { interactive_bash } : {}),
},
tool: allTools,
"chat.message": async (input, output) => {
await claudeCodeHooks["chat.message"]?.(input, output);
@@ -367,6 +381,8 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
) => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
await emptyMessageSanitizer?.["experimental.chat.messages.transform"]?.(input, output as any);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
await toolCallValidator?.["experimental.chat.messages.transform"]?.(input, output as any);
},
config: async (config) => {

View File

@@ -0,0 +1,63 @@
export interface ToolRegistry {
isValidTool(name: string): boolean
getAllToolNames(): string[]
}
interface MCPServerConfig {
tools?: Array<{ name: string }>
[key: string]: unknown
}
/**
* Create tool registry for validation
* MCP tools use prefix matching: "serverName_methodName"
*/
export function createToolRegistry(
builtinTools: Record<string, unknown>,
dynamicTools: Record<string, unknown>,
mcpServers: Record<string, MCPServerConfig>
): ToolRegistry {
const toolNames = new Set<string>([
...Object.keys(builtinTools),
...Object.keys(dynamicTools),
])
const mcpPrefixes = new Set<string>()
for (const serverName of Object.keys(mcpServers)) {
mcpPrefixes.add(`${serverName}_`)
const mcpTools = mcpServers[serverName]?.tools
if (mcpTools && Array.isArray(mcpTools)) {
for (const tool of mcpTools) {
if (tool.name) {
toolNames.add(tool.name)
}
}
}
}
return {
isValidTool(name: string): boolean {
if (!name || typeof name !== "string" || name.trim() === "") {
return false
}
if (toolNames.has(name)) {
return true
}
for (const prefix of mcpPrefixes) {
if (name.startsWith(prefix)) {
return true
}
}
return false
},
getAllToolNames(): string[] {
return Array.from(toolNames).sort()
},
}
}