ccm,ocm: add reverse proxy support for external credentials

Allow two CCM/OCM instances to share credentials when only one has a
public IP, using yamux-multiplexed reverse connections.

Three credential modes:
- Normal: URL set, reverse=false — standard HTTP proxy
- Receiver: URL empty — waits for incoming reverse connection
- Connector: URL set, reverse=true — dials out to establish connection

Extend InterfaceUpdated to services so network changes trigger
reverse connection reconnection.
This commit is contained in:
世界 2026-03-13 18:51:02 +08:00
parent 283a5aacee
commit 16aeba8ec0
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
11 changed files with 873 additions and 134 deletions

View File

@ -95,9 +95,10 @@ type CCMBalancerCredentialOptions struct {
}
type CCMExternalCredentialOptions struct {
URL string `json:"url"`
URL string `json:"url,omitempty"`
ServerOptions
Token string `json:"token"`
Reverse bool `json:"reverse,omitempty"`
Detour string `json:"detour,omitempty"`
UsagesPath string `json:"usages_path,omitempty"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`

View File

@ -95,9 +95,10 @@ type OCMBalancerCredentialOptions struct {
}
type OCMExternalCredentialOptions struct {
URL string `json:"url"`
URL string `json:"url,omitempty"`
ServerOptions
Token string `json:"token"`
Reverse bool `json:"reverse,omitempty"`
Detour string `json:"detour,omitempty"`
UsagesPath string `json:"usages_path,omitempty"`
PollInterval badoption.Duration `json:"poll_interval,omitempty"`

View File

@ -51,6 +51,7 @@ type NetworkManager struct {
endpoint adapter.EndpointManager
inbound adapter.InboundManager
outbound adapter.OutboundManager
serviceManager adapter.ServiceManager
needWIFIState bool
wifiMonitor settings.WIFIMonitor
wifiState adapter.WIFIState
@ -94,6 +95,7 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, options
endpoint: service.FromContext[adapter.EndpointManager](ctx),
inbound: service.FromContext[adapter.InboundManager](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
serviceManager: service.FromContext[adapter.ServiceManager](ctx),
needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule),
}
if options.DefaultNetworkStrategy != nil {
@ -475,6 +477,15 @@ func (r *NetworkManager) ResetNetwork() {
listener.InterfaceUpdated()
}
}
if r.serviceManager != nil {
for _, svc := range r.serviceManager.Services() {
listener, isListener := svc.(adapter.InterfaceUpdateListener)
if isListener {
listener.InterfaceUpdated()
}
}
}
}
func (r *NetworkManager) notifyInterfaceUpdate(defaultInterface *control.Interface, flags int) {

View File

@ -19,9 +19,14 @@ import (
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/hashicorp/yamux"
)
const reverseProxyBaseURL = "http://reverse-proxy"
type externalCredential struct {
tag string
baseURL string
@ -39,86 +44,134 @@ type externalCredential struct {
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
// Reverse proxy fields
reverse bool
reverseSession *yamux.Session
reverseAccess sync.RWMutex
reverseContext context.Context
reverseCancel context.CancelFunc
connectorDialer N.Dialer
connectorURL *url.URL
connectorTLS *stdTLS.Config
reverseService http.Handler
}
func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
parsedURL, err := url.Parse(options.URL)
if err != nil {
return nil, E.Cause(err, "parse url for credential ", tag)
}
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: option.DialerOptions{
Detour: options.Detour,
},
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "create dialer for credential ", tag)
}
transport := &http.Transport{
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.Server != "" {
serverPort := options.ServerPort
if serverPort == 0 {
portStr := parsedURL.Port()
if portStr != "" {
port, parseErr := strconv.ParseUint(portStr, 10, 16)
if parseErr == nil {
serverPort = uint16(port)
}
}
if serverPort == 0 {
if parsedURL.Scheme == "https" {
serverPort = 443
} else {
serverPort = 80
}
}
}
destination := M.ParseSocksaddrHostPort(options.Server, serverPort)
return credentialDialer.DialContext(ctx, network, destination)
}
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
}
if parsedURL.Scheme == "https" {
transport.TLSClientConfig = &stdTLS.Config{
ServerName: parsedURL.Hostname(),
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
}
}
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
if parsedURL.Path != "" && parsedURL.Path != "/" {
baseURL += parsedURL.Path
}
// Strip trailing slash
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
baseURL = baseURL[:len(baseURL)-1]
}
pollInterval := time.Duration(options.PollInterval)
if pollInterval <= 0 {
pollInterval = 30 * time.Minute
}
requestContext, cancelRequests := context.WithCancel(context.Background())
reverseContext, reverseCancel := context.WithCancel(context.Background())
cred := &externalCredential{
tag: tag,
baseURL: baseURL,
token: options.Token,
httpClient: &http.Client{Transport: transport},
pollInterval: pollInterval,
logger: logger,
requestContext: requestContext,
cancelRequests: cancelRequests,
reverse: options.Reverse,
reverseContext: reverseContext,
reverseCancel: reverseCancel,
}
if options.URL == "" {
// Receiver mode: no URL, wait for reverse connection
cred.baseURL = reverseProxyBaseURL
cred.httpClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: false,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
session := cred.getReverseSession()
if session == nil || session.IsClosed() {
return nil, E.New("reverse connection not established for ", cred.tag)
}
return session.Open()
},
},
}
} else {
// Normal or connector mode: has URL
parsedURL, err := url.Parse(options.URL)
if err != nil {
return nil, E.Cause(err, "parse url for credential ", tag)
}
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: option.DialerOptions{
Detour: options.Detour,
},
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "create dialer for credential ", tag)
}
transport := &http.Transport{
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.Server != "" {
serverPort := options.ServerPort
if serverPort == 0 {
portStr := parsedURL.Port()
if portStr != "" {
port, parseErr := strconv.ParseUint(portStr, 10, 16)
if parseErr == nil {
serverPort = uint16(port)
}
}
if serverPort == 0 {
if parsedURL.Scheme == "https" {
serverPort = 443
} else {
serverPort = 80
}
}
}
destination := M.ParseSocksaddrHostPort(options.Server, serverPort)
return credentialDialer.DialContext(ctx, network, destination)
}
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
}
if parsedURL.Scheme == "https" {
transport.TLSClientConfig = &stdTLS.Config{
ServerName: parsedURL.Hostname(),
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
}
}
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
if parsedURL.Path != "" && parsedURL.Path != "/" {
baseURL += parsedURL.Path
}
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
baseURL = baseURL[:len(baseURL)-1]
}
cred.baseURL = baseURL
if options.Reverse {
// Connector mode: we dial out to serve, not to proxy
cred.connectorDialer = credentialDialer
cred.connectorURL = parsedURL
if parsedURL.Scheme == "https" {
cred.connectorTLS = &stdTLS.Config{
ServerName: parsedURL.Hostname(),
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
}
}
} else {
// Normal mode: standard HTTP client for proxying
cred.httpClient = &http.Client{Transport: transport}
}
}
if options.UsagesPath != "" {
@ -140,6 +193,9 @@ func (c *externalCredential) start() error {
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
}
}
if c.reverse && c.connectorURL != nil {
go c.connectorLoop()
}
return nil
}
@ -152,6 +208,14 @@ func (c *externalCredential) isExternal() bool {
}
func (c *externalCredential) isAvailable() bool {
if c.reverse && c.connectorURL != nil {
return false // connector mode: not for local proxying
}
if c.baseURL == reverseProxyBaseURL {
// receiver mode: only available when reverse connection active
session := c.getReverseSession()
return session != nil && !session.IsClosed()
}
return true
}
@ -426,6 +490,16 @@ func (c *externalCredential) httpTransport() *http.Client {
}
func (c *externalCredential) close() {
if c.reverseCancel != nil {
c.reverseCancel()
}
c.reverseAccess.Lock()
session := c.reverseSession
c.reverseSession = nil
c.reverseAccess.Unlock()
if session != nil {
session.Close()
}
if c.usageTracker != nil {
c.usageTracker.cancelPendingSave()
err := c.usageTracker.Save()
@ -434,3 +508,27 @@ func (c *externalCredential) close() {
}
}
}
func (c *externalCredential) getReverseSession() *yamux.Session {
c.reverseAccess.RLock()
defer c.reverseAccess.RUnlock()
return c.reverseSession
}
func (c *externalCredential) setReverseSession(session *yamux.Session) {
c.reverseAccess.Lock()
old := c.reverseSession
c.reverseSession = session
c.reverseAccess.Unlock()
if old != nil {
old.Close()
}
}
func (c *externalCredential) clearReverseSession(session *yamux.Session) {
c.reverseAccess.Lock()
if c.reverseSession == session {
c.reverseSession = nil
}
c.reverseAccess.Unlock()
}

