diff --git a/controller/channel.go b/controller/channel.go index 809a2e932..b2db2b777 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -165,6 +165,30 @@ func GetAllChannels(c *gin.Context) { return } +func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, error) { + var headers http.Header + switch channel.Type { + case constant.ChannelTypeAnthropic: + headers = GetClaudeAuthHeader(key) + default: + headers = GetAuthHeader(key) + } + + headerOverride := channel.GetHeaderOverride() + for k, v := range headerOverride { + str, ok := v.(string) + if !ok { + return nil, fmt.Errorf("invalid header override for key %s", k) + } + if strings.Contains(str, "{api_key}") { + str = strings.ReplaceAll(str, "{api_key}", key) + } + headers.Set(k, str) + } + + return headers, nil +} + func FetchUpstreamModels(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { @@ -223,14 +247,13 @@ func FetchUpstreamModels(c *gin.Context) { } key = strings.TrimSpace(key) - // 获取响应体 - 根据渠道类型决定是否添加 AuthHeader - var body []byte - switch channel.Type { - case constant.ChannelTypeAnthropic: - body, err = GetResponseBody("GET", url, channel, GetClaudeAuthHeader(key)) - default: - body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) + headers, err := buildFetchModelsHeaders(channel, key) + if err != nil { + common.ApiError(c, err) + return } + + body, err := GetResponseBody("GET", url, channel, headers) if err != nil { common.ApiError(c, err) return