fix: preserve custom provider prefixes in think mode model switching (#451)

When using custom providers with model ID prefixes (e.g., vertex_ai/claude-sonnet-4-5),
the think mode switcher was stripping the prefix when mapping to high variants,
causing routing failures in custom LLM proxies.

Changes:
- Add extractModelPrefix() to parse and preserve prefixes like vertex_ai/, openai/, etc.
- Update getHighVariant() to preserve prefix when mapping to -high variants
- Update isAlreadyHighVariant() to check base model name (without prefix)
- Update getThinkingConfig() to check capability using base model name
- Add comprehensive tests for custom provider prefix scenarios

This fix ensures backward compatibility while supporting custom providers
that use prefixed model IDs for routing.

Fixes issue where think mode would break custom providers with prefixed models
by stripping the routing prefix during model variant switching.
This commit is contained in:
hqone
2026-01-03 15:21:44 +01:00
committed by GitHub
parent fc76ea9d93
commit 6fbc5ba582
2 changed files with 175 additions and 5 deletions

View File

@@ -322,4 +322,140 @@ describe("think-mode switcher", () => {
expect(config.maxTokens).toBe(64000)
})
})
describe("Custom provider prefixes support", () => {
describe("getHighVariant with prefixes", () => {
it("should preserve vertex_ai/ prefix when getting high variant", () => {
// #given a model ID with vertex_ai/ prefix
const variant = getHighVariant("vertex_ai/claude-sonnet-4-5")
// #then should return high variant with prefix preserved
expect(variant).toBe("vertex_ai/claude-sonnet-4-5-high")
})
it("should preserve openai/ prefix when getting high variant", () => {
// #given a model ID with openai/ prefix
const variant = getHighVariant("openai/gpt-5-2")
// #then should return high variant with prefix preserved
expect(variant).toBe("openai/gpt-5-2-high")
})
it("should handle prefixes with dots in version numbers", () => {
// #given a model ID with prefix and dots
const variant = getHighVariant("vertex_ai/claude-opus-4.5")
// #then should normalize dots and preserve prefix
expect(variant).toBe("vertex_ai/claude-opus-4-5-high")
})
it("should handle multiple different prefixes", () => {
// #given various custom prefixes
expect(getHighVariant("azure/gpt-5")).toBe("azure/gpt-5-high")
expect(getHighVariant("bedrock/claude-sonnet-4-5")).toBe("bedrock/claude-sonnet-4-5-high")
expect(getHighVariant("custom-llm/gemini-3-pro")).toBe("custom-llm/gemini-3-pro-high")
})
it("should return null for prefixed models without high variant mapping", () => {
// #given prefixed model IDs without high variant mapping
expect(getHighVariant("vertex_ai/unknown-model")).toBeNull()
expect(getHighVariant("custom/llama-3-70b")).toBeNull()
})
it("should return null for already-high prefixed models", () => {
// #given prefixed model IDs that are already high
expect(getHighVariant("vertex_ai/claude-opus-4-5-high")).toBeNull()
expect(getHighVariant("openai/gpt-5-2-high")).toBeNull()
})
})
describe("isAlreadyHighVariant with prefixes", () => {
it("should detect -high suffix in prefixed models", () => {
// #given prefixed model IDs with -high suffix
expect(isAlreadyHighVariant("vertex_ai/claude-opus-4-5-high")).toBe(true)
expect(isAlreadyHighVariant("openai/gpt-5-2-high")).toBe(true)
expect(isAlreadyHighVariant("custom/gemini-3-pro-high")).toBe(true)
})
it("should return false for prefixed base models", () => {
// #given prefixed base model IDs without -high suffix
expect(isAlreadyHighVariant("vertex_ai/claude-opus-4-5")).toBe(false)
expect(isAlreadyHighVariant("openai/gpt-5-2")).toBe(false)
})
it("should handle prefixed models with dots", () => {
// #given prefixed model IDs with dots
expect(isAlreadyHighVariant("vertex_ai/gpt-5.2")).toBe(false)
expect(isAlreadyHighVariant("vertex_ai/gpt-5.2-high")).toBe(true)
})
})
describe("getThinkingConfig with prefixes", () => {
it("should return null for custom providers (not in THINKING_CONFIGS)", () => {
// #given custom provider with prefixed Claude model
const config = getThinkingConfig("dia-llm", "vertex_ai/claude-sonnet-4-5")
// #then should return null (custom provider not in THINKING_CONFIGS)
expect(config).toBeNull()
})
it("should work with prefixed models on known providers", () => {
// #given known provider (anthropic) with prefixed model
// This tests that the base model name is correctly extracted for capability check
const config = getThinkingConfig("anthropic", "custom-prefix/claude-opus-4-5")
// #then should return thinking config (base model is capable)
expect(config).not.toBeNull()
expect(config?.thinking).toBeDefined()
})
it("should return null for prefixed models that are already high", () => {
// #given prefixed already-high model
const config = getThinkingConfig("anthropic", "vertex_ai/claude-opus-4-5-high")
// #then should return null
expect(config).toBeNull()
})
})
describe("Real-world custom provider scenario", () => {
it("should handle LLM proxy with vertex_ai prefix correctly", () => {
// #given a custom LLM proxy provider using vertex_ai/ prefix
const providerID = "dia-llm"
const modelID = "vertex_ai/claude-sonnet-4-5"
// #when getting high variant
const highVariant = getHighVariant(modelID)
// #then should preserve the prefix
expect(highVariant).toBe("vertex_ai/claude-sonnet-4-5-high")
// #and when checking if already high
expect(isAlreadyHighVariant(modelID)).toBe(false)
expect(isAlreadyHighVariant(highVariant!)).toBe(true)
// #and when getting thinking config for custom provider
const config = getThinkingConfig(providerID, modelID)
// #then should return null (custom provider, not anthropic)
// This prevents applying incompatible thinking configs to custom providers
expect(config).toBeNull()
})
it("should not break when switching to high variant in think mode", () => {
// #given think mode switching vertex_ai/claude model to high variant
const original = "vertex_ai/claude-opus-4-5"
const high = getHighVariant(original)
// #then the high variant should be valid
expect(high).toBe("vertex_ai/claude-opus-4-5-high")
// #and should be recognized as already high
expect(isAlreadyHighVariant(high!)).toBe(true)
// #and switching again should return null (already high)
expect(getHighVariant(high!)).toBeNull()
})
})
})
})

