diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 12073da0b..5e8361175 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -813,6 +813,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/net/stun from tailscale.com/ipn/localapi+ tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial + tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal tailscale.com/net/tsaddr from tailscale.com/client/web+ tailscale.com/net/tsdial from tailscale.com/control/controlclient+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy diff --git a/cmd/tailscaled/depaware-min.txt b/cmd/tailscaled/depaware-min.txt index 8f0c34cf1..9b9003c85 100644 --- a/cmd/tailscaled/depaware-min.txt +++ b/cmd/tailscaled/depaware-min.txt @@ -112,6 +112,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/stun from tailscale.com/net/netcheck+ tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial + tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal tailscale.com/net/tsaddr from tailscale.com/ipn+ tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+ diff --git a/cmd/tailscaled/depaware-minbox.txt b/cmd/tailscaled/depaware-minbox.txt index 994310d60..90d410664 100644 --- a/cmd/tailscaled/depaware-minbox.txt +++ b/cmd/tailscaled/depaware-minbox.txt @@ -129,6 +129,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/stun from tailscale.com/net/netcheck+ tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial + tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal tailscale.com/net/tsaddr from tailscale.com/ipn+ tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 7e0e95be8..84dde50c1 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -378,6 +378,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/stun from tailscale.com/ipn/localapi+ tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial + tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal tailscale.com/net/tsaddr from tailscale.com/client/web+ tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy diff --git a/cmd/tsidp/depaware.txt b/cmd/tsidp/depaware.txt index cf1a4c279..a5cac442e 100644 --- a/cmd/tsidp/depaware.txt +++ b/cmd/tsidp/depaware.txt @@ -212,6 +212,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar tailscale.com/net/stun from tailscale.com/ipn/localapi+ tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial + tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal tailscale.com/net/tsaddr from tailscale.com/client/web+ tailscale.com/net/tsdial from tailscale.com/control/controlclient+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index f211d965a..481436fbe 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -9,8 +9,6 @@ "bufio" "cmp" "context" - "crypto/sha256" - "encoding/binary" "encoding/json" "errors" "fmt" @@ -64,6 +62,7 @@ "tailscale.com/net/netns" "tailscale.com/net/netutil" "tailscale.com/net/packet" + "tailscale.com/net/traffic" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" "tailscale.com/paths" @@ -8133,42 +8132,18 @@ func suggestExitNodeUsingTrafficSteering(nb *nodeBackend, allowed set.Set[tailcf return true }) - scores := make(map[tailcfg.NodeID]int, len(nodes)) - score := func(n tailcfg.NodeView) int { - id := n.ID() - s, ok := scores[id] - if !ok { - s = 0 // score of zero means incomparable - if hi := n.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc.Valid() { - s = loc.Priority() - } - } - scores[id] = s - } - return s - } - rdvHash := makeRendezvousHasher(self.ID()) + ss := traffic.ScoresFor(self.ID(), nodes) var pick tailcfg.NodeView if len(nodes) == 1 { pick = nodes[0] } if len(nodes) > 1 { - // Find the highest scoring exit nodes. - slices.SortFunc(nodes, func(a, b tailcfg.NodeView) int { - c := cmp.Compare(score(b), score(a)) // Highest score first. - if c == 0 { - // Rendezvous hashing for reliably picking the - // same node from a list: tailscale/tailscale#16551. - return cmp.Compare(rdvHash(b.ID()), rdvHash(a.ID())) - } - return c - }) + ss.SortNodes(nodes) // TODO(sfllaw): add a temperature knob so that this client has // a chance of picking the next best option. - pick = nodes[0] + pick = nodes[0] // Pick the highest score. } nb.logf("netmap: traffic steering: exit node scores: %v", logger.ArgWriter(func(bw *bufio.Writer) { @@ -8182,7 +8157,7 @@ func suggestExitNodeUsingTrafficSteering(nb *nodeBackend, allowed set.Set[tailcf bw.WriteString(", ") } name, _, _ := strings.Cut(n.Name(), ".") - fmt.Fprintf(bw, "%d:%s", score(n), name) + fmt.Fprintf(bw, "%d:%s", ss.Score(n), name) } })) @@ -8284,19 +8259,6 @@ func longLatDistance(fromLat, fromLong, toLat, toLong float64) float64 { return earthRadiusMeters * c } -// makeRendezvousHasher returns a function that hashes a node ID to a uint64. -// https://en.wikipedia.org/wiki/Rendezvous_hashing -func makeRendezvousHasher(seed tailcfg.NodeID) func(tailcfg.NodeID) uint64 { - en := binary.BigEndian - return func(n tailcfg.NodeID) uint64 { - var b [16]byte - en.PutUint64(b[:], uint64(seed)) - en.PutUint64(b[8:], uint64(n)) - v := sha256.Sum256(b[:]) - return en.Uint64(v[:]) - } -} - const ( // unresolvedExitNodeID is a special [tailcfg.StableNodeID] value // used as an exit node ID to install a blackhole route, preventing diff --git a/net/traffic/traffic.go b/net/traffic/traffic.go new file mode 100644 index 000000000..84b3efbfb --- /dev/null +++ b/net/traffic/traffic.go @@ -0,0 +1,126 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package traffic contains helpers for evaluating traffic steering scores and +// picking appropriate nodes. +package traffic + +import ( + "cmp" + "crypto/sha256" + "encoding/binary" + "iter" + "maps" + "slices" + + "tailscale.com/tailcfg" + "tailscale.com/util/mak" +) + +// Score is a node’s traffic score, where any int could be a valid score. +// A higher traffic score suggests that the client should prefer that peer +// over one with a lower traffic score. +type Score int + +// Scores is a memoization cache for the traffic scores of the current node’s peers. +type Scores struct { + self tailcfg.NodeID + hash NodeHasher + + scores map[tailcfg.NodeID]Score +} + +// ScoresFor returns a new [Scores] cache for the current node’s ID, +// after scoring the peer nodes and adding these scores to the cache. +func ScoresFor(self tailcfg.NodeID, peers []tailcfg.NodeView) Scores { + ss := Scores{ + self: self, + hash: MakeRendezvousHasher(self), + } + ss.ScorePeers(peers) + return ss +} + +// IsValid reports whether ss has been initialized with the current node ID. +func (ss Scores) IsValid() bool { + return !ss.self.IsZero() +} + +// Score scores the given peer node and returns it after adding the score to the cache. +func (ss *Scores) Score(n tailcfg.NodeView) Score { + id := n.ID() + if s, ok := ss.scores[id]; ok { + return s + } + + var s Score + if hi := n.Hostinfo(); hi.Valid() { + if loc := hi.Location(); loc.Valid() { + s = Score(loc.Priority()) + } + } + mak.Set(&ss.scores, id, s) + return s +} + +// ScorePeers scores the peer nodes and adds these scores to the cache. +func (ss *Scores) ScorePeers(peers []tailcfg.NodeView) { + if len(peers) == 0 { + return + } + if ss.scores == nil { + ss.scores = make(map[tailcfg.NodeID]Score, len(peers)) + } + for _, n := range peers { + ss.Score(n) + } +} + +// All returns an iterator over the scores for every peer in the cache. +// The iteration order is not specified and is not guaranteed to be the same +// from one call to the next. +func (ss Scores) All() iter.Seq2[tailcfg.NodeID, Score] { + return maps.All(ss.scores) +} + +// SortNodes sorts the slice of nodes in descending order of [Scores.Score], +// using rendezvous hashing to break ties when both nodes have the same score. +// After sorting, the zeroth element is the preferred node. +func (ss Scores) SortNodes(nodes []tailcfg.NodeView) { + slices.SortFunc(nodes, func(a, b tailcfg.NodeView) int { + c := cmp.Compare(ss.Score(b), ss.Score(a)) // Highest score first. + if c == 0 { + return ss.hash.Compare(b.ID(), a.ID()) // Descending order. + } + return c + }) +} + +// NodeHasher returns a 64-bit hash of a node ID. +type NodeHasher func(tailcfg.NodeID) uint64 + +// MakeRendezvousHasher returns a function that hashes a node ID to a uint64. +// https://en.wikipedia.org/wiki/Rendezvous_hashing +func MakeRendezvousHasher(seed tailcfg.NodeID) NodeHasher { + en := binary.BigEndian + return func(n tailcfg.NodeID) uint64 { + var b [16]byte + en.PutUint64(b[:], uint64(seed)) + en.PutUint64(b[8:], uint64(n)) + v := sha256.Sum256(b[:]) + return en.Uint64(v[:]) + } +} + +// Compare compares the node ID hashes of peers a and b, using the same convention as [cmp.Compare]. +// Since h is seeded with the current node’s ID, the ordering between a and b will remain stable +// for this node; but the order may flip for when h is seeded for another node. +// This function should return zero, if and only if a and b have the same node ID. +func (h NodeHasher) Compare(a, b tailcfg.NodeID) int { + c := cmp.Compare(h(a), h(b)) + if c == 0 { + // In the unlikely event of a hash collision, compare the actual IDs. + return cmp.Compare(a, b) + } + return c +} diff --git a/net/traffic/traffic_test.go b/net/traffic/traffic_test.go new file mode 100644 index 000000000..cae4801a3 --- /dev/null +++ b/net/traffic/traffic_test.go @@ -0,0 +1,139 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package traffic_test + +import ( + "maps" + "testing" + + gocmp "github.com/google/go-cmp/cmp" + "tailscale.com/net/traffic" + "tailscale.com/tailcfg" +) + +// WantScores is a convenience alias for the type of [traffic.Score.scores]. +type wantScores = map[tailcfg.NodeID]traffic.Score + +var scoresCases = []struct { + name string + peers []*tailcfg.Node + want wantScores +}{ + { + name: "none", + peers: nil, + want: wantScores{}, + }, + { + name: "no-scores", + peers: []*tailcfg.Node{ + {ID: 37}, + {ID: 42}, + }, + want: wantScores{ + 37: 0, + 42: 0, + }, + }, + { + name: "mixed-scores", + peers: []*tailcfg.Node{ + {ID: 37}, + { + ID: 42, + Hostinfo: (&tailcfg.Hostinfo{ + Location: &tailcfg.Location{Priority: 1}, + }).View(), + }, + }, + want: wantScores{ + 37: 0, + 42: 1, + }, + }, +} + +func TestScoreOne(t *testing.T) { + for _, tc := range scoresCases { + if len(tc.peers) == 0 { + continue + } + t.Run(tc.name, func(t *testing.T) { + selfID := tailcfg.NodeID(1) + ss := traffic.ScoresFor(selfID, nil) + for _, n := range tc.peers { + want := tc.want[n.ID] + score := ss.Score(n.View()) + if score != want { + t.Errorf("initial Score for nodeid:%d: score %d, want %d", n.ID, score, want) + } + score = ss.Score(n.View()) + if score != want { + t.Errorf("subsequent Score for nodeid:%d: score %d, want %d", n.ID, score, want) + } + } + got := maps.Collect(ss.All()) + if diff := gocmp.Diff(tc.want, got); diff != "" { + t.Errorf("-want +got:\n%s", diff) + } + }) + } +} + +func TestScoreMany(t *testing.T) { + for _, tc := range scoresCases { + t.Run(tc.name, func(t *testing.T) { + selfID := tailcfg.NodeID(1) + var peers []tailcfg.NodeView + for _, n := range tc.peers { + peers = append(peers, n.View()) + } + + t.Run("ScoresFor", func(t *testing.T) { + ss := traffic.ScoresFor(selfID, peers) + got := maps.Collect(ss.All()) + if diff := gocmp.Diff(tc.want, got); diff != "" { + t.Errorf("-want +got:\n%s", diff) + } + }) + + t.Run("ScorePeers", func(t *testing.T) { + ss := traffic.ScoresFor(selfID, nil) + ss.ScorePeers(peers) + got := maps.Collect(ss.All()) + if diff := gocmp.Diff(tc.want, got); diff != "" { + t.Errorf("-want +got:\n%s", diff) + } + }) + }) + } +} + +func FuzzNodeHasherCompare(f *testing.F) { + for _, seed := range [][]uint64{ + {0, 0, 0}, + {1, 1, 1}, + {1, 10, 11}, + {1, 11, 10}, + {2, 10, 11}, + } { + selfID, aID, bID := seed[0], seed[1], seed[2] + f.Add(selfID, aID, bID) + } + f.Fuzz(func(t *testing.T, selfID, aID, bID uint64) { + t.Logf("selfID %d, aID %d, bID %d", selfID, aID, bID) + h := traffic.MakeRendezvousHasher(tailcfg.NodeID(selfID)) + a, b := tailcfg.NodeID(aID), tailcfg.NodeID(bID) + c := h.Compare(a, b) + if c == 0 && a != b { + t.Fatalf("got %d: expected different hashes because a ≠ b, ", c) + } + if cc := h.Compare(a, b); c != cc { + t.Fatalf("c %d, cc %d: expected matching values", c, cc) + } + if d := h.Compare(b, a); c != -d { + t.Fatalf("c %d, d %d: expected inverse values", c, d) + } + }) +} diff --git a/tsnet/depaware.txt b/tsnet/depaware.txt index a4eed2a13..817043896 100644 --- a/tsnet/depaware.txt +++ b/tsnet/depaware.txt @@ -208,6 +208,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) tailscale.com/net/stun from tailscale.com/ipn/localapi+ tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial + tailscale.com/net/traffic from tailscale.com/ipn/ipnlocal tailscale.com/net/tsaddr from tailscale.com/client/web+ tailscale.com/net/tsdial from tailscale.com/control/controlclient+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy