tokenFactory/controller/channel-test.go

1609 lines
55 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/middleware"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay"
taskalivideo "github.com/QuantumNous/new-api/relay/channel/task/alivideo"
taskopenaivideo "github.com/QuantumNous/new-api/relay/channel/task/openaivideo"
tasktencentvod "github.com/QuantumNous/new-api/relay/channel/task/tencentvod"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/types"
"github.com/bytedance/gopkg/util/gopool"
"github.com/samber/lo"
"github.com/tidwall/gjson"
"github.com/gin-gonic/gin"
)
// testResult 渠道测试一次调用的结果recordedModelName 为与模型元数据/操练场 model_name 对齐的用户侧名称(非上游 UpstreamModelName仅成功路径会填充。
type testResult struct {
context *gin.Context
localErr error
tokenFactoryError *types.TokenFactoryError
recordedModelName string
}
// tokenFactoryOpenVideoTestHeuristic 在「端点类型留空」且模型元数据未指明端点时,
// 判断 TokenFactoryOpen(60) 是否应按视频任务入口测试。
// 与真实客户端一致:视频走 POST /v1/video/generations而非 chat 或外部 Hidream 的 /v1/videos/generations。
//
// 注意:豆包 LLMdoubao-seed-*与豆包视频doubao-seedance-*)均含 "doubao"
// 仅 seedance 等视频族关键字应命中;勿用裸 "doubao" 匹配,否则建站渠道会把对话模型误测为视频。
func tokenFactoryOpenVideoTestHeuristic(modelName string) bool {
s := strings.ToLower(strings.TrimSpace(modelName))
if s == "" {
return false
}
if strings.Contains(s, "seedance") || strings.Contains(s, "sora") ||
strings.Contains(s, "kling") || strings.Contains(s, "wan") ||
strings.Contains(s, "vidu") || strings.Contains(s, "veo") ||
strings.Contains(s, "jimeng") ||
strings.Contains(s, "hailuo") || strings.Contains(s, "minimax") ||
strings.Contains(s, "text2video") || strings.Contains(s, "image2video") {
return true
}
if strings.HasPrefix(strings.TrimSpace(modelName), "Video-") {
return true
}
return false
}
// tokenFactoryOpenTestEndpointFromMeta 若模型元数据/能力表仅声明视频端点,则返回 tokenfactory-video否则留空走 chat 探测。
func tokenFactoryOpenTestEndpointFromMeta(modelName string) string {
eps := model.GetModelSupportEndpointTypes(modelName)
if len(eps) == 0 {
return ""
}
hasNonVideo := false
for _, et := range eps {
switch et {
case constant.EndpointTypeTokenFactoryVideo,
constant.EndpointTypeOpenAIVideo,
constant.EndpointTypeOpenAIVideoGW,
constant.EndpointTypeVideoGenerator,
constant.EndpointTypeTencentCloudVODVideo:
default:
hasNonVideo = true
}
}
if !hasNonVideo {
return string(constant.EndpointTypeTokenFactoryVideo)
}
return ""
}
// tfOpenUpstreamModelForChannelTest 对齐 relay/helper.ModelMappedHelper 中 TokenFactoryOpen 路由改写:
// 将发往上游的模型名设为 {model}/{route_slug} 或旧版 alias/model/channelNo。
func tfOpenUpstreamModelForChannelTest(c *gin.Context, originModel string, upstreamAfterMapping string) string {
tfRoute := strings.TrimSpace(common.GetContextKeyString(c, constant.ContextKeyTFOpenUpstreamChannelRoute))
if tfRoute == "" {
return upstreamAfterMapping
}
modelForUpstream := strings.TrimSpace(originModel)
if strings.HasSuffix(modelForUpstream, ratio_setting.CompactModelSuffix) {
modelForUpstream = strings.TrimSuffix(modelForUpstream, ratio_setting.CompactModelSuffix)
}
if strings.HasPrefix(tfRoute, "legacy|") {
parts := strings.SplitN(tfRoute, "|", 3)
if len(parts) == 3 {
alias := parts[1]
channelNo := parts[2]
if alias != "" && channelNo != "" {
return alias + "/" + modelForUpstream + "/" + channelNo
}
}
return upstreamAfterMapping
}
if slug := strings.TrimSpace(tfRoute); slug != "" {
return modelForUpstream + "/" + slug
}
return upstreamAfterMapping
}
func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointType string) string {
normalized := strings.TrimSpace(endpointType)
if normalized != "" {
return normalized
}
if strings.HasSuffix(modelName, ratio_setting.CompactModelSuffix) {
return string(constant.EndpointTypeOpenAIResponseCompact)
}
if channel != nil && channel.Type == constant.ChannelTypeCodex {
return string(constant.EndpointTypeOpenAIResponse)
}
if channel != nil && channel.Type == constant.ChannelTypeOpenAIVideo {
return string(constant.EndpointTypeOpenAIVideoGW)
}
if channel != nil && channel.Type == constant.ChannelTypeTokenFactoryOpen {
if ep := tokenFactoryOpenTestEndpointFromMeta(modelName); ep != "" {
return ep
}
if tokenFactoryOpenVideoTestHeuristic(modelName) {
return string(constant.EndpointTypeTokenFactoryVideo)
}
}
if channel != nil && channel.Type == constant.ChannelTypeVideoGenerator {
return string(constant.EndpointTypeVideoGenerator)
}
if channel != nil && channel.Type == constant.ChannelTypeTencentCloudVideo {
return string(constant.EndpointTypeTencentCloudVODVideo)
}
if channel != nil && channel.Type == constant.ChannelTypeTencentCloudImage {
return string(constant.EndpointTypeTencentCloudVODImage)
}
if channel != nil && channel.Type == constant.ChannelTypeAliVideo {
return string(constant.EndpointTypeAliVideo)
}
return normalized
}
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
tik := time.Now()
var unsupportedTestChannelTypes = []int{
constant.ChannelTypeMidjourney,
constant.ChannelTypeMidjourneyPlus,
constant.ChannelTypeSunoAPI,
constant.ChannelTypeKling,
constant.ChannelTypeJimeng,
constant.ChannelTypeDoubaoVideo,
constant.ChannelTypeVidu,
}
if lo.Contains(unsupportedTestChannelTypes, channel.Type) {
channelTypeName := constant.GetChannelTypeName(channel.Type)
return testResult{
localErr: fmt.Errorf("%s channel test is not supported", channelTypeName),
}
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
testModel = strings.TrimSpace(testModel)
if testModel == "" {
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = strings.TrimSpace(*channel.TestModel)
} else {
models := channel.GetModels()
if len(models) > 0 {
testModel = strings.TrimSpace(models[0])
}
if testModel == "" {
testModel = "gpt-4o-mini"
}
}
}
endpointType = normalizeChannelTestEndpoint(channel, testModel, endpointType)
requestPath := "/v1/chat/completions"
// 如果指定了端点类型,使用指定的端点类型
if endpointType != "" {
if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok {
requestPath = endpointInfo.Path
}
} else {
// 如果没有指定端点类型,使用原有的自动检测逻辑
if strings.Contains(strings.ToLower(testModel), "rerank") {
requestPath = "/v1/rerank"
}
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
strings.Contains(testModel, "embed") ||
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
// VolcEngine 图像生成模型
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
requestPath = "/v1/images/generations"
}
// responses-only models
if strings.Contains(strings.ToLower(testModel), "codex") {
requestPath = "/v1/responses"
}
// responses compaction models (must use /v1/responses/compact)
if strings.HasSuffix(testModel, ratio_setting.CompactModelSuffix) {
requestPath = "/v1/responses/compact"
}
}
if strings.HasPrefix(requestPath, "/v1/responses/compact") {
testModel = ratio_setting.WithCompactModelSuffix(testModel)
}
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: requestPath}, // 使用动态路径
Body: nil,
Header: make(http.Header),
}
cache, err := model.GetUserCache(1)
if err != nil {
return testResult{
localErr: err,
tokenFactoryError: nil,
}
}
cache.WriteContext(c)
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
group, _ := model.GetUserGroup(1, false)
c.Set("group", group)
tokenFactoryError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
if tokenFactoryError != nil {
return testResult{
context: c,
localErr: tokenFactoryError,
tokenFactoryError: tokenFactoryError,
}
}
// 视频生成端点走任务式异步上游协议Sora /v1/videos、OpenAI 视频网关 /v1/videos/generations 等),
// 与同步 chat/embeddings/image 走的 relay.GetAdaptor 流程不兼容,因此在这里直接旁路:
// 仅校验上游能正确接收任务创建请求并返回 task_id不做轮询。
if endpointType == string(constant.EndpointTypeOpenAIVideo) ||
endpointType == string(constant.EndpointTypeOpenAIVideoGW) ||
endpointType == string(constant.EndpointTypeTokenFactoryVideo) ||
endpointType == string(constant.EndpointTypeVideoGenerator) ||
endpointType == string(constant.EndpointTypeTencentCloudVODVideo) ||
endpointType == string(constant.EndpointTypeAliVideo) {
return testChannelVideo(c, channel, testModel, endpointType, tik)
}
// Determine relay format based on endpoint type or request path
var relayFormat types.RelayFormat
if endpointType != "" {
// 根据指定的端点类型设置 relayFormat
switch constant.EndpointType(endpointType) {
case constant.EndpointTypeOpenAI:
relayFormat = types.RelayFormatOpenAI
case constant.EndpointTypeOpenAIResponse:
relayFormat = types.RelayFormatOpenAIResponses
case constant.EndpointTypeOpenAIResponseCompact:
relayFormat = types.RelayFormatOpenAIResponsesCompaction
case constant.EndpointTypeAnthropic:
relayFormat = types.RelayFormatClaude
case constant.EndpointTypeGemini:
relayFormat = types.RelayFormatGemini
case constant.EndpointTypeJinaRerank:
relayFormat = types.RelayFormatRerank
case constant.EndpointTypeImageGeneration:
relayFormat = types.RelayFormatOpenAIImage
case constant.EndpointTypeTencentCloudVODImage:
relayFormat = types.RelayFormatOpenAIImage
case constant.EndpointTypeEmbeddings:
relayFormat = types.RelayFormatEmbedding
default:
relayFormat = types.RelayFormatOpenAI
}
} else {
// 根据请求路径自动检测
relayFormat = types.RelayFormatOpenAI
if c.Request.URL.Path == "/v1/embeddings" {
relayFormat = types.RelayFormatEmbedding
}
if c.Request.URL.Path == "/v1/images/generations" {
relayFormat = types.RelayFormatOpenAIImage
}
if c.Request.URL.Path == "/v1/messages" {
relayFormat = types.RelayFormatClaude
}
if strings.Contains(c.Request.URL.Path, "/v1beta/models") {
relayFormat = types.RelayFormatGemini
}
if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" {
relayFormat = types.RelayFormatRerank
}
if c.Request.URL.Path == "/v1/responses" {
relayFormat = types.RelayFormatOpenAIResponses
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/responses/compact") {
relayFormat = types.RelayFormatOpenAIResponsesCompaction
}
}
request := buildTestRequest(testModel, endpointType, channel, isStream)
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
if err != nil {
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
}
}
info.IsChannelTest = true
info.InitChannelMeta(c)
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
}
}
testModel = info.UpstreamModelName
// 更新请求中的模型名称
request.SetModelName(testModel)
apiType, _ := common.ChannelType2APIType(channel.Type)
if info.RelayMode == relayconstant.RelayModeResponsesCompact &&
apiType != constant.APITypeOpenAI &&
apiType != constant.APITypeCodex {
return testResult{
context: c,
localErr: fmt.Errorf("responses compaction test only supports openai/codex channels, got api type %d", apiType),
tokenFactoryError: types.NewError(fmt.Errorf("unsupported api type: %d", apiType), types.ErrorCodeInvalidApiType),
}
}
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return testResult{
context: c,
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
tokenFactoryError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
}
}
//// 创建一个用于日志的 info 副本,移除 ApiKey
//logInfo := info
//logInfo.ApiKey = ""
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
if err != nil {
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewError(err, types.ErrorCodeModelPriceError),
}
}
adaptor.Init(info)
var convertedRequest any
// 根据 RelayMode 选择正确的转换函数
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
// Embedding 请求 - request 已经是正确的类型
if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok {
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq)
} else {
return testResult{
context: c,
localErr: errors.New("invalid embedding request type"),
tokenFactoryError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed),
}
}
case relayconstant.RelayModeImagesGenerations:
// 图像生成请求 - request 已经是正确的类型
if imageReq, ok := request.(*dto.ImageRequest); ok {
convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq)
} else {
return testResult{
context: c,
localErr: errors.New("invalid image request type"),
tokenFactoryError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed),
}
}
case relayconstant.RelayModeRerank:
// Rerank 请求 - request 已经是正确的类型
if rerankReq, ok := request.(*dto.RerankRequest); ok {
convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq)
} else {
return testResult{
context: c,
localErr: errors.New("invalid rerank request type"),
tokenFactoryError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed),
}
}
case relayconstant.RelayModeResponses:
// Response 请求 - request 已经是正确的类型
if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok {
convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq)
} else {
return testResult{
context: c,
localErr: errors.New("invalid response request type"),
tokenFactoryError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed),
}
}
case relayconstant.RelayModeResponsesCompact:
// Response compaction request - convert to OpenAIResponsesRequest before adapting
switch req := request.(type) {
case *dto.OpenAIResponsesCompactionRequest:
convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, dto.OpenAIResponsesRequest{
Model: req.Model,
Input: req.Input,
Instructions: req.Instructions,
PreviousResponseID: req.PreviousResponseID,
})
case *dto.OpenAIResponsesRequest:
convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *req)
default:
return testResult{
context: c,
localErr: errors.New("invalid response compaction request type"),
tokenFactoryError: types.NewError(errors.New("invalid response compaction request type"), types.ErrorCodeConvertRequestFailed),
}
}
default:
// Chat/Completion 等其他请求类型
if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok {
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq)
} else {
return testResult{
context: c,
localErr: errors.New("invalid general request type"),
tokenFactoryError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed),
}
}
}
if err != nil {
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
}
}
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
}
}
//jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
//if err != nil {
// return testResult{
// context: c,
// localErr: err,
// tokenFactoryError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
// }
//}
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
return testResult{
context: c,
localErr: fixedErr,
tokenFactoryError: relaycommon.TokenFactoryErrorFromParamOverride(fixedErr),
}
}
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid),
}
}
}
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
}
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
common.SysError(fmt.Sprintf(
"channel test bad response: channel_id=%d name=%s type=%d model=%s endpoint_type=%s status=%d err=%v",
channel.Id,
channel.Name,
channel.Type,
testModel,
endpointType,
httpResp.StatusCode,
err,
))
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
}
}
}
// 腾讯云图片模型测试:只校验是否成功提交任务(返回 TaskId不等待任务完成与 URL 回填。
// 这样可避免 DescribeTaskDetail/DescribeMediaInfos 带来的 30~40 秒测试时延。
if endpointType == string(constant.EndpointTypeTencentCloudVODImage) {
if httpResp == nil || httpResp.Body == nil {
err := errors.New("empty upstream response")
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
raw, readErr := io.ReadAll(httpResp.Body)
_ = httpResp.Body.Close()
if readErr != nil {
return testResult{
context: c,
localErr: readErr,
tokenFactoryError: types.NewOpenAIError(readErr, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
}
}
taskID := strings.TrimSpace(gjson.GetBytes(raw, "Response.TaskId").String())
if taskID == "" {
errMsg := strings.TrimSpace(gjson.GetBytes(raw, "Response.Error.Message").String())
if errMsg == "" {
errMsg = strings.TrimSpace(gjson.GetBytes(raw, "Response.Error.Code").String())
}
if errMsg == "" {
errMsg = fmt.Sprintf("submit succeeded but missing TaskId, body=%s", truncateForError(string(raw)))
}
err := errors.New(errMsg)
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
common.SysLog(fmt.Sprintf("tencent image test channel #%d accepted, task_id=%s", channel.Id, taskID))
recordedName := strings.TrimSpace(info.OriginModelName)
if recordedName == "" {
recordedName = strings.TrimSpace(common.GetContextKeyString(c, constant.ContextKeyOriginalModel))
}
return testResult{
context: c,
localErr: nil,
tokenFactoryError: nil,
recordedModelName: recordedName,
}
}
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
if respErr != nil {
return testResult{
context: c,
localErr: respErr,
tokenFactoryError: respErr,
}
}
usage, usageErr := coerceTestUsage(usageA, isStream, info.GetEstimatePromptTokens())
if usageErr != nil {
return testResult{
context: c,
localErr: usageErr,
tokenFactoryError: types.NewOpenAIError(usageErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
result := w.Result()
respBody, err := readTestResponseBody(result.Body, isStream)
if err != nil {
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
}
}
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
return testResult{
context: c,
localErr: bodyErr,
tokenFactoryError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
info.SetEstimatePromptTokens(usage.PromptTokens)
quota := 0
if !priceData.UsePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
if priceData.ModelRatio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
}
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
ChannelId: channel.Id,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
ModelName: info.OriginModelName,
TokenName: "模型测试",
Quota: quota,
Content: "模型测试",
UseTimeSeconds: int(consumedTime),
IsStream: info.IsStream,
Group: info.UsingGroup,
Other: other,
})
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
// 与 model_test_results、操练场 GetUserModels 的 model_name 一致,供 Upsert 使用
recordedName := strings.TrimSpace(info.OriginModelName)
if recordedName == "" {
recordedName = strings.TrimSpace(common.GetContextKeyString(c, constant.ContextKeyOriginalModel))
}
return testResult{
context: c,
localErr: nil,
tokenFactoryError: nil,
recordedModelName: recordedName,
}
}
// truncateForError 把请求/响应内容截短到 800 字符以内,避免错误消息过长。
func truncateForError(s string) string {
const maxLen = 800
s = strings.TrimSpace(s)
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "...(truncated)"
}
// testChannelVideo 处理视频生成类端点OpenAI Sora /v1/videos、外部视频网关 /v1/videos/generations、
// TokenFactory 统一入口 /v1/video/generations 等)的渠道测试。
// 视频生成是任务式异步接口,这里只验证上游能否正确创建任务(返回 task_id不做轮询
// 避免长时间阻塞和测试期间产生真实视频生成费用。
func testChannelVideo(c *gin.Context, channel *model.Channel, testModel string, endpointType string, tik time.Time) testResult {
// TokenFactoryOpen(60) 指向上游 TokenFactory 时,真实路由是 /v1/video/generations不是外部 Hidream 的 /v1/videos/generations。
if channel != nil && channel.Type == constant.ChannelTypeTokenFactoryOpen &&
endpointType == string(constant.EndpointTypeOpenAIVideoGW) {
endpointType = string(constant.EndpointTypeTokenFactoryVideo)
}
endpoint, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType))
if !ok {
err := fmt.Errorf("unsupported video endpoint type: %s", endpointType)
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewError(err, types.ErrorCodeInvalidApiType),
}
}
// 模型映射:与正式 relay 对齐TokenFactoryOpen(60) 指向上游 TF 时跳过 model_mapping见 relay/helper/model_mapped.go
originModel := strings.TrimSpace(testModel)
upstreamModel := originModel
if channel == nil || channel.Type != constant.ChannelTypeTokenFactoryOpen {
if mapping := strings.TrimSpace(c.GetString("model_mapping")); mapping != "" && mapping != "{}" {
modelMap := make(map[string]string)
if err := common.UnmarshalJsonStr(mapping, &modelMap); err == nil {
current := upstreamModel
visited := map[string]bool{current: true}
for {
next, exists := modelMap[current]
if !exists || next == "" || next == current {
break
}
if visited[next] {
break
}
visited[next] = true
current = next
}
upstreamModel = current
}
}
}
upstreamModel = tfOpenUpstreamModelForChannelTest(c, originModel, upstreamModel)
apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
if apiKey == "" {
apiKey = channel.Key
}
baseURL := strings.TrimRight(channel.GetBaseURL(), "/")
if baseURL == "" {
err := fmt.Errorf("channel base_url is empty")
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewError(err, types.ErrorCodeChannelBaseUrlEmpty),
}
}
fullURL := baseURL + endpoint.Path
var bodyMap map[string]any
switch constant.EndpointType(endpointType) {
case constant.EndpointTypeOpenAIVideo:
// OpenAI Sora 风格POST /v1/videosbody 字段参考 https://platform.openai.com/docs/api-reference/videos/create
bodyMap = map[string]any{
"model": upstreamModel,
"prompt": "a cute cat dancing in a sunny garden",
"size": "720x1280",
"seconds": "4",
}
case constant.EndpointTypeTencentCloudVODVideo:
// 腾讯云官方 VOD 视频接口必须使用 TC3 签名和 X-TC-* 公共头,不能直接 Bearer 调上游。
cred, credErr := tasktencentvod.ParseCredentials(apiKey)
if credErr != nil {
return testResult{
context: c,
localErr: credErr,
tokenFactoryError: types.NewError(credErr, types.ErrorCodeChannelInvalidKey),
}
}
modelName, modelVersion := tasktencentvod.SplitCombinedModel(upstreamModel)
if strings.TrimSpace(modelName) == "" || strings.TrimSpace(modelVersion) == "" {
invalidModelErr := fmt.Errorf("invalid tencent vod model %q, expected ModelName-ModelVersion", upstreamModel)
return testResult{
context: c,
localErr: invalidModelErr,
tokenFactoryError: types.NewError(invalidModelErr, types.ErrorCodeBadRequestBody),
}
}
signedBody := map[string]any{
"SubAppId": cred.SubAppID,
"ModelName": modelName,
"ModelVersion": modelVersion,
"Prompt": "a cute cat dancing in a sunny garden",
}
signedPayload, marshalErr := common.Marshal(signedBody)
if marshalErr != nil {
return testResult{
context: c,
localErr: marshalErr,
tokenFactoryError: types.NewError(marshalErr, types.ErrorCodeJsonMarshalFailed),
}
}
signedResp, reqErr := tasktencentvod.SignedPOSTJSON(strings.TrimSpace(channel.GetSetting().Proxy), baseURL, cred.Region, cred, "CreateAigcVideoTask", signedPayload)
if reqErr != nil {
return testResult{
context: c,
localErr: reqErr,
tokenFactoryError: types.NewOpenAIError(reqErr, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
}
}
defer func() { _ = signedResp.Body.Close() }()
respBody, readErr := io.ReadAll(signedResp.Body)
if readErr != nil {
return testResult{
context: c,
localErr: readErr,
tokenFactoryError: types.NewOpenAIError(readErr, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
}
}
common.SysLog(fmt.Sprintf("video test channel #%d response: status=%d, body=%s", channel.Id, signedResp.StatusCode, string(respBody)))
if signedResp.StatusCode != http.StatusOK {
msg := detectErrorMessageFromJSONBytes(respBody)
if msg == "" {
msg = strings.TrimSpace(string(respBody))
}
if msg == "" {
msg = fmt.Sprintf("upstream returned status %d", signedResp.StatusCode)
}
bodyErr := fmt.Errorf("status=%d, body=%s", signedResp.StatusCode, msg)
return testResult{
context: c,
localErr: bodyErr,
tokenFactoryError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponse, http.StatusInternalServerError),
}
}
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "Response.TaskId").String())
if taskID == "" {
bodyErr := fmt.Errorf("upstream did not return task_id, body: %s", string(respBody))
return testResult{
context: c,
localErr: bodyErr,
tokenFactoryError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
common.SysLog(fmt.Sprintf("video test channel #%d ok, task_id=%s", channel.Id, taskID))
return testResult{
context: c,
localErr: nil,
tokenFactoryError: nil,
recordedModelName: originModel,
}
case constant.EndpointTypeOpenAIVideoGW:
// OpenAI 视频网关:根据 base URL 自动选 MaaSHidream 官方)或 ARKByteDance 兼容代理)。
// 提交路径统一是 /v1/videos/generations是否带 /api/maas/gw 前缀由用户在 base URL 中决定。
// 两种协议的 body 结构其实一致content 数组仅模型字段名不同MaaS 用 model_idARK 用 model。
// 数字人等少数 MaaS 平铺字段模型image+sound_file需要在自定义参数里手动调整 body
// 测试入口只发主流 Seedance/Doubao 系列能通过校验的最小集合。
modelKey := "model"
if taskopenaivideo.DetectProtocol(baseURL) == taskopenaivideo.ProtocolMaaS {
modelKey = "model_id"
}
bodyMap = map[string]any{
modelKey: upstreamModel,
"content": []map[string]any{
{"type": "text", "text": "a cute cat dancing in a sunny garden --duration 5"},
},
}
case constant.EndpointTypeTokenFactoryVideo:
// 与 RelayTask 解析一致TaskSubmitReq JSONrelay/common/relay_utils.go ValidateBasicTaskRequest
// 勿附带 n/fps/motion 等可选字段ValidateBasicTaskRequest 会把它们写入 metadata
// openaivideo adaptor 会将 metadata 逐项并入上游 bodyHidream/MaaS 可能因未知顶层字段返回 Invalid input params。
// prompt 内附带 --duration 与 OpenAI 视频网关测试一致,兼容部分上游对文生参数的校验。
bodyMap = map[string]any{
"model": upstreamModel,
"prompt": "a cute cat dancing in a sunny garden --duration 5",
"duration": 5,
"size": "960x540",
}
case constant.EndpointTypeVideoGenerator:
bodyMap = map[string]any{
"model": upstreamModel,
"content": []map[string]any{
{"type": "text", "text": "a cute cat dancing in a sunny garden"},
},
"parameters": map[string]any{
"duration": 5,
"resolution": "720P",
"ratio": "16:9",
"watermark": false,
},
}
case constant.EndpointTypeAliVideo:
fullURL = taskalivideo.SubmitURL(baseURL)
bodyMap = map[string]any{
"model": upstreamModel,
"input": map[string]any{
"prompt": "a cute cat dancing in a sunny garden",
},
"parameters": map[string]any{
"resolution": "720P",
"ratio": "16:9",
"duration": 5,
},
}
default:
err := fmt.Errorf("unsupported video endpoint type: %s", endpointType)
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewError(err, types.ErrorCodeInvalidApiType),
}
}
bodyBytes, err := common.Marshal(bodyMap)
if err != nil {
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
}
}
method := endpoint.Method
if method == "" {
method = http.MethodPost
}
httpReq, err := http.NewRequestWithContext(c.Request.Context(), method, fullURL, bytes.NewReader(bodyBytes))
if err != nil {
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
}
}
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
httpReq.Header.Set("Content-Type", "application/json")
if constant.EndpointType(endpointType) == constant.EndpointTypeAliVideo {
httpReq.Header.Set("X-DashScope-Async", "enable")
}
common.SysLog(fmt.Sprintf(
"video test channel #%d (%s) endpoint=%s url=%s model=%s -> upstream=%s, request body: %s",
channel.Id, channel.Name, endpointType, fullURL, originModel, upstreamModel, string(bodyBytes),
))
client := service.GetHttpClient()
resp, err := client.Do(httpReq)
if err != nil {
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
}
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return testResult{
context: c,
localErr: err,
tokenFactoryError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
}
}
// 无论成功失败都打印完整响应,便于排查上游字段对不上的问题。
common.SysLog(fmt.Sprintf(
"video test channel #%d response: status=%d, body=%s",
channel.Id, resp.StatusCode, string(respBody),
))
if resp.StatusCode != http.StatusOK {
msg := detectErrorMessageFromJSONBytes(respBody)
if msg == "" {
msg = strings.TrimSpace(string(respBody))
}
if msg == "" {
msg = fmt.Sprintf("upstream returned status %d", resp.StatusCode)
}
bodyErr := fmt.Errorf("status=%d, body=%s", resp.StatusCode, msg)
return testResult{
context: c,
localErr: bodyErr,
tokenFactoryError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponse, http.StatusInternalServerError),
}
}
var taskID string
switch constant.EndpointType(endpointType) {
case constant.EndpointTypeOpenAIVideo:
// OpenAI Sora 风格:顶层 id新接口或 task_id旧接口兼容
taskID = strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if taskID == "" {
taskID = strings.TrimSpace(gjson.GetBytes(respBody, "task_id").String())
}
if errMsg := strings.TrimSpace(gjson.GetBytes(respBody, "error.message").String()); errMsg != "" {
bodyErr := fmt.Errorf("upstream error: %s, body: %s", errMsg, truncateForError(string(respBody)))
return testResult{
context: c,
localErr: bodyErr,
tokenFactoryError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
case constant.EndpointTypeOpenAIVideoGW, constant.EndpointTypeTokenFactoryVideo:
// OpenAI 视频网关:两种响应格式根据顶层字段自动判断。
// MaaS{"code":0,"message":"","result":{"task_id":"..."}}
// 失败时 code != 0错误消息在 message/messasge 里。
// ARK {"id":"cgt-...","model":"...","status":"queued",...}
// 失败时 {"error":{"code":"...","message":"...","type":"..."}}
if errMsg := strings.TrimSpace(gjson.GetBytes(respBody, "error.message").String()); errMsg != "" {
// 把完整响应附在错误里返回给前端,方便用户直接看到上游对哪些字段不满。
bodyErr := fmt.Errorf(
"upstream error: %s | request: %s | response: %s",
errMsg,
truncateForError(string(bodyBytes)),
truncateForError(string(respBody)),
)
return testResult{
context: c,
localErr: bodyErr,
tokenFactoryError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
// MaaS顶层 code 字段存在且 != 0 表示失败。注意 code=0 是合法的成功值,
// 所以要先判断字段是否存在,再判断是否非零。
if codeRes := gjson.GetBytes(respBody, "code"); codeRes.Exists() && codeRes.Int() != 0 {
errMsg := strings.TrimSpace(gjson.GetBytes(respBody, "message").String())
if errMsg == "" {
errMsg = strings.TrimSpace(gjson.GetBytes(respBody, "messasge").String())
}
if errMsg == "" {
errMsg = fmt.Sprintf("upstream returned code=%d", codeRes.Int())
}
bodyErr := fmt.Errorf(
"upstream error: %s | request: %s | response: %s",
errMsg,
truncateForError(string(bodyBytes)),
truncateForError(string(respBody)),
)
return testResult{
context: c,
localErr: bodyErr,
tokenFactoryError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
// 优先取 ARK 风格的顶层 id再取 MaaS 风格的 result.task_id最后兜底顶层 task_id。
taskID = strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if taskID == "" {
taskID = strings.TrimSpace(gjson.GetBytes(respBody, "result.task_id").String())
}
if taskID == "" {
taskID = strings.TrimSpace(gjson.GetBytes(respBody, "task_id").String())
}
if taskID == "" {
taskID = strings.TrimSpace(gjson.GetBytes(respBody, "data.id").String())
}
case constant.EndpointTypeAliVideo:
if codeStr := strings.TrimSpace(gjson.GetBytes(respBody, "code").String()); codeStr != "" && codeStr != "0" {
errMsg := strings.TrimSpace(gjson.GetBytes(respBody, "message").String())
if errMsg == "" {
errMsg = fmt.Sprintf("upstream returned code=%s", codeStr)
}
bodyErr := fmt.Errorf("upstream error: %s, body: %s", errMsg, truncateForError(string(respBody)))
return testResult{
context: c,
localErr: bodyErr,
tokenFactoryError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
taskID = strings.TrimSpace(gjson.GetBytes(respBody, "output.task_id").String())
case constant.EndpointTypeVideoGenerator:
if codeRes := gjson.GetBytes(respBody, "status"); codeRes.Exists() && codeRes.Int() != 0 {
errMsg := strings.TrimSpace(gjson.GetBytes(respBody, "message").String())
if errMsg == "" {
errMsg = fmt.Sprintf("upstream returned status=%d", codeRes.Int())
}
bodyErr := fmt.Errorf(
"upstream error: %s | request: %s | response: %s",
errMsg,
truncateForError(string(bodyBytes)),
truncateForError(string(respBody)),
)
return testResult{
context: c,
localErr: bodyErr,
tokenFactoryError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
taskID = strings.TrimSpace(gjson.GetBytes(respBody, "result.task_id").String())
if taskID == "" {
taskID = strings.TrimSpace(gjson.GetBytes(respBody, "task_id").String())
}
}
if taskID == "" {
bodyErr := fmt.Errorf("upstream did not return task_id, body: %s", string(respBody))
return testResult{
context: c,
localErr: bodyErr,
tokenFactoryError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
milliseconds := time.Since(tik).Milliseconds()
common.SysLog(fmt.Sprintf("video test channel #%d ok, task_id=%s, took %dms, body: %s",
channel.Id, taskID, milliseconds, string(respBody)))
group, _ := model.GetUserGroup(1, false)
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
ChannelId: channel.Id,
ModelName: originModel,
TokenName: "模型测试",
Quota: 0,
Content: fmt.Sprintf("模型测试-视频生成(task_id=%s)", taskID),
UseTimeSeconds: int(milliseconds / 1000),
IsStream: false,
Group: group,
})
return testResult{
context: c,
localErr: nil,
tokenFactoryError: nil,
recordedModelName: originModel,
}
}
func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
switch u := usageAny.(type) {
case *dto.Usage:
return u, nil
case dto.Usage:
return &u, nil
case nil:
if !isStream {
return nil, errors.New("usage is nil")
}
usage := &dto.Usage{
PromptTokens: estimatePromptTokens,
}
usage.TotalTokens = usage.PromptTokens
return usage, nil
default:
if !isStream {
return nil, fmt.Errorf("invalid usage type: %T", usageAny)
}
usage := &dto.Usage{
PromptTokens: estimatePromptTokens,
}
usage.TotalTokens = usage.PromptTokens
return usage, nil
}
}
func readTestResponseBody(body io.ReadCloser, isStream bool) ([]byte, error) {
defer func() { _ = body.Close() }()
const maxStreamLogBytes = 8 << 10
if isStream {
return io.ReadAll(io.LimitReader(body, maxStreamLogBytes))
}
return io.ReadAll(body)
}
func detectErrorFromTestResponseBody(respBody []byte) error {
b := bytes.TrimSpace(respBody)
if len(b) == 0 {
return nil
}
if message := detectErrorMessageFromJSONBytes(b); message != "" {
return fmt.Errorf("upstream error: %s", message)
}
for _, line := range bytes.Split(b, []byte{'\n'}) {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
continue
}
if message := detectErrorMessageFromJSONBytes(payload); message != "" {
return fmt.Errorf("upstream error: %s", message)
}
}
return nil
}
func detectErrorMessageFromJSONBytes(jsonBytes []byte) string {
if len(jsonBytes) == 0 {
return ""
}
if jsonBytes[0] != '{' && jsonBytes[0] != '[' {
return ""
}
errVal := gjson.GetBytes(jsonBytes, "error")
if !errVal.Exists() || errVal.Type == gjson.Null {
return ""
}
message := gjson.GetBytes(jsonBytes, "error.message").String()
if message == "" {
message = gjson.GetBytes(jsonBytes, "error.error.message").String()
}
if message == "" && errVal.Type == gjson.String {
message = errVal.String()
}
if message == "" {
message = errVal.Raw
}
message = strings.TrimSpace(message)
if message == "" {
return "upstream returned error payload"
}
return message
}
func buildTestRequest(model string, endpointType string, channel *model.Channel, isStream bool) dto.Request {
testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`)
// 根据端点类型构建不同的测试请求
if endpointType != "" {
switch constant.EndpointType(endpointType) {
case constant.EndpointTypeEmbeddings:
// 返回 EmbeddingRequest
return &dto.EmbeddingRequest{
Model: model,
Input: []any{"hello world"},
}
case constant.EndpointTypeImageGeneration, constant.EndpointTypeTencentCloudVODImage:
// 返回 ImageRequest
return &dto.ImageRequest{
Model: model,
Prompt: "a cute cat",
N: lo.ToPtr(uint(1)),
Size: "1024x1024",
}
case constant.EndpointTypeJinaRerank:
// 返回 RerankRequest
return &dto.RerankRequest{
Model: model,
Query: "What is Deep Learning?",
Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
TopN: lo.ToPtr(2),
}
case constant.EndpointTypeOpenAIResponse:
// 返回 OpenAIResponsesRequest
return &dto.OpenAIResponsesRequest{
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Stream: lo.ToPtr(isStream),
}
case constant.EndpointTypeOpenAIResponseCompact:
// 返回 OpenAIResponsesCompactionRequest
return &dto.OpenAIResponsesCompactionRequest{
Model: model,
Input: testResponsesInput,
}
case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
// 返回 GeneralOpenAIRequest
maxTokens := uint(16)
if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
maxTokens = 3000
}
req := &dto.GeneralOpenAIRequest{
Model: model,
Stream: lo.ToPtr(isStream),
Messages: []dto.Message{
{
Role: "user",
Content: "hi",
},
},
MaxTokens: lo.ToPtr(maxTokens),
}
if isStream {
req.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
}
return req
}
}
// 自动检测逻辑(保持原有行为)
if strings.Contains(strings.ToLower(model), "rerank") {
return &dto.RerankRequest{
Model: model,
Query: "What is Deep Learning?",
Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
TopN: lo.ToPtr(2),
}
}
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") ||
strings.HasPrefix(model, "m3e") ||
strings.Contains(model, "bge-") {
// 返回 EmbeddingRequest
return &dto.EmbeddingRequest{
Model: model,
Input: []any{"hello world"},
}
}
// Responses compaction models (must use /v1/responses/compact)
if strings.HasSuffix(model, ratio_setting.CompactModelSuffix) {
return &dto.OpenAIResponsesCompactionRequest{
Model: model,
Input: testResponsesInput,
}
}
// Responses-only models (e.g. codex series)
if strings.Contains(strings.ToLower(model), "codex") {
return &dto.OpenAIResponsesRequest{
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Stream: lo.ToPtr(isStream),
}
}
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
testRequest := &dto.GeneralOpenAIRequest{
Model: model,
Stream: lo.ToPtr(isStream),
Messages: []dto.Message{
{
Role: "user",
Content: "hi",
},
},
}
if isStream {
testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
}
if strings.HasPrefix(model, "o") {
testRequest.MaxCompletionTokens = lo.ToPtr(uint(16))
} else if strings.Contains(model, "thinking") {
if !strings.Contains(model, "claude") {
testRequest.MaxTokens = lo.ToPtr(uint(50))
}
} else if strings.Contains(model, "gemini") {
testRequest.MaxTokens = lo.ToPtr(uint(3000))
} else {
testRequest.MaxTokens = lo.ToPtr(uint(16))
}
return testRequest
}
func TestChannel(c *gin.Context) {
channelId, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.CacheGetChannel(channelId)
if err != nil {
channel, err = model.GetChannelById(channelId, true)
if err != nil {
common.ApiError(c, err)
return
}
}
// 供应商仅允许测试自己归属的渠道。
if c.GetInt("role") < common.RoleAdminUser && channel.OwnerUserID != c.GetInt("id") {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "无权测试其他供应商渠道",
})
return
}
//defer func() {
// if channel.ChannelInfo.IsMultiKey {
// go func() { _ = channel.SaveChannelInfo() }()
// }
//}()
testModel := c.Query("model")
endpointType := c.Query("endpoint_type")
isStream, _ := strconv.ParseBool(c.Query("stream"))
tik := time.Now()
result := testChannel(channel, testModel, endpointType, isStream)
milliseconds := time.Since(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
modelForRecord := modelNameForChannelTestRecord(channel, testModel, result)
if result.localErr != nil {
go channel.UpdateTestResult(false, milliseconds, result.localErr.Error(), modelForRecord)
go func() {
_ = model.UpsertModelTestResult(channel.Id, modelForRecord, false, milliseconds, result.localErr.Error())
}()
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": result.localErr.Error(),
"time": consumedTime,
})
return
}
if result.tokenFactoryError != nil {
go channel.UpdateTestResult(false, milliseconds, result.tokenFactoryError.Error(), modelForRecord)
go func() {
_ = model.UpsertModelTestResult(channel.Id, modelForRecord, false, milliseconds, result.tokenFactoryError.Error())
}()
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": result.tokenFactoryError.Error(),
"time": consumedTime,
})
return
}
go channel.UpdateTestResult(true, milliseconds, "", modelForRecord)
go func() { _ = model.UpsertModelTestResult(channel.Id, modelForRecord, true, milliseconds, "") }()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"time": consumedTime,
})
}
// collectModelsForScheduledChannelTest 返回定时全量测试要对渠道串行探测的模型名列表(与非空 models 的批量上架测试一致);无任何绑定时返回单元素空串以走 testChannel 内置默认模型。
func collectModelsForScheduledChannelTest(channel *model.Channel) []string {
raw := channel.GetModels()
out := make([]string, 0, len(raw))
for _, m := range raw {
m = strings.TrimSpace(m)
if m != "" {
out = append(out, m)
}
}
if len(out) == 0 && channel.TestModel != nil {
tm := strings.TrimSpace(*channel.TestModel)
if tm != "" {
out = append(out, tm)
}
}
if len(out) == 0 {
return []string{""}
}
return out
}
// modelNameForChannelTestRecord 与 TestChannel HTTP 接口写入逻辑一致:优先 relay 解析名,其次本次请求模型名,再兜底渠道配置。
func modelNameForChannelTestRecord(channel *model.Channel, requestedModel string, result testResult) string {
modelForRecord := strings.TrimSpace(result.recordedModelName)
if modelForRecord != "" {
return modelForRecord
}
modelForRecord = strings.TrimSpace(requestedModel)
if modelForRecord != "" {
return modelForRecord
}
if channel.TestModel != nil && strings.TrimSpace(*channel.TestModel) != "" {
return strings.TrimSpace(*channel.TestModel)
}
models := channel.GetModels()
if len(models) > 0 {
return strings.TrimSpace(models[0])
}
return ""
}
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error {
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
return errors.New("测试已在运行中")
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
if getChannelErr != nil {
return getChannelErr
}
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
gopool.Go(func() {
// 使用 defer 确保无论如何都会重置运行状态,防止死锁
defer func() {
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
}()
for _, channel := range channels {
if channel.Status == common.ChannelStatusManuallyDisabled {
continue
}
statusAtStart := channel.Status
isInitiallyEnabled := statusAtStart == common.ChannelStatusEnabled
isInitiallyAutoDisabled := statusAtStart == common.ChannelStatusAutoDisabled
modelsToRun := collectModelsForScheduledChannelTest(channel)
var firstBanCtx *gin.Context
var firstBanTfErr *types.TokenFactoryError
var enableCtx *gin.Context
shouldEnableAfterTests := false
for _, modelName := range modelsToRun {
tik := time.Now()
result := testChannel(channel, modelName, "", false)
ms := time.Since(tik).Milliseconds()
testMessage := ""
if result.localErr != nil {
testMessage = result.localErr.Error()
} else if result.tokenFactoryError != nil {
testMessage = result.tokenFactoryError.Error()
}
testSuccess := result.localErr == nil && result.tokenFactoryError == nil
modelForRecord := modelNameForChannelTestRecord(channel, modelName, result)
channel.UpdateTestResult(testSuccess, ms, testMessage, modelForRecord)
_ = model.UpsertModelTestResult(channel.Id, modelForRecord, testSuccess, ms, testMessage)
tfErr := result.tokenFactoryError
shouldBanThis := false
var banTfErr *types.TokenFactoryError
if tfErr != nil {
shouldBanThis = service.ShouldDisableChannel(channel.Type, tfErr)
banTfErr = tfErr
}
if common.AutomaticDisableChannelEnabled && !shouldBanThis {
if ms > disableThreshold {
errSlow := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(ms)/1000.0, float64(disableThreshold)/1000.0)
banTfErr = types.NewOpenAIError(errSlow, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
shouldBanThis = true
}
}
if shouldBanThis && firstBanTfErr == nil && banTfErr != nil {
firstBanCtx = result.context
firstBanTfErr = banTfErr
}
if testSuccess && isInitiallyAutoDisabled {
shouldEnableAfterTests = true
if enableCtx == nil && result.context != nil {
enableCtx = result.context
}
}
time.Sleep(common.RequestInterval)
}
if isInitiallyEnabled && firstBanTfErr != nil && channel.GetAutoBan() {
if firstBanCtx != nil {
processChannelError(firstBanCtx, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(firstBanCtx, constant.ContextKeyChannelKey), channel.GetAutoBan()), firstBanTfErr)
} else {
common.SysError(fmt.Sprintf("scheduled channel test: ban signaled but context nil for channel id=%d", channel.Id))
}
}
if isInitiallyAutoDisabled && shouldEnableAfterTests && service.ShouldEnableChannel(nil, common.ChannelStatusAutoDisabled) {
usingKey := ""
if enableCtx != nil {
usingKey = common.GetContextKeyString(enableCtx, constant.ContextKeyChannelKey)
}
service.EnableChannel(channel.Id, usingKey, channel.Name)
}
}
if notify {
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}
})
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(true)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
}
var autoTestChannelsOnce sync.Once
func AutomaticallyTestChannels() {
// 只在Master节点定时测试渠道
if !common.IsMasterNode {
return
}
autoTestChannelsOnce.Do(func() {
for {
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
time.Sleep(1 * time.Minute)
continue
}
for {
frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes
time.Sleep(time.Duration(int(math.Round(frequency))) * time.Minute)
common.SysLog(fmt.Sprintf("automatically test channels with interval %f minutes", frequency))
common.SysLog("automatically testing all channels")
_ = testAllChannels(false)
common.SysLog("automatically channel test finished")
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
break
}
}
}
})
}