View File

@ -1117,12 +1117,12 @@ func validateCCMOptions(options option.CCMServiceOptions) error {
}
}
if cred.Type == "external" {
if cred.ExternalOptions.URL == "" {
return E.New("credential ", cred.Tag, ": external credential requires url")
}
if cred.ExternalOptions.Token == "" {
return E.New("credential ", cred.Tag, ": external credential requires token")
}
if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" {
return E.New("credential ", cred.Tag, ": reverse external credential requires url")
}
}
if cred.Type == "balancer" {
switch cred.BalancerOptions.Strategy {

243
service/ccm/reverse.go Normal file
View File

@ -0,0 +1,243 @@
package ccm
import (
"bufio"
stdTLS "crypto/tls"
"errors"
"io"
"math/rand/v2"
"net"
"net/http"
"strings"
"time"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/hashicorp/yamux"
)
func reverseYamuxConfig() *yamux.Config {
config := yamux.DefaultConfig()
config.KeepAliveInterval = 15 * time.Second
config.ConnectionWriteTimeout = 10 * time.Second
config.MaxStreamWindowSize = 512 * 1024
config.LogOutput = io.Discard
return config
}
type yamuxNetListener struct {
session *yamux.Session
}
func (l *yamuxNetListener) Accept() (net.Conn, error) {
return l.session.Accept()
}
func (l *yamuxNetListener) Close() error {
return l.session.Close()
}
func (l *yamuxNetListener) Addr() net.Addr {
return l.session.Addr()
}
func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") != "reverse-proxy" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
receiverCredential := s.findReceiverCredential(clientToken)
if receiverCredential == nil {
s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
s.logger.Error("reverse connect: hijack not supported")
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
return
}
conn, bufferedReadWriter, err := hijacker.Hijack()
if err != nil {
s.logger.Error("reverse connect: hijack: ", err)
return
}
response := "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: reverse-proxy\r\n\r\n"
_, err = bufferedReadWriter.WriteString(response)
if err != nil {
conn.Close()
s.logger.Error("reverse connect: write upgrade response: ", err)
return
}
err = bufferedReadWriter.Flush()
if err != nil {
conn.Close()
s.logger.Error("reverse connect: flush upgrade response: ", err)
return
}
session, err := yamux.Client(conn, reverseYamuxConfig())
if err != nil {
conn.Close()
s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
return
}
receiverCredential.setReverseSession(session)
s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
go func() {
<-session.CloseChan()
receiverCredential.clearReverseSession(session)
s.logger.Warn("reverse connection lost for ", receiverCredential.tagName())
}()
}
func (s *Service) findReceiverCredential(token string) *externalCredential {
for _, cred := range s.allCredentials {
extCred, ok := cred.(*externalCredential)
if !ok {
continue
}
if extCred.baseURL == reverseProxyBaseURL && extCred.token == token {
return extCred
}
}
return nil
}
func (c *externalCredential) connectorLoop() {
var consecutiveFailures int
for {
select {
case <-c.reverseContext.Done():
return
default:
}
err := c.connectorConnect()
if c.reverseContext.Err() != nil {
return
}
consecutiveFailures++
backoff := connectorBackoff(consecutiveFailures)
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
select {
case <-time.After(backoff):
case <-c.reverseContext.Done():
return
}
}
}
func connectorBackoff(failures int) time.Duration {
if failures > 5 {
failures = 5
}
base := time.Second * time.Duration(1<<failures)
if base > 30*time.Second {
base = 30 * time.Second
}
jitter := time.Duration(rand.Int64N(int64(base) / 2))
return base + jitter
}
func (c *externalCredential) connectorConnect() error {
destination := c.connectorResolveDestination()
conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination)
if err != nil {
return E.Cause(err, "dial")
}
if c.connectorTLS != nil {
tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone())
err = tlsConn.HandshakeContext(c.reverseContext)
if err != nil {
conn.Close()
return E.Cause(err, "tls handshake")
}
conn = tlsConn
}
upgradeRequest := "GET /ccm/v1/reverse HTTP/1.1\r\n" +
"Host: " + c.connectorURL.Host + "\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: reverse-proxy\r\n" +
"Authorization: Bearer " + c.token + "\r\n" +
"\r\n"
_, err = io.WriteString(conn, upgradeRequest)
if err != nil {
conn.Close()
return E.Cause(err, "write upgrade request")
}
reader := bufio.NewReader(conn)
statusLine, err := reader.ReadString('\n')
if err != nil {
conn.Close()
return E.Cause(err, "read upgrade response")
}
if !strings.HasPrefix(statusLine, "HTTP/1.1 101") {
conn.Close()
return E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine))
}
for {
line, readErr := reader.ReadString('\n')
if readErr != nil {
conn.Close()
return E.Cause(readErr, "read upgrade headers")
}
if strings.TrimSpace(line) == "" {
break
}
}
session, err := yamux.Server(conn, reverseYamuxConfig())
if err != nil {
conn.Close()
return E.Cause(err, "create yamux server")
}
defer session.Close()
c.logger.Info("reverse connection established for ", c.tag)
httpServer := &http.Server{
Handler: c.reverseService,
ReadTimeout: 0,
IdleTimeout: 120 * time.Second,
}
err = httpServer.Serve(&yamuxNetListener{session: session})
if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil {
return E.Cause(err, "serve")
}
return E.New("connection closed")
}
func (c *externalCredential) connectorResolveDestination() M.Socksaddr {
port := c.connectorURL.Port()
if port == "" {
if c.connectorURL.Scheme == "https" {
port = "443"
} else {
port = "80"
}
}
return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port))
}

