feat: add two-layer tool call validation system (proactive + reactive) (#249)
Co-authored-by: sisyphus-dev-ai <sisyphus-dev-ai@users.noreply.github.com>
This commit is contained in:
@@ -59,7 +59,8 @@
|
||||
"agent-usage-reminder",
|
||||
"non-interactive-env",
|
||||
"interactive-bash-session",
|
||||
"empty-message-sanitizer"
|
||||
"empty-message-sanitizer",
|
||||
"tool-call-validator"
|
||||
]
|
||||
}
|
||||
},
|
||||
|
||||
@@ -64,6 +64,7 @@ export const HookNameSchema = z.enum([
|
||||
"non-interactive-env",
|
||||
"interactive-bash-session",
|
||||
"empty-message-sanitizer",
|
||||
"tool-call-validator",
|
||||
])
|
||||
|
||||
export const AgentOverrideConfigSchema = z.object({
|
||||
|
||||
@@ -21,3 +21,4 @@ export { createKeywordDetectorHook } from "./keyword-detector";
|
||||
export { createNonInteractiveEnvHook } from "./non-interactive-env";
|
||||
export { createInteractiveBashSessionHook } from "./interactive-bash-session";
|
||||
export { createEmptyMessageSanitizerHook } from "./empty-message-sanitizer";
|
||||
export { createToolCallValidatorHook } from "./tool-call-validator";
|
||||
|
||||
@@ -27,6 +27,7 @@ type RecoveryErrorType =
|
||||
| "tool_result_missing"
|
||||
| "thinking_block_order"
|
||||
| "thinking_disabled_violation"
|
||||
| "tool_not_found"
|
||||
| null
|
||||
|
||||
interface MessageInfo {
|
||||
@@ -143,6 +144,15 @@ function detectErrorType(error: unknown): RecoveryErrorType {
|
||||
return "thinking_disabled_violation"
|
||||
}
|
||||
|
||||
if (
|
||||
message.includes("tool") &&
|
||||
(message.includes("not found") ||
|
||||
message.includes("unknown tool") ||
|
||||
message.includes("invalid tool"))
|
||||
) {
|
||||
return "tool_not_found"
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
@@ -243,6 +253,53 @@ async function recoverThinkingDisabledViolation(
|
||||
return anySuccess
|
||||
}
|
||||
|
||||
async function recoverToolNotFound(
|
||||
client: Client,
|
||||
sessionID: string,
|
||||
failedAssistantMsg: MessageData,
|
||||
error: unknown
|
||||
): Promise<boolean> {
|
||||
const errorMsg = getErrorMessage(error)
|
||||
const toolNameMatch = errorMsg.match(/tool[:\s]+["']?([a-z0-9_-]+)["']?/i)
|
||||
const toolName = toolNameMatch?.[1] ?? "unknown"
|
||||
|
||||
let parts = failedAssistantMsg.parts || []
|
||||
if (parts.length === 0 && failedAssistantMsg.info?.id) {
|
||||
const storedParts = readParts(failedAssistantMsg.info.id)
|
||||
parts = storedParts.map((p) => ({
|
||||
type: p.type === "tool" ? "tool_use" : p.type,
|
||||
id: "callID" in p ? (p as { callID?: string }).callID : p.id,
|
||||
name: "tool" in p ? (p as { tool?: string }).tool : undefined,
|
||||
}))
|
||||
}
|
||||
|
||||
const invalidToolUse = parts.find(
|
||||
(p) => p.type === "tool_use" && "name" in p && p.name === toolName
|
||||
)
|
||||
|
||||
if (!invalidToolUse || !("id" in invalidToolUse)) {
|
||||
return false
|
||||
}
|
||||
|
||||
const toolResultPart = {
|
||||
type: "tool_result" as const,
|
||||
tool_use_id: invalidToolUse.id,
|
||||
content: `Error: Tool '${toolName}' does not exist. The model attempted to use a tool that is not available. This may indicate the model hallucinated the tool name or the tool was recently removed.`,
|
||||
}
|
||||
|
||||
try {
|
||||
await client.session.prompt({
|
||||
path: { id: sessionID },
|
||||
// @ts-expect-error - SDK types may not include tool_result parts
|
||||
body: { parts: [toolResultPart] },
|
||||
})
|
||||
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const PLACEHOLDER_TEXT = "[user interrupted]"
|
||||
|
||||
async function recoverEmptyContentMessage(
|
||||
@@ -369,11 +426,13 @@ export function createSessionRecoveryHook(ctx: PluginInput, options?: SessionRec
|
||||
tool_result_missing: "Tool Crash Recovery",
|
||||
thinking_block_order: "Thinking Block Recovery",
|
||||
thinking_disabled_violation: "Thinking Strip Recovery",
|
||||
tool_not_found: "Invalid Tool Recovery",
|
||||
}
|
||||
const toastMessages: Record<RecoveryErrorType & string, string> = {
|
||||
tool_result_missing: "Injecting cancelled tool results...",
|
||||
thinking_block_order: "Fixing message structure...",
|
||||
thinking_disabled_violation: "Stripping thinking blocks...",
|
||||
tool_not_found: "Handling invalid tool call...",
|
||||
}
|
||||
|
||||
await ctx.client.tui
|
||||
@@ -405,6 +464,8 @@ export function createSessionRecoveryHook(ctx: PluginInput, options?: SessionRec
|
||||
const resumeConfig = extractResumeConfig(lastUser, sessionID)
|
||||
await resumeSession(ctx.client, resumeConfig)
|
||||
}
|
||||
} else if (errorType === "tool_not_found") {
|
||||
success = await recoverToolNotFound(ctx.client, sessionID, failedMsg, info.error)
|
||||
}
|
||||
|
||||
return success
|
||||
|
||||
5
src/hooks/tool-call-validator/constants.ts
Normal file
5
src/hooks/tool-call-validator/constants.ts
Normal file
@@ -0,0 +1,5 @@
|
||||
export const HOOK_NAME = "tool-call-validator"
|
||||
export const INVALID_TOOL_PLACEHOLDER_PREFIX = "[Invalid tool call: "
|
||||
export const INVALID_TOOL_PLACEHOLDER_SUFFIX = " - tool does not exist]"
|
||||
export const PAIRED_RESULT_PLACEHOLDER =
|
||||
"[Tool result removed - corresponding tool call was invalid]"
|
||||
98
src/hooks/tool-call-validator/index.ts
Normal file
98
src/hooks/tool-call-validator/index.ts
Normal file
@@ -0,0 +1,98 @@
|
||||
import type { Part } from "@opencode-ai/sdk"
|
||||
import type { ToolRegistry } from "../../shared/tool-registry"
|
||||
import type { MessageWithParts, ToolUsePart, ToolResultPart, TextPart } from "./types"
|
||||
import {
|
||||
INVALID_TOOL_PLACEHOLDER_PREFIX,
|
||||
INVALID_TOOL_PLACEHOLDER_SUFFIX,
|
||||
PAIRED_RESULT_PLACEHOLDER,
|
||||
} from "./constants"
|
||||
|
||||
type MessagesTransformHook = {
|
||||
"experimental.chat.messages.transform"?: (
|
||||
input: Record<string, never>,
|
||||
output: { messages: MessageWithParts[] }
|
||||
) => Promise<void>
|
||||
}
|
||||
|
||||
function isToolUsePart(part: Part): part is ToolUsePart {
|
||||
const type = part.type as string
|
||||
return type === "tool_use" || type === "tool"
|
||||
}
|
||||
|
||||
function isToolResultPart(part: Part): part is ToolResultPart {
|
||||
const type = part.type as string
|
||||
return type === "tool_result"
|
||||
}
|
||||
|
||||
function getToolName(part: ToolUsePart): string | undefined {
|
||||
return part.name || part.tool
|
||||
}
|
||||
|
||||
function getToolUseIdFromResult(part: Part): string | undefined {
|
||||
return (part as { tool_use_id?: string }).tool_use_id
|
||||
}
|
||||
|
||||
function createInvalidToolPlaceholder(toolName: string): TextPart {
|
||||
return {
|
||||
type: "text",
|
||||
text: `${INVALID_TOOL_PLACEHOLDER_PREFIX}${toolName}${INVALID_TOOL_PLACEHOLDER_SUFFIX}`,
|
||||
synthetic: true,
|
||||
} as TextPart
|
||||
}
|
||||
|
||||
function createPairedResultPlaceholder(): TextPart {
|
||||
return {
|
||||
type: "text",
|
||||
text: PAIRED_RESULT_PLACEHOLDER,
|
||||
synthetic: true,
|
||||
} as TextPart
|
||||
}
|
||||
|
||||
export function createToolCallValidatorHook(
|
||||
registry: ToolRegistry
|
||||
): MessagesTransformHook {
|
||||
return {
|
||||
"experimental.chat.messages.transform": async (_input, output) => {
|
||||
const { messages } = output
|
||||
|
||||
if (!messages || messages.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
const invalidatedToolUseIds = new Set<string>()
|
||||
|
||||
for (const message of messages) {
|
||||
if (message.info.role === "user") {
|
||||
continue
|
||||
}
|
||||
|
||||
const parts = message.parts
|
||||
if (!parts || parts.length === 0) {
|
||||
continue
|
||||
}
|
||||
|
||||
for (let i = 0; i < parts.length; i++) {
|
||||
const part = parts[i]
|
||||
|
||||
if (isToolUsePart(part)) {
|
||||
const toolName = getToolName(part)
|
||||
|
||||
if (toolName && !registry.isValidTool(toolName)) {
|
||||
if (part.id) {
|
||||
invalidatedToolUseIds.add(part.id)
|
||||
}
|
||||
|
||||
parts[i] = createInvalidToolPlaceholder(toolName)
|
||||
}
|
||||
} else if (isToolResultPart(part)) {
|
||||
const toolUseId = getToolUseIdFromResult(part)
|
||||
|
||||
if (toolUseId && invalidatedToolUseIds.has(toolUseId)) {
|
||||
parts[i] = createPairedResultPlaceholder()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
28
src/hooks/tool-call-validator/types.ts
Normal file
28
src/hooks/tool-call-validator/types.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
import type { Message, Part } from "@opencode-ai/sdk"
|
||||
|
||||
export interface MessageWithParts {
|
||||
info: Message
|
||||
parts: Part[]
|
||||
}
|
||||
|
||||
export type ToolUsePart = Part & {
|
||||
type: "tool_use" | "tool"
|
||||
id: string
|
||||
name?: string
|
||||
tool?: string
|
||||
input?: Record<string, unknown>
|
||||
}
|
||||
|
||||
export type ToolResultPart = Part & {
|
||||
type: "tool_result"
|
||||
id?: string
|
||||
tool_use_id?: string
|
||||
content?: unknown
|
||||
}
|
||||
|
||||
export type TextPart = Part & {
|
||||
type: "text"
|
||||
id?: string
|
||||
text: string
|
||||
synthetic?: boolean
|
||||
}
|
||||
26
src/index.ts
26
src/index.ts
@@ -23,6 +23,7 @@ import {
|
||||
createNonInteractiveEnvHook,
|
||||
createInteractiveBashSessionHook,
|
||||
createEmptyMessageSanitizerHook,
|
||||
createToolCallValidatorHook,
|
||||
} from "./hooks";
|
||||
import { createGoogleAntigravityAuthPlugin } from "./auth/antigravity";
|
||||
import {
|
||||
@@ -49,6 +50,7 @@ import { BackgroundManager } from "./features/background-agent";
|
||||
import { createBuiltinMcps } from "./mcp";
|
||||
import { OhMyOpenCodeConfigSchema, type OhMyOpenCodeConfig, type HookName } from "./config";
|
||||
import { log, deepMerge, getUserConfigDir, addConfigLoadError } from "./shared";
|
||||
import { createToolRegistry } from "./shared/tool-registry";
|
||||
import { PLAN_SYSTEM_PROMPT, PLAN_PERMISSION } from "./agents/plan-prompt";
|
||||
import * as fs from "fs";
|
||||
import * as path from "path";
|
||||
@@ -345,16 +347,28 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
|
||||
|
||||
const tmuxAvailable = await getTmuxPath();
|
||||
|
||||
return {
|
||||
...(googleAuthHooks ? { auth: googleAuthHooks.auth } : {}),
|
||||
|
||||
tool: {
|
||||
const allTools = {
|
||||
...builtinTools,
|
||||
...backgroundTools,
|
||||
call_omo_agent: callOmoAgent,
|
||||
look_at: lookAt,
|
||||
...(tmuxAvailable ? { interactive_bash } : {}),
|
||||
},
|
||||
};
|
||||
|
||||
const toolRegistry = createToolRegistry(
|
||||
allTools,
|
||||
{},
|
||||
{}
|
||||
);
|
||||
|
||||
const toolCallValidator = isHookEnabled("tool-call-validator")
|
||||
? createToolCallValidatorHook(toolRegistry)
|
||||
: null;
|
||||
|
||||
return {
|
||||
...(googleAuthHooks ? { auth: googleAuthHooks.auth } : {}),
|
||||
|
||||
tool: allTools,
|
||||
|
||||
"chat.message": async (input, output) => {
|
||||
await claudeCodeHooks["chat.message"]?.(input, output);
|
||||
@@ -367,6 +381,8 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
|
||||
) => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
await emptyMessageSanitizer?.["experimental.chat.messages.transform"]?.(input, output as any);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
await toolCallValidator?.["experimental.chat.messages.transform"]?.(input, output as any);
|
||||
},
|
||||
|
||||
config: async (config) => {
|
||||
|
||||
63
src/shared/tool-registry.ts
Normal file
63
src/shared/tool-registry.ts
Normal file
@@ -0,0 +1,63 @@
|
||||
export interface ToolRegistry {
|
||||
isValidTool(name: string): boolean
|
||||
getAllToolNames(): string[]
|
||||
}
|
||||
|
||||
interface MCPServerConfig {
|
||||
tools?: Array<{ name: string }>
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/**
|
||||
* Create tool registry for validation
|
||||
* MCP tools use prefix matching: "serverName_methodName"
|
||||
*/
|
||||
export function createToolRegistry(
|
||||
builtinTools: Record<string, unknown>,
|
||||
dynamicTools: Record<string, unknown>,
|
||||
mcpServers: Record<string, MCPServerConfig>
|
||||
): ToolRegistry {
|
||||
const toolNames = new Set<string>([
|
||||
...Object.keys(builtinTools),
|
||||
...Object.keys(dynamicTools),
|
||||
])
|
||||
|
||||
const mcpPrefixes = new Set<string>()
|
||||
|
||||
for (const serverName of Object.keys(mcpServers)) {
|
||||
mcpPrefixes.add(`${serverName}_`)
|
||||
|
||||
const mcpTools = mcpServers[serverName]?.tools
|
||||
if (mcpTools && Array.isArray(mcpTools)) {
|
||||
for (const tool of mcpTools) {
|
||||
if (tool.name) {
|
||||
toolNames.add(tool.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isValidTool(name: string): boolean {
|
||||
if (!name || typeof name !== "string" || name.trim() === "") {
|
||||
return false
|
||||
}
|
||||
|
||||
if (toolNames.has(name)) {
|
||||
return true
|
||||
}
|
||||
|
||||
for (const prefix of mcpPrefixes) {
|
||||
if (name.startsWith(prefix)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
},
|
||||
|
||||
getAllToolNames(): string[] {
|
||||
return Array.from(toolNames).sort()
|
||||
},
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user