diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index a78db5acc..5e08f0275 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -472,13 +472,14 @@ func (mrs mapRoutineState) UpdateNetmapDelta(muts []netmap.NodeMutation) bool { c.mu.Lock() goodState := c.loggedIn && c.inMapPoll ndu, canDelta := c.observer.(NetmapDeltaUpdater) + mapCtx := c.mapCtx c.mu.Unlock() if !goodState || !canDelta { return false } - ctx, cancel := context.WithTimeout(c.mapCtx, 2*time.Second) + ctx, cancel := context.WithTimeout(mapCtx, 2*time.Second) defer cancel() ch := make(chan bool, 1) @@ -508,11 +509,12 @@ func (mrs mapRoutineState) UpdatePacketFilter(rules views.Slice[tailcfg.FilterRu c.mu.Lock() goodState := c.loggedIn && c.inMapPoll pfu, ok := c.observer.(PacketFilterUpdater) + mapCtx := c.mapCtx c.mu.Unlock() if !goodState || !ok { return false } - ctx, cancel := context.WithTimeout(c.mapCtx, 2*time.Second) + ctx, cancel := context.WithTimeout(mapCtx, 2*time.Second) defer cancel() ch := make(chan bool, 1) c.observerQueue.Add(func() { @@ -536,11 +538,12 @@ func (mrs mapRoutineState) UpdateUserProfiles(profiles map[tailcfg.UserID]tailcf c.mu.Lock() goodState := c.loggedIn && c.inMapPoll upu, ok := c.observer.(UserProfileUpdater) + mapCtx := c.mapCtx c.mu.Unlock() if !goodState || !ok { return false } - ctx, cancel := context.WithTimeout(c.mapCtx, 2*time.Second) + ctx, cancel := context.WithTimeout(mapCtx, 2*time.Second) defer cancel() ch := make(chan bool, 1) c.observerQueue.Add(func() { @@ -561,13 +564,14 @@ func (mrs mapRoutineState) PatchDiscoKey(pub key.NodePublic, disco key.DiscoPubl c.mu.Lock() goodState := c.loggedIn && c.inMapPoll dun, ok := c.observer.(patchDiscoKeyer) + mapCtx := c.mapCtx c.mu.Unlock() if !goodState || !ok { return } - ctx, cancel := context.WithTimeout(c.mapCtx, 2*time.Second) + ctx, cancel := context.WithTimeout(mapCtx, 2*time.Second) defer cancel() c.observerQueue.RunSync(ctx, func() { diff --git a/control/controlclient/auto_test.go b/control/controlclient/auto_test.go new file mode 100644 index 000000000..1c7e89521 --- /dev/null +++ b/control/controlclient/auto_test.go @@ -0,0 +1,66 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package controlclient + +import ( + "context" + "sync" + "testing" + "time" + + "tailscale.com/tailcfg" +) + +type userProfileUpdateObserver struct{} + +func (userProfileUpdateObserver) SetControlClientStatus(Client, Status) {} + +func (userProfileUpdateObserver) UpdateUserProfiles(map[tailcfg.UserID]tailcfg.UserProfileView) bool { + return true +} + +func TestMapRoutineStateUpdateUserProfilesConcurrentCancelMapCtx(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := &Auto{ + logf: func(string, ...any) {}, + observer: userProfileUpdateObserver{}, + mapCtx: ctx, + mapCancel: cancel, + loggedIn: true, + inMapPoll: true, + } + mrs := mapRoutineState{c: c} + + start := make(chan struct{}) + var wg sync.WaitGroup + for range 4 { + wg.Go(func() { + <-start + for range 2000 { + c.mu.Lock() + c.cancelMapCtxLocked() + c.mu.Unlock() + } + }) + } + for range 4 { + wg.Go(func() { + <-start + for range 2000 { + mrs.UpdateUserProfiles(nil) + } + }) + } + + close(start) + wg.Wait() + + waitCtx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + if err := c.observerQueue.Wait(waitCtx); err != nil { + t.Fatal(err) + } + c.observerQueue.Shutdown() + c.mapCancel() +}