792 lines
23 KiB
Go
792 lines
23 KiB
Go
package ratio_setting
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/types"
|
|
"github.com/shopspring/decimal"
|
|
)
|
|
|
|
const RequestTierModeProgressive = "progressive"
|
|
|
|
type RequestTierSegment struct {
|
|
UpTo int64 `json:"up_to"`
|
|
Ratio float64 `json:"ratio"`
|
|
}
|
|
|
|
// RequestTierPricingRule 用于模板系统,保留用于兼容性
|
|
type RequestTierPricingRule struct {
|
|
Mode string `json:"mode,omitempty"`
|
|
Input []RequestTierSegment `json:"input,omitempty"`
|
|
Output []RequestTierSegment `json:"output,omitempty"`
|
|
CacheRead []RequestTierSegment `json:"cache_read,omitempty"`
|
|
CacheWrite []RequestTierSegment `json:"cache_write,omitempty"`
|
|
}
|
|
|
|
// 新的独立阶梯倍率结构
|
|
type TierSegments struct {
|
|
Segments []RequestTierSegment `json:"segments,omitempty"`
|
|
}
|
|
|
|
type RequestTierPricingTemplate struct {
|
|
Name string `json:"name,omitempty"`
|
|
RequestTierPricingRule
|
|
}
|
|
|
|
type RequestTierPricingBreakdownItem struct {
|
|
From int64 `json:"from"`
|
|
To int64 `json:"to"`
|
|
Tokens string `json:"tokens"`
|
|
Ratio float64 `json:"ratio"`
|
|
Result string `json:"result"`
|
|
}
|
|
|
|
type RequestTierPricingBreakdown struct {
|
|
InputBefore string `json:"input_before,omitempty"`
|
|
InputAfter string `json:"input_after,omitempty"`
|
|
OutputBefore string `json:"output_before,omitempty"`
|
|
OutputAfter string `json:"output_after,omitempty"`
|
|
CacheReadBefore string `json:"cache_read_before,omitempty"`
|
|
CacheReadAfter string `json:"cache_read_after,omitempty"`
|
|
CacheWriteBefore string `json:"cache_write_before,omitempty"`
|
|
CacheWriteAfter string `json:"cache_write_after,omitempty"`
|
|
Details map[string][]RequestTierPricingBreakdownItem `json:"details,omitempty"`
|
|
}
|
|
|
|
var requestTierPricingTemplatesMap = types.NewRWMap[string, RequestTierPricingTemplate]()
|
|
|
|
// 新的四个独立阶梯倍率 Map
|
|
var modelTierRatioMap = types.NewRWMap[string, TierSegments]()
|
|
var completionTierRatioMap = types.NewRWMap[string, TierSegments]()
|
|
var cacheTierRatioMap = types.NewRWMap[string, TierSegments]()
|
|
var createCacheTierRatioMap = types.NewRWMap[string, TierSegments]()
|
|
|
|
// 渠道级别的四个独立阶梯倍率 Map
|
|
var channelModelTierRatioMap = types.NewRWMap[string, map[string]TierSegments]()
|
|
var channelCompletionTierRatioMap = types.NewRWMap[string, map[string]TierSegments]()
|
|
var channelCacheTierRatioMap = types.NewRWMap[string, map[string]TierSegments]()
|
|
var channelCreateCacheTierRatioMap = types.NewRWMap[string, map[string]TierSegments]()
|
|
|
|
func normalizeRequestTierSegments(segments []RequestTierSegment) []RequestTierSegment {
|
|
out := make([]RequestTierSegment, 0, len(segments))
|
|
for _, segment := range segments {
|
|
if segment.Ratio < 0 {
|
|
continue
|
|
}
|
|
out = append(out, segment)
|
|
}
|
|
sort.SliceStable(out, func(i, j int) bool {
|
|
if out[i].UpTo == 0 {
|
|
return false
|
|
}
|
|
if out[j].UpTo == 0 {
|
|
return true
|
|
}
|
|
return out[i].UpTo < out[j].UpTo
|
|
})
|
|
return out
|
|
}
|
|
|
|
func normalizeRequestTierRule(rule RequestTierPricingRule) RequestTierPricingRule {
|
|
if strings.TrimSpace(rule.Mode) == "" {
|
|
rule.Mode = RequestTierModeProgressive
|
|
}
|
|
rule.Input = normalizeRequestTierSegments(rule.Input)
|
|
rule.Output = normalizeRequestTierSegments(rule.Output)
|
|
rule.CacheRead = normalizeRequestTierSegments(rule.CacheRead)
|
|
rule.CacheWrite = normalizeRequestTierSegments(rule.CacheWrite)
|
|
return rule
|
|
}
|
|
|
|
func validateRequestTierSegments(name string, segments []RequestTierSegment) error {
|
|
if len(segments) == 0 {
|
|
return nil
|
|
}
|
|
previous := int64(0)
|
|
for i, segment := range segments {
|
|
if segment.Ratio < 0 {
|
|
return fmt.Errorf("%s 第 %d 档 ratio 不能小于 0", name, i+1)
|
|
}
|
|
if segment.UpTo == 0 {
|
|
if i != len(segments)-1 {
|
|
return fmt.Errorf("%s 只有最后一档 up_to 可以为 0", name)
|
|
}
|
|
continue
|
|
}
|
|
if segment.UpTo <= previous {
|
|
return fmt.Errorf("%s 第 %d 档 up_to 必须递增", name, i+1)
|
|
}
|
|
previous = segment.UpTo
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidateRequestTierRule(rule RequestTierPricingRule) error {
|
|
mode := strings.TrimSpace(rule.Mode)
|
|
if mode == "" {
|
|
mode = RequestTierModeProgressive
|
|
}
|
|
if mode != RequestTierModeProgressive {
|
|
return errors.New("仅支持 progressive 阶梯计费模式")
|
|
}
|
|
if err := validateRequestTierSegments("input", rule.Input); err != nil {
|
|
return err
|
|
}
|
|
if err := validateRequestTierSegments("output", rule.Output); err != nil {
|
|
return err
|
|
}
|
|
if err := validateRequestTierSegments("cache_read", rule.CacheRead); err != nil {
|
|
return err
|
|
}
|
|
if err := validateRequestTierSegments("cache_write", rule.CacheWrite); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func normalizeRequestTierRuleMap(src map[string]RequestTierPricingRule) (map[string]RequestTierPricingRule, error) {
|
|
dst := make(map[string]RequestTierPricingRule, len(src))
|
|
for modelName, rule := range src {
|
|
name := FormatMatchingModelName(strings.TrimSpace(modelName))
|
|
if name == "" {
|
|
continue
|
|
}
|
|
rule = normalizeRequestTierRule(rule)
|
|
if err := ValidateRequestTierRule(rule); err != nil {
|
|
return nil, fmt.Errorf("%s: %w", name, err)
|
|
}
|
|
dst[name] = rule
|
|
}
|
|
return dst, nil
|
|
}
|
|
|
|
// ========== 新的四个独立阶梯倍率处理函数 ==========
|
|
|
|
func normalizeTierSegmentsMap(src map[string]TierSegments) (map[string]TierSegments, error) {
|
|
dst := make(map[string]TierSegments, len(src))
|
|
for modelName, tier := range src {
|
|
name := FormatMatchingModelName(strings.TrimSpace(modelName))
|
|
if name == "" {
|
|
continue
|
|
}
|
|
tier.Segments = normalizeRequestTierSegments(tier.Segments)
|
|
if err := validateRequestTierSegments("segments", tier.Segments); err != nil {
|
|
return nil, fmt.Errorf("%s: %w", name, err)
|
|
}
|
|
dst[name] = tier
|
|
}
|
|
return dst, nil
|
|
}
|
|
|
|
// ModelTierRatio
|
|
func ModelTierRatio2JSONString() string {
|
|
return modelTierRatioMap.MarshalJSONString()
|
|
}
|
|
|
|
func UpdateModelTierRatioByJSONString(jsonStr string) error {
|
|
trimmed := strings.TrimSpace(jsonStr)
|
|
if trimmed == "" {
|
|
modelTierRatioMap.Clear()
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
var parsed map[string]TierSegments
|
|
if err := common.UnmarshalJsonStr(trimmed, &parsed); err != nil {
|
|
return err
|
|
}
|
|
normalized, err := normalizeTierSegmentsMap(parsed)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
modelTierRatioMap.Clear()
|
|
modelTierRatioMap.AddAll(normalized)
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
|
|
func GetModelTierRatio(model string) (TierSegments, bool) {
|
|
return modelTierRatioMap.Get(FormatMatchingModelName(model))
|
|
}
|
|
|
|
func GetModelTierRatioCopy() map[string]TierSegments {
|
|
return modelTierRatioMap.ReadAll()
|
|
}
|
|
|
|
// ChannelModelTierRatio
|
|
func ChannelModelTierRatio2JSONString() string {
|
|
return channelModelTierRatioMap.MarshalJSONString()
|
|
}
|
|
|
|
func UpdateChannelModelTierRatioByJSONString(jsonStr string) error {
|
|
trimmed := strings.TrimSpace(jsonStr)
|
|
if trimmed == "" {
|
|
channelModelTierRatioMap.Clear()
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
var parsed map[string]map[string]TierSegments
|
|
if err := common.UnmarshalJsonStr(trimmed, &parsed); err != nil {
|
|
return err
|
|
}
|
|
normalized := make(map[string]map[string]TierSegments, len(parsed))
|
|
for channelID, rules := range parsed {
|
|
id, convErr := strconv.Atoi(strings.TrimSpace(channelID))
|
|
if convErr != nil {
|
|
continue
|
|
}
|
|
key := normalizeChannelID(id)
|
|
if key == "" {
|
|
continue
|
|
}
|
|
normalizedRules, err := normalizeTierSegmentsMap(rules)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
normalized[key] = normalizedRules
|
|
}
|
|
channelModelTierRatioMap.Clear()
|
|
channelModelTierRatioMap.AddAll(normalized)
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
|
|
func GetChannelModelTierRatio(channelID int, model string) (TierSegments, bool) {
|
|
key := normalizeChannelID(channelID)
|
|
if key == "" {
|
|
return TierSegments{}, false
|
|
}
|
|
channelRules, ok := channelModelTierRatioMap.Get(key)
|
|
if !ok {
|
|
return TierSegments{}, false
|
|
}
|
|
rule, ok := channelRules[FormatMatchingModelName(model)]
|
|
return rule, ok
|
|
}
|
|
|
|
func GetChannelModelTierRatioCopy() map[string]map[string]TierSegments {
|
|
return channelModelTierRatioMap.ReadAll()
|
|
}
|
|
|
|
// CompletionTierRatio
|
|
func CompletionTierRatio2JSONString() string {
|
|
return completionTierRatioMap.MarshalJSONString()
|
|
}
|
|
|
|
func UpdateCompletionTierRatioByJSONString(jsonStr string) error {
|
|
trimmed := strings.TrimSpace(jsonStr)
|
|
if trimmed == "" {
|
|
completionTierRatioMap.Clear()
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
var parsed map[string]TierSegments
|
|
if err := common.UnmarshalJsonStr(trimmed, &parsed); err != nil {
|
|
return err
|
|
}
|
|
normalized, err := normalizeTierSegmentsMap(parsed)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
completionTierRatioMap.Clear()
|
|
completionTierRatioMap.AddAll(normalized)
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
|
|
func GetCompletionTierRatio(model string) (TierSegments, bool) {
|
|
return completionTierRatioMap.Get(FormatMatchingModelName(model))
|
|
}
|
|
|
|
func GetCompletionTierRatioCopy() map[string]TierSegments {
|
|
return completionTierRatioMap.ReadAll()
|
|
}
|
|
|
|
// ChannelCompletionTierRatio
|
|
func ChannelCompletionTierRatio2JSONString() string {
|
|
return channelCompletionTierRatioMap.MarshalJSONString()
|
|
}
|
|
|
|
func UpdateChannelCompletionTierRatioByJSONString(jsonStr string) error {
|
|
trimmed := strings.TrimSpace(jsonStr)
|
|
if trimmed == "" {
|
|
channelCompletionTierRatioMap.Clear()
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
var parsed map[string]map[string]TierSegments
|
|
if err := common.UnmarshalJsonStr(trimmed, &parsed); err != nil {
|
|
return err
|
|
}
|
|
normalized := make(map[string]map[string]TierSegments, len(parsed))
|
|
for channelID, rules := range parsed {
|
|
id, convErr := strconv.Atoi(strings.TrimSpace(channelID))
|
|
if convErr != nil {
|
|
continue
|
|
}
|
|
key := normalizeChannelID(id)
|
|
if key == "" {
|
|
continue
|
|
}
|
|
normalizedRules, err := normalizeTierSegmentsMap(rules)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
normalized[key] = normalizedRules
|
|
}
|
|
channelCompletionTierRatioMap.Clear()
|
|
channelCompletionTierRatioMap.AddAll(normalized)
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
|
|
func GetChannelCompletionTierRatio(channelID int, model string) (TierSegments, bool) {
|
|
key := normalizeChannelID(channelID)
|
|
if key == "" {
|
|
return TierSegments{}, false
|
|
}
|
|
channelRules, ok := channelCompletionTierRatioMap.Get(key)
|
|
if !ok {
|
|
return TierSegments{}, false
|
|
}
|
|
rule, ok := channelRules[FormatMatchingModelName(model)]
|
|
return rule, ok
|
|
}
|
|
|
|
func GetChannelCompletionTierRatioCopy() map[string]map[string]TierSegments {
|
|
return channelCompletionTierRatioMap.ReadAll()
|
|
}
|
|
|
|
// CacheTierRatio
|
|
func CacheTierRatio2JSONString() string {
|
|
return cacheTierRatioMap.MarshalJSONString()
|
|
}
|
|
|
|
func UpdateCacheTierRatioByJSONString(jsonStr string) error {
|
|
trimmed := strings.TrimSpace(jsonStr)
|
|
if trimmed == "" {
|
|
cacheTierRatioMap.Clear()
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
var parsed map[string]TierSegments
|
|
if err := common.UnmarshalJsonStr(trimmed, &parsed); err != nil {
|
|
return err
|
|
}
|
|
normalized, err := normalizeTierSegmentsMap(parsed)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cacheTierRatioMap.Clear()
|
|
cacheTierRatioMap.AddAll(normalized)
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
|
|
func GetCacheTierRatio(model string) (TierSegments, bool) {
|
|
return cacheTierRatioMap.Get(FormatMatchingModelName(model))
|
|
}
|
|
|
|
func GetCacheTierRatioCopy() map[string]TierSegments {
|
|
return cacheTierRatioMap.ReadAll()
|
|
}
|
|
|
|
// ChannelCacheTierRatio
|
|
func ChannelCacheTierRatio2JSONString() string {
|
|
return channelCacheTierRatioMap.MarshalJSONString()
|
|
}
|
|
|
|
func UpdateChannelCacheTierRatioByJSONString(jsonStr string) error {
|
|
trimmed := strings.TrimSpace(jsonStr)
|
|
if trimmed == "" {
|
|
channelCacheTierRatioMap.Clear()
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
var parsed map[string]map[string]TierSegments
|
|
if err := common.UnmarshalJsonStr(trimmed, &parsed); err != nil {
|
|
return err
|
|
}
|
|
normalized := make(map[string]map[string]TierSegments, len(parsed))
|
|
for channelID, rules := range parsed {
|
|
id, convErr := strconv.Atoi(strings.TrimSpace(channelID))
|
|
if convErr != nil {
|
|
continue
|
|
}
|
|
key := normalizeChannelID(id)
|
|
if key == "" {
|
|
continue
|
|
}
|
|
normalizedRules, err := normalizeTierSegmentsMap(rules)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
normalized[key] = normalizedRules
|
|
}
|
|
channelCacheTierRatioMap.Clear()
|
|
channelCacheTierRatioMap.AddAll(normalized)
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
|
|
func GetChannelCacheTierRatio(channelID int, model string) (TierSegments, bool) {
|
|
key := normalizeChannelID(channelID)
|
|
if key == "" {
|
|
return TierSegments{}, false
|
|
}
|
|
channelRules, ok := channelCacheTierRatioMap.Get(key)
|
|
if !ok {
|
|
return TierSegments{}, false
|
|
}
|
|
rule, ok := channelRules[FormatMatchingModelName(model)]
|
|
return rule, ok
|
|
}
|
|
|
|
func GetChannelCacheTierRatioCopy() map[string]map[string]TierSegments {
|
|
return channelCacheTierRatioMap.ReadAll()
|
|
}
|
|
|
|
// CreateCacheTierRatio
|
|
func CreateCacheTierRatio2JSONString() string {
|
|
return createCacheTierRatioMap.MarshalJSONString()
|
|
}
|
|
|
|
func UpdateCreateCacheTierRatioByJSONString(jsonStr string) error {
|
|
trimmed := strings.TrimSpace(jsonStr)
|
|
if trimmed == "" {
|
|
createCacheTierRatioMap.Clear()
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
var parsed map[string]TierSegments
|
|
if err := common.UnmarshalJsonStr(trimmed, &parsed); err != nil {
|
|
return err
|
|
}
|
|
normalized, err := normalizeTierSegmentsMap(parsed)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
createCacheTierRatioMap.Clear()
|
|
createCacheTierRatioMap.AddAll(normalized)
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
|
|
func GetCreateCacheTierRatio(model string) (TierSegments, bool) {
|
|
return createCacheTierRatioMap.Get(FormatMatchingModelName(model))
|
|
}
|
|
|
|
func GetCreateCacheTierRatioCopy() map[string]TierSegments {
|
|
return createCacheTierRatioMap.ReadAll()
|
|
}
|
|
|
|
// ChannelCreateCacheTierRatio
|
|
func ChannelCreateCacheTierRatio2JSONString() string {
|
|
return channelCreateCacheTierRatioMap.MarshalJSONString()
|
|
}
|
|
|
|
func UpdateChannelCreateCacheTierRatioByJSONString(jsonStr string) error {
|
|
trimmed := strings.TrimSpace(jsonStr)
|
|
if trimmed == "" {
|
|
channelCreateCacheTierRatioMap.Clear()
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
var parsed map[string]map[string]TierSegments
|
|
if err := common.UnmarshalJsonStr(trimmed, &parsed); err != nil {
|
|
return err
|
|
}
|
|
normalized := make(map[string]map[string]TierSegments, len(parsed))
|
|
for channelID, rules := range parsed {
|
|
id, convErr := strconv.Atoi(strings.TrimSpace(channelID))
|
|
if convErr != nil {
|
|
continue
|
|
}
|
|
key := normalizeChannelID(id)
|
|
if key == "" {
|
|
continue
|
|
}
|
|
normalizedRules, err := normalizeTierSegmentsMap(rules)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
normalized[key] = normalizedRules
|
|
}
|
|
channelCreateCacheTierRatioMap.Clear()
|
|
channelCreateCacheTierRatioMap.AddAll(normalized)
|
|
InvalidateExposedDataCache()
|
|
return nil
|
|
}
|
|
|
|
func GetChannelCreateCacheTierRatio(channelID int, model string) (TierSegments, bool) {
|
|
key := normalizeChannelID(channelID)
|
|
if key == "" {
|
|
return TierSegments{}, false
|
|
}
|
|
channelRules, ok := channelCreateCacheTierRatioMap.Get(key)
|
|
if !ok {
|
|
return TierSegments{}, false
|
|
}
|
|
rule, ok := channelRules[FormatMatchingModelName(model)]
|
|
return rule, ok
|
|
}
|
|
|
|
func GetChannelCreateCacheTierRatioCopy() map[string]map[string]TierSegments {
|
|
return channelCreateCacheTierRatioMap.ReadAll()
|
|
}
|
|
|
|
func RequestTierPricingTemplates2JSONString() string {
|
|
return requestTierPricingTemplatesMap.MarshalJSONString()
|
|
}
|
|
|
|
func requestTierTemplateID(template RequestTierPricingTemplate, index int, existing map[string]RequestTierPricingTemplate) string {
|
|
base := strings.TrimSpace(template.Name)
|
|
if base == "" {
|
|
base = "template"
|
|
}
|
|
base = strings.Map(func(r rune) rune {
|
|
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' {
|
|
return r
|
|
}
|
|
return '_'
|
|
}, strings.ToLower(base))
|
|
base = strings.Trim(base, "_")
|
|
if base == "" {
|
|
base = "template"
|
|
}
|
|
prefix := fmt.Sprintf("tpl_%s_%d", base, time.Now().UnixNano())
|
|
id := fmt.Sprintf("%s_%d", prefix, index+1)
|
|
for suffix := 2; ; suffix++ {
|
|
if _, ok := existing[id]; !ok {
|
|
return id
|
|
}
|
|
id = fmt.Sprintf("%s_%d_%d", prefix, index+1, suffix)
|
|
}
|
|
}
|
|
|
|
func UpdateRequestTierPricingTemplatesByJSONString(jsonStr string) error {
|
|
trimmed := strings.TrimSpace(jsonStr)
|
|
if trimmed == "" {
|
|
requestTierPricingTemplatesMap.Clear()
|
|
return nil
|
|
}
|
|
var parsed map[string]RequestTierPricingTemplate
|
|
if err := common.UnmarshalJsonStr(trimmed, &parsed); err != nil {
|
|
var list []RequestTierPricingTemplate
|
|
if listErr := common.UnmarshalJsonStr(trimmed, &list); listErr != nil {
|
|
return err
|
|
}
|
|
parsed = make(map[string]RequestTierPricingTemplate, len(list))
|
|
for index, template := range list {
|
|
parsed[requestTierTemplateID(template, index, parsed)] = template
|
|
}
|
|
}
|
|
normalized := make(map[string]RequestTierPricingTemplate, len(parsed))
|
|
index := 0
|
|
for key, template := range parsed {
|
|
name := strings.TrimSpace(key)
|
|
if name == "" {
|
|
name = requestTierTemplateID(template, index, normalized)
|
|
}
|
|
template.RequestTierPricingRule = normalizeRequestTierRule(template.RequestTierPricingRule)
|
|
if err := ValidateRequestTierRule(template.RequestTierPricingRule); err != nil {
|
|
return fmt.Errorf("%s: %w", name, err)
|
|
}
|
|
normalized[name] = template
|
|
index++
|
|
}
|
|
requestTierPricingTemplatesMap.Clear()
|
|
requestTierPricingTemplatesMap.AddAll(normalized)
|
|
return nil
|
|
}
|
|
|
|
func applyTierSegments(tokens decimal.Decimal, segments []RequestTierSegment) (decimal.Decimal, []RequestTierPricingBreakdownItem) {
|
|
if tokens.LessThanOrEqual(decimal.Zero) || len(segments) == 0 {
|
|
return tokens, nil
|
|
}
|
|
remaining := tokens
|
|
previous := decimal.Zero
|
|
result := decimal.Zero
|
|
items := make([]RequestTierPricingBreakdownItem, 0, len(segments))
|
|
for _, segment := range segments {
|
|
if remaining.LessThanOrEqual(decimal.Zero) {
|
|
break
|
|
}
|
|
var size decimal.Decimal
|
|
var to int64
|
|
if segment.UpTo == 0 {
|
|
size = remaining
|
|
to = 0
|
|
} else {
|
|
upper := decimal.NewFromInt(segment.UpTo)
|
|
if upper.LessThanOrEqual(previous) {
|
|
continue
|
|
}
|
|
capacity := upper.Sub(previous)
|
|
if remaining.GreaterThan(capacity) {
|
|
size = capacity
|
|
} else {
|
|
size = remaining
|
|
}
|
|
to = segment.UpTo
|
|
}
|
|
segResult := size.Mul(decimal.NewFromFloat(segment.Ratio))
|
|
result = result.Add(segResult)
|
|
items = append(items, RequestTierPricingBreakdownItem{
|
|
From: previous.IntPart(),
|
|
To: to,
|
|
Tokens: size.String(),
|
|
Ratio: segment.Ratio,
|
|
Result: segResult.String(),
|
|
})
|
|
remaining = remaining.Sub(size)
|
|
if segment.UpTo == 0 {
|
|
previous = previous.Add(size)
|
|
} else {
|
|
previous = decimal.NewFromInt(segment.UpTo)
|
|
}
|
|
}
|
|
if remaining.GreaterThan(decimal.Zero) {
|
|
result = result.Add(remaining)
|
|
}
|
|
return result, items
|
|
}
|
|
|
|
// ApplyTierSegmentsForType 应用阶梯倍率到单个类型
|
|
func ApplyTierSegmentsForType(tokens decimal.Decimal, tier TierSegments) decimal.Decimal {
|
|
if tokens.LessThanOrEqual(decimal.Zero) || len(tier.Segments) == 0 {
|
|
return tokens
|
|
}
|
|
remaining := tokens
|
|
previous := decimal.Zero
|
|
result := decimal.Zero
|
|
for _, segment := range tier.Segments {
|
|
if remaining.LessThanOrEqual(decimal.Zero) {
|
|
break
|
|
}
|
|
var size decimal.Decimal
|
|
if segment.UpTo == 0 {
|
|
size = remaining
|
|
} else {
|
|
upper := decimal.NewFromInt(segment.UpTo)
|
|
if upper.LessThanOrEqual(previous) {
|
|
continue
|
|
}
|
|
capacity := upper.Sub(previous)
|
|
if remaining.GreaterThan(capacity) {
|
|
size = capacity
|
|
} else {
|
|
size = remaining
|
|
}
|
|
}
|
|
segResult := size.Mul(decimal.NewFromFloat(segment.Ratio))
|
|
result = result.Add(segResult)
|
|
remaining = remaining.Sub(size)
|
|
if segment.UpTo == 0 {
|
|
previous = previous.Add(size)
|
|
} else {
|
|
previous = decimal.NewFromInt(segment.UpTo)
|
|
}
|
|
}
|
|
if remaining.GreaterThan(decimal.Zero) {
|
|
result = result.Add(remaining)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// ResolveModelTierRatio 解析模型阶梯倍率(优先使用渠道配置)
|
|
func ResolveModelTierRatio(channelID int, model string) (TierSegments, bool) {
|
|
if modelTierRatioMap == nil || channelModelTierRatioMap == nil {
|
|
return TierSegments{}, false
|
|
}
|
|
if modelTierRatioMap.Len() == 0 && channelModelTierRatioMap.Len() == 0 {
|
|
return TierSegments{}, false
|
|
}
|
|
if rule, ok := GetChannelModelTierRatio(channelID, model); ok {
|
|
return rule, true
|
|
}
|
|
return GetModelTierRatio(model)
|
|
}
|
|
|
|
// ResolveCompletionTierRatio 解析完成阶梯倍率(优先使用渠道配置)
|
|
func ResolveCompletionTierRatio(channelID int, model string) (TierSegments, bool) {
|
|
if completionTierRatioMap == nil || channelCompletionTierRatioMap == nil {
|
|
return TierSegments{}, false
|
|
}
|
|
if completionTierRatioMap.Len() == 0 && channelCompletionTierRatioMap.Len() == 0 {
|
|
return TierSegments{}, false
|
|
}
|
|
if rule, ok := GetChannelCompletionTierRatio(channelID, model); ok {
|
|
return rule, true
|
|
}
|
|
return GetCompletionTierRatio(model)
|
|
}
|
|
|
|
// ResolveCacheTierRatio 解析缓存读取阶梯倍率(优先使用渠道配置)
|
|
func ResolveCacheTierRatio(channelID int, model string) (TierSegments, bool) {
|
|
if cacheTierRatioMap == nil || channelCacheTierRatioMap == nil {
|
|
return TierSegments{}, false
|
|
}
|
|
if cacheTierRatioMap.Len() == 0 && channelCacheTierRatioMap.Len() == 0 {
|
|
return TierSegments{}, false
|
|
}
|
|
if rule, ok := GetChannelCacheTierRatio(channelID, model); ok {
|
|
return rule, true
|
|
}
|
|
return GetCacheTierRatio(model)
|
|
}
|
|
|
|
// ResolveCreateCacheTierRatio 解析缓存写入阶梯倍率(优先使用渠道配置)
|
|
func ResolveCreateCacheTierRatio(channelID int, model string) (TierSegments, bool) {
|
|
if createCacheTierRatioMap == nil || channelCreateCacheTierRatioMap == nil {
|
|
return TierSegments{}, false
|
|
}
|
|
if createCacheTierRatioMap.Len() == 0 && channelCreateCacheTierRatioMap.Len() == 0 {
|
|
return TierSegments{}, false
|
|
}
|
|
if rule, ok := GetChannelCreateCacheTierRatio(channelID, model); ok {
|
|
return rule, true
|
|
}
|
|
return GetCreateCacheTierRatio(model)
|
|
}
|
|
|
|
func ApplyRequestTierPricingDecimal(rule RequestTierPricingRule, input, output, cacheRead, cacheWrite decimal.Decimal) (decimal.Decimal, decimal.Decimal, decimal.Decimal, decimal.Decimal, RequestTierPricingBreakdown) {
|
|
rule = normalizeRequestTierRule(rule)
|
|
inAfter, inItems := applyTierSegments(input, rule.Input)
|
|
outAfter, outItems := applyTierSegments(output, rule.Output)
|
|
cacheReadAfter, cacheReadItems := applyTierSegments(cacheRead, rule.CacheRead)
|
|
cacheWriteAfter, cacheWriteItems := applyTierSegments(cacheWrite, rule.CacheWrite)
|
|
breakdown := RequestTierPricingBreakdown{
|
|
InputBefore: input.String(),
|
|
InputAfter: inAfter.String(),
|
|
OutputBefore: output.String(),
|
|
OutputAfter: outAfter.String(),
|
|
CacheReadBefore: cacheRead.String(),
|
|
CacheReadAfter: cacheReadAfter.String(),
|
|
CacheWriteBefore: cacheWrite.String(),
|
|
CacheWriteAfter: cacheWriteAfter.String(),
|
|
Details: map[string][]RequestTierPricingBreakdownItem{},
|
|
}
|
|
if len(inItems) > 0 {
|
|
breakdown.Details["input"] = inItems
|
|
}
|
|
if len(outItems) > 0 {
|
|
breakdown.Details["output"] = outItems
|
|
}
|
|
if len(cacheReadItems) > 0 {
|
|
breakdown.Details["cache_read"] = cacheReadItems
|
|
}
|
|
if len(cacheWriteItems) > 0 {
|
|
breakdown.Details["cache_write"] = cacheWriteItems
|
|
}
|
|
if len(breakdown.Details) == 0 {
|
|
breakdown.Details = nil
|
|
}
|
|
return inAfter, outAfter, cacheReadAfter, cacheWriteAfter, breakdown
|
|
}
|