View File

@ -258,6 +258,9 @@ func (s *Service) Start(stage adapter.StartStage) error {
if err != nil {
return err
}
if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
}
}
router := chi.NewRouter()
@ -318,6 +321,11 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if r.URL.Path == "/ccm/v1/reverse" {
s.handleReverseConnect(w, r)
return
}
if !strings.HasPrefix(r.URL.Path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
return
@ -786,6 +794,20 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use
headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64))
}
func (s *Service) InterfaceUpdated() {
for _, cred := range s.allCredentials {
extCred, ok := cred.(*externalCredential)
if !ok {
continue
}
if extCred.reverse && extCred.connectorURL != nil {
extCred.reverseCancel()
extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background())
go extCred.connectorLoop()
}
}
}
func (s *Service) Close() error {
err := common.Close(
common.PtrOrNil(s.httpServer),

View File

@ -21,8 +21,12 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/hashicorp/yamux"
)
const reverseProxyBaseURL = "http://reverse-proxy"
type externalCredential struct {
tag string
baseURL string
@ -41,86 +45,135 @@ type externalCredential struct {
requestContext context.Context
cancelRequests context.CancelFunc
requestAccess sync.Mutex
// Reverse proxy fields
reverse bool
reverseSession *yamux.Session
reverseAccess sync.RWMutex
reverseContext context.Context
reverseCancel context.CancelFunc
connectorDialer N.Dialer
connectorURL *url.URL
connectorTLS *stdTLS.Config
reverseService http.Handler
}
func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) {
parsedURL, err := url.Parse(options.URL)
if err != nil {
return nil, E.Cause(err, "parse url for credential ", tag)
}
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: option.DialerOptions{
Detour: options.Detour,
},
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "create dialer for credential ", tag)
}
transport := &http.Transport{
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.Server != "" {
serverPort := options.ServerPort
if serverPort == 0 {
portStr := parsedURL.Port()
if portStr != "" {
port, parseErr := strconv.ParseUint(portStr, 10, 16)
if parseErr == nil {
serverPort = uint16(port)
}
}
if serverPort == 0 {
if parsedURL.Scheme == "https" {
serverPort = 443
} else {
serverPort = 80
}
}
}
destination := M.ParseSocksaddrHostPort(options.Server, serverPort)
return credentialDialer.DialContext(ctx, network, destination)
}
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
}
if parsedURL.Scheme == "https" {
transport.TLSClientConfig = &stdTLS.Config{
ServerName: parsedURL.Hostname(),
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
}
}
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
if parsedURL.Path != "" && parsedURL.Path != "/" {
baseURL += parsedURL.Path
}
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
baseURL = baseURL[:len(baseURL)-1]
}
pollInterval := time.Duration(options.PollInterval)
if pollInterval <= 0 {
pollInterval = 30 * time.Minute
}
requestContext, cancelRequests := context.WithCancel(context.Background())
reverseContext, reverseCancel := context.WithCancel(context.Background())
cred := &externalCredential{
tag: tag,
baseURL: baseURL,
token: options.Token,
credDialer: credentialDialer,
httpClient: &http.Client{Transport: transport},
pollInterval: pollInterval,
logger: logger,
requestContext: requestContext,
cancelRequests: cancelRequests,
reverse: options.Reverse,
reverseContext: reverseContext,
reverseCancel: reverseCancel,
}
if options.URL == "" {
// Receiver mode: no URL, wait for reverse connection
cred.baseURL = reverseProxyBaseURL
cred.httpClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: false,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
session := cred.getReverseSession()
if session == nil || session.IsClosed() {
return nil, E.New("reverse connection not established for ", cred.tag)
}
return session.Open()
},
},
}
} else {
// Normal or connector mode: has URL
parsedURL, err := url.Parse(options.URL)
if err != nil {
return nil, E.Cause(err, "parse url for credential ", tag)
}
credentialDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: option.DialerOptions{
Detour: options.Detour,
},
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "create dialer for credential ", tag)
}
transport := &http.Transport{
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.Server != "" {
serverPort := options.ServerPort
if serverPort == 0 {
portStr := parsedURL.Port()
if portStr != "" {
port, parseErr := strconv.ParseUint(portStr, 10, 16)
if parseErr == nil {
serverPort = uint16(port)
}
}
if serverPort == 0 {
if parsedURL.Scheme == "https" {
serverPort = 443
} else {
serverPort = 80
}
}
}
destination := M.ParseSocksaddrHostPort(options.Server, serverPort)
return credentialDialer.DialContext(ctx, network, destination)
}
return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
}
if parsedURL.Scheme == "https" {
transport.TLSClientConfig = &stdTLS.Config{
ServerName: parsedURL.Hostname(),
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
}
}
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
if parsedURL.Path != "" && parsedURL.Path != "/" {
baseURL += parsedURL.Path
}
if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' {
baseURL = baseURL[:len(baseURL)-1]
}
cred.baseURL = baseURL
if options.Reverse {
// Connector mode: we dial out to serve, not to proxy
cred.connectorDialer = credentialDialer
cred.connectorURL = parsedURL
if parsedURL.Scheme == "https" {
cred.connectorTLS = &stdTLS.Config{
ServerName: parsedURL.Hostname(),
RootCAs: adapter.RootPoolFromContext(ctx),
Time: ntp.TimeFuncFromContext(ctx),
}
}
} else {
// Normal mode: standard HTTP client for proxying
cred.credDialer = credentialDialer
cred.httpClient = &http.Client{Transport: transport}
}
}
if options.UsagesPath != "" {
@ -142,6 +195,9 @@ func (c *externalCredential) start() error {
c.logger.Warn("load usage statistics for ", c.tag, ": ", err)
}
}
if c.reverse && c.connectorURL != nil {
go c.connectorLoop()
}
return nil
}
@ -158,6 +214,14 @@ func (c *externalCredential) isExternal() bool {
}
func (c *externalCredential) isAvailable() bool {
if c.reverse && c.connectorURL != nil {
return false // connector mode: not for local proxying
}
if c.baseURL == reverseProxyBaseURL {
// receiver mode: only available when reverse connection active
session := c.getReverseSession()
return session != nil && !session.IsClosed()
}
return true
}
@ -461,6 +525,16 @@ func (c *externalCredential) ocmGetBaseURL() string {
}
func (c *externalCredential) close() {
if c.reverseCancel != nil {
c.reverseCancel()
}
c.reverseAccess.Lock()
session := c.reverseSession
c.reverseSession = nil
c.reverseAccess.Unlock()
if session != nil {
session.Close()
}
if c.usageTracker != nil {
c.usageTracker.cancelPendingSave()
err := c.usageTracker.Save()
@ -469,3 +543,27 @@ func (c *externalCredential) close() {
}
}
}
func (c *externalCredential) getReverseSession() *yamux.Session {
c.reverseAccess.RLock()
defer c.reverseAccess.RUnlock()
return c.reverseSession
}
func (c *externalCredential) setReverseSession(session *yamux.Session) {
c.reverseAccess.Lock()
old := c.reverseSession
c.reverseSession = session
c.reverseAccess.Unlock()
if old != nil {
old.Close()
}
}
func (c *externalCredential) clearReverseSession(session *yamux.Session) {
c.reverseAccess.Lock()
if c.reverseSession == session {
c.reverseSession = nil
}
c.reverseAccess.Unlock()
}

