331 lines
10 KiB
Go
331 lines
10 KiB
Go
package doubao
|
||
|
||
import (
|
||
"bytes"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
"github.com/QuantumNous/new-api/model"
|
||
"github.com/QuantumNous/new-api/relay/channel"
|
||
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||
"github.com/QuantumNous/new-api/service"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/pkg/errors"
|
||
"github.com/samber/lo"
|
||
)
|
||
|
||
// ============================
|
||
// Request / Response structures
|
||
// ============================
|
||
|
||
type ContentItem struct {
|
||
Type string `json:"type,omitempty"`
|
||
Text string `json:"text,omitempty"`
|
||
ImageURL *MediaURL `json:"image_url,omitempty"`
|
||
VideoURL *MediaURL `json:"video_url,omitempty"`
|
||
AudioURL *MediaURL `json:"audio_url,omitempty"`
|
||
Role string `json:"role,omitempty"`
|
||
}
|
||
|
||
type MediaURL struct {
|
||
URL string `json:"url,omitempty"`
|
||
}
|
||
|
||
type requestPayload struct {
|
||
Model string `json:"model"`
|
||
Content []ContentItem `json:"content,omitempty"`
|
||
CallbackURL string `json:"callback_url,omitempty"`
|
||
ReturnLastFrame *dto.BoolValue `json:"return_last_frame,omitempty"`
|
||
ServiceTier string `json:"service_tier,omitempty"`
|
||
ExecutionExpiresAfter *dto.IntValue `json:"execution_expires_after,omitempty"`
|
||
GenerateAudio *dto.BoolValue `json:"generate_audio,omitempty"`
|
||
Draft *dto.BoolValue `json:"draft,omitempty"`
|
||
Tools []struct {
|
||
Type string `json:"type,omitempty"`
|
||
} `json:"tools,omitempty"`
|
||
Resolution string `json:"resolution,omitempty"`
|
||
Ratio string `json:"ratio,omitempty"`
|
||
Duration *dto.IntValue `json:"duration,omitempty"`
|
||
Frames *dto.IntValue `json:"frames,omitempty"`
|
||
Seed *dto.IntValue `json:"seed,omitempty"`
|
||
CameraFixed *dto.BoolValue `json:"camera_fixed,omitempty"`
|
||
Watermark *dto.BoolValue `json:"watermark,omitempty"`
|
||
}
|
||
|
||
type responsePayload struct {
|
||
ID string `json:"id"` // task_id
|
||
}
|
||
|
||
type responseTask struct {
|
||
ID string `json:"id"`
|
||
Model string `json:"model"`
|
||
Status string `json:"status"`
|
||
Content struct {
|
||
VideoURL string `json:"video_url"`
|
||
} `json:"content"`
|
||
Seed int `json:"seed"`
|
||
Resolution string `json:"resolution"`
|
||
Duration int `json:"duration"`
|
||
Ratio string `json:"ratio"`
|
||
FramesPerSecond int `json:"framespersecond"`
|
||
ServiceTier string `json:"service_tier"`
|
||
Tools []struct {
|
||
Type string `json:"type"`
|
||
} `json:"tools"`
|
||
Usage struct {
|
||
CompletionTokens int `json:"completion_tokens"`
|
||
TotalTokens int `json:"total_tokens"`
|
||
ToolUsage struct {
|
||
WebSearch int `json:"web_search"`
|
||
} `json:"tool_usage"`
|
||
} `json:"usage"`
|
||
Error struct {
|
||
Code string `json:"code"`
|
||
Message string `json:"message"`
|
||
} `json:"error"`
|
||
CreatedAt int64 `json:"created_at"`
|
||
UpdatedAt int64 `json:"updated_at"`
|
||
}
|
||
|
||
// ============================
|
||
// Adaptor implementation
|
||
// ============================
|
||
|
||
type TaskAdaptor struct {
|
||
taskcommon.BaseBilling
|
||
ChannelType int
|
||
apiKey string
|
||
baseURL string
|
||
}
|
||
|
||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||
a.ChannelType = info.ChannelType
|
||
a.baseURL = info.ChannelBaseUrl
|
||
a.apiKey = info.ApiKey
|
||
}
|
||
|
||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||
// Accept only POST /v1/video/generations as "generate" action.
|
||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||
}
|
||
|
||
// BuildRequestURL constructs the upstream URL.
|
||
func (a *TaskAdaptor) BuildRequestURL(_ *relaycommon.RelayInfo) (string, error) {
|
||
return fmt.Sprintf("%s/api/v3/contents/generations/tasks", a.baseURL), nil
|
||
}
|
||
|
||
// BuildRequestHeader sets required headers.
|
||
func (a *TaskAdaptor) BuildRequestHeader(_ *gin.Context, req *http.Request, _ *relaycommon.RelayInfo) error {
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Accept", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+a.apiKey)
|
||
return nil
|
||
}
|
||
|
||
// BuildRequestBody converts request into Doubao specific format.
|
||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||
req, err := relaycommon.GetTaskRequest(c)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
body, err := a.convertToRequestPayload(&req)
|
||
if err != nil {
|
||
return nil, errors.Wrap(err, "convert request payload failed")
|
||
}
|
||
if info.UseRelayTaskUpstreamModel() {
|
||
body.Model = info.UpstreamModelName
|
||
} else {
|
||
info.UpstreamModelName = body.Model
|
||
}
|
||
data, err := common.Marshal(body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return bytes.NewReader(data), nil
|
||
}
|
||
|
||
// DoRequest delegates to common helper.
|
||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||
}
|
||
|
||
// DoResponse handles upstream response, returns taskID etc.
|
||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
_ = resp.Body.Close()
|
||
|
||
// Parse Doubao response
|
||
var dResp responsePayload
|
||
if err := common.Unmarshal(responseBody, &dResp); err != nil {
|
||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
if dResp.ID == "" {
|
||
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
ov := dto.NewOpenAIVideo()
|
||
ov.ID = info.PublicTaskID
|
||
ov.CreatedAt = dto.FormatTimeUnixRFC3339(time.Now().Unix())
|
||
ov.Model = info.OriginModelName
|
||
|
||
c.JSON(http.StatusOK, ov)
|
||
return dResp.ID, responseBody, nil
|
||
}
|
||
|
||
// FetchTask fetch task status
|
||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||
taskID, ok := body["task_id"].(string)
|
||
if !ok {
|
||
return nil, fmt.Errorf("invalid task_id")
|
||
}
|
||
|
||
uri := fmt.Sprintf("%s/api/v3/contents/generations/tasks/%s", baseUrl, taskID)
|
||
|
||
req, err := http.NewRequest(http.MethodGet, uri, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
req.Header.Set("Accept", "application/json")
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+key)
|
||
|
||
client, err := service.GetHttpClientWithProxy(proxy)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||
}
|
||
return client.Do(req)
|
||
}
|
||
|
||
func (a *TaskAdaptor) GetModelList() []string {
|
||
return ModelList
|
||
}
|
||
|
||
func (a *TaskAdaptor) GetChannelName() string {
|
||
return ChannelName
|
||
}
|
||
|
||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||
r := requestPayload{
|
||
Model: req.Model,
|
||
Content: []ContentItem{},
|
||
}
|
||
|
||
// Add images if present
|
||
if req.HasImage() {
|
||
for _, imgURL := range req.Images {
|
||
r.Content = append(r.Content, ContentItem{
|
||
Type: "image_url",
|
||
ImageURL: &MediaURL{
|
||
URL: imgURL,
|
||
},
|
||
})
|
||
}
|
||
}
|
||
|
||
metadata := req.Metadata
|
||
if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil {
|
||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||
}
|
||
|
||
if sec, _ := strconv.Atoi(req.Seconds); sec > 0 {
|
||
r.Duration = lo.ToPtr(dto.IntValue(sec))
|
||
}
|
||
|
||
r.Content = lo.Reject(r.Content, func(c ContentItem, _ int) bool { return c.Type == "text" })
|
||
r.Content = append(r.Content, ContentItem{
|
||
Type: "text",
|
||
Text: req.Prompt,
|
||
})
|
||
|
||
return &r, nil
|
||
}
|
||
|
||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||
resTask := responseTask{}
|
||
if err := common.Unmarshal(respBody, &resTask); err != nil {
|
||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||
}
|
||
|
||
taskResult := relaycommon.TaskInfo{
|
||
Code: 0,
|
||
}
|
||
|
||
// 上游(含 Seedance 2.x)可能返回大小写混合的状态枚举,例如 SUCCESS / Succeeded
|
||
statusNorm := strings.ToLower(strings.TrimSpace(resTask.Status))
|
||
|
||
// Map Doubao status to internal status
|
||
switch statusNorm {
|
||
case "pending", "queued", "submitted":
|
||
taskResult.Status = model.TaskStatusQueued
|
||
taskResult.Progress = "10%"
|
||
case "processing", "running", "in_progress":
|
||
taskResult.Status = model.TaskStatusInProgress
|
||
taskResult.Progress = "50%"
|
||
case "succeeded", "success", "completed":
|
||
taskResult.Status = model.TaskStatusSuccess
|
||
taskResult.Progress = "100%"
|
||
taskResult.Url = resTask.Content.VideoURL
|
||
// 解析 usage 信息用于按倍率计费
|
||
taskResult.CompletionTokens = resTask.Usage.CompletionTokens
|
||
taskResult.TotalTokens = resTask.Usage.TotalTokens
|
||
case "failed", "error", "cancelled", "canceled":
|
||
taskResult.Status = model.TaskStatusFailure
|
||
taskResult.Progress = "100%"
|
||
taskResult.Reason = resTask.Error.Message
|
||
default:
|
||
// Unknown status, treat as processing
|
||
taskResult.Status = model.TaskStatusInProgress
|
||
taskResult.Progress = "30%"
|
||
}
|
||
|
||
return &taskResult, nil
|
||
}
|
||
|
||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||
var dResp responseTask
|
||
if err := common.Unmarshal(originTask.Data, &dResp); err != nil {
|
||
return nil, errors.Wrap(err, "unmarshal doubao task data failed")
|
||
}
|
||
|
||
openAIVideo := dto.NewOpenAIVideo()
|
||
openAIVideo.ID = originTask.TaskID
|
||
openAIVideo.Status = originTask.Status.ToVideoStatus()
|
||
openAIVideo.SetProgressStr(originTask.Progress)
|
||
openAIVideo.SetMetadata("url", dResp.Content.VideoURL)
|
||
openAIVideo.CreatedAt = dto.FormatTimeUnixRFC3339(originTask.CreatedAt)
|
||
if originTask.FinishTime > 0 {
|
||
openAIVideo.CompletedAt = dto.FormatTimeUnixRFC3339(originTask.FinishTime)
|
||
}
|
||
openAIVideo.Model = originTask.Properties.OriginModelName
|
||
|
||
st := strings.ToLower(strings.TrimSpace(dResp.Status))
|
||
if st == "failed" || st == "error" || st == "cancelled" || st == "canceled" {
|
||
openAIVideo.Error = &dto.OpenAIVideoError{
|
||
Message: dResp.Error.Message,
|
||
Code: dResp.Error.Code,
|
||
}
|
||
}
|
||
|
||
return common.Marshal(openAIVideo)
|
||
}
|