diff --git a/src/hooks/think-mode/switcher.test.ts b/src/hooks/think-mode/switcher.test.ts index 8791b59..cdd1cb0 100644 --- a/src/hooks/think-mode/switcher.test.ts +++ b/src/hooks/think-mode/switcher.test.ts @@ -322,4 +322,140 @@ describe("think-mode switcher", () => { expect(config.maxTokens).toBe(64000) }) }) + + describe("Custom provider prefixes support", () => { + describe("getHighVariant with prefixes", () => { + it("should preserve vertex_ai/ prefix when getting high variant", () => { + // #given a model ID with vertex_ai/ prefix + const variant = getHighVariant("vertex_ai/claude-sonnet-4-5") + + // #then should return high variant with prefix preserved + expect(variant).toBe("vertex_ai/claude-sonnet-4-5-high") + }) + + it("should preserve openai/ prefix when getting high variant", () => { + // #given a model ID with openai/ prefix + const variant = getHighVariant("openai/gpt-5-2") + + // #then should return high variant with prefix preserved + expect(variant).toBe("openai/gpt-5-2-high") + }) + + it("should handle prefixes with dots in version numbers", () => { + // #given a model ID with prefix and dots + const variant = getHighVariant("vertex_ai/claude-opus-4.5") + + // #then should normalize dots and preserve prefix + expect(variant).toBe("vertex_ai/claude-opus-4-5-high") + }) + + it("should handle multiple different prefixes", () => { + // #given various custom prefixes + expect(getHighVariant("azure/gpt-5")).toBe("azure/gpt-5-high") + expect(getHighVariant("bedrock/claude-sonnet-4-5")).toBe("bedrock/claude-sonnet-4-5-high") + expect(getHighVariant("custom-llm/gemini-3-pro")).toBe("custom-llm/gemini-3-pro-high") + }) + + it("should return null for prefixed models without high variant mapping", () => { + // #given prefixed model IDs without high variant mapping + expect(getHighVariant("vertex_ai/unknown-model")).toBeNull() + expect(getHighVariant("custom/llama-3-70b")).toBeNull() + }) + + it("should return null for already-high prefixed models", () => { + // #given prefixed model IDs that are already high + expect(getHighVariant("vertex_ai/claude-opus-4-5-high")).toBeNull() + expect(getHighVariant("openai/gpt-5-2-high")).toBeNull() + }) + }) + + describe("isAlreadyHighVariant with prefixes", () => { + it("should detect -high suffix in prefixed models", () => { + // #given prefixed model IDs with -high suffix + expect(isAlreadyHighVariant("vertex_ai/claude-opus-4-5-high")).toBe(true) + expect(isAlreadyHighVariant("openai/gpt-5-2-high")).toBe(true) + expect(isAlreadyHighVariant("custom/gemini-3-pro-high")).toBe(true) + }) + + it("should return false for prefixed base models", () => { + // #given prefixed base model IDs without -high suffix + expect(isAlreadyHighVariant("vertex_ai/claude-opus-4-5")).toBe(false) + expect(isAlreadyHighVariant("openai/gpt-5-2")).toBe(false) + }) + + it("should handle prefixed models with dots", () => { + // #given prefixed model IDs with dots + expect(isAlreadyHighVariant("vertex_ai/gpt-5.2")).toBe(false) + expect(isAlreadyHighVariant("vertex_ai/gpt-5.2-high")).toBe(true) + }) + }) + + describe("getThinkingConfig with prefixes", () => { + it("should return null for custom providers (not in THINKING_CONFIGS)", () => { + // #given custom provider with prefixed Claude model + const config = getThinkingConfig("dia-llm", "vertex_ai/claude-sonnet-4-5") + + // #then should return null (custom provider not in THINKING_CONFIGS) + expect(config).toBeNull() + }) + + it("should work with prefixed models on known providers", () => { + // #given known provider (anthropic) with prefixed model + // This tests that the base model name is correctly extracted for capability check + const config = getThinkingConfig("anthropic", "custom-prefix/claude-opus-4-5") + + // #then should return thinking config (base model is capable) + expect(config).not.toBeNull() + expect(config?.thinking).toBeDefined() + }) + + it("should return null for prefixed models that are already high", () => { + // #given prefixed already-high model + const config = getThinkingConfig("anthropic", "vertex_ai/claude-opus-4-5-high") + + // #then should return null + expect(config).toBeNull() + }) + }) + + describe("Real-world custom provider scenario", () => { + it("should handle LLM proxy with vertex_ai prefix correctly", () => { + // #given a custom LLM proxy provider using vertex_ai/ prefix + const providerID = "dia-llm" + const modelID = "vertex_ai/claude-sonnet-4-5" + + // #when getting high variant + const highVariant = getHighVariant(modelID) + + // #then should preserve the prefix + expect(highVariant).toBe("vertex_ai/claude-sonnet-4-5-high") + + // #and when checking if already high + expect(isAlreadyHighVariant(modelID)).toBe(false) + expect(isAlreadyHighVariant(highVariant!)).toBe(true) + + // #and when getting thinking config for custom provider + const config = getThinkingConfig(providerID, modelID) + + // #then should return null (custom provider, not anthropic) + // This prevents applying incompatible thinking configs to custom providers + expect(config).toBeNull() + }) + + it("should not break when switching to high variant in think mode", () => { + // #given think mode switching vertex_ai/claude model to high variant + const original = "vertex_ai/claude-opus-4-5" + const high = getHighVariant(original) + + // #then the high variant should be valid + expect(high).toBe("vertex_ai/claude-opus-4-5-high") + + // #and should be recognized as already high + expect(isAlreadyHighVariant(high!)).toBe(true) + + // #and switching again should return null (already high) + expect(getHighVariant(high!)).toBeNull() + }) + }) + }) }) diff --git a/src/hooks/think-mode/switcher.ts b/src/hooks/think-mode/switcher.ts index ddf8f76..e99ce54 100644 --- a/src/hooks/think-mode/switcher.ts +++ b/src/hooks/think-mode/switcher.ts @@ -16,6 +16,26 @@ * inconsistencies defensively while maintaining backwards compatibility. */ +/** + * Extracts provider-specific prefix from model ID (if present). + * Custom providers may use prefixes for routing (e.g., vertex_ai/, openai/). + * + * @example + * extractModelPrefix("vertex_ai/claude-sonnet-4-5") // { prefix: "vertex_ai/", base: "claude-sonnet-4-5" } + * extractModelPrefix("claude-sonnet-4-5") // { prefix: "", base: "claude-sonnet-4-5" } + * extractModelPrefix("openai/gpt-5.2") // { prefix: "openai/", base: "gpt-5.2" } + */ +function extractModelPrefix(modelID: string): { prefix: string; base: string } { + const slashIndex = modelID.indexOf("/") + if (slashIndex === -1) { + return { prefix: "", base: modelID } + } + return { + prefix: modelID.slice(0, slashIndex + 1), + base: modelID.slice(slashIndex + 1), + } +} + /** * Normalizes model IDs to use consistent hyphen formatting. * GitHub Copilot may use dots (claude-opus-4.5) but our maps use hyphens (claude-opus-4-5). @@ -25,6 +45,7 @@ * normalizeModelID("claude-opus-4.5") // "claude-opus-4-5" * normalizeModelID("gemini-3.5-pro") // "gemini-3-5-pro" * normalizeModelID("gpt-5.2") // "gpt-5-2" + * normalizeModelID("vertex_ai/claude-opus-4.5") // "vertex_ai/claude-opus-4-5" */ function normalizeModelID(modelID: string): string { // Replace dots with hyphens when followed by a digit @@ -142,16 +163,27 @@ const THINKING_CAPABLE_MODELS = { export function getHighVariant(modelID: string): string | null { const normalized = normalizeModelID(modelID) + const { prefix, base } = extractModelPrefix(normalized) - if (ALREADY_HIGH.has(normalized)) { + // Check if already high variant (with or without prefix) + if (ALREADY_HIGH.has(base) || base.endsWith("-high")) { return null } - return HIGH_VARIANT_MAP[normalized] ?? null + + // Look up high variant for base model + const highBase = HIGH_VARIANT_MAP[base] + if (!highBase) { + return null + } + + // Preserve prefix in the high variant + return prefix + highBase } export function isAlreadyHighVariant(modelID: string): boolean { const normalized = normalizeModelID(modelID) - return ALREADY_HIGH.has(normalized) || normalized.endsWith("-high") + const { base } = extractModelPrefix(normalized) + return ALREADY_HIGH.has(base) || base.endsWith("-high") } type ThinkingProvider = keyof typeof THINKING_CONFIGS @@ -165,6 +197,7 @@ export function getThinkingConfig( modelID: string ): Record | null { const normalized = normalizeModelID(modelID) + const { base } = extractModelPrefix(normalized) if (isAlreadyHighVariant(normalized)) { return null @@ -179,9 +212,10 @@ export function getThinkingConfig( const config = THINKING_CONFIGS[resolvedProvider] const capablePatterns = THINKING_CAPABLE_MODELS[resolvedProvider] - const modelLower = normalized.toLowerCase() + // Check capability using base model name (without prefix) + const baseLower = base.toLowerCase() const isCapable = capablePatterns.some((pattern) => - modelLower.includes(pattern.toLowerCase()) + baseLower.includes(pattern.toLowerCase()) ) return isCapable ? config : null