265 lines
7.3 KiB
Go
265 lines
7.3 KiB
Go
package common
|
||
|
||
import (
|
||
"fmt"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/samber/lo"
|
||
)
|
||
|
||
type HasPrompt interface {
|
||
GetPrompt() string
|
||
}
|
||
|
||
type HasImage interface {
|
||
HasImage() bool
|
||
}
|
||
|
||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||
|
||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||
switch channelType {
|
||
case constant.ChannelTypeOpenAI:
|
||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||
case constant.ChannelTypeAzure:
|
||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||
}
|
||
}
|
||
return fullRequestURL
|
||
}
|
||
|
||
func GetAPIVersion(c *gin.Context) string {
|
||
query := c.Request.URL.Query()
|
||
apiVersion := query.Get("api-version")
|
||
if apiVersion == "" {
|
||
apiVersion = c.GetString("api_version")
|
||
}
|
||
return apiVersion
|
||
}
|
||
|
||
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
|
||
return &dto.TaskError{
|
||
Code: code,
|
||
Message: err.Error(),
|
||
StatusCode: statusCode,
|
||
LocalError: localError,
|
||
Error: err,
|
||
}
|
||
}
|
||
|
||
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) {
|
||
info.Action = action
|
||
c.Set("task_request", requestObj)
|
||
}
|
||
func GetTaskRequest(c *gin.Context) (TaskSubmitReq, error) {
|
||
v, exists := c.Get("task_request")
|
||
if !exists {
|
||
return TaskSubmitReq{}, fmt.Errorf("request not found in context")
|
||
}
|
||
req, ok := v.(TaskSubmitReq)
|
||
if !ok {
|
||
return TaskSubmitReq{}, fmt.Errorf("invalid task request type")
|
||
}
|
||
return req, nil
|
||
}
|
||
|
||
func validatePrompt(prompt string) *dto.TaskError {
|
||
if strings.TrimSpace(prompt) == "" {
|
||
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) {
|
||
var req TaskSubmitReq
|
||
if _, err := c.MultipartForm(); err != nil {
|
||
return req, err
|
||
}
|
||
|
||
formData := c.Request.PostForm
|
||
req = TaskSubmitReq{
|
||
Prompt: formData.Get("prompt"),
|
||
Model: formData.Get("model"),
|
||
Mode: formData.Get("mode"),
|
||
Image: formData.Get("image"),
|
||
Size: formData.Get("size"),
|
||
Metadata: make(map[string]interface{}),
|
||
}
|
||
|
||
if durationStr := formData.Get("seconds"); durationStr != "" {
|
||
if duration, err := strconv.Atoi(durationStr); err == nil {
|
||
req.Duration = duration
|
||
}
|
||
}
|
||
|
||
if images := formData["images"]; len(images) > 0 {
|
||
req.Images = images
|
||
}
|
||
|
||
for key, values := range formData {
|
||
if len(values) > 0 && !isKnownTaskField(key) {
|
||
if intVal, err := strconv.Atoi(values[0]); err == nil {
|
||
req.Metadata[key] = intVal
|
||
} else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil {
|
||
req.Metadata[key] = floatVal
|
||
} else {
|
||
req.Metadata[key] = values[0]
|
||
}
|
||
}
|
||
}
|
||
return req, nil
|
||
}
|
||
|
||
func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||
var prompt string
|
||
var model string
|
||
var seconds int
|
||
var size string
|
||
var hasInputReference bool
|
||
|
||
var req TaskSubmitReq
|
||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||
return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
|
||
}
|
||
|
||
prompt = req.Prompt
|
||
model = req.Model
|
||
size = req.Size
|
||
seconds, _ = strconv.Atoi(req.Seconds)
|
||
if seconds == 0 {
|
||
seconds = req.Duration
|
||
}
|
||
if req.InputReference != "" {
|
||
req.Images = []string{req.InputReference}
|
||
}
|
||
|
||
if strings.TrimSpace(req.Model) == "" {
|
||
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
|
||
}
|
||
|
||
if req.HasImage() {
|
||
hasInputReference = true
|
||
}
|
||
|
||
if taskErr := validatePrompt(prompt); taskErr != nil {
|
||
return taskErr
|
||
}
|
||
|
||
action := constant.TaskActionTextGenerate
|
||
if hasInputReference {
|
||
action = constant.TaskActionGenerate
|
||
}
|
||
if strings.HasPrefix(model, "sora-2") {
|
||
|
||
if size == "" {
|
||
size = "720x1280"
|
||
}
|
||
|
||
if seconds <= 0 {
|
||
seconds = 4
|
||
}
|
||
|
||
if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) {
|
||
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
|
||
}
|
||
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
|
||
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
|
||
}
|
||
// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
|
||
}
|
||
|
||
storeTaskRequest(c, info, action, req)
|
||
|
||
return nil
|
||
}
|
||
|
||
func isKnownTaskField(field string) bool {
|
||
knownFields := map[string]bool{
|
||
"prompt": true,
|
||
"model": true,
|
||
"mode": true,
|
||
"image": true,
|
||
"images": true,
|
||
"size": true,
|
||
"duration": true,
|
||
"input_reference": true, // Sora 特有字段
|
||
}
|
||
return knownFields[field]
|
||
}
|
||
|
||
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
|
||
var err error
|
||
contentType := c.GetHeader("Content-Type")
|
||
var req TaskSubmitReq
|
||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||
req, err = validateMultipartTaskRequest(c, info, action)
|
||
if err != nil {
|
||
return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
|
||
}
|
||
} else if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
|
||
}
|
||
|
||
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
|
||
return taskErr
|
||
}
|
||
|
||
if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
|
||
// 兼容单图上传
|
||
req.Images = []string{req.Image}
|
||
}
|
||
if len(req.Images) == 0 && strings.TrimSpace(req.InputReference) != "" {
|
||
// 与 ValidateMultipartDirect 一致:Sora 等 JSON 路径的参考图视为图生输入
|
||
req.Images = []string{req.InputReference}
|
||
}
|
||
// 去掉仅空白/空串的占位项,避免 images: [""] 被误判为图生
|
||
if len(req.Images) > 0 {
|
||
compact := make([]string, 0, len(req.Images))
|
||
for _, u := range req.Images {
|
||
if s := strings.TrimSpace(u); s != "" {
|
||
compact = append(compact, s)
|
||
}
|
||
}
|
||
req.Images = compact
|
||
}
|
||
|
||
// 将标准 OpenAI 视频字段统一落到 metadata,供各渠道 adaptor 与计费逻辑消费。
|
||
if req.Metadata == nil {
|
||
req.Metadata = make(map[string]interface{})
|
||
}
|
||
if req.N != nil && *req.N > 0 {
|
||
req.Metadata["n"] = *req.N
|
||
}
|
||
if req.FPS != nil && *req.FPS > 0 {
|
||
req.Metadata["fps"] = *req.FPS
|
||
}
|
||
if req.Motion != nil {
|
||
req.Metadata["motion"] = *req.Motion
|
||
}
|
||
if strings.TrimSpace(req.NegativePrompt) != "" {
|
||
req.Metadata["negative_prompt"] = req.NegativePrompt
|
||
}
|
||
if req.Seed != nil {
|
||
req.Metadata["seed"] = *req.Seed
|
||
}
|
||
|
||
// 多个视频 adaptor 误把默认 action 传成 TaskActionGenerate,导致无图请求在任务日志里
|
||
// 仍显示为「图生视频」。仅在请求侧确实无图时降为文生;remix 等已由 ResolveOriginTask 预设的 action 不得覆盖。
|
||
actionToStore := action
|
||
if info.Action == constant.TaskActionRemix {
|
||
actionToStore = constant.TaskActionRemix
|
||
} else if action == constant.TaskActionGenerate && !req.HasImage() {
|
||
actionToStore = constant.TaskActionTextGenerate
|
||
}
|
||
storeTaskRequest(c, info, actionToStore, req)
|
||
return nil
|
||
}
|