tokenFactory/controller/relay.go

801 lines
28 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/middleware"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/types"
"github.com/bytedance/gopkg/util/gopool"
"github.com/samber/lo"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func resolveRelayPriceData(c *gin.Context, relayInfo *relaycommon.RelayInfo, tokens int, meta *types.TokenCountMeta) (types.PriceData, error) {
if relayInfo != nil &&
(relayInfo.RelayMode == relayconstant.RelayModeImagesGenerations ||
relayInfo.RelayMode == relayconstant.RelayModeImagesEdits) {
channelID := 0
if relayInfo.ChannelMeta != nil {
channelID = relayInfo.ChannelId
}
hasImageTable := helper.HasImagePerImageTablePricing(channelID, relayInfo.OriginModelName) ||
helper.HasImagePerImageTablePricingForInfo(channelID, relayInfo)
if priceData, ok, err := helper.TryModelPriceHelperImage(c, relayInfo); err != nil {
return types.PriceData{}, err
} else if ok {
return priceData, nil
}
if hasImageTable {
matchName := relayInfo.OriginModelName
return types.PriceData{}, fmt.Errorf(
"图片模型 %s 已配置按张分辨率价格,但未能匹配有效价格,请检查文生图/图生图规则或兜底每张价Image model %s per-image pricing configured but no price matched",
matchName, matchName,
)
}
return helper.ModelPriceHelperForImageFallback(c, relayInfo, tokens, meta)
}
return helper.ModelPriceHelper(c, relayInfo, tokens, meta)
}
func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.TokenFactoryError {
var err *types.TokenFactoryError
switch info.RelayMode {
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
err = relay.ImageHelper(c, info)
case relayconstant.RelayModeAudioSpeech:
fallthrough
case relayconstant.RelayModeAudioTranslation:
fallthrough
case relayconstant.RelayModeAudioTranscription:
err = relay.AudioHelper(c, info)
case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, info)
case relayconstant.RelayModeEmbeddings:
err = relay.EmbeddingHelper(c, info)
case relayconstant.RelayModeResponses, relayconstant.RelayModeResponsesCompact:
err = relay.ResponsesHelper(c, info)
default:
err = relay.TextHelper(c, info)
}
return err
}
func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.TokenFactoryError {
var err *types.TokenFactoryError
if strings.Contains(c.Request.URL.Path, "embed") {
err = relay.GeminiEmbeddingHandler(c, info)
} else {
err = relay.GeminiHelper(c, info)
}
return err
}
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
requestId := c.GetString(common.RequestIdKey)
//group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
//originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
var (
tokenFactoryError *types.TokenFactoryError
ws *websocket.Conn
)
if relayFormat == types.RelayFormatOpenAIRealtime {
var err error
ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
return
}
defer ws.Close()
}
defer func() {
if tokenFactoryError != nil {
logger.LogError(c, fmt.Sprintf("relay error: %s", tokenFactoryError.Error()))
tokenFactoryError.SetMessage(common.MessageWithRequestId(tokenFactoryError.Error(), requestId))
switch relayFormat {
case types.RelayFormatOpenAIRealtime:
helper.WssError(c, ws, tokenFactoryError.ToOpenAIError())
case types.RelayFormatClaude:
c.JSON(tokenFactoryError.StatusCode, gin.H{
"type": "error",
"error": tokenFactoryError.ToClaudeError(),
})
default:
c.JSON(tokenFactoryError.StatusCode, gin.H{
"error": tokenFactoryError.ToOpenAIError(),
})
}
}
}()
request, err := helper.GetAndValidateRequest(c, relayFormat)
if err != nil {
// Map "request body too large" to 413 so clients can handle it correctly
if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
tokenFactoryError = types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry())
} else {
tokenFactoryError = types.NewError(err, types.ErrorCodeInvalidRequest)
}
return
}
relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
if err != nil {
tokenFactoryError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
return
}
needSensitiveCheck := setting.ShouldCheckPromptSensitive()
needCountToken := constant.CountToken
// Avoid building huge CombineText (strings.Join) when token counting and sensitive check are both disabled.
var meta *types.TokenCountMeta
if needSensitiveCheck || needCountToken {
meta = request.GetTokenCountMeta()
} else {
meta = fastTokenCountMetaForPricing(request)
}
if needSensitiveCheck && meta != nil {
contains, words := service.CheckSensitiveText(meta.CombineText)
if contains {
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
tokenFactoryError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
return
}
}
tokens, err := service.EstimateRequestToken(c, meta, relayInfo)
if err != nil {
tokenFactoryError = types.NewError(err, types.ErrorCodeCountTokenFailed)
return
}
relayInfo.SetEstimatePromptTokens(tokens)
retryParam := &service.RetryParam{
Ctx: c,
TokenGroup: relayInfo.TokenGroup,
ModelName: relayInfo.OriginModelName,
Retry: common.GetPointer(0),
}
relayInfo.RetryIndex = 0
relayInfo.LastError = nil
// Select first channel before pricing so pre-consume uses selected channel pricing.
firstChannel, firstChannelErr := getChannel(c, relayInfo, retryParam)
if firstChannelErr != nil {
logger.LogError(c, firstChannelErr.Error())
tokenFactoryError = firstChannelErr
return
}
if relayFormat != types.RelayFormatTask {
if tfErr := errVideoTaskChannelOnNonTaskRelay(firstChannel); tfErr != nil {
tokenFactoryError = tfErr
return
}
}
if relayInfo.ChannelMeta == nil {
relayInfo.InitChannelMeta(c)
}
priceData, err := resolveRelayPriceData(c, relayInfo, tokens, meta)
if err != nil {
tokenFactoryError = types.NewError(err, types.ErrorCodeModelPriceError)
return
}
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
if priceData.FreeModel {
logger.LogInfo(c, fmt.Sprintf("模型 %s 免费,跳过预扣费", relayInfo.OriginModelName))
} else {
tokenFactoryError = service.PreConsumeBilling(c, priceData.QuotaToPreConsume, relayInfo)
if tokenFactoryError != nil {
return
}
}
defer func() {
// Only return quota if downstream failed and quota was actually pre-consumed
if tokenFactoryError != nil {
tokenFactoryError = service.NormalizeViolationFeeError(tokenFactoryError)
if relayInfo.Billing != nil {
relayInfo.Billing.Refund(c)
}
service.ChargeViolationFeeIfNeeded(c, relayInfo, tokenFactoryError)
}
}()
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
relayInfo.RetryIndex = retryParam.GetRetry()
channel := firstChannel
if retryParam.GetRetry() > 0 {
var channelErr *types.TokenFactoryError
channel, channelErr = getChannel(c, relayInfo, retryParam)
if channelErr != nil {
logger.LogError(c, channelErr.Error())
tokenFactoryError = channelErr
break
}
}
if relayFormat != types.RelayFormatTask {
if tfErr := errVideoTaskChannelOnNonTaskRelay(channel); tfErr != nil {
tokenFactoryError = tfErr
relayInfo.LastError = tokenFactoryError
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), tokenFactoryError)
break
}
}
// Refresh PriceData after final channel decision of this attempt.
if relayInfo.ChannelMeta == nil {
relayInfo.InitChannelMeta(c)
}
if _, priceErr := resolveRelayPriceData(c, relayInfo, tokens, meta); priceErr != nil {
tokenFactoryError = types.NewError(priceErr, types.ErrorCodeModelPriceError)
relayInfo.LastError = tokenFactoryError
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), tokenFactoryError)
if !shouldRetry(c, tokenFactoryError, common.RetryTimes-retryParam.GetRetry()) {
break
}
continue
}
addUsedChannel(c, channel.Id)
bodyStorage, bodyErr := common.GetBodyStorage(c)
if bodyErr != nil {
// Ensure consistent 413 for oversized bodies even when error occurs later (e.g., retry path)
if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
tokenFactoryError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry())
} else {
tokenFactoryError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
break
}
c.Request.Body = io.NopCloser(bodyStorage)
switch relayFormat {
case types.RelayFormatOpenAIRealtime:
tokenFactoryError = relay.WssHelper(c, relayInfo)
case types.RelayFormatClaude:
tokenFactoryError = relay.ClaudeHelper(c, relayInfo)
case types.RelayFormatGemini:
tokenFactoryError = geminiRelayHandler(c, relayInfo)
default:
tokenFactoryError = relayHandler(c, relayInfo)
}
if tokenFactoryError == nil {
relayInfo.LastError = nil
return
}
tokenFactoryError = service.NormalizeViolationFeeError(tokenFactoryError)
relayInfo.LastError = tokenFactoryError
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), tokenFactoryError)
if !shouldRetry(c, tokenFactoryError, common.RetryTimes-retryParam.GetRetry()) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
logger.LogInfo(c, retryLogStr)
}
}
var upgrader = websocket.Upgrader{
Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol则必须在此声明对应的 Protocol TODO add other protocol
CheckOrigin: func(r *http.Request) bool {
return true // 允许跨域
},
}
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}
func fastTokenCountMetaForPricing(request dto.Request) *types.TokenCountMeta {
if request == nil {
return &types.TokenCountMeta{}
}
meta := &types.TokenCountMeta{
TokenType: types.TokenTypeTokenizer,
}
switch r := request.(type) {
case *dto.GeneralOpenAIRequest:
maxCompletionTokens := lo.FromPtrOr(r.MaxCompletionTokens, uint(0))
maxTokens := lo.FromPtrOr(r.MaxTokens, uint(0))
if maxCompletionTokens > maxTokens {
meta.MaxTokens = int(maxCompletionTokens)
} else {
meta.MaxTokens = int(maxTokens)
}
case *dto.OpenAIResponsesRequest:
meta.MaxTokens = int(lo.FromPtrOr(r.MaxOutputTokens, uint(0)))
case *dto.ClaudeRequest:
meta.MaxTokens = int(lo.FromPtr(r.MaxTokens))
case *dto.ImageRequest:
// Pricing for image requests depends on ImagePriceRatio; safe to compute even when CountToken is disabled.
return r.GetTokenCountMeta()
default:
// Best-effort: leave CombineText empty to avoid large allocations.
}
return meta
}
// errVideoTaskChannelOnNonTaskRelay 视频任务类渠道Sora/OpenAI 视频/腾讯云视频等)只能走 RelayTask 与 ModelPriceHelperVideo
// 若误命中 /v1/chat/completions 等会按文本 token 计费,导致控制台日志与利润分成错误。
func errVideoTaskChannelOnNonTaskRelay(ch *model.Channel) *types.TokenFactoryError {
if ch == nil || !constant.IsVideoTaskChannel(ch.Type) {
return nil
}
return types.NewErrorWithStatusCode(
fmt.Errorf(
"当前模型命中渠道「%s」(type=%d) 为视频任务渠道,仅支持视频任务接口(如 POST /v1/videos 或操练场 POST /api/playground/videos"+
"不能使用聊天补全、嵌入、图生等非任务接口,否则会按文本 token 误计费;请改用视频任务 API 或为该模型配置文本类渠道",
ch.Name, ch.Type,
),
types.ErrorCodeInvalidRequest,
http.StatusBadRequest,
types.ErrOptionWithSkipRetry(),
)
}
func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service.RetryParam) (*model.Channel, *types.TokenFactoryError) {
if info.ChannelMeta == nil {
info.InitChannelMeta(c)
}
// 首轮优先复用分发中间件已经选中的渠道,避免重复选路覆盖 specific_channel_id 语义。
if retryParam.GetRetry() == 0 {
if selectedID := common.GetContextKeyInt(c, constant.ContextKeyChannelId); selectedID > 0 {
if ch, chErr := model.CacheGetChannel(selectedID); chErr == nil && ch != nil && ch.Status == common.ChannelStatusEnabled {
return ch, nil
}
}
}
// playground specific_channel_id / 强制渠道路由:仅允许首轮命中已选渠道,
// 禁止在重试阶段切换到 smart-route 或随机候选池。
if retryParam.GetRetry() > 0 {
if _, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId); ok {
return nil, types.NewError(
fmt.Errorf("已指定渠道,禁用重试切换渠道"),
types.ErrorCodeGetChannelFailed,
types.ErrOptionWithSkipRetry(),
)
}
if _, ok := common.GetContextKey(c, constant.ContextKeyForcedChannelID); ok {
return nil, types.NewError(
fmt.Errorf("已指定渠道,禁用重试切换渠道"),
types.ErrorCodeGetChannelFailed,
types.ErrOptionWithSkipRetry(),
)
}
}
if orderAny, ok := common.GetContextKey(c, constant.ContextKeySmartRouteChannelOrder); ok {
if order, ok := orderAny.([]int); ok && len(order) > 0 {
idx := retryParam.GetRetry()
if idx < len(order) {
ch, chErr := model.CacheGetChannel(order[idx])
if chErr == nil && ch != nil && ch.Status == common.ChannelStatusEnabled {
info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info)
if tfErr := middleware.SetupContextForSelectedChannel(c, ch, info.OriginModelName); tfErr != nil {
return nil, tfErr
}
return ch, nil
}
}
}
}
// 命中「指定供应商 + 任意渠道」时候选池已限定order 列表耗尽意味着供应商内无更多可用渠道,
// 不应回落到全局的 CacheGetRandomSatisfiedChannel那会跨供应商直接结束重试。
if _, forced := service.ForcedSupplierFromContext(c); forced {
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 在指定供应商内已无可用渠道retry", retryParam.TokenGroup, info.OriginModelName), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(retryParam)
info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info)
if err != nil {
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败retry: %s", selectGroup, info.OriginModelName, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
if channel == nil {
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在retry", selectGroup, info.OriginModelName), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
tokenFactoryError := middleware.SetupContextForSelectedChannel(c, channel, info.OriginModelName)
if tokenFactoryError != nil {
return nil, tokenFactoryError
}
return channel, nil
}
func shouldRetry(c *gin.Context, openaiErr *types.TokenFactoryError, retryTimes int) bool {
if openaiErr == nil {
return false
}
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
return false
}
// 明确指定渠道playground specific_channel_id / 强制路由)时,不允许重试切换渠道。
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if _, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId); ok {
return false
}
if _, ok := common.GetContextKey(c, constant.ContextKeyForcedChannelID); ok {
return false
}
if types.IsChannelError(openaiErr) {
return true
}
if types.IsSkipRetryError(openaiErr) {
return false
}
if retryTimes <= 0 {
return false
}
code := openaiErr.StatusCode
if code >= 200 && code < 300 {
return false
}
if code < 100 || code > 599 {
return true
}
if operation_setting.IsAlwaysSkipRetryCode(openaiErr.GetErrorCode()) {
return false
}
return operation_setting.ShouldRetryByStatusCode(code)
}
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.TokenFactoryError) {
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
if service.ShouldDisableChannel(channelError.ChannelType, err) && channelError.AutoBan {
gopool.Go(func() {
service.DisableChannel(channelError, err.ErrorWithStatusCode())
})
}
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
// 保存错误日志到mysql中
userId := c.GetInt("id")
tokenName := c.GetString("token_name")
modelName := c.GetString("original_model")
tokenId := c.GetInt("token_id")
userGroup := c.GetString("group")
channelId := c.GetInt("channel_id")
other := make(map[string]interface{})
if c.Request != nil && c.Request.URL != nil {
other["request_path"] = c.Request.URL.Path
}
other["error_type"] = err.GetErrorType()
other["error_code"] = err.GetErrorCode()
other["status_code"] = err.StatusCode
other["channel_id"] = channelId
other["channel_name"] = c.GetString("channel_name")
other["channel_type"] = c.GetInt("channel_type")
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
if isMultiKey {
adminInfo["is_multi_key"] = true
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
}
service.AppendChannelAffinityAdminInfo(c, adminInfo)
other["admin_info"] = adminInfo
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
if startTime.IsZero() {
startTime = time.Now()
}
useTimeSeconds := int(time.Since(startTime).Seconds())
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, useTimeSeconds, false, userGroup, other)
}
}
func RelayMidjourney(c *gin.Context) {
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
"type": "upstream_error",
"code": 4,
})
return
}
var mjErr *dto.MidjourneyResponse
switch relayInfo.RelayMode {
case relayconstant.RelayModeMidjourneyNotify:
mjErr = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed:
mjErr = relay.RelayMidjourneyTaskImageSeed(c)
case relayconstant.RelayModeSwapFace:
mjErr = relay.RelaySwapFace(c, relayInfo)
default:
mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(mjErr)
if mjErr != nil {
statusCode := http.StatusBadRequest
if mjErr.Code == 30 {
mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
statusCode = http.StatusTooManyRequests
}
c.JSON(statusCode, gin.H{
"description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
"type": "upstream_error",
"code": mjErr.Code,
})
channelId := c.GetInt("channel_id")
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
}
}
func RelayNotImplemented(c *gin.Context) {
err := types.OpenAIError{
Message: "API not implemented",
Type: "token_factory_error",
Param: "",
Code: "api_not_implemented",
}
c.JSON(http.StatusNotImplemented, gin.H{
"error": err,
})
}
func RelayNotFound(c *gin.Context) {
err := types.OpenAIError{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",
Code: "",
}
c.JSON(http.StatusNotFound, gin.H{
"error": err,
})
}
func RelayTaskFetch(c *gin.Context) {
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, &dto.TaskError{
Code: "gen_relay_info_failed",
Message: err.Error(),
StatusCode: http.StatusInternalServerError,
})
return
}
if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil {
respondTaskError(c, taskErr)
}
}
func RelayTask(c *gin.Context) {
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, &dto.TaskError{
Code: "gen_relay_info_failed",
Message: err.Error(),
StatusCode: http.StatusInternalServerError,
})
return
}
if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil {
respondTaskError(c, taskErr)
return
}
var result *relay.TaskSubmitResult
var taskErr *dto.TaskError
defer func() {
if taskErr != nil && relayInfo.Billing != nil {
relayInfo.Billing.Refund(c)
}
}()
retryParam := &service.RetryParam{
Ctx: c,
TokenGroup: relayInfo.TokenGroup,
ModelName: relayInfo.OriginModelName,
Retry: common.GetPointer(0),
}
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
var channel *model.Channel
if lockedCh, ok := relayInfo.LockedChannel.(*model.Channel); ok && lockedCh != nil {
channel = lockedCh
if retryParam.GetRetry() > 0 {
if setupErr := middleware.SetupContextForSelectedChannel(c, channel, relayInfo.OriginModelName); setupErr != nil {
taskErr = service.TaskErrorWrapperLocal(setupErr.Err, "setup_locked_channel_failed", http.StatusInternalServerError)
break
}
}
} else {
var channelErr *types.TokenFactoryError
channel, channelErr = getChannel(c, relayInfo, retryParam)
if channelErr != nil {
logger.LogError(c, channelErr.Error())
taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError)
break
}
}
addUsedChannel(c, channel.Id)
bodyStorage, bodyErr := common.GetBodyStorage(c)
if bodyErr != nil {
if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge)
} else {
taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusBadRequest)
}
break
}
c.Request.Body = io.NopCloser(bodyStorage)
result, taskErr = relay.RelayTaskSubmit(c, relayInfo)
if taskErr == nil {
break
}
if !taskErr.LocalError {
processChannelError(c,
*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey,
common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()),
types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
}
if !shouldRetryTaskRelay(c, channel.Id, taskErr, common.RetryTimes-retryParam.GetRetry()) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
logger.LogInfo(c, retryLogStr)
}
// ── 成功:结算 + 日志 + 插入任务 ──
if taskErr == nil {
actualQuota := service.ResolveActualTaskQuotaOnSubmit(c, relayInfo, result.TaskData, result.Quota)
result.Quota = actualQuota
relayInfo.PriceData.Quota = actualQuota
if settleErr := service.SettleBilling(c, relayInfo, actualQuota); settleErr != nil {
common.SysError("settle task billing error: " + settleErr.Error())
}
service.LogTaskConsumption(c, relayInfo)
task := model.InitTask(result.Platform, relayInfo)
if req, err := relaycommon.GetTaskRequest(c); err == nil {
if reqBytes, mErr := common.Marshal(req); mErr == nil {
task.Properties.Input = string(reqBytes)
}
}
task.PrivateData.UpstreamTaskID = result.UpstreamTaskID
task.PrivateData.TfOpenVideoUpstreamStyle = relayInfo.TfOpenVideoUpstreamStyle
if k := strings.TrimSpace(relayInfo.ApiKey); k != "" {
// 轮询上游(如腾讯云 DescribeTaskDetail时使用与提交相同的密钥避免多 Key 渠道错钥
task.PrivateData.Key = k
}
task.PrivateData.BillingSource = relayInfo.BillingSource
task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
task.PrivateData.TokenId = relayInfo.TokenId
task.PrivateData.TokenName = c.GetString("token_name")
chDiscPct := model.ResolveChannelPriceDiscountPercent(relayInfo.ChannelId)
if relayInfo.PriceData.ChannelPriceDiscount != nil {
chDiscPct = *relayInfo.PriceData.ChannelPriceDiscount
}
task.PrivateData.BillingContext = &model.TaskBillingContext{
ModelPrice: relayInfo.PriceData.ModelPrice,
GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
ModelRatio: relayInfo.PriceData.ModelRatio,
OtherRatios: relayInfo.PriceData.OtherRatios,
OriginModelName: relayInfo.OriginModelName,
PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName),
ChannelPriceDiscountPercent: chDiscPct,
}
task.Quota = actualQuota
task.Data = result.TaskData
task.Action = relayInfo.Action
if insertErr := task.Insert(); insertErr != nil {
common.SysError("insert task error: " + insertErr.Error())
}
}
if taskErr != nil {
respondTaskError(c, taskErr)
}
}
// respondTaskError 统一输出 Task 错误响应(含 429 限流提示改写)
func respondTaskError(c *gin.Context, taskErr *dto.TaskError) {
if taskErr.StatusCode == http.StatusTooManyRequests {
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
}
c.JSON(taskErr.StatusCode, taskErr)
}
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
if taskErr == nil {
return false
}
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
return false
}
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if _, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId); ok {
return false
}
if taskErr.StatusCode == http.StatusTooManyRequests {
return true
}
if taskErr.StatusCode == 307 {
return true
}
if taskErr.StatusCode/100 == 5 {
// 超时不重试
if operation_setting.IsAlwaysSkipRetryStatusCode(taskErr.StatusCode) {
return false
}
return true
}
if taskErr.StatusCode == http.StatusBadRequest {
return false
}
if taskErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
if taskErr.LocalError {
return false
}
if taskErr.StatusCode/100 == 2 {
return false
}
return true
}