tokenFactory/relay/channel/task/doubao/adaptor.go

331 lines
10 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package doubao
import (
"bytes"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/samber/lo"
)
// ============================
// Request / Response structures
// ============================
type ContentItem struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *MediaURL `json:"image_url,omitempty"`
VideoURL *MediaURL `json:"video_url,omitempty"`
AudioURL *MediaURL `json:"audio_url,omitempty"`
Role string `json:"role,omitempty"`
}
type MediaURL struct {
URL string `json:"url,omitempty"`
}
type requestPayload struct {
Model string `json:"model"`
Content []ContentItem `json:"content,omitempty"`
CallbackURL string `json:"callback_url,omitempty"`
ReturnLastFrame *dto.BoolValue `json:"return_last_frame,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
ExecutionExpiresAfter *dto.IntValue `json:"execution_expires_after,omitempty"`
GenerateAudio *dto.BoolValue `json:"generate_audio,omitempty"`
Draft *dto.BoolValue `json:"draft,omitempty"`
Tools []struct {
Type string `json:"type,omitempty"`
} `json:"tools,omitempty"`
Resolution string `json:"resolution,omitempty"`
Ratio string `json:"ratio,omitempty"`
Duration *dto.IntValue `json:"duration,omitempty"`
Frames *dto.IntValue `json:"frames,omitempty"`
Seed *dto.IntValue `json:"seed,omitempty"`
CameraFixed *dto.BoolValue `json:"camera_fixed,omitempty"`
Watermark *dto.BoolValue `json:"watermark,omitempty"`
}
type responsePayload struct {
ID string `json:"id"` // task_id
}
type responseTask struct {
ID string `json:"id"`
Model string `json:"model"`
Status string `json:"status"`
Content struct {
VideoURL string `json:"video_url"`
} `json:"content"`
Seed int `json:"seed"`
Resolution string `json:"resolution"`
Duration int `json:"duration"`
Ratio string `json:"ratio"`
FramesPerSecond int `json:"framespersecond"`
ServiceTier string `json:"service_tier"`
Tools []struct {
Type string `json:"type"`
} `json:"tools"`
Usage struct {
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
ToolUsage struct {
WebSearch int `json:"web_search"`
} `json:"tool_usage"`
} `json:"usage"`
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.ChannelBaseUrl
a.apiKey = info.ApiKey
}
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(_ *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v3/contents/generations/tasks", a.baseURL), nil
}
// BuildRequestHeader sets required headers.
func (a *TaskAdaptor) BuildRequestHeader(_ *gin.Context, req *http.Request, _ *relaycommon.RelayInfo) error {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+a.apiKey)
return nil
}
// BuildRequestBody converts request into Doubao specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
req, err := relaycommon.GetTaskRequest(c)
if err != nil {
return nil, err
}
body, err := a.convertToRequestPayload(&req)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
if info.UseRelayTaskUpstreamModel() {
body.Model = info.UpstreamModelName
} else {
info.UpstreamModelName = body.Model
}
data, err := common.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response, returns taskID etc.
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
_ = resp.Body.Close()
// Parse Doubao response
var dResp responsePayload
if err := common.Unmarshal(responseBody, &dResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
if dResp.ID == "" {
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
return
}
ov := dto.NewOpenAIVideo()
ov.ID = info.PublicTaskID
ov.CreatedAt = dto.FormatTimeUnixRFC3339(time.Now().Unix())
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
return dResp.ID, responseBody, nil
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
uri := fmt.Sprintf("%s/api/v3/contents/generations/tasks/%s", baseUrl, taskID)
req, err := http.NewRequest(http.MethodGet, uri, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+key)
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return ModelList
}
func (a *TaskAdaptor) GetChannelName() string {
return ChannelName
}
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
r := requestPayload{
Model: req.Model,
Content: []ContentItem{},
}
// Add images if present
if req.HasImage() {
for _, imgURL := range req.Images {
r.Content = append(r.Content, ContentItem{
Type: "image_url",
ImageURL: &MediaURL{
URL: imgURL,
},
})
}
}
metadata := req.Metadata
if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
if sec, _ := strconv.Atoi(req.Seconds); sec > 0 {
r.Duration = lo.ToPtr(dto.IntValue(sec))
}
r.Content = lo.Reject(r.Content, func(c ContentItem, _ int) bool { return c.Type == "text" })
r.Content = append(r.Content, ContentItem{
Type: "text",
Text: req.Prompt,
})
return &r, nil
}
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resTask := responseTask{}
if err := common.Unmarshal(respBody, &resTask); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
taskResult := relaycommon.TaskInfo{
Code: 0,
}
// 上游(含 Seedance 2.x可能返回大小写混合的状态枚举例如 SUCCESS / Succeeded
statusNorm := strings.ToLower(strings.TrimSpace(resTask.Status))
// Map Doubao status to internal status
switch statusNorm {
case "pending", "queued", "submitted":
taskResult.Status = model.TaskStatusQueued
taskResult.Progress = "10%"
case "processing", "running", "in_progress":
taskResult.Status = model.TaskStatusInProgress
taskResult.Progress = "50%"
case "succeeded", "success", "completed":
taskResult.Status = model.TaskStatusSuccess
taskResult.Progress = "100%"
taskResult.Url = resTask.Content.VideoURL
// 解析 usage 信息用于按倍率计费
taskResult.CompletionTokens = resTask.Usage.CompletionTokens
taskResult.TotalTokens = resTask.Usage.TotalTokens
case "failed", "error", "cancelled", "canceled":
taskResult.Status = model.TaskStatusFailure
taskResult.Progress = "100%"
taskResult.Reason = resTask.Error.Message
default:
// Unknown status, treat as processing
taskResult.Status = model.TaskStatusInProgress
taskResult.Progress = "30%"
}
return &taskResult, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var dResp responseTask
if err := common.Unmarshal(originTask.Data, &dResp); err != nil {
return nil, errors.Wrap(err, "unmarshal doubao task data failed")
}
openAIVideo := dto.NewOpenAIVideo()
openAIVideo.ID = originTask.TaskID
openAIVideo.Status = originTask.Status.ToVideoStatus()
openAIVideo.SetProgressStr(originTask.Progress)
openAIVideo.SetMetadata("url", dResp.Content.VideoURL)
openAIVideo.CreatedAt = dto.FormatTimeUnixRFC3339(originTask.CreatedAt)
if originTask.FinishTime > 0 {
openAIVideo.CompletedAt = dto.FormatTimeUnixRFC3339(originTask.FinishTime)
}
openAIVideo.Model = originTask.Properties.OriginModelName
st := strings.ToLower(strings.TrimSpace(dResp.Status))
if st == "failed" || st == "error" || st == "cancelled" || st == "canceled" {
openAIVideo.Error = &dto.OpenAIVideoError{
Message: dResp.Error.Message,
Code: dResp.Error.Code,
}
}
return common.Marshal(openAIVideo)
}