diff --git a/src/hooks/todo-continuation-enforcer.test.ts b/src/hooks/todo-continuation-enforcer.test.ts index c98b1de..91bf6e3 100644 --- a/src/hooks/todo-continuation-enforcer.test.ts +++ b/src/hooks/todo-continuation-enforcer.test.ts @@ -1,11 +1,11 @@ -import { describe, expect, test, beforeEach, afterEach, mock } from "bun:test" +import { afterEach, beforeEach, describe, expect, test } from "bun:test" -import { createTodoContinuationEnforcer } from "./todo-continuation-enforcer" -import { setMainSession, subagentSessions } from "../features/claude-code-session-state" import type { BackgroundManager } from "../features/background-agent" +import { setMainSession, subagentSessions } from "../features/claude-code-session-state" +import { createTodoContinuationEnforcer } from "./todo-continuation-enforcer" describe("todo-continuation-enforcer", () => { - let promptCalls: Array<{ sessionID: string; agent?: string; text: string }> + let promptCalls: Array<{ sessionID: string; agent?: string; model?: { providerID?: string; modelID?: string }; text: string }> let toastCalls: Array<{ title: string; message: string }> function createMockPluginInput() { @@ -20,6 +20,7 @@ describe("todo-continuation-enforcer", () => { promptCalls.push({ sessionID: opts.path.id, agent: opts.body.agent, + model: opts.body.model, text: opts.body.parts[0].text, }) return {} @@ -41,8 +42,8 @@ describe("todo-continuation-enforcer", () => { function createMockBackgroundManager(runningTasks: boolean = false): BackgroundManager { return { - getTasksByParentSession: () => runningTasks - ? [{ status: "running" }] + getTasksByParentSession: () => runningTasks + ? [{ status: "running" }] : [], } as any } @@ -229,9 +230,9 @@ describe("todo-continuation-enforcer", () => { // #when - user sends message immediately (before 2s countdown) await hook.handler({ - event: { - type: "message.updated", - properties: { info: { sessionID, role: "user" } } + event: { + type: "message.updated", + properties: { info: { sessionID, role: "user" } } }, }) @@ -255,9 +256,9 @@ describe("todo-continuation-enforcer", () => { // #when - assistant starts responding await new Promise(r => setTimeout(r, 500)) await hook.handler({ - event: { - type: "message.part.updated", - properties: { info: { sessionID, role: "assistant" } } + event: { + type: "message.part.updated", + properties: { info: { sessionID, role: "assistant" } } }, }) @@ -418,12 +419,12 @@ describe("todo-continuation-enforcer", () => { // #when - abort error occurs (with abort-specific error) await hook.handler({ - event: { - type: "session.error", - properties: { - sessionID, - error: { name: "MessageAbortedError", message: "The operation was aborted" } - } + event: { + type: "session.error", + properties: { + sessionID, + error: { name: "MessageAbortedError", message: "The operation was aborted" } + } }, }) @@ -447,20 +448,20 @@ describe("todo-continuation-enforcer", () => { // #when - abort error occurs await hook.handler({ - event: { - type: "session.error", - properties: { - sessionID, - error: { name: "MessageAbortedError", message: "The operation was aborted" } - } + event: { + type: "session.error", + properties: { + sessionID, + error: { name: "MessageAbortedError", message: "The operation was aborted" } + } }, }) // #when - assistant sends a message (intervening event clears abort state) await hook.handler({ - event: { - type: "message.updated", - properties: { info: { sessionID, role: "assistant" } } + event: { + type: "message.updated", + properties: { info: { sessionID, role: "assistant" } } }, }) @@ -484,12 +485,12 @@ describe("todo-continuation-enforcer", () => { // #when - abort error occurs await hook.handler({ - event: { - type: "session.error", - properties: { - sessionID, - error: { message: "aborted" } - } + event: { + type: "session.error", + properties: { + sessionID, + error: { message: "aborted" } + } }, }) @@ -518,12 +519,12 @@ describe("todo-continuation-enforcer", () => { // #when - non-abort error occurs (e.g., network error, API error) await hook.handler({ - event: { - type: "session.error", - properties: { - sessionID, - error: { name: "NetworkError", message: "Connection failed" } - } + event: { + type: "session.error", + properties: { + sessionID, + error: { name: "NetworkError", message: "Connection failed" } + } }, }) @@ -547,12 +548,12 @@ describe("todo-continuation-enforcer", () => { // #when - abort error occurs await hook.handler({ - event: { - type: "session.error", - properties: { - sessionID, - error: { name: "AbortError", message: "cancelled" } - } + event: { + type: "session.error", + properties: { + sessionID, + error: { name: "AbortError", message: "cancelled" } + } }, }) @@ -584,17 +585,17 @@ describe("todo-continuation-enforcer", () => { // #when - first abort error await hook.handler({ - event: { - type: "session.error", - properties: { sessionID, error: { message: "aborted" } } + event: { + type: "session.error", + properties: { sessionID, error: { message: "aborted" } } }, }) // #when - second abort error (immediately before idle) await hook.handler({ - event: { - type: "session.error", - properties: { sessionID, error: { message: "interrupted" } } + event: { + type: "session.error", + properties: { sessionID, error: { message: "interrupted" } } }, }) diff --git a/src/hooks/todo-continuation-enforcer.ts b/src/hooks/todo-continuation-enforcer.ts index 5300d50..1812d3f 100644 --- a/src/hooks/todo-continuation-enforcer.ts +++ b/src/hooks/todo-continuation-enforcer.ts @@ -1,12 +1,12 @@ +import type { PluginInput } from "@opencode-ai/plugin" import { existsSync, readdirSync } from "node:fs" import { join } from "node:path" -import type { PluginInput } from "@opencode-ai/plugin" +import type { BackgroundManager } from "../features/background-agent" import { getMainSessionID, subagentSessions } from "../features/claude-code-session-state" import { - findNearestMessageWithFields, - MESSAGE_STORAGE, + findNearestMessageWithFields, + MESSAGE_STORAGE, } from "../features/hook-message-injector" -import type { BackgroundManager } from "../features/background-agent" import { log } from "../shared/logger" const HOOK_NAME = "todo-continuation-enforcer" @@ -62,22 +62,22 @@ function getMessageDir(sessionID: string): string | null { function isAbortError(error: unknown): boolean { if (!error) return false - + if (typeof error === "object") { const errObj = error as Record const name = errObj.name as string | undefined const message = (errObj.message as string | undefined)?.toLowerCase() ?? "" - + if (name === "MessageAbortedError" || name === "AbortError") return true if (name === "DOMException" && message.includes("abort")) return true if (message.includes("aborted") || message.includes("cancelled") || message.includes("interrupted")) return true } - + if (typeof error === "string") { const lower = error.toLowerCase() return lower.includes("abort") || lower.includes("cancel") || lower.includes("interrupt") } - + return false } @@ -104,7 +104,7 @@ export function createTodoContinuationEnforcer( function cancelCountdown(sessionID: string): void { const state = sessions.get(sessionID) if (!state) return - + if (state.countdownTimer) { clearTimeout(state.countdownTimer) state.countdownTimer = undefined @@ -148,7 +148,7 @@ export function createTodoContinuationEnforcer( async function injectContinuation(sessionID: string, incompleteCount: number, total: number): Promise { const state = sessions.get(sessionID) - + if (state?.isRecovering) { log(`[${HOOK_NAME}] Skipped injection: in recovery`, { sessionID }) return @@ -183,9 +183,9 @@ export function createTodoContinuationEnforcer( const messageDir = getMessageDir(sessionID) const prevMessage = messageDir ? findNearestMessageWithFields(messageDir) : null - const hasWritePermission = !prevMessage?.tools || + const hasWritePermission = !prevMessage?.tools || (prevMessage.tools.write !== false && prevMessage.tools.edit !== false) - + if (!hasWritePermission) { log(`[${HOOK_NAME}] Skipped: agent lacks write permission`, { sessionID, agent: prevMessage?.agent }) return @@ -199,18 +199,23 @@ export function createTodoContinuationEnforcer( const prompt = `${CONTINUATION_PROMPT}\n\n[Status: ${todos.length - freshIncompleteCount}/${todos.length} completed, ${freshIncompleteCount} remaining]` + const modelField = prevMessage?.model?.providerID && prevMessage?.model?.modelID + ? { providerID: prevMessage.model.providerID, modelID: prevMessage.model.modelID } + : undefined + try { - log(`[${HOOK_NAME}] Injecting continuation`, { sessionID, agent: prevMessage?.agent, incompleteCount: freshIncompleteCount }) - + log(`[${HOOK_NAME}] Injecting continuation`, { sessionID, agent: prevMessage?.agent, model: modelField, incompleteCount: freshIncompleteCount }) + await ctx.client.session.prompt({ path: { id: sessionID }, body: { agent: prevMessage?.agent, + model: modelField, parts: [{ type: "text", text: prompt }], }, query: { directory: ctx.directory }, }) - + log(`[${HOOK_NAME}] Injection successful`, { sessionID }) } catch (err) { log(`[${HOOK_NAME}] Injection failed`, { sessionID, error: String(err) }) @@ -250,7 +255,7 @@ export function createTodoContinuationEnforcer( const isAbort = isAbortError(props?.error) state.lastEventWasAbortError = isAbort cancelCountdown(sessionID) - + log(`[${HOOK_NAME}] session.error`, { sessionID, isAbort }) return } @@ -264,7 +269,7 @@ export function createTodoContinuationEnforcer( const mainSessionID = getMainSessionID() const isMainSession = sessionID === mainSessionID const isBackgroundTaskSession = subagentSessions.has(sessionID) - + if (mainSessionID && !isMainSession && !isBackgroundTaskSession) { log(`[${HOOK_NAME}] Skipped: not main or background task session`, { sessionID }) return