feat: add two-layer tool call validation system (proactive + reactive) (#249)
Co-authored-by: sisyphus-dev-ai <sisyphus-dev-ai@users.noreply.github.com>
This commit is contained in:
@@ -59,7 +59,8 @@
|
|||||||
"agent-usage-reminder",
|
"agent-usage-reminder",
|
||||||
"non-interactive-env",
|
"non-interactive-env",
|
||||||
"interactive-bash-session",
|
"interactive-bash-session",
|
||||||
"empty-message-sanitizer"
|
"empty-message-sanitizer",
|
||||||
|
"tool-call-validator"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ export const HookNameSchema = z.enum([
|
|||||||
"non-interactive-env",
|
"non-interactive-env",
|
||||||
"interactive-bash-session",
|
"interactive-bash-session",
|
||||||
"empty-message-sanitizer",
|
"empty-message-sanitizer",
|
||||||
|
"tool-call-validator",
|
||||||
])
|
])
|
||||||
|
|
||||||
export const AgentOverrideConfigSchema = z.object({
|
export const AgentOverrideConfigSchema = z.object({
|
||||||
|
|||||||
@@ -21,3 +21,4 @@ export { createKeywordDetectorHook } from "./keyword-detector";
|
|||||||
export { createNonInteractiveEnvHook } from "./non-interactive-env";
|
export { createNonInteractiveEnvHook } from "./non-interactive-env";
|
||||||
export { createInteractiveBashSessionHook } from "./interactive-bash-session";
|
export { createInteractiveBashSessionHook } from "./interactive-bash-session";
|
||||||
export { createEmptyMessageSanitizerHook } from "./empty-message-sanitizer";
|
export { createEmptyMessageSanitizerHook } from "./empty-message-sanitizer";
|
||||||
|
export { createToolCallValidatorHook } from "./tool-call-validator";
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ type RecoveryErrorType =
|
|||||||
| "tool_result_missing"
|
| "tool_result_missing"
|
||||||
| "thinking_block_order"
|
| "thinking_block_order"
|
||||||
| "thinking_disabled_violation"
|
| "thinking_disabled_violation"
|
||||||
|
| "tool_not_found"
|
||||||
| null
|
| null
|
||||||
|
|
||||||
interface MessageInfo {
|
interface MessageInfo {
|
||||||
@@ -143,6 +144,15 @@ function detectErrorType(error: unknown): RecoveryErrorType {
|
|||||||
return "thinking_disabled_violation"
|
return "thinking_disabled_violation"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
message.includes("tool") &&
|
||||||
|
(message.includes("not found") ||
|
||||||
|
message.includes("unknown tool") ||
|
||||||
|
message.includes("invalid tool"))
|
||||||
|
) {
|
||||||
|
return "tool_not_found"
|
||||||
|
}
|
||||||
|
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,6 +253,53 @@ async function recoverThinkingDisabledViolation(
|
|||||||
return anySuccess
|
return anySuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function recoverToolNotFound(
|
||||||
|
client: Client,
|
||||||
|
sessionID: string,
|
||||||
|
failedAssistantMsg: MessageData,
|
||||||
|
error: unknown
|
||||||
|
): Promise<boolean> {
|
||||||
|
const errorMsg = getErrorMessage(error)
|
||||||
|
const toolNameMatch = errorMsg.match(/tool[:\s]+["']?([a-z0-9_-]+)["']?/i)
|
||||||
|
const toolName = toolNameMatch?.[1] ?? "unknown"
|
||||||
|
|
||||||
|
let parts = failedAssistantMsg.parts || []
|
||||||
|
if (parts.length === 0 && failedAssistantMsg.info?.id) {
|
||||||
|
const storedParts = readParts(failedAssistantMsg.info.id)
|
||||||
|
parts = storedParts.map((p) => ({
|
||||||
|
type: p.type === "tool" ? "tool_use" : p.type,
|
||||||
|
id: "callID" in p ? (p as { callID?: string }).callID : p.id,
|
||||||
|
name: "tool" in p ? (p as { tool?: string }).tool : undefined,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
const invalidToolUse = parts.find(
|
||||||
|
(p) => p.type === "tool_use" && "name" in p && p.name === toolName
|
||||||
|
)
|
||||||
|
|
||||||
|
if (!invalidToolUse || !("id" in invalidToolUse)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolResultPart = {
|
||||||
|
type: "tool_result" as const,
|
||||||
|
tool_use_id: invalidToolUse.id,
|
||||||
|
content: `Error: Tool '${toolName}' does not exist. The model attempted to use a tool that is not available. This may indicate the model hallucinated the tool name or the tool was recently removed.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
await client.session.prompt({
|
||||||
|
path: { id: sessionID },
|
||||||
|
// @ts-expect-error - SDK types may not include tool_result parts
|
||||||
|
body: { parts: [toolResultPart] },
|
||||||
|
})
|
||||||
|
|
||||||
|
return true
|
||||||
|
} catch {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const PLACEHOLDER_TEXT = "[user interrupted]"
|
const PLACEHOLDER_TEXT = "[user interrupted]"
|
||||||
|
|
||||||
async function recoverEmptyContentMessage(
|
async function recoverEmptyContentMessage(
|
||||||
@@ -369,11 +426,13 @@ export function createSessionRecoveryHook(ctx: PluginInput, options?: SessionRec
|
|||||||
tool_result_missing: "Tool Crash Recovery",
|
tool_result_missing: "Tool Crash Recovery",
|
||||||
thinking_block_order: "Thinking Block Recovery",
|
thinking_block_order: "Thinking Block Recovery",
|
||||||
thinking_disabled_violation: "Thinking Strip Recovery",
|
thinking_disabled_violation: "Thinking Strip Recovery",
|
||||||
|
tool_not_found: "Invalid Tool Recovery",
|
||||||
}
|
}
|
||||||
const toastMessages: Record<RecoveryErrorType & string, string> = {
|
const toastMessages: Record<RecoveryErrorType & string, string> = {
|
||||||
tool_result_missing: "Injecting cancelled tool results...",
|
tool_result_missing: "Injecting cancelled tool results...",
|
||||||
thinking_block_order: "Fixing message structure...",
|
thinking_block_order: "Fixing message structure...",
|
||||||
thinking_disabled_violation: "Stripping thinking blocks...",
|
thinking_disabled_violation: "Stripping thinking blocks...",
|
||||||
|
tool_not_found: "Handling invalid tool call...",
|
||||||
}
|
}
|
||||||
|
|
||||||
await ctx.client.tui
|
await ctx.client.tui
|
||||||
@@ -405,6 +464,8 @@ export function createSessionRecoveryHook(ctx: PluginInput, options?: SessionRec
|
|||||||
const resumeConfig = extractResumeConfig(lastUser, sessionID)
|
const resumeConfig = extractResumeConfig(lastUser, sessionID)
|
||||||
await resumeSession(ctx.client, resumeConfig)
|
await resumeSession(ctx.client, resumeConfig)
|
||||||
}
|
}
|
||||||
|
} else if (errorType === "tool_not_found") {
|
||||||
|
success = await recoverToolNotFound(ctx.client, sessionID, failedMsg, info.error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return success
|
return success
|
||||||
|
|||||||
5
src/hooks/tool-call-validator/constants.ts
Normal file
5
src/hooks/tool-call-validator/constants.ts
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
export const HOOK_NAME = "tool-call-validator"
|
||||||
|
export const INVALID_TOOL_PLACEHOLDER_PREFIX = "[Invalid tool call: "
|
||||||
|
export const INVALID_TOOL_PLACEHOLDER_SUFFIX = " - tool does not exist]"
|
||||||
|
export const PAIRED_RESULT_PLACEHOLDER =
|
||||||
|
"[Tool result removed - corresponding tool call was invalid]"
|
||||||
98
src/hooks/tool-call-validator/index.ts
Normal file
98
src/hooks/tool-call-validator/index.ts
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import type { Part } from "@opencode-ai/sdk"
|
||||||
|
import type { ToolRegistry } from "../../shared/tool-registry"
|
||||||
|
import type { MessageWithParts, ToolUsePart, ToolResultPart, TextPart } from "./types"
|
||||||
|
import {
|
||||||
|
INVALID_TOOL_PLACEHOLDER_PREFIX,
|
||||||
|
INVALID_TOOL_PLACEHOLDER_SUFFIX,
|
||||||
|
PAIRED_RESULT_PLACEHOLDER,
|
||||||
|
} from "./constants"
|
||||||
|
|
||||||
|
type MessagesTransformHook = {
|
||||||
|
"experimental.chat.messages.transform"?: (
|
||||||
|
input: Record<string, never>,
|
||||||
|
output: { messages: MessageWithParts[] }
|
||||||
|
) => Promise<void>
|
||||||
|
}
|
||||||
|
|
||||||
|
function isToolUsePart(part: Part): part is ToolUsePart {
|
||||||
|
const type = part.type as string
|
||||||
|
return type === "tool_use" || type === "tool"
|
||||||
|
}
|
||||||
|
|
||||||
|
function isToolResultPart(part: Part): part is ToolResultPart {
|
||||||
|
const type = part.type as string
|
||||||
|
return type === "tool_result"
|
||||||
|
}
|
||||||
|
|
||||||
|
function getToolName(part: ToolUsePart): string | undefined {
|
||||||
|
return part.name || part.tool
|
||||||
|
}
|
||||||
|
|
||||||
|
function getToolUseIdFromResult(part: Part): string | undefined {
|
||||||
|
return (part as { tool_use_id?: string }).tool_use_id
|
||||||
|
}
|
||||||
|
|
||||||
|
function createInvalidToolPlaceholder(toolName: string): TextPart {
|
||||||
|
return {
|
||||||
|
type: "text",
|
||||||
|
text: `${INVALID_TOOL_PLACEHOLDER_PREFIX}${toolName}${INVALID_TOOL_PLACEHOLDER_SUFFIX}`,
|
||||||
|
synthetic: true,
|
||||||
|
} as TextPart
|
||||||
|
}
|
||||||
|
|
||||||
|
function createPairedResultPlaceholder(): TextPart {
|
||||||
|
return {
|
||||||
|
type: "text",
|
||||||
|
text: PAIRED_RESULT_PLACEHOLDER,
|
||||||
|
synthetic: true,
|
||||||
|
} as TextPart
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createToolCallValidatorHook(
|
||||||
|
registry: ToolRegistry
|
||||||
|
): MessagesTransformHook {
|
||||||
|
return {
|
||||||
|
"experimental.chat.messages.transform": async (_input, output) => {
|
||||||
|
const { messages } = output
|
||||||
|
|
||||||
|
if (!messages || messages.length === 0) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const invalidatedToolUseIds = new Set<string>()
|
||||||
|
|
||||||
|
for (const message of messages) {
|
||||||
|
if (message.info.role === "user") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
const parts = message.parts
|
||||||
|
if (!parts || parts.length === 0) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for (let i = 0; i < parts.length; i++) {
|
||||||
|
const part = parts[i]
|
||||||
|
|
||||||
|
if (isToolUsePart(part)) {
|
||||||
|
const toolName = getToolName(part)
|
||||||
|
|
||||||
|
if (toolName && !registry.isValidTool(toolName)) {
|
||||||
|
if (part.id) {
|
||||||
|
invalidatedToolUseIds.add(part.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
parts[i] = createInvalidToolPlaceholder(toolName)
|
||||||
|
}
|
||||||
|
} else if (isToolResultPart(part)) {
|
||||||
|
const toolUseId = getToolUseIdFromResult(part)
|
||||||
|
|
||||||
|
if (toolUseId && invalidatedToolUseIds.has(toolUseId)) {
|
||||||
|
parts[i] = createPairedResultPlaceholder()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
28
src/hooks/tool-call-validator/types.ts
Normal file
28
src/hooks/tool-call-validator/types.ts
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import type { Message, Part } from "@opencode-ai/sdk"
|
||||||
|
|
||||||
|
export interface MessageWithParts {
|
||||||
|
info: Message
|
||||||
|
parts: Part[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ToolUsePart = Part & {
|
||||||
|
type: "tool_use" | "tool"
|
||||||
|
id: string
|
||||||
|
name?: string
|
||||||
|
tool?: string
|
||||||
|
input?: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ToolResultPart = Part & {
|
||||||
|
type: "tool_result"
|
||||||
|
id?: string
|
||||||
|
tool_use_id?: string
|
||||||
|
content?: unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
export type TextPart = Part & {
|
||||||
|
type: "text"
|
||||||
|
id?: string
|
||||||
|
text: string
|
||||||
|
synthetic?: boolean
|
||||||
|
}
|
||||||
26
src/index.ts
26
src/index.ts
@@ -23,6 +23,7 @@ import {
|
|||||||
createNonInteractiveEnvHook,
|
createNonInteractiveEnvHook,
|
||||||
createInteractiveBashSessionHook,
|
createInteractiveBashSessionHook,
|
||||||
createEmptyMessageSanitizerHook,
|
createEmptyMessageSanitizerHook,
|
||||||
|
createToolCallValidatorHook,
|
||||||
} from "./hooks";
|
} from "./hooks";
|
||||||
import { createGoogleAntigravityAuthPlugin } from "./auth/antigravity";
|
import { createGoogleAntigravityAuthPlugin } from "./auth/antigravity";
|
||||||
import {
|
import {
|
||||||
@@ -49,6 +50,7 @@ import { BackgroundManager } from "./features/background-agent";
|
|||||||
import { createBuiltinMcps } from "./mcp";
|
import { createBuiltinMcps } from "./mcp";
|
||||||
import { OhMyOpenCodeConfigSchema, type OhMyOpenCodeConfig, type HookName } from "./config";
|
import { OhMyOpenCodeConfigSchema, type OhMyOpenCodeConfig, type HookName } from "./config";
|
||||||
import { log, deepMerge, getUserConfigDir, addConfigLoadError } from "./shared";
|
import { log, deepMerge, getUserConfigDir, addConfigLoadError } from "./shared";
|
||||||
|
import { createToolRegistry } from "./shared/tool-registry";
|
||||||
import { PLAN_SYSTEM_PROMPT, PLAN_PERMISSION } from "./agents/plan-prompt";
|
import { PLAN_SYSTEM_PROMPT, PLAN_PERMISSION } from "./agents/plan-prompt";
|
||||||
import * as fs from "fs";
|
import * as fs from "fs";
|
||||||
import * as path from "path";
|
import * as path from "path";
|
||||||
@@ -345,16 +347,28 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
|
|||||||
|
|
||||||
const tmuxAvailable = await getTmuxPath();
|
const tmuxAvailable = await getTmuxPath();
|
||||||
|
|
||||||
return {
|
const allTools = {
|
||||||
...(googleAuthHooks ? { auth: googleAuthHooks.auth } : {}),
|
|
||||||
|
|
||||||
tool: {
|
|
||||||
...builtinTools,
|
...builtinTools,
|
||||||
...backgroundTools,
|
...backgroundTools,
|
||||||
call_omo_agent: callOmoAgent,
|
call_omo_agent: callOmoAgent,
|
||||||
look_at: lookAt,
|
look_at: lookAt,
|
||||||
...(tmuxAvailable ? { interactive_bash } : {}),
|
...(tmuxAvailable ? { interactive_bash } : {}),
|
||||||
},
|
};
|
||||||
|
|
||||||
|
const toolRegistry = createToolRegistry(
|
||||||
|
allTools,
|
||||||
|
{},
|
||||||
|
{}
|
||||||
|
);
|
||||||
|
|
||||||
|
const toolCallValidator = isHookEnabled("tool-call-validator")
|
||||||
|
? createToolCallValidatorHook(toolRegistry)
|
||||||
|
: null;
|
||||||
|
|
||||||
|
return {
|
||||||
|
...(googleAuthHooks ? { auth: googleAuthHooks.auth } : {}),
|
||||||
|
|
||||||
|
tool: allTools,
|
||||||
|
|
||||||
"chat.message": async (input, output) => {
|
"chat.message": async (input, output) => {
|
||||||
await claudeCodeHooks["chat.message"]?.(input, output);
|
await claudeCodeHooks["chat.message"]?.(input, output);
|
||||||
@@ -367,6 +381,8 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
|
|||||||
) => {
|
) => {
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
await emptyMessageSanitizer?.["experimental.chat.messages.transform"]?.(input, output as any);
|
await emptyMessageSanitizer?.["experimental.chat.messages.transform"]?.(input, output as any);
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
await toolCallValidator?.["experimental.chat.messages.transform"]?.(input, output as any);
|
||||||
},
|
},
|
||||||
|
|
||||||
config: async (config) => {
|
config: async (config) => {
|
||||||
|
|||||||
63
src/shared/tool-registry.ts
Normal file
63
src/shared/tool-registry.ts
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
export interface ToolRegistry {
|
||||||
|
isValidTool(name: string): boolean
|
||||||
|
getAllToolNames(): string[]
|
||||||
|
}
|
||||||
|
|
||||||
|
interface MCPServerConfig {
|
||||||
|
tools?: Array<{ name: string }>
|
||||||
|
[key: string]: unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create tool registry for validation
|
||||||
|
* MCP tools use prefix matching: "serverName_methodName"
|
||||||
|
*/
|
||||||
|
export function createToolRegistry(
|
||||||
|
builtinTools: Record<string, unknown>,
|
||||||
|
dynamicTools: Record<string, unknown>,
|
||||||
|
mcpServers: Record<string, MCPServerConfig>
|
||||||
|
): ToolRegistry {
|
||||||
|
const toolNames = new Set<string>([
|
||||||
|
...Object.keys(builtinTools),
|
||||||
|
...Object.keys(dynamicTools),
|
||||||
|
])
|
||||||
|
|
||||||
|
const mcpPrefixes = new Set<string>()
|
||||||
|
|
||||||
|
for (const serverName of Object.keys(mcpServers)) {
|
||||||
|
mcpPrefixes.add(`${serverName}_`)
|
||||||
|
|
||||||
|
const mcpTools = mcpServers[serverName]?.tools
|
||||||
|
if (mcpTools && Array.isArray(mcpTools)) {
|
||||||
|
for (const tool of mcpTools) {
|
||||||
|
if (tool.name) {
|
||||||
|
toolNames.add(tool.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
isValidTool(name: string): boolean {
|
||||||
|
if (!name || typeof name !== "string" || name.trim() === "") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if (toolNames.has(name)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const prefix of mcpPrefixes) {
|
||||||
|
if (name.startsWith(prefix)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
|
||||||
|
getAllToolNames(): string[] {
|
||||||
|
return Array.from(toolNames).sort()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user