tokenFactory/relay/common/relay_utils.go

265 lines
7.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package common
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
)
type HasPrompt interface {
GetPrompt() string
}
type HasImage interface {
HasImage() bool
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case constant.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case constant.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
return apiVersion
}
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
return &dto.TaskError{
Code: code,
Message: err.Error(),
StatusCode: statusCode,
LocalError: localError,
Error: err,
}
}
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) {
info.Action = action
c.Set("task_request", requestObj)
}
func GetTaskRequest(c *gin.Context) (TaskSubmitReq, error) {
v, exists := c.Get("task_request")
if !exists {
return TaskSubmitReq{}, fmt.Errorf("request not found in context")
}
req, ok := v.(TaskSubmitReq)
if !ok {
return TaskSubmitReq{}, fmt.Errorf("invalid task request type")
}
return req, nil
}
func validatePrompt(prompt string) *dto.TaskError {
if strings.TrimSpace(prompt) == "" {
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
}
return nil
}
func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) {
var req TaskSubmitReq
if _, err := c.MultipartForm(); err != nil {
return req, err
}
formData := c.Request.PostForm
req = TaskSubmitReq{
Prompt: formData.Get("prompt"),
Model: formData.Get("model"),
Mode: formData.Get("mode"),
Image: formData.Get("image"),
Size: formData.Get("size"),
Metadata: make(map[string]interface{}),
}
if durationStr := formData.Get("seconds"); durationStr != "" {
if duration, err := strconv.Atoi(durationStr); err == nil {
req.Duration = duration
}
}
if images := formData["images"]; len(images) > 0 {
req.Images = images
}
for key, values := range formData {
if len(values) > 0 && !isKnownTaskField(key) {
if intVal, err := strconv.Atoi(values[0]); err == nil {
req.Metadata[key] = intVal
} else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil {
req.Metadata[key] = floatVal
} else {
req.Metadata[key] = values[0]
}
}
}
return req, nil
}
func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
var prompt string
var model string
var seconds int
var size string
var hasInputReference bool
var req TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
}
prompt = req.Prompt
model = req.Model
size = req.Size
seconds, _ = strconv.Atoi(req.Seconds)
if seconds == 0 {
seconds = req.Duration
}
if req.InputReference != "" {
req.Images = []string{req.InputReference}
}
if strings.TrimSpace(req.Model) == "" {
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
}
if req.HasImage() {
hasInputReference = true
}
if taskErr := validatePrompt(prompt); taskErr != nil {
return taskErr
}
action := constant.TaskActionTextGenerate
if hasInputReference {
action = constant.TaskActionGenerate
}
if strings.HasPrefix(model, "sora-2") {
if size == "" {
size = "720x1280"
}
if seconds <= 0 {
seconds = 4
}
if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) {
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
}
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
}
// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
}
storeTaskRequest(c, info, action, req)
return nil
}
func isKnownTaskField(field string) bool {
knownFields := map[string]bool{
"prompt": true,
"model": true,
"mode": true,
"image": true,
"images": true,
"size": true,
"duration": true,
"input_reference": true, // Sora 特有字段
}
return knownFields[field]
}
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
var err error
contentType := c.GetHeader("Content-Type")
var req TaskSubmitReq
if strings.HasPrefix(contentType, "multipart/form-data") {
req, err = validateMultipartTaskRequest(c, info, action)
if err != nil {
return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
}
} else if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
return taskErr
}
if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
// 兼容单图上传
req.Images = []string{req.Image}
}
if len(req.Images) == 0 && strings.TrimSpace(req.InputReference) != "" {
// 与 ValidateMultipartDirect 一致Sora 等 JSON 路径的参考图视为图生输入
req.Images = []string{req.InputReference}
}
// 去掉仅空白/空串的占位项,避免 images: [""] 被误判为图生
if len(req.Images) > 0 {
compact := make([]string, 0, len(req.Images))
for _, u := range req.Images {
if s := strings.TrimSpace(u); s != "" {
compact = append(compact, s)
}
}
req.Images = compact
}
// 将标准 OpenAI 视频字段统一落到 metadata供各渠道 adaptor 与计费逻辑消费。
if req.Metadata == nil {
req.Metadata = make(map[string]interface{})
}
if req.N != nil && *req.N > 0 {
req.Metadata["n"] = *req.N
}
if req.FPS != nil && *req.FPS > 0 {
req.Metadata["fps"] = *req.FPS
}
if req.Motion != nil {
req.Metadata["motion"] = *req.Motion
}
if strings.TrimSpace(req.NegativePrompt) != "" {
req.Metadata["negative_prompt"] = req.NegativePrompt
}
if req.Seed != nil {
req.Metadata["seed"] = *req.Seed
}
// 多个视频 adaptor 误把默认 action 传成 TaskActionGenerate导致无图请求在任务日志里
// 仍显示为「图生视频」。仅在请求侧确实无图时降为文生remix 等已由 ResolveOriginTask 预设的 action 不得覆盖。
actionToStore := action
if info.Action == constant.TaskActionRemix {
actionToStore = constant.TaskActionRemix
} else if action == constant.TaskActionGenerate && !req.HasImage() {
actionToStore = constant.TaskActionTextGenerate
}
storeTaskRequest(c, info, actionToStore, req)
return nil
}