View File

@ -1176,12 +1176,12 @@ func validateOCMOptions(options option.OCMServiceOptions) error {
}
}
if cred.Type == "external" {
if cred.ExternalOptions.URL == "" {
return E.New("credential ", cred.Tag, ": external credential requires url")
}
if cred.ExternalOptions.Token == "" {
return E.New("credential ", cred.Tag, ": external credential requires token")
}
if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" {
return E.New("credential ", cred.Tag, ": reverse external credential requires url")
}
}
if cred.Type == "balancer" {
switch cred.BalancerOptions.Strategy {

243
service/ocm/reverse.go Normal file
View File

@ -0,0 +1,243 @@
package ocm
import (
"bufio"
stdTLS "crypto/tls"
"errors"
"io"
"math/rand/v2"
"net"
"net/http"
"strings"
"time"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/hashicorp/yamux"
)
func reverseYamuxConfig() *yamux.Config {
config := yamux.DefaultConfig()
config.KeepAliveInterval = 15 * time.Second
config.ConnectionWriteTimeout = 10 * time.Second
config.MaxStreamWindowSize = 512 * 1024
config.LogOutput = io.Discard
return config
}
type yamuxNetListener struct {
session *yamux.Session
}
func (l *yamuxNetListener) Accept() (net.Conn, error) {
return l.session.Accept()
}
func (l *yamuxNetListener) Close() error {
return l.session.Close()
}
func (l *yamuxNetListener) Addr() net.Addr {
return l.session.Addr()
}
func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") != "reverse-proxy" {
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header")
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
receiverCredential := s.findReceiverCredential(clientToken)
if receiverCredential == nil {
s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token")
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
s.logger.Error("reverse connect: hijack not supported")
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported")
return
}
conn, bufferedReadWriter, err := hijacker.Hijack()
if err != nil {
s.logger.Error("reverse connect: hijack: ", err)
return
}
response := "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: reverse-proxy\r\n\r\n"
_, err = bufferedReadWriter.WriteString(response)
if err != nil {
conn.Close()
s.logger.Error("reverse connect: write upgrade response: ", err)
return
}
err = bufferedReadWriter.Flush()
if err != nil {
conn.Close()
s.logger.Error("reverse connect: flush upgrade response: ", err)
return
}
session, err := yamux.Client(conn, reverseYamuxConfig())
if err != nil {
conn.Close()
s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err)
return
}
receiverCredential.setReverseSession(session)
s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr)
go func() {
<-session.CloseChan()
receiverCredential.clearReverseSession(session)
s.logger.Warn("reverse connection lost for ", receiverCredential.tagName())
}()
}
func (s *Service) findReceiverCredential(token string) *externalCredential {
for _, cred := range s.allCredentials {
extCred, ok := cred.(*externalCredential)
if !ok {
continue
}
if extCred.baseURL == reverseProxyBaseURL && extCred.token == token {
return extCred
}
}
return nil
}
func (c *externalCredential) connectorLoop() {
var consecutiveFailures int
for {
select {
case <-c.reverseContext.Done():
return
default:
}
err := c.connectorConnect()
if c.reverseContext.Err() != nil {
return
}
consecutiveFailures++
backoff := connectorBackoff(consecutiveFailures)
c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff)
select {
case <-time.After(backoff):
case <-c.reverseContext.Done():
return
}
}
}
func connectorBackoff(failures int) time.Duration {
if failures > 5 {
failures = 5
}
base := time.Second * time.Duration(1<<failures)
if base > 30*time.Second {
base = 30 * time.Second
}
jitter := time.Duration(rand.Int64N(int64(base) / 2))
return base + jitter
}
func (c *externalCredential) connectorConnect() error {
destination := c.connectorResolveDestination()
conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination)
if err != nil {
return E.Cause(err, "dial")
}
if c.connectorTLS != nil {
tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone())
err = tlsConn.HandshakeContext(c.reverseContext)
if err != nil {
conn.Close()
return E.Cause(err, "tls handshake")
}
conn = tlsConn
}
upgradeRequest := "GET /ocm/v1/reverse HTTP/1.1\r\n" +
"Host: " + c.connectorURL.Host + "\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: reverse-proxy\r\n" +
"Authorization: Bearer " + c.token + "\r\n" +
"\r\n"
_, err = io.WriteString(conn, upgradeRequest)
if err != nil {
conn.Close()
return E.Cause(err, "write upgrade request")
}
reader := bufio.NewReader(conn)
statusLine, err := reader.ReadString('\n')
if err != nil {
conn.Close()
return E.Cause(err, "read upgrade response")
}
if !strings.HasPrefix(statusLine, "HTTP/1.1 101") {
conn.Close()
return E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine))
}
for {
line, readErr := reader.ReadString('\n')
if readErr != nil {
conn.Close()
return E.Cause(readErr, "read upgrade headers")
}
if strings.TrimSpace(line) == "" {
break
}
}
session, err := yamux.Server(conn, reverseYamuxConfig())
if err != nil {
conn.Close()
return E.Cause(err, "create yamux server")
}
defer session.Close()
c.logger.Info("reverse connection established for ", c.tag)
httpServer := &http.Server{
Handler: c.reverseService,
ReadTimeout: 0,
IdleTimeout: 120 * time.Second,
}
err = httpServer.Serve(&yamuxNetListener{session: session})
if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil {
return E.Cause(err, "serve")
}
return E.New("connection closed")
}
func (c *externalCredential) connectorResolveDestination() M.Socksaddr {
port := c.connectorURL.Port()
if port == "" {
if c.connectorURL.Scheme == "https" {
port = "443"
} else {
port = "80"
}
}
return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port))
}

View File

@ -305,6 +305,9 @@ func (s *Service) Start(stage adapter.StartStage) error {
cred.setOnBecameUnusable(func() {
s.interruptWebSocketSessionsForCredential(tag)
})
if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil {
extCred.reverseService = s
}
}
if len(s.options.Credentials) > 0 {
err := validateOCMCompositeCredentialModes(s.options, s.providers)
@ -364,6 +367,11 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if r.URL.Path == "/ocm/v1/reverse" {
s.handleReverseConnect(w, r)
return
}
path := r.URL.Path
if !strings.HasPrefix(path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
@ -860,6 +868,20 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use
headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64))
}
func (s *Service) InterfaceUpdated() {
for _, cred := range s.allCredentials {
extCred, ok := cred.(*externalCredential)
if !ok {
continue
}
if extCred.reverse && extCred.connectorURL != nil {
extCred.reverseCancel()
extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background())
go extCred.connectorLoop()
}
}
}
func (s *Service) Close() error {
webSocketSessions := s.startWebSocketShutdown()