fix(todo-continuation-enforcer): preserve model/provider from nearest message

When injecting continuation prompts, extract and pass the model field

(providerID + modelID) from the nearest stored message, matching the

pattern used in background-agent/manager.ts and session-recovery.

Also updated tests to capture the model field for verification.
This commit is contained in:
YeonGyu-Kim
2026-01-02 19:07:44 +09:00
parent bf3dd91da2
commit 823f12d88d
2 changed files with 74 additions and 68 deletions

View File

@@ -1,12 +1,12 @@
import type { PluginInput } from "@opencode-ai/plugin"
import { existsSync, readdirSync } from "node:fs"
import { join } from "node:path"
import type { PluginInput } from "@opencode-ai/plugin"
import type { BackgroundManager } from "../features/background-agent"
import { getMainSessionID, subagentSessions } from "../features/claude-code-session-state"
import {
findNearestMessageWithFields,
MESSAGE_STORAGE,
findNearestMessageWithFields,
MESSAGE_STORAGE,
} from "../features/hook-message-injector"
import type { BackgroundManager } from "../features/background-agent"
import { log } from "../shared/logger"
const HOOK_NAME = "todo-continuation-enforcer"
@@ -62,22 +62,22 @@ function getMessageDir(sessionID: string): string | null {
function isAbortError(error: unknown): boolean {
if (!error) return false
if (typeof error === "object") {
const errObj = error as Record<string, unknown>
const name = errObj.name as string | undefined
const message = (errObj.message as string | undefined)?.toLowerCase() ?? ""
if (name === "MessageAbortedError" || name === "AbortError") return true
if (name === "DOMException" && message.includes("abort")) return true
if (message.includes("aborted") || message.includes("cancelled") || message.includes("interrupted")) return true
}
if (typeof error === "string") {
const lower = error.toLowerCase()
return lower.includes("abort") || lower.includes("cancel") || lower.includes("interrupt")
}
return false
}
@@ -104,7 +104,7 @@ export function createTodoContinuationEnforcer(
function cancelCountdown(sessionID: string): void {
const state = sessions.get(sessionID)
if (!state) return
if (state.countdownTimer) {
clearTimeout(state.countdownTimer)
state.countdownTimer = undefined
@@ -148,7 +148,7 @@ export function createTodoContinuationEnforcer(
async function injectContinuation(sessionID: string, incompleteCount: number, total: number): Promise<void> {
const state = sessions.get(sessionID)
if (state?.isRecovering) {
log(`[${HOOK_NAME}] Skipped injection: in recovery`, { sessionID })
return
@@ -183,9 +183,9 @@ export function createTodoContinuationEnforcer(
const messageDir = getMessageDir(sessionID)
const prevMessage = messageDir ? findNearestMessageWithFields(messageDir) : null
const hasWritePermission = !prevMessage?.tools ||
const hasWritePermission = !prevMessage?.tools ||
(prevMessage.tools.write !== false && prevMessage.tools.edit !== false)
if (!hasWritePermission) {
log(`[${HOOK_NAME}] Skipped: agent lacks write permission`, { sessionID, agent: prevMessage?.agent })
return
@@ -199,18 +199,23 @@ export function createTodoContinuationEnforcer(
const prompt = `${CONTINUATION_PROMPT}\n\n[Status: ${todos.length - freshIncompleteCount}/${todos.length} completed, ${freshIncompleteCount} remaining]`
const modelField = prevMessage?.model?.providerID && prevMessage?.model?.modelID
? { providerID: prevMessage.model.providerID, modelID: prevMessage.model.modelID }
: undefined
try {
log(`[${HOOK_NAME}] Injecting continuation`, { sessionID, agent: prevMessage?.agent, incompleteCount: freshIncompleteCount })
log(`[${HOOK_NAME}] Injecting continuation`, { sessionID, agent: prevMessage?.agent, model: modelField, incompleteCount: freshIncompleteCount })
await ctx.client.session.prompt({
path: { id: sessionID },
body: {
agent: prevMessage?.agent,
model: modelField,
parts: [{ type: "text", text: prompt }],
},
query: { directory: ctx.directory },
})
log(`[${HOOK_NAME}] Injection successful`, { sessionID })
} catch (err) {
log(`[${HOOK_NAME}] Injection failed`, { sessionID, error: String(err) })
@@ -250,7 +255,7 @@ export function createTodoContinuationEnforcer(
const isAbort = isAbortError(props?.error)
state.lastEventWasAbortError = isAbort
cancelCountdown(sessionID)
log(`[${HOOK_NAME}] session.error`, { sessionID, isAbort })
return
}
@@ -264,7 +269,7 @@ export function createTodoContinuationEnforcer(
const mainSessionID = getMainSessionID()
const isMainSession = sessionID === mainSessionID
const isBackgroundTaskSession = subagentSessions.has(sessionID)
if (mainSessionID && !isMainSession && !isBackgroundTaskSession) {
log(`[${HOOK_NAME}] Skipped: not main or background task session`, { sessionID })
return