net/traffic,ipn/ipnlocal: extract traffic steering utilities (#19682)

The traffic package contains helpers for evaluating traffic steering
scores and picking appropriate nodes. These were extracted from
ipnlocal.suggestExitNodeUsingTrafficSteering so they can be reused by
the new routecheck package to probe exit nodes in priority order.

Updates #17366
Updates tailscale/corp#33033

Signed-off-by: Simon Law <sfllaw@tailscale.com>
This commit is contained in:
Simon Law 2026-05-21 08:28:27 -07:00 committed by GitHub
parent dbe92f98b5
commit 7ebca58042
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 276 additions and 43 deletions

View File

@ -813,6 +813,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
tailscale.com/net/stun from tailscale.com/ipn/localapi+
tailscale.com/net/tlsdial from tailscale.com/control/controlclient+
tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial
tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal
tailscale.com/net/tsaddr from tailscale.com/client/web+
tailscale.com/net/tsdial from tailscale.com/control/controlclient+
💣 tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy

View File

@ -112,6 +112,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/net/stun from tailscale.com/net/netcheck+
tailscale.com/net/tlsdial from tailscale.com/control/controlclient+
tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial
tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal
tailscale.com/net/tsaddr from tailscale.com/ipn+
tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+
tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+

View File

@ -129,6 +129,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/net/stun from tailscale.com/net/netcheck+
tailscale.com/net/tlsdial from tailscale.com/control/controlclient+
tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial
tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal
tailscale.com/net/tsaddr from tailscale.com/ipn+
tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+
tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+

View File

@ -378,6 +378,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/net/stun from tailscale.com/ipn/localapi+
tailscale.com/net/tlsdial from tailscale.com/control/controlclient+
tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial
tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal
tailscale.com/net/tsaddr from tailscale.com/client/web+
tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+
💣 tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy

View File

@ -212,6 +212,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar
tailscale.com/net/stun from tailscale.com/ipn/localapi+
tailscale.com/net/tlsdial from tailscale.com/control/controlclient+
tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial
tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal
tailscale.com/net/tsaddr from tailscale.com/client/web+
tailscale.com/net/tsdial from tailscale.com/control/controlclient+
💣 tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy

View File

@ -9,8 +9,6 @@
"bufio"
"cmp"
"context"
"crypto/sha256"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
@ -64,6 +62,7 @@
"tailscale.com/net/netns"
"tailscale.com/net/netutil"
"tailscale.com/net/packet"
"tailscale.com/net/traffic"
"tailscale.com/net/tsaddr"
"tailscale.com/net/tsdial"
"tailscale.com/paths"
@ -8133,42 +8132,18 @@ func suggestExitNodeUsingTrafficSteering(nb *nodeBackend, allowed set.Set[tailcf
return true
})
scores := make(map[tailcfg.NodeID]int, len(nodes))
score := func(n tailcfg.NodeView) int {
id := n.ID()
s, ok := scores[id]
if !ok {
s = 0 // score of zero means incomparable
if hi := n.Hostinfo(); hi.Valid() {
if loc := hi.Location(); loc.Valid() {
s = loc.Priority()
}
}
scores[id] = s
}
return s
}
rdvHash := makeRendezvousHasher(self.ID())
ss := traffic.ScoresFor(self.ID(), nodes)
var pick tailcfg.NodeView
if len(nodes) == 1 {
pick = nodes[0]
}
if len(nodes) > 1 {
// Find the highest scoring exit nodes.
slices.SortFunc(nodes, func(a, b tailcfg.NodeView) int {
c := cmp.Compare(score(b), score(a)) // Highest score first.
if c == 0 {
// Rendezvous hashing for reliably picking the
// same node from a list: tailscale/tailscale#16551.
return cmp.Compare(rdvHash(b.ID()), rdvHash(a.ID()))
}
return c
})
ss.SortNodes(nodes)
// TODO(sfllaw): add a temperature knob so that this client has
// a chance of picking the next best option.
pick = nodes[0]
pick = nodes[0] // Pick the highest score.
}
nb.logf("netmap: traffic steering: exit node scores: %v", logger.ArgWriter(func(bw *bufio.Writer) {
@ -8182,7 +8157,7 @@ func suggestExitNodeUsingTrafficSteering(nb *nodeBackend, allowed set.Set[tailcf
bw.WriteString(", ")
}
name, _, _ := strings.Cut(n.Name(), ".")
fmt.Fprintf(bw, "%d:%s", score(n), name)
fmt.Fprintf(bw, "%d:%s", ss.Score(n), name)
}
}))
@ -8284,19 +8259,6 @@ func longLatDistance(fromLat, fromLong, toLat, toLong float64) float64 {
return earthRadiusMeters * c
}
// makeRendezvousHasher returns a function that hashes a node ID to a uint64.
// https://en.wikipedia.org/wiki/Rendezvous_hashing
func makeRendezvousHasher(seed tailcfg.NodeID) func(tailcfg.NodeID) uint64 {
en := binary.BigEndian
return func(n tailcfg.NodeID) uint64 {
var b [16]byte
en.PutUint64(b[:], uint64(seed))
en.PutUint64(b[8:], uint64(n))
v := sha256.Sum256(b[:])
return en.Uint64(v[:])
}
}
const (
// unresolvedExitNodeID is a special [tailcfg.StableNodeID] value
// used as an exit node ID to install a blackhole route, preventing

126
net/traffic/traffic.go Normal file
View File

@ -0,0 +1,126 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
// Package traffic contains helpers for evaluating traffic steering scores and
// picking appropriate nodes.
package traffic
import (
"cmp"
"crypto/sha256"
"encoding/binary"
"iter"
"maps"
"slices"
"tailscale.com/tailcfg"
"tailscale.com/util/mak"
)
// Score is a nodes traffic score, where any int could be a valid score.
// A higher traffic score suggests that the client should prefer that peer
// over one with a lower traffic score.
type Score int
// Scores is a memoization cache for the traffic scores of the current nodes peers.
type Scores struct {
self tailcfg.NodeID
hash NodeHasher
scores map[tailcfg.NodeID]Score
}
// ScoresFor returns a new [Scores] cache for the current nodes ID,
// after scoring the peer nodes and adding these scores to the cache.
func ScoresFor(self tailcfg.NodeID, peers []tailcfg.NodeView) Scores {
ss := Scores{
self: self,
hash: MakeRendezvousHasher(self),
}
ss.ScorePeers(peers)
return ss
}
// IsValid reports whether ss has been initialized with the current node ID.
func (ss Scores) IsValid() bool {
return !ss.self.IsZero()
}
// Score scores the given peer node and returns it after adding the score to the cache.
func (ss *Scores) Score(n tailcfg.NodeView) Score {
id := n.ID()
if s, ok := ss.scores[id]; ok {
return s
}
var s Score
if hi := n.Hostinfo(); hi.Valid() {
if loc := hi.Location(); loc.Valid() {
s = Score(loc.Priority())
}
}
mak.Set(&ss.scores, id, s)
return s
}
// ScorePeers scores the peer nodes and adds these scores to the cache.
func (ss *Scores) ScorePeers(peers []tailcfg.NodeView) {
if len(peers) == 0 {
return
}
if ss.scores == nil {
ss.scores = make(map[tailcfg.NodeID]Score, len(peers))
}
for _, n := range peers {
ss.Score(n)
}
}
// All returns an iterator over the scores for every peer in the cache.
// The iteration order is not specified and is not guaranteed to be the same
// from one call to the next.
func (ss Scores) All() iter.Seq2[tailcfg.NodeID, Score] {
return maps.All(ss.scores)
}
// SortNodes sorts the slice of nodes in descending order of [Scores.Score],
// using rendezvous hashing to break ties when both nodes have the same score.
// After sorting, the zeroth element is the preferred node.
func (ss Scores) SortNodes(nodes []tailcfg.NodeView) {
slices.SortFunc(nodes, func(a, b tailcfg.NodeView) int {
c := cmp.Compare(ss.Score(b), ss.Score(a)) // Highest score first.
if c == 0 {
return ss.hash.Compare(b.ID(), a.ID()) // Descending order.
}
return c
})
}
// NodeHasher returns a 64-bit hash of a node ID.
type NodeHasher func(tailcfg.NodeID) uint64
// MakeRendezvousHasher returns a function that hashes a node ID to a uint64.
// https://en.wikipedia.org/wiki/Rendezvous_hashing
func MakeRendezvousHasher(seed tailcfg.NodeID) NodeHasher {
en := binary.BigEndian
return func(n tailcfg.NodeID) uint64 {
var b [16]byte
en.PutUint64(b[:], uint64(seed))
en.PutUint64(b[8:], uint64(n))
v := sha256.Sum256(b[:])
return en.Uint64(v[:])
}
}
// Compare compares the node ID hashes of peers a and b, using the same convention as [cmp.Compare].
// Since h is seeded with the current nodes ID, the ordering between a and b will remain stable
// for this node; but the order may flip for when h is seeded for another node.
// This function should return zero, if and only if a and b have the same node ID.
func (h NodeHasher) Compare(a, b tailcfg.NodeID) int {
c := cmp.Compare(h(a), h(b))
if c == 0 {
// In the unlikely event of a hash collision, compare the actual IDs.
return cmp.Compare(a, b)
}
return c
}

139
net/traffic/traffic_test.go Normal file
View File

@ -0,0 +1,139 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package traffic_test
import (
"maps"
"testing"
gocmp "github.com/google/go-cmp/cmp"
"tailscale.com/net/traffic"
"tailscale.com/tailcfg"
)
// WantScores is a convenience alias for the type of [traffic.Score.scores].
type wantScores = map[tailcfg.NodeID]traffic.Score
var scoresCases = []struct {
name string
peers []*tailcfg.Node
want wantScores
}{
{
name: "none",
peers: nil,
want: wantScores{},
},
{
name: "no-scores",
peers: []*tailcfg.Node{
{ID: 37},
{ID: 42},
},
want: wantScores{
37: 0,
42: 0,
},
},
{
name: "mixed-scores",
peers: []*tailcfg.Node{
{ID: 37},
{
ID: 42,
Hostinfo: (&tailcfg.Hostinfo{
Location: &tailcfg.Location{Priority: 1},
}).View(),
},
},
want: wantScores{
37: 0,
42: 1,
},
},
}
func TestScoreOne(t *testing.T) {
for _, tc := range scoresCases {
if len(tc.peers) == 0 {
continue
}
t.Run(tc.name, func(t *testing.T) {
selfID := tailcfg.NodeID(1)
ss := traffic.ScoresFor(selfID, nil)
for _, n := range tc.peers {
want := tc.want[n.ID]
score := ss.Score(n.View())
if score != want {
t.Errorf("initial Score for nodeid:%d: score %d, want %d", n.ID, score, want)
}
score = ss.Score(n.View())
if score != want {
t.Errorf("subsequent Score for nodeid:%d: score %d, want %d", n.ID, score, want)
}
}
got := maps.Collect(ss.All())
if diff := gocmp.Diff(tc.want, got); diff != "" {
t.Errorf("-want +got:\n%s", diff)
}
})
}
}
func TestScoreMany(t *testing.T) {
for _, tc := range scoresCases {
t.Run(tc.name, func(t *testing.T) {
selfID := tailcfg.NodeID(1)
var peers []tailcfg.NodeView
for _, n := range tc.peers {
peers = append(peers, n.View())
}
t.Run("ScoresFor", func(t *testing.T) {
ss := traffic.ScoresFor(selfID, peers)
got := maps.Collect(ss.All())
if diff := gocmp.Diff(tc.want, got); diff != "" {
t.Errorf("-want +got:\n%s", diff)
}
})
t.Run("ScorePeers", func(t *testing.T) {
ss := traffic.ScoresFor(selfID, nil)
ss.ScorePeers(peers)
got := maps.Collect(ss.All())
if diff := gocmp.Diff(tc.want, got); diff != "" {
t.Errorf("-want +got:\n%s", diff)
}
})
})
}
}
func FuzzNodeHasherCompare(f *testing.F) {
for _, seed := range [][]uint64{
{0, 0, 0},
{1, 1, 1},
{1, 10, 11},
{1, 11, 10},
{2, 10, 11},
} {
selfID, aID, bID := seed[0], seed[1], seed[2]
f.Add(selfID, aID, bID)
}
f.Fuzz(func(t *testing.T, selfID, aID, bID uint64) {
t.Logf("selfID %d, aID %d, bID %d", selfID, aID, bID)
h := traffic.MakeRendezvousHasher(tailcfg.NodeID(selfID))
a, b := tailcfg.NodeID(aID), tailcfg.NodeID(bID)
c := h.Compare(a, b)
if c == 0 && a != b {
t.Fatalf("got %d: expected different hashes because a ≠ b, ", c)
}
if cc := h.Compare(a, b); c != cc {
t.Fatalf("c %d, cc %d: expected matching values", c, cc)
}
if d := h.Compare(b, a); c != -d {
t.Fatalf("c %d, d %d: expected inverse values", c, d)
}
})
}

View File

@ -208,6 +208,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware)
tailscale.com/net/stun from tailscale.com/ipn/localapi+
tailscale.com/net/tlsdial from tailscale.com/control/controlclient+
tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial
tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal
tailscale.com/net/tsaddr from tailscale.com/client/web+
tailscale.com/net/tsdial from tailscale.com/control/controlclient+
💣 tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy