diff --git a/src/hooks/todo-continuation-enforcer.ts b/src/hooks/todo-continuation-enforcer.ts index f75e58f..2ab3245 100644 --- a/src/hooks/todo-continuation-enforcer.ts +++ b/src/hooks/todo-continuation-enforcer.ts @@ -8,7 +8,6 @@ import { } from "../features/hook-message-injector" import type { BackgroundManager } from "../features/background-agent" import { log } from "../shared/logger" -import { isNonInteractive } from "./non-interactive-env/detector" const HOOK_NAME = "todo-continuation-enforcer" @@ -37,6 +36,32 @@ Incomplete tasks remain in your todo list. Continue working on the next pending - Mark each task complete when finished - Do not stop until all tasks are done` +const COUNTDOWN_SECONDS = 2 +const TOAST_DURATION_MS = 900 +const MIN_INJECTION_INTERVAL_MS = 10_000 + +// ============================================================================ +// STATE MACHINE TYPES +// ============================================================================ + +type SessionMode = + | "idle" // Observed idle, no countdown started yet + | "countingDown" // Waiting N seconds before injecting + | "injecting" // Currently calling session.prompt + | "recovering" // Session recovery in progress (external control) + | "errorBypass" // Bypass mode after session.error/interrupt + +interface SessionState { + version: number // Monotonic generation token - increment to invalidate pending callbacks + mode: SessionMode + timer?: ReturnType // Pending countdown timer + lastInjectedAt?: number // Timestamp of last injection (anti-spam) +} + +// ============================================================================ +// HELPER FUNCTIONS +// ============================================================================ + function getMessageDir(sessionID: string): string | null { if (!existsSync(MESSAGE_STORAGE)) return null @@ -68,104 +93,338 @@ function detectInterrupt(error: unknown): boolean { return false } -const COUNTDOWN_SECONDS = 2 -const TOAST_DURATION_MS = 900 // Slightly less than 1s so toasts don't overlap - -interface CountdownState { - secondsRemaining: number - intervalId: ReturnType +function getIncompleteCount(todos: Todo[]): number { + return todos.filter(t => t.status !== "completed" && t.status !== "cancelled").length } +// ============================================================================ +// MAIN IMPLEMENTATION +// ============================================================================ + export function createTodoContinuationEnforcer( ctx: PluginInput, options: TodoContinuationEnforcerOptions = {} ): TodoContinuationEnforcer { const { backgroundManager } = options - const remindedSessions = new Set() - const interruptedSessions = new Set() - const errorSessions = new Set() - const recoveringSessions = new Set() - const pendingCountdowns = new Map() - const preemptivelyInjectedSessions = new Set() + + // Single source of truth: per-session state machine + const sessions = new Map() + + // ============================================================================ + // STATE HELPERS + // ============================================================================ + + function getOrCreateState(sessionID: string): SessionState { + let state = sessions.get(sessionID) + if (!state) { + state = { version: 0, mode: "idle" } + sessions.set(sessionID, state) + } + return state + } + + function clearTimer(state: SessionState): void { + if (state.timer) { + clearTimeout(state.timer) + state.timer = undefined + } + } + + /** + * Cancel any pending countdown by incrementing version and clearing timer. + * This invalidates any async callbacks that were started with the old version. + */ + function cancelCountdown(sessionID: string, reason: string): void { + const state = sessions.get(sessionID) + if (!state) return + + if (state.mode === "countingDown" || state.timer) { + state.version++ + clearTimer(state) + state.mode = "idle" + log(`[${HOOK_NAME}] Countdown cancelled`, { sessionID, reason, newVersion: state.version }) + } + } + + /** + * Check if this is the main session (not a subagent session). + */ + function isMainSession(sessionID: string): boolean { + const mainSessionID = getMainSessionID() + // If no main session is set, allow all. If set, only allow main. + return !mainSessionID || sessionID === mainSessionID + } + + // ============================================================================ + // EXTERNAL API + // ============================================================================ const markRecovering = (sessionID: string): void => { - recoveringSessions.add(sessionID) + const state = getOrCreateState(sessionID) + cancelCountdown(sessionID, "entering recovery mode") + state.mode = "recovering" + log(`[${HOOK_NAME}] Session marked as recovering`, { sessionID }) } const markRecoveryComplete = (sessionID: string): void => { - recoveringSessions.delete(sessionID) + const state = sessions.get(sessionID) + if (state && state.mode === "recovering") { + state.mode = "idle" + log(`[${HOOK_NAME}] Session recovery complete`, { sessionID }) + } } + // ============================================================================ + // TOAST HELPER + // ============================================================================ + + async function showCountdownToast(seconds: number, incompleteCount: number): Promise { + await ctx.client.tui.showToast({ + body: { + title: "Todo Continuation", + message: `Resuming in ${seconds}s... (${incompleteCount} tasks remaining)`, + variant: "warning" as const, + duration: TOAST_DURATION_MS, + }, + }).catch(() => {}) + } + + // ============================================================================ + // CORE INJECTION LOGIC + // ============================================================================ + + async function executeInjection(sessionID: string, capturedVersion: number): Promise { + const state = sessions.get(sessionID) + if (!state) return + + // Version check: if version changed since we started, abort + if (state.version !== capturedVersion) { + log(`[${HOOK_NAME}] Injection aborted: version mismatch`, { + sessionID, capturedVersion, currentVersion: state.version + }) + return + } + + // Mode check: must still be in countingDown mode + if (state.mode !== "countingDown") { + log(`[${HOOK_NAME}] Injection aborted: mode changed`, { + sessionID, mode: state.mode + }) + return + } + + // Throttle check: minimum interval between injections + if (state.lastInjectedAt) { + const elapsed = Date.now() - state.lastInjectedAt + if (elapsed < MIN_INJECTION_INTERVAL_MS) { + log(`[${HOOK_NAME}] Injection throttled: too soon since last injection`, { + sessionID, elapsedMs: elapsed, minIntervalMs: MIN_INJECTION_INTERVAL_MS + }) + state.mode = "idle" + return + } + } + + state.mode = "injecting" + + // Re-verify todos (CRITICAL: always re-check before injecting) + let todos: Todo[] = [] + try { + const response = await ctx.client.session.todo({ path: { id: sessionID } }) + todos = (response.data ?? response) as Todo[] + } catch (err) { + log(`[${HOOK_NAME}] Failed to fetch todos for injection`, { sessionID, error: String(err) }) + state.mode = "idle" + return + } + + // Version check again after async operation + if (state.version !== capturedVersion) { + log(`[${HOOK_NAME}] Injection aborted after todo fetch: version mismatch`, { sessionID }) + state.mode = "idle" + return + } + + const incompleteCount = getIncompleteCount(todos) + if (incompleteCount === 0) { + log(`[${HOOK_NAME}] No incomplete todos at injection time`, { sessionID, total: todos.length }) + state.mode = "idle" + return + } + + // Skip entirely if background tasks are running (no false positives) + const hasRunningBgTasks = backgroundManager + ? backgroundManager.getTasksByParentSession(sessionID).some((t) => t.status === "running") + : false + + if (hasRunningBgTasks) { + log(`[${HOOK_NAME}] Skipped: background tasks still running`, { sessionID }) + state.mode = "idle" + return + } + + // Get previous message agent info + const messageDir = getMessageDir(sessionID) + const prevMessage = messageDir ? findNearestMessageWithFields(messageDir) : null + + // Check write permission + const agentHasWritePermission = !prevMessage?.tools || + (prevMessage.tools.write !== false && prevMessage.tools.edit !== false) + + if (!agentHasWritePermission) { + log(`[${HOOK_NAME}] Skipped: agent lacks write permission`, { + sessionID, agent: prevMessage?.agent, tools: prevMessage?.tools + }) + state.mode = "idle" + return + } + + const prompt = `${CONTINUATION_PROMPT}\n\n[Status: ${todos.length - incompleteCount}/${todos.length} completed, ${incompleteCount} remaining]` + + // Final version check right before API call (last-mile race mitigation) + if (state.version !== capturedVersion) { + log(`[${HOOK_NAME}] Injection aborted: version changed before API call`, { sessionID }) + state.mode = "idle" + return + } + + try { + log(`[${HOOK_NAME}] Injecting continuation prompt`, { + sessionID, + agent: prevMessage?.agent, + incompleteCount + }) + + await ctx.client.session.prompt({ + path: { id: sessionID }, + body: { + agent: prevMessage?.agent, + parts: [{ type: "text", text: prompt }], + }, + query: { directory: ctx.directory }, + }) + + state.lastInjectedAt = Date.now() + log(`[${HOOK_NAME}] Continuation prompt injected successfully`, { sessionID }) + } catch (err) { + log(`[${HOOK_NAME}] Prompt injection failed`, { sessionID, error: String(err) }) + } + + state.mode = "idle" + } + + // ============================================================================ + // COUNTDOWN STARTER + // ============================================================================ + + function startCountdown(sessionID: string, incompleteCount: number): void { + const state = getOrCreateState(sessionID) + + // Cancel any existing countdown + cancelCountdown(sessionID, "starting new countdown") + + // Increment version for this new countdown + state.version++ + state.mode = "countingDown" + const capturedVersion = state.version + + log(`[${HOOK_NAME}] Starting countdown`, { + sessionID, + seconds: COUNTDOWN_SECONDS, + version: capturedVersion, + incompleteCount + }) + + // Show initial toast + showCountdownToast(COUNTDOWN_SECONDS, incompleteCount) + + // Show countdown toasts + let secondsRemaining = COUNTDOWN_SECONDS + const toastInterval = setInterval(() => { + // Check if countdown was cancelled + if (state.version !== capturedVersion) { + clearInterval(toastInterval) + return + } + secondsRemaining-- + if (secondsRemaining > 0) { + showCountdownToast(secondsRemaining, incompleteCount) + } + }, 1000) + + // Schedule the injection + state.timer = setTimeout(() => { + clearInterval(toastInterval) + clearTimer(state) + executeInjection(sessionID, capturedVersion) + }, COUNTDOWN_SECONDS * 1000) + } + + // ============================================================================ + // EVENT HANDLER + // ============================================================================ + const handler = async ({ event }: { event: { type: string; properties?: unknown } }): Promise => { const props = event.properties as Record | undefined + // ------------------------------------------------------------------------- + // SESSION.ERROR - Enter error bypass mode + // ------------------------------------------------------------------------- if (event.type === "session.error") { const sessionID = props?.sessionID as string | undefined - if (sessionID) { - const isInterrupt = detectInterrupt(props?.error) - errorSessions.add(sessionID) - if (isInterrupt) { - interruptedSessions.add(sessionID) - } - log(`[${HOOK_NAME}] session.error received`, { sessionID, isInterrupt, error: props?.error }) - - const countdown = pendingCountdowns.get(sessionID) - if (countdown) { - clearInterval(countdown.intervalId) - pendingCountdowns.delete(sessionID) - } - } + if (!sessionID) return + + const isInterrupt = detectInterrupt(props?.error) + const state = getOrCreateState(sessionID) + + cancelCountdown(sessionID, isInterrupt ? "user interrupt" : "session error") + state.mode = "errorBypass" + + log(`[${HOOK_NAME}] session.error received`, { sessionID, isInterrupt, error: props?.error }) return } + // ------------------------------------------------------------------------- + // SESSION.IDLE - Main trigger for todo continuation + // ------------------------------------------------------------------------- if (event.type === "session.idle") { const sessionID = props?.sessionID as string | undefined if (!sessionID) return log(`[${HOOK_NAME}] session.idle received`, { sessionID }) - const mainSessionID = getMainSessionID() - if (mainSessionID && sessionID !== mainSessionID) { - log(`[${HOOK_NAME}] Skipped: not main session`, { sessionID, mainSessionID }) + // Skip if not main session + if (!isMainSession(sessionID)) { + log(`[${HOOK_NAME}] Skipped: not main session`, { sessionID }) return } - const existingCountdown = pendingCountdowns.get(sessionID) - if (existingCountdown) { - clearInterval(existingCountdown.intervalId) - pendingCountdowns.delete(sessionID) - log(`[${HOOK_NAME}] Cancelled existing countdown`, { sessionID }) - } + const state = getOrCreateState(sessionID) - // Check if session is in recovery mode - if so, skip entirely without clearing state - if (recoveringSessions.has(sessionID)) { + // Skip if in recovery mode + if (state.mode === "recovering") { log(`[${HOOK_NAME}] Skipped: session in recovery mode`, { sessionID }) return } - const shouldBypass = interruptedSessions.has(sessionID) || errorSessions.has(sessionID) - - if (shouldBypass) { - interruptedSessions.delete(sessionID) - errorSessions.delete(sessionID) - log(`[${HOOK_NAME}] Skipped: error/interrupt bypass`, { sessionID }) + // Skip if in error bypass mode (clear it for next time) + if (state.mode === "errorBypass") { + state.mode = "idle" + log(`[${HOOK_NAME}] Skipped: error bypass (cleared for next idle)`, { sessionID }) return } - if (remindedSessions.has(sessionID)) { - log(`[${HOOK_NAME}] Skipped: already reminded this session`, { sessionID }) + // Skip if already counting down or injecting + if (state.mode === "countingDown" || state.mode === "injecting") { + log(`[${HOOK_NAME}] Skipped: already ${state.mode}`, { sessionID }) return } - // Check for incomplete todos BEFORE starting countdown + // Fetch todos let todos: Todo[] = [] try { - log(`[${HOOK_NAME}] Fetching todos for session`, { sessionID }) - const response = await ctx.client.session.todo({ - path: { id: sessionID }, - }) + const response = await ctx.client.session.todo({ path: { id: sessionID } }) todos = (response.data ?? response) as Todo[] - log(`[${HOOK_NAME}] Todo API response`, { sessionID, todosCount: todos?.length ?? 0 }) } catch (err) { log(`[${HOOK_NAME}] Todo API error`, { sessionID, error: String(err) }) return @@ -176,231 +435,93 @@ export function createTodoContinuationEnforcer( return } - const incomplete = todos.filter( - (t) => t.status !== "completed" && t.status !== "cancelled" - ) - - if (incomplete.length === 0) { + const incompleteCount = getIncompleteCount(todos) + if (incompleteCount === 0) { log(`[${HOOK_NAME}] All todos completed`, { sessionID, total: todos.length }) return } - log(`[${HOOK_NAME}] Found incomplete todos, starting countdown`, { sessionID, incomplete: incomplete.length, total: todos.length }) + log(`[${HOOK_NAME}] Found incomplete todos`, { + sessionID, + incomplete: incompleteCount, + total: todos.length + }) - const showCountdownToast = async (seconds: number): Promise => { - await ctx.client.tui.showToast({ - body: { - title: "Todo Continuation", - message: `Resuming in ${seconds}s... (${incomplete.length} tasks remaining)`, - variant: "warning" as const, - duration: TOAST_DURATION_MS, - }, - }).catch(() => {}) - } - - const executeAfterCountdown = async (): Promise => { - pendingCountdowns.delete(sessionID) - log(`[${HOOK_NAME}] Countdown finished, executing continuation`, { sessionID }) - - // Re-check conditions after countdown - if (recoveringSessions.has(sessionID)) { - log(`[${HOOK_NAME}] Abort: session entered recovery mode during countdown`, { sessionID }) - return - } - - if (interruptedSessions.has(sessionID) || errorSessions.has(sessionID)) { - log(`[${HOOK_NAME}] Abort: error/interrupt occurred during countdown`, { sessionID }) - interruptedSessions.delete(sessionID) - errorSessions.delete(sessionID) - return - } - - let freshTodos: Todo[] = [] - try { - log(`[${HOOK_NAME}] Re-verifying todos after countdown`, { sessionID }) - const response = await ctx.client.session.todo({ - path: { id: sessionID }, - }) - freshTodos = (response.data ?? response) as Todo[] - log(`[${HOOK_NAME}] Fresh todo count`, { sessionID, todosCount: freshTodos?.length ?? 0 }) - } catch (err) { - log(`[${HOOK_NAME}] Failed to re-verify todos`, { sessionID, error: String(err) }) - return - } - - const freshIncomplete = freshTodos.filter( - (t) => t.status !== "completed" && t.status !== "cancelled" - ) - - if (freshIncomplete.length === 0) { - log(`[${HOOK_NAME}] Abort: no incomplete todos after countdown`, { sessionID, total: freshTodos.length }) - return - } - - log(`[${HOOK_NAME}] Confirmed incomplete todos, proceeding with injection`, { sessionID, incomplete: freshIncomplete.length, total: freshTodos.length }) - - remindedSessions.add(sessionID) - - try { - // Get previous message's agent info to respect agent mode - const messageDir = getMessageDir(sessionID) - const prevMessage = messageDir ? findNearestMessageWithFields(messageDir) : null - - const agentHasWritePermission = !prevMessage?.tools || (prevMessage.tools.write !== false && prevMessage.tools.edit !== false) - if (!agentHasWritePermission) { - log(`[${HOOK_NAME}] Skipped: previous agent lacks write permission`, { sessionID, agent: prevMessage?.agent, tools: prevMessage?.tools }) - remindedSessions.delete(sessionID) - return - } - - log(`[${HOOK_NAME}] Injecting continuation prompt`, { sessionID, agent: prevMessage?.agent }) - await ctx.client.session.prompt({ - path: { id: sessionID }, - body: { - agent: prevMessage?.agent, - parts: [ - { - type: "text", - text: `${CONTINUATION_PROMPT}\n\n[Status: ${freshTodos.length - freshIncomplete.length}/${freshTodos.length} completed, ${freshIncomplete.length} remaining]`, - }, - ], - }, - query: { directory: ctx.directory }, - }) - log(`[${HOOK_NAME}] Continuation prompt injected successfully`, { sessionID }) - } catch (err) { - log(`[${HOOK_NAME}] Prompt injection failed`, { sessionID, error: String(err) }) - remindedSessions.delete(sessionID) - } - } - - let secondsRemaining = COUNTDOWN_SECONDS - showCountdownToast(secondsRemaining).catch(() => {}) - - const intervalId = setInterval(() => { - secondsRemaining-- - - if (secondsRemaining <= 0) { - clearInterval(intervalId) - pendingCountdowns.delete(sessionID) - executeAfterCountdown() - return - } - - const countdown = pendingCountdowns.get(sessionID) - if (!countdown) { - clearInterval(intervalId) - return - } - - countdown.secondsRemaining = secondsRemaining - showCountdownToast(secondsRemaining).catch(() => {}) - }, 1000) - - pendingCountdowns.set(sessionID, { secondsRemaining, intervalId }) + // Start countdown + startCountdown(sessionID, incompleteCount) + return } + // ------------------------------------------------------------------------- + // MESSAGE.UPDATED - Cancel countdown on activity + // ------------------------------------------------------------------------- if (event.type === "message.updated") { const info = props?.info as Record | undefined const sessionID = info?.sessionID as string | undefined const role = info?.role as string | undefined const finish = info?.finish as string | undefined - log(`[${HOOK_NAME}] message.updated received`, { sessionID, role, finish }) - - if (sessionID && role === "user") { - const countdown = pendingCountdowns.get(sessionID) - if (countdown) { - clearInterval(countdown.intervalId) - pendingCountdowns.delete(sessionID) - log(`[${HOOK_NAME}] Cancelled countdown on user message`, { sessionID }) - } - remindedSessions.delete(sessionID) - preemptivelyInjectedSessions.delete(sessionID) + + if (!sessionID) return + + // User message: Always cancel countdown + if (role === "user") { + cancelCountdown(sessionID, "user message received") + return } - if (sessionID && role === "assistant" && finish) { - remindedSessions.delete(sessionID) - preemptivelyInjectedSessions.delete(sessionID) - log(`[${HOOK_NAME}] Cleared reminded/preemptive state on assistant finish`, { sessionID }) - - const isTerminalFinish = finish && !["tool-calls", "unknown"].includes(finish) - if (isTerminalFinish && isNonInteractive()) { - log(`[${HOOK_NAME}] Terminal finish in non-interactive mode`, { sessionID, finish }) - - const mainSessionID = getMainSessionID() - if (mainSessionID && sessionID !== mainSessionID) { - log(`[${HOOK_NAME}] Skipped preemptive: not main session`, { sessionID, mainSessionID }) - return - } - - if (preemptivelyInjectedSessions.has(sessionID)) { - log(`[${HOOK_NAME}] Skipped preemptive: already injected`, { sessionID }) - return - } - - if (recoveringSessions.has(sessionID) || errorSessions.has(sessionID) || interruptedSessions.has(sessionID)) { - log(`[${HOOK_NAME}] Skipped preemptive: session in error/recovery state`, { sessionID }) - return - } - - const hasRunningBgTasks = backgroundManager - ? backgroundManager.getTasksByParentSession(sessionID).some((t) => t.status === "running") - : false - - let hasIncompleteTodos = false - try { - const response = await ctx.client.session.todo({ path: { id: sessionID } }) - const todos = (response.data ?? response) as Todo[] - hasIncompleteTodos = todos?.some((t) => t.status !== "completed" && t.status !== "cancelled") ?? false - } catch { - log(`[${HOOK_NAME}] Failed to fetch todos for preemptive check`, { sessionID }) - } - - if (hasRunningBgTasks || hasIncompleteTodos) { - log(`[${HOOK_NAME}] Preemptive injection needed`, { sessionID, hasRunningBgTasks, hasIncompleteTodos }) - preemptivelyInjectedSessions.add(sessionID) - - try { - const messageDir = getMessageDir(sessionID) - const prevMessage = messageDir ? findNearestMessageWithFields(messageDir) : null - - const prompt = hasRunningBgTasks - ? "[SYSTEM] Background tasks are still running. Wait for their completion before proceeding." - : CONTINUATION_PROMPT - - await ctx.client.session.prompt({ - path: { id: sessionID }, - body: { - agent: prevMessage?.agent, - parts: [{ type: "text", text: prompt }], - }, - query: { directory: ctx.directory }, - }) - log(`[${HOOK_NAME}] Preemptive injection successful`, { sessionID }) - } catch (err) { - log(`[${HOOK_NAME}] Preemptive injection failed`, { sessionID, error: String(err) }) - preemptivelyInjectedSessions.delete(sessionID) - } - } - } + // Assistant message WITHOUT finish: Agent is working, cancel countdown + if (role === "assistant" && !finish) { + cancelCountdown(sessionID, "assistant is working (streaming)") + return } + + // Assistant message WITH finish: Agent finished a turn (let session.idle handle it) + if (role === "assistant" && finish) { + log(`[${HOOK_NAME}] Assistant turn finished`, { sessionID, finish }) + return + } + return } + // ------------------------------------------------------------------------- + // MESSAGE.PART.UPDATED - Cancel countdown on streaming activity + // ------------------------------------------------------------------------- + if (event.type === "message.part.updated") { + const info = props?.info as Record | undefined + const sessionID = info?.sessionID as string | undefined + const role = info?.role as string | undefined + + if (sessionID && role === "assistant") { + cancelCountdown(sessionID, "assistant streaming") + } + return + } + + // ------------------------------------------------------------------------- + // TOOL EVENTS - Cancel countdown when tools are executing + // ------------------------------------------------------------------------- + if (event.type === "tool.execute.before" || event.type === "tool.execute.after") { + const sessionID = props?.sessionID as string | undefined + if (sessionID) { + cancelCountdown(sessionID, `tool execution (${event.type})`) + } + return + } + + // ------------------------------------------------------------------------- + // SESSION.DELETED - Cleanup + // ------------------------------------------------------------------------- if (event.type === "session.deleted") { const sessionInfo = props?.info as { id?: string } | undefined if (sessionInfo?.id) { - remindedSessions.delete(sessionInfo.id) - interruptedSessions.delete(sessionInfo.id) - errorSessions.delete(sessionInfo.id) - recoveringSessions.delete(sessionInfo.id) - preemptivelyInjectedSessions.delete(sessionInfo.id) - - const countdown = pendingCountdowns.get(sessionInfo.id) - if (countdown) { - clearInterval(countdown.intervalId) - pendingCountdowns.delete(sessionInfo.id) + const state = sessions.get(sessionInfo.id) + if (state) { + clearTimer(state) } + sessions.delete(sessionInfo.id) + log(`[${HOOK_NAME}] Session deleted, state cleaned up`, { sessionID: sessionInfo.id }) } + return } }