View File

@@ -16,6 +16,26 @@
* inconsistencies defensively while maintaining backwards compatibility.
*/
/**
* Extracts provider-specific prefix from model ID (if present).
* Custom providers may use prefixes for routing (e.g., vertex_ai/, openai/).
*
* @example
* extractModelPrefix("vertex_ai/claude-sonnet-4-5") // { prefix: "vertex_ai/", base: "claude-sonnet-4-5" }
* extractModelPrefix("claude-sonnet-4-5") // { prefix: "", base: "claude-sonnet-4-5" }
* extractModelPrefix("openai/gpt-5.2") // { prefix: "openai/", base: "gpt-5.2" }
*/
function extractModelPrefix(modelID: string): { prefix: string; base: string } {
const slashIndex = modelID.indexOf("/")
if (slashIndex === -1) {
return { prefix: "", base: modelID }
}
return {
prefix: modelID.slice(0, slashIndex + 1),
base: modelID.slice(slashIndex + 1),
}
}
/**
* Normalizes model IDs to use consistent hyphen formatting.
* GitHub Copilot may use dots (claude-opus-4.5) but our maps use hyphens (claude-opus-4-5).
@@ -25,6 +45,7 @@
* normalizeModelID("claude-opus-4.5") // "claude-opus-4-5"
* normalizeModelID("gemini-3.5-pro") // "gemini-3-5-pro"
* normalizeModelID("gpt-5.2") // "gpt-5-2"
* normalizeModelID("vertex_ai/claude-opus-4.5") // "vertex_ai/claude-opus-4-5"
*/
function normalizeModelID(modelID: string): string {
// Replace dots with hyphens when followed by a digit
@@ -142,16 +163,27 @@ const THINKING_CAPABLE_MODELS = {
export function getHighVariant(modelID: string): string | null {
const normalized = normalizeModelID(modelID)
const { prefix, base } = extractModelPrefix(normalized)
if (ALREADY_HIGH.has(normalized)) {
// Check if already high variant (with or without prefix)
if (ALREADY_HIGH.has(base) || base.endsWith("-high")) {
return null
}
return HIGH_VARIANT_MAP[normalized] ?? null
// Look up high variant for base model
const highBase = HIGH_VARIANT_MAP[base]
if (!highBase) {
return null
}
// Preserve prefix in the high variant
return prefix + highBase
}
export function isAlreadyHighVariant(modelID: string): boolean {
const normalized = normalizeModelID(modelID)
return ALREADY_HIGH.has(normalized) || normalized.endsWith("-high")
const { base } = extractModelPrefix(normalized)
return ALREADY_HIGH.has(base) || base.endsWith("-high")
}
type ThinkingProvider = keyof typeof THINKING_CONFIGS
@@ -165,6 +197,7 @@ export function getThinkingConfig(
modelID: string
): Record<string, unknown> | null {
const normalized = normalizeModelID(modelID)
const { base } = extractModelPrefix(normalized)
if (isAlreadyHighVariant(normalized)) {
return null
@@ -179,9 +212,10 @@ export function getThinkingConfig(
const config = THINKING_CONFIGS[resolvedProvider]
const capablePatterns = THINKING_CAPABLE_MODELS[resolvedProvider]
const modelLower = normalized.toLowerCase()
// Check capability using base model name (without prefix)
const baseLower = base.toLowerCase()
const isCapable = capablePatterns.some((pattern) =>
modelLower.includes(pattern.toLowerCase())
baseLower.includes(pattern.toLowerCase())
)
return isCapable ? config : null