From 16aeba8ec0a8febbbf1106710375af153575deb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 13 Mar 2026 18:51:02 +0800 Subject: [PATCH] ccm,ocm: add reverse proxy support for external credentials MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow two CCM/OCM instances to share credentials when only one has a public IP, using yamux-multiplexed reverse connections. Three credential modes: - Normal: URL set, reverse=false — standard HTTP proxy - Receiver: URL empty — waits for incoming reverse connection - Connector: URL set, reverse=true — dials out to establish connection Extend InterfaceUpdated to services so network changes trigger reverse connection reconnection. --- option/ccm.go | 3 +- option/ocm.go | 3 +- route/network.go | 11 ++ service/ccm/credential_external.go | 224 ++++++++++++++++++-------- service/ccm/credential_state.go | 6 +- service/ccm/reverse.go | 243 +++++++++++++++++++++++++++++ service/ccm/service.go | 22 +++ service/ocm/credential_external.go | 224 ++++++++++++++++++-------- service/ocm/credential_state.go | 6 +- service/ocm/reverse.go | 243 +++++++++++++++++++++++++++++ service/ocm/service.go | 22 +++ 11 files changed, 873 insertions(+), 134 deletions(-) create mode 100644 service/ccm/reverse.go create mode 100644 service/ocm/reverse.go diff --git a/option/ccm.go b/option/ccm.go index 6846dfccb..ae80cc64b 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -95,9 +95,10 @@ type CCMBalancerCredentialOptions struct { } type CCMExternalCredentialOptions struct { - URL string `json:"url"` + URL string `json:"url,omitempty"` ServerOptions Token string `json:"token"` + Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` diff --git a/option/ocm.go b/option/ocm.go index 4d495ff27..20cafee12 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -95,9 +95,10 @@ type OCMBalancerCredentialOptions struct { } type OCMExternalCredentialOptions struct { - URL string `json:"url"` + URL string `json:"url,omitempty"` ServerOptions Token string `json:"token"` + Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` diff --git a/route/network.go b/route/network.go index b8eefdc06..3f0cf57ca 100644 --- a/route/network.go +++ b/route/network.go @@ -51,6 +51,7 @@ type NetworkManager struct { endpoint adapter.EndpointManager inbound adapter.InboundManager outbound adapter.OutboundManager + serviceManager adapter.ServiceManager needWIFIState bool wifiMonitor settings.WIFIMonitor wifiState adapter.WIFIState @@ -94,6 +95,7 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, options endpoint: service.FromContext[adapter.EndpointManager](ctx), inbound: service.FromContext[adapter.InboundManager](ctx), outbound: service.FromContext[adapter.OutboundManager](ctx), + serviceManager: service.FromContext[adapter.ServiceManager](ctx), needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule), } if options.DefaultNetworkStrategy != nil { @@ -475,6 +477,15 @@ func (r *NetworkManager) ResetNetwork() { listener.InterfaceUpdated() } } + + if r.serviceManager != nil { + for _, svc := range r.serviceManager.Services() { + listener, isListener := svc.(adapter.InterfaceUpdateListener) + if isListener { + listener.InterfaceUpdated() + } + } + } } func (r *NetworkManager) notifyInterfaceUpdate(defaultInterface *control.Interface, flags int) { diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index f6560a2e6..e8e53c181 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -19,9 +19,14 @@ import ( "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" + + "github.com/hashicorp/yamux" ) +const reverseProxyBaseURL = "http://reverse-proxy" + type externalCredential struct { tag string baseURL string @@ -39,86 +44,134 @@ type externalCredential struct { requestContext context.Context cancelRequests context.CancelFunc requestAccess sync.Mutex + + // Reverse proxy fields + reverse bool + reverseSession *yamux.Session + reverseAccess sync.RWMutex + reverseContext context.Context + reverseCancel context.CancelFunc + connectorDialer N.Dialer + connectorURL *url.URL + connectorTLS *stdTLS.Config + reverseService http.Handler } func newExternalCredential(ctx context.Context, tag string, options option.CCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { - parsedURL, err := url.Parse(options.URL) - if err != nil { - return nil, E.Cause(err, "parse url for credential ", tag) - } - - credentialDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) - if err != nil { - return nil, E.Cause(err, "create dialer for credential ", tag) - } - - transport := &http.Transport{ - ForceAttemptHTTP2: true, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if options.Server != "" { - serverPort := options.ServerPort - if serverPort == 0 { - portStr := parsedURL.Port() - if portStr != "" { - port, parseErr := strconv.ParseUint(portStr, 10, 16) - if parseErr == nil { - serverPort = uint16(port) - } - } - if serverPort == 0 { - if parsedURL.Scheme == "https" { - serverPort = 443 - } else { - serverPort = 80 - } - } - } - destination := M.ParseSocksaddrHostPort(options.Server, serverPort) - return credentialDialer.DialContext(ctx, network, destination) - } - return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - } - - if parsedURL.Scheme == "https" { - transport.TLSClientConfig = &stdTLS.Config{ - ServerName: parsedURL.Hostname(), - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - } - } - - baseURL := parsedURL.Scheme + "://" + parsedURL.Host - if parsedURL.Path != "" && parsedURL.Path != "/" { - baseURL += parsedURL.Path - } - // Strip trailing slash - if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { - baseURL = baseURL[:len(baseURL)-1] - } - pollInterval := time.Duration(options.PollInterval) if pollInterval <= 0 { pollInterval = 30 * time.Minute } requestContext, cancelRequests := context.WithCancel(context.Background()) + reverseContext, reverseCancel := context.WithCancel(context.Background()) cred := &externalCredential{ tag: tag, - baseURL: baseURL, token: options.Token, - httpClient: &http.Client{Transport: transport}, pollInterval: pollInterval, logger: logger, requestContext: requestContext, cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, + } + + if options.URL == "" { + // Receiver mode: no URL, wait for reverse connection + cred.baseURL = reverseProxyBaseURL + cred.httpClient = &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: false, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + session := cred.getReverseSession() + if session == nil || session.IsClosed() { + return nil, E.New("reverse connection not established for ", cred.tag) + } + return session.Open() + }, + }, + } + } else { + // Normal or connector mode: has URL + parsedURL, err := url.Parse(options.URL) + if err != nil { + return nil, E.Cause(err, "parse url for credential ", tag) + } + + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + + transport := &http.Transport{ + ForceAttemptHTTP2: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if options.Server != "" { + serverPort := options.ServerPort + if serverPort == 0 { + portStr := parsedURL.Port() + if portStr != "" { + port, parseErr := strconv.ParseUint(portStr, 10, 16) + if parseErr == nil { + serverPort = uint16(port) + } + } + if serverPort == 0 { + if parsedURL.Scheme == "https" { + serverPort = 443 + } else { + serverPort = 80 + } + } + } + destination := M.ParseSocksaddrHostPort(options.Server, serverPort) + return credentialDialer.DialContext(ctx, network, destination) + } + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + } + + if parsedURL.Scheme == "https" { + transport.TLSClientConfig = &stdTLS.Config{ + ServerName: parsedURL.Hostname(), + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + } + } + + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + if parsedURL.Path != "" && parsedURL.Path != "/" { + baseURL += parsedURL.Path + } + if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { + baseURL = baseURL[:len(baseURL)-1] + } + + cred.baseURL = baseURL + + if options.Reverse { + // Connector mode: we dial out to serve, not to proxy + cred.connectorDialer = credentialDialer + cred.connectorURL = parsedURL + if parsedURL.Scheme == "https" { + cred.connectorTLS = &stdTLS.Config{ + ServerName: parsedURL.Hostname(), + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + } + } + } else { + // Normal mode: standard HTTP client for proxying + cred.httpClient = &http.Client{Transport: transport} + } } if options.UsagesPath != "" { @@ -140,6 +193,9 @@ func (c *externalCredential) start() error { c.logger.Warn("load usage statistics for ", c.tag, ": ", err) } } + if c.reverse && c.connectorURL != nil { + go c.connectorLoop() + } return nil } @@ -152,6 +208,14 @@ func (c *externalCredential) isExternal() bool { } func (c *externalCredential) isAvailable() bool { + if c.reverse && c.connectorURL != nil { + return false // connector mode: not for local proxying + } + if c.baseURL == reverseProxyBaseURL { + // receiver mode: only available when reverse connection active + session := c.getReverseSession() + return session != nil && !session.IsClosed() + } return true } @@ -426,6 +490,16 @@ func (c *externalCredential) httpTransport() *http.Client { } func (c *externalCredential) close() { + if c.reverseCancel != nil { + c.reverseCancel() + } + c.reverseAccess.Lock() + session := c.reverseSession + c.reverseSession = nil + c.reverseAccess.Unlock() + if session != nil { + session.Close() + } if c.usageTracker != nil { c.usageTracker.cancelPendingSave() err := c.usageTracker.Save() @@ -434,3 +508,27 @@ func (c *externalCredential) close() { } } } + +func (c *externalCredential) getReverseSession() *yamux.Session { + c.reverseAccess.RLock() + defer c.reverseAccess.RUnlock() + return c.reverseSession +} + +func (c *externalCredential) setReverseSession(session *yamux.Session) { + c.reverseAccess.Lock() + old := c.reverseSession + c.reverseSession = session + c.reverseAccess.Unlock() + if old != nil { + old.Close() + } +} + +func (c *externalCredential) clearReverseSession(session *yamux.Session) { + c.reverseAccess.Lock() + if c.reverseSession == session { + c.reverseSession = nil + } + c.reverseAccess.Unlock() +} diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index fbde0e8ac..673af5c2e 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -1117,12 +1117,12 @@ func validateCCMOptions(options option.CCMServiceOptions) error { } } if cred.Type == "external" { - if cred.ExternalOptions.URL == "" { - return E.New("credential ", cred.Tag, ": external credential requires url") - } if cred.ExternalOptions.Token == "" { return E.New("credential ", cred.Tag, ": external credential requires token") } + if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { + return E.New("credential ", cred.Tag, ": reverse external credential requires url") + } } if cred.Type == "balancer" { switch cred.BalancerOptions.Strategy { diff --git a/service/ccm/reverse.go b/service/ccm/reverse.go new file mode 100644 index 000000000..571c8c55a --- /dev/null +++ b/service/ccm/reverse.go @@ -0,0 +1,243 @@ +package ccm + +import ( + "bufio" + stdTLS "crypto/tls" + "errors" + "io" + "math/rand/v2" + "net" + "net/http" + "strings" + "time" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + + "github.com/hashicorp/yamux" +) + +func reverseYamuxConfig() *yamux.Config { + config := yamux.DefaultConfig() + config.KeepAliveInterval = 15 * time.Second + config.ConnectionWriteTimeout = 10 * time.Second + config.MaxStreamWindowSize = 512 * 1024 + config.LogOutput = io.Discard + return config +} + +type yamuxNetListener struct { + session *yamux.Session +} + +func (l *yamuxNetListener) Accept() (net.Conn, error) { + return l.session.Accept() +} + +func (l *yamuxNetListener) Close() error { + return l.session.Close() +} + +func (l *yamuxNetListener) Addr() net.Addr { + return l.session.Addr() +} + +func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") != "reverse-proxy" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header") + return + } + + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + + receiverCredential := s.findReceiverCredential(clientToken) + if receiverCredential == nil { + s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token") + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + s.logger.Error("reverse connect: hijack not supported") + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported") + return + } + + conn, bufferedReadWriter, err := hijacker.Hijack() + if err != nil { + s.logger.Error("reverse connect: hijack: ", err) + return + } + + response := "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: reverse-proxy\r\n\r\n" + _, err = bufferedReadWriter.WriteString(response) + if err != nil { + conn.Close() + s.logger.Error("reverse connect: write upgrade response: ", err) + return + } + err = bufferedReadWriter.Flush() + if err != nil { + conn.Close() + s.logger.Error("reverse connect: flush upgrade response: ", err) + return + } + + session, err := yamux.Client(conn, reverseYamuxConfig()) + if err != nil { + conn.Close() + s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) + return + } + + receiverCredential.setReverseSession(session) + s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) + + go func() { + <-session.CloseChan() + receiverCredential.clearReverseSession(session) + s.logger.Warn("reverse connection lost for ", receiverCredential.tagName()) + }() +} + +func (s *Service) findReceiverCredential(token string) *externalCredential { + for _, cred := range s.allCredentials { + extCred, ok := cred.(*externalCredential) + if !ok { + continue + } + if extCred.baseURL == reverseProxyBaseURL && extCred.token == token { + return extCred + } + } + return nil +} + +func (c *externalCredential) connectorLoop() { + var consecutiveFailures int + for { + select { + case <-c.reverseContext.Done(): + return + default: + } + + err := c.connectorConnect() + if c.reverseContext.Err() != nil { + return + } + consecutiveFailures++ + backoff := connectorBackoff(consecutiveFailures) + c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) + select { + case <-time.After(backoff): + case <-c.reverseContext.Done(): + return + } + } +} + +func connectorBackoff(failures int) time.Duration { + if failures > 5 { + failures = 5 + } + base := time.Second * time.Duration(1< 30*time.Second { + base = 30 * time.Second + } + jitter := time.Duration(rand.Int64N(int64(base) / 2)) + return base + jitter +} + +func (c *externalCredential) connectorConnect() error { + destination := c.connectorResolveDestination() + conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) + if err != nil { + return E.Cause(err, "dial") + } + + if c.connectorTLS != nil { + tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone()) + err = tlsConn.HandshakeContext(c.reverseContext) + if err != nil { + conn.Close() + return E.Cause(err, "tls handshake") + } + conn = tlsConn + } + + upgradeRequest := "GET /ccm/v1/reverse HTTP/1.1\r\n" + + "Host: " + c.connectorURL.Host + "\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: reverse-proxy\r\n" + + "Authorization: Bearer " + c.token + "\r\n" + + "\r\n" + _, err = io.WriteString(conn, upgradeRequest) + if err != nil { + conn.Close() + return E.Cause(err, "write upgrade request") + } + + reader := bufio.NewReader(conn) + statusLine, err := reader.ReadString('\n') + if err != nil { + conn.Close() + return E.Cause(err, "read upgrade response") + } + if !strings.HasPrefix(statusLine, "HTTP/1.1 101") { + conn.Close() + return E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine)) + } + for { + line, readErr := reader.ReadString('\n') + if readErr != nil { + conn.Close() + return E.Cause(readErr, "read upgrade headers") + } + if strings.TrimSpace(line) == "" { + break + } + } + + session, err := yamux.Server(conn, reverseYamuxConfig()) + if err != nil { + conn.Close() + return E.Cause(err, "create yamux server") + } + defer session.Close() + + c.logger.Info("reverse connection established for ", c.tag) + + httpServer := &http.Server{ + Handler: c.reverseService, + ReadTimeout: 0, + IdleTimeout: 120 * time.Second, + } + err = httpServer.Serve(&yamuxNetListener{session: session}) + if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil { + return E.Cause(err, "serve") + } + return E.New("connection closed") +} + +func (c *externalCredential) connectorResolveDestination() M.Socksaddr { + port := c.connectorURL.Port() + if port == "" { + if c.connectorURL.Scheme == "https" { + port = "443" + } else { + port = "80" + } + } + return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port)) +} diff --git a/service/ccm/service.go b/service/ccm/service.go index 4bd24b176..69697b5c0 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -258,6 +258,9 @@ func (s *Service) Start(stage adapter.StartStage) error { if err != nil { return err } + if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s + } } router := chi.NewRouter() @@ -318,6 +321,11 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if r.URL.Path == "/ccm/v1/reverse" { + s.handleReverseConnect(w, r) + return + } + if !strings.HasPrefix(r.URL.Path, "/v1/") { writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found") return @@ -786,6 +794,20 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64)) } +func (s *Service) InterfaceUpdated() { + for _, cred := range s.allCredentials { + extCred, ok := cred.(*externalCredential) + if !ok { + continue + } + if extCred.reverse && extCred.connectorURL != nil { + extCred.reverseCancel() + extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) + go extCred.connectorLoop() + } + } +} + func (s *Service) Close() error { err := common.Close( common.PtrOrNil(s.httpServer), diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index ca9664f1e..8226d6366 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -21,8 +21,12 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" + + "github.com/hashicorp/yamux" ) +const reverseProxyBaseURL = "http://reverse-proxy" + type externalCredential struct { tag string baseURL string @@ -41,86 +45,135 @@ type externalCredential struct { requestContext context.Context cancelRequests context.CancelFunc requestAccess sync.Mutex + + // Reverse proxy fields + reverse bool + reverseSession *yamux.Session + reverseAccess sync.RWMutex + reverseContext context.Context + reverseCancel context.CancelFunc + connectorDialer N.Dialer + connectorURL *url.URL + connectorTLS *stdTLS.Config + reverseService http.Handler } func newExternalCredential(ctx context.Context, tag string, options option.OCMExternalCredentialOptions, logger log.ContextLogger) (*externalCredential, error) { - parsedURL, err := url.Parse(options.URL) - if err != nil { - return nil, E.Cause(err, "parse url for credential ", tag) - } - - credentialDialer, err := dialer.NewWithOptions(dialer.Options{ - Context: ctx, - Options: option.DialerOptions{ - Detour: options.Detour, - }, - RemoteIsDomain: true, - }) - if err != nil { - return nil, E.Cause(err, "create dialer for credential ", tag) - } - - transport := &http.Transport{ - ForceAttemptHTTP2: true, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if options.Server != "" { - serverPort := options.ServerPort - if serverPort == 0 { - portStr := parsedURL.Port() - if portStr != "" { - port, parseErr := strconv.ParseUint(portStr, 10, 16) - if parseErr == nil { - serverPort = uint16(port) - } - } - if serverPort == 0 { - if parsedURL.Scheme == "https" { - serverPort = 443 - } else { - serverPort = 80 - } - } - } - destination := M.ParseSocksaddrHostPort(options.Server, serverPort) - return credentialDialer.DialContext(ctx, network, destination) - } - return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - } - - if parsedURL.Scheme == "https" { - transport.TLSClientConfig = &stdTLS.Config{ - ServerName: parsedURL.Hostname(), - RootCAs: adapter.RootPoolFromContext(ctx), - Time: ntp.TimeFuncFromContext(ctx), - } - } - - baseURL := parsedURL.Scheme + "://" + parsedURL.Host - if parsedURL.Path != "" && parsedURL.Path != "/" { - baseURL += parsedURL.Path - } - if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { - baseURL = baseURL[:len(baseURL)-1] - } - pollInterval := time.Duration(options.PollInterval) if pollInterval <= 0 { pollInterval = 30 * time.Minute } requestContext, cancelRequests := context.WithCancel(context.Background()) + reverseContext, reverseCancel := context.WithCancel(context.Background()) cred := &externalCredential{ tag: tag, - baseURL: baseURL, token: options.Token, - credDialer: credentialDialer, - httpClient: &http.Client{Transport: transport}, pollInterval: pollInterval, logger: logger, requestContext: requestContext, cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, + } + + if options.URL == "" { + // Receiver mode: no URL, wait for reverse connection + cred.baseURL = reverseProxyBaseURL + cred.httpClient = &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: false, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + session := cred.getReverseSession() + if session == nil || session.IsClosed() { + return nil, E.New("reverse connection not established for ", cred.tag) + } + return session.Open() + }, + }, + } + } else { + // Normal or connector mode: has URL + parsedURL, err := url.Parse(options.URL) + if err != nil { + return nil, E.Cause(err, "parse url for credential ", tag) + } + + credentialDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create dialer for credential ", tag) + } + + transport := &http.Transport{ + ForceAttemptHTTP2: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if options.Server != "" { + serverPort := options.ServerPort + if serverPort == 0 { + portStr := parsedURL.Port() + if portStr != "" { + port, parseErr := strconv.ParseUint(portStr, 10, 16) + if parseErr == nil { + serverPort = uint16(port) + } + } + if serverPort == 0 { + if parsedURL.Scheme == "https" { + serverPort = 443 + } else { + serverPort = 80 + } + } + } + destination := M.ParseSocksaddrHostPort(options.Server, serverPort) + return credentialDialer.DialContext(ctx, network, destination) + } + return credentialDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + } + + if parsedURL.Scheme == "https" { + transport.TLSClientConfig = &stdTLS.Config{ + ServerName: parsedURL.Hostname(), + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + } + } + + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + if parsedURL.Path != "" && parsedURL.Path != "/" { + baseURL += parsedURL.Path + } + if len(baseURL) > 0 && baseURL[len(baseURL)-1] == '/' { + baseURL = baseURL[:len(baseURL)-1] + } + + cred.baseURL = baseURL + + if options.Reverse { + // Connector mode: we dial out to serve, not to proxy + cred.connectorDialer = credentialDialer + cred.connectorURL = parsedURL + if parsedURL.Scheme == "https" { + cred.connectorTLS = &stdTLS.Config{ + ServerName: parsedURL.Hostname(), + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + } + } + } else { + // Normal mode: standard HTTP client for proxying + cred.credDialer = credentialDialer + cred.httpClient = &http.Client{Transport: transport} + } } if options.UsagesPath != "" { @@ -142,6 +195,9 @@ func (c *externalCredential) start() error { c.logger.Warn("load usage statistics for ", c.tag, ": ", err) } } + if c.reverse && c.connectorURL != nil { + go c.connectorLoop() + } return nil } @@ -158,6 +214,14 @@ func (c *externalCredential) isExternal() bool { } func (c *externalCredential) isAvailable() bool { + if c.reverse && c.connectorURL != nil { + return false // connector mode: not for local proxying + } + if c.baseURL == reverseProxyBaseURL { + // receiver mode: only available when reverse connection active + session := c.getReverseSession() + return session != nil && !session.IsClosed() + } return true } @@ -461,6 +525,16 @@ func (c *externalCredential) ocmGetBaseURL() string { } func (c *externalCredential) close() { + if c.reverseCancel != nil { + c.reverseCancel() + } + c.reverseAccess.Lock() + session := c.reverseSession + c.reverseSession = nil + c.reverseAccess.Unlock() + if session != nil { + session.Close() + } if c.usageTracker != nil { c.usageTracker.cancelPendingSave() err := c.usageTracker.Save() @@ -469,3 +543,27 @@ func (c *externalCredential) close() { } } } + +func (c *externalCredential) getReverseSession() *yamux.Session { + c.reverseAccess.RLock() + defer c.reverseAccess.RUnlock() + return c.reverseSession +} + +func (c *externalCredential) setReverseSession(session *yamux.Session) { + c.reverseAccess.Lock() + old := c.reverseSession + c.reverseSession = session + c.reverseAccess.Unlock() + if old != nil { + old.Close() + } +} + +func (c *externalCredential) clearReverseSession(session *yamux.Session) { + c.reverseAccess.Lock() + if c.reverseSession == session { + c.reverseSession = nil + } + c.reverseAccess.Unlock() +} diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 92745492d..b663632af 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -1176,12 +1176,12 @@ func validateOCMOptions(options option.OCMServiceOptions) error { } } if cred.Type == "external" { - if cred.ExternalOptions.URL == "" { - return E.New("credential ", cred.Tag, ": external credential requires url") - } if cred.ExternalOptions.Token == "" { return E.New("credential ", cred.Tag, ": external credential requires token") } + if cred.ExternalOptions.Reverse && cred.ExternalOptions.URL == "" { + return E.New("credential ", cred.Tag, ": reverse external credential requires url") + } } if cred.Type == "balancer" { switch cred.BalancerOptions.Strategy { diff --git a/service/ocm/reverse.go b/service/ocm/reverse.go new file mode 100644 index 000000000..b02a20222 --- /dev/null +++ b/service/ocm/reverse.go @@ -0,0 +1,243 @@ +package ocm + +import ( + "bufio" + stdTLS "crypto/tls" + "errors" + "io" + "math/rand/v2" + "net" + "net/http" + "strings" + "time" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + + "github.com/hashicorp/yamux" +) + +func reverseYamuxConfig() *yamux.Config { + config := yamux.DefaultConfig() + config.KeepAliveInterval = 15 * time.Second + config.ConnectionWriteTimeout = 10 * time.Second + config.MaxStreamWindowSize = 512 * 1024 + config.LogOutput = io.Discard + return config +} + +type yamuxNetListener struct { + session *yamux.Session +} + +func (l *yamuxNetListener) Accept() (net.Conn, error) { + return l.session.Accept() +} + +func (l *yamuxNetListener) Close() error { + return l.session.Close() +} + +func (l *yamuxNetListener) Addr() net.Addr { + return l.session.Addr() +} + +func (s *Service) handleReverseConnect(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") != "reverse-proxy" { + writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error", "missing Upgrade header") + return + } + + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key") + return + } + clientToken := strings.TrimPrefix(authHeader, "Bearer ") + if clientToken == authHeader { + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format") + return + } + + receiverCredential := s.findReceiverCredential(clientToken) + if receiverCredential == nil { + s.logger.Warn("reverse connect failed from ", r.RemoteAddr, ": no matching receiver credential") + writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid reverse token") + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + s.logger.Error("reverse connect: hijack not supported") + writeJSONError(w, r, http.StatusInternalServerError, "api_error", "hijack not supported") + return + } + + conn, bufferedReadWriter, err := hijacker.Hijack() + if err != nil { + s.logger.Error("reverse connect: hijack: ", err) + return + } + + response := "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: reverse-proxy\r\n\r\n" + _, err = bufferedReadWriter.WriteString(response) + if err != nil { + conn.Close() + s.logger.Error("reverse connect: write upgrade response: ", err) + return + } + err = bufferedReadWriter.Flush() + if err != nil { + conn.Close() + s.logger.Error("reverse connect: flush upgrade response: ", err) + return + } + + session, err := yamux.Client(conn, reverseYamuxConfig()) + if err != nil { + conn.Close() + s.logger.Error("reverse connect: create yamux client for ", receiverCredential.tagName(), ": ", err) + return + } + + receiverCredential.setReverseSession(session) + s.logger.Info("reverse connection established for ", receiverCredential.tagName(), " from ", r.RemoteAddr) + + go func() { + <-session.CloseChan() + receiverCredential.clearReverseSession(session) + s.logger.Warn("reverse connection lost for ", receiverCredential.tagName()) + }() +} + +func (s *Service) findReceiverCredential(token string) *externalCredential { + for _, cred := range s.allCredentials { + extCred, ok := cred.(*externalCredential) + if !ok { + continue + } + if extCred.baseURL == reverseProxyBaseURL && extCred.token == token { + return extCred + } + } + return nil +} + +func (c *externalCredential) connectorLoop() { + var consecutiveFailures int + for { + select { + case <-c.reverseContext.Done(): + return + default: + } + + err := c.connectorConnect() + if c.reverseContext.Err() != nil { + return + } + consecutiveFailures++ + backoff := connectorBackoff(consecutiveFailures) + c.logger.Warn("reverse connection for ", c.tag, " lost: ", err, ", reconnecting in ", backoff) + select { + case <-time.After(backoff): + case <-c.reverseContext.Done(): + return + } + } +} + +func connectorBackoff(failures int) time.Duration { + if failures > 5 { + failures = 5 + } + base := time.Second * time.Duration(1< 30*time.Second { + base = 30 * time.Second + } + jitter := time.Duration(rand.Int64N(int64(base) / 2)) + return base + jitter +} + +func (c *externalCredential) connectorConnect() error { + destination := c.connectorResolveDestination() + conn, err := c.connectorDialer.DialContext(c.reverseContext, "tcp", destination) + if err != nil { + return E.Cause(err, "dial") + } + + if c.connectorTLS != nil { + tlsConn := stdTLS.Client(conn, c.connectorTLS.Clone()) + err = tlsConn.HandshakeContext(c.reverseContext) + if err != nil { + conn.Close() + return E.Cause(err, "tls handshake") + } + conn = tlsConn + } + + upgradeRequest := "GET /ocm/v1/reverse HTTP/1.1\r\n" + + "Host: " + c.connectorURL.Host + "\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: reverse-proxy\r\n" + + "Authorization: Bearer " + c.token + "\r\n" + + "\r\n" + _, err = io.WriteString(conn, upgradeRequest) + if err != nil { + conn.Close() + return E.Cause(err, "write upgrade request") + } + + reader := bufio.NewReader(conn) + statusLine, err := reader.ReadString('\n') + if err != nil { + conn.Close() + return E.Cause(err, "read upgrade response") + } + if !strings.HasPrefix(statusLine, "HTTP/1.1 101") { + conn.Close() + return E.New("unexpected upgrade response: ", strings.TrimSpace(statusLine)) + } + for { + line, readErr := reader.ReadString('\n') + if readErr != nil { + conn.Close() + return E.Cause(readErr, "read upgrade headers") + } + if strings.TrimSpace(line) == "" { + break + } + } + + session, err := yamux.Server(conn, reverseYamuxConfig()) + if err != nil { + conn.Close() + return E.Cause(err, "create yamux server") + } + defer session.Close() + + c.logger.Info("reverse connection established for ", c.tag) + + httpServer := &http.Server{ + Handler: c.reverseService, + ReadTimeout: 0, + IdleTimeout: 120 * time.Second, + } + err = httpServer.Serve(&yamuxNetListener{session: session}) + if err != nil && !errors.Is(err, http.ErrServerClosed) && c.reverseContext.Err() == nil { + return E.Cause(err, "serve") + } + return E.New("connection closed") +} + +func (c *externalCredential) connectorResolveDestination() M.Socksaddr { + port := c.connectorURL.Port() + if port == "" { + if c.connectorURL.Scheme == "https" { + port = "443" + } else { + port = "80" + } + } + return M.ParseSocksaddr(net.JoinHostPort(c.connectorURL.Hostname(), port)) +} diff --git a/service/ocm/service.go b/service/ocm/service.go index 74fa776d8..50a44db89 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -305,6 +305,9 @@ func (s *Service) Start(stage adapter.StartStage) error { cred.setOnBecameUnusable(func() { s.interruptWebSocketSessionsForCredential(tag) }) + if extCred, ok := cred.(*externalCredential); ok && extCred.reverse && extCred.connectorURL != nil { + extCred.reverseService = s + } } if len(s.options.Credentials) > 0 { err := validateOCMCompositeCredentialModes(s.options, s.providers) @@ -364,6 +367,11 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if r.URL.Path == "/ocm/v1/reverse" { + s.handleReverseConnect(w, r) + return + } + path := r.URL.Path if !strings.HasPrefix(path, "/v1/") { writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/") @@ -860,6 +868,20 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64)) } +func (s *Service) InterfaceUpdated() { + for _, cred := range s.allCredentials { + extCred, ok := cred.(*externalCredential) + if !ok { + continue + } + if extCred.reverse && extCred.connectorURL != nil { + extCred.reverseCancel() + extCred.reverseContext, extCred.reverseCancel = context.WithCancel(context.Background()) + go extCred.connectorLoop() + } + } +} + func (s *Service) Close() error { webSocketSessions := s.startWebSocketShutdown()