tokenFactory/service/http_client.go

213 lines
6.2 KiB
Go

package service
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/system_setting"
"golang.org/x/net/proxy"
)
var (
httpClient *http.Client
ossHTTPClient *http.Client
ossHTTPOnce sync.Once
proxyClientLock sync.Mutex
proxyClients = make(map[string]*http.Client)
)
func checkRedirect(req *http.Request, via []*http.Request) error {
fetchSetting := system_setting.GetFetchSetting()
urlStr := req.URL.String()
if err := common.ValidateURLWithFetchSetting(urlStr, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
return fmt.Errorf("redirect to %s blocked: %v", urlStr, err)
}
if len(via) >= 10 {
return fmt.Errorf("stopped after 10 redirects")
}
// Go 的 http.Client 在跨域重定向时会自动剥离 Authorization 等敏感头,
// 导致上游 TokenFactory 平台收到无认证的请求而返回 401。
// 此处恢复原始请求的 Authorization 头,确保重定向后仍能正常认证。
if len(via) > 0 && req.Header.Get("Authorization") == "" {
origAuth := via[0].Header.Get("Authorization")
if origAuth != "" {
req.Header.Set("Authorization", origAuth)
}
}
return nil
}
func InitHttpClient() {
transport := &http.Transport{
MaxIdleConns: common.RelayMaxIdleConns,
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
ForceAttemptHTTP2: true,
Proxy: http.ProxyFromEnvironment, // Support HTTP_PROXY, HTTPS_PROXY, NO_PROXY env vars
}
if common.TLSInsecureSkipVerify {
transport.TLSClientConfig = common.InsecureTLSConfig
}
if common.RelayTimeout == 0 {
httpClient = &http.Client{
Transport: transport,
CheckRedirect: checkRedirect,
}
} else {
httpClient = &http.Client{
Transport: transport,
Timeout: time.Duration(common.RelayTimeout) * time.Second,
CheckRedirect: checkRedirect,
}
}
}
func GetHttpClient() *http.Client {
return httpClient
}
// InitOssHttpClient 初始化 OSS 上传专用 HTTP 客户端(与 Relay 连接池隔离,降低复用到已被对端关闭的连接的概率)。幂等。
func InitOssHttpClient() {
ossHTTPOnce.Do(func() {
transport := &http.Transport{
MaxIdleConns: 16,
MaxIdleConnsPerHost: 8,
IdleConnTimeout: 45 * time.Second,
ForceAttemptHTTP2: false,
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 15 * time.Second,
}
if common.TLSInsecureSkipVerify {
transport.TLSClientConfig = common.InsecureTLSConfig
}
timeout := 120 * time.Second
if common.RelayTimeout > 0 {
timeout = time.Duration(common.RelayTimeout) * time.Second
}
ossHTTPClient = &http.Client{
Transport: transport,
CheckRedirect: checkRedirect,
Timeout: timeout,
}
})
}
// GetOssHttpClient 返回 OSS 专用客户端;若尚未初始化则懒执行 InitOssHttpClient。
func GetOssHttpClient() *http.Client {
InitOssHttpClient()
return ossHTTPClient
}
// GetHttpClientWithProxy returns the default client or a proxy-enabled one when proxyURL is provided.
func GetHttpClientWithProxy(proxyURL string) (*http.Client, error) {
if proxyURL == "" {
return GetHttpClient(), nil
}
return NewProxyHttpClient(proxyURL)
}
// ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化
func ResetProxyClientCache() {
proxyClientLock.Lock()
defer proxyClientLock.Unlock()
for _, client := range proxyClients {
if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
transport.CloseIdleConnections()
}
}
proxyClients = make(map[string]*http.Client)
}
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
if proxyURL == "" {
if client := GetHttpClient(); client != nil {
return client, nil
}
return http.DefaultClient, nil
}
proxyClientLock.Lock()
if client, ok := proxyClients[proxyURL]; ok {
proxyClientLock.Unlock()
return client, nil
}
proxyClientLock.Unlock()
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, err
}
switch parsedURL.Scheme {
case "http", "https":
transport := &http.Transport{
MaxIdleConns: common.RelayMaxIdleConns,
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
ForceAttemptHTTP2: true,
Proxy: http.ProxyURL(parsedURL),
}
if common.TLSInsecureSkipVerify {
transport.TLSClientConfig = common.InsecureTLSConfig
}
client := &http.Client{
Transport: transport,
CheckRedirect: checkRedirect,
}
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
proxyClientLock.Lock()
proxyClients[proxyURL] = client
proxyClientLock.Unlock()
return client, nil
case "socks5", "socks5h":
// 获取认证信息
var auth *proxy.Auth
if parsedURL.User != nil {
auth = &proxy.Auth{
User: parsedURL.User.Username(),
Password: "",
}
if password, ok := parsedURL.User.Password(); ok {
auth.Password = password
}
}
// 创建 SOCKS5 代理拨号器
// proxy.SOCKS5 使用 tcp 参数,所有 TCP 连接包括 DNS 查询都将通过代理进行。行为与 socks5h 相同
dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct)
if err != nil {
return nil, err
}
transport := &http.Transport{
MaxIdleConns: common.RelayMaxIdleConns,
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
if common.TLSInsecureSkipVerify {
transport.TLSClientConfig = common.InsecureTLSConfig
}
client := &http.Client{Transport: transport, CheckRedirect: checkRedirect}
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
proxyClientLock.Lock()
proxyClients[proxyURL] = client
proxyClientLock.Unlock()
return client, nil
default:
return nil, fmt.Errorf("unsupported proxy scheme: %s, must be http, https, socks5 or socks5h", parsedURL.Scheme)
}
}