From 0809b45eb4ff2b108f2e5091fc7eae7609e77645 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 24 Apr 2026 08:27:18 +0800 Subject: [PATCH] fix usbip control resync and darwin bounds --- .github/workflows/test.yml | 2 +- service/usbip/client_control.go | 24 +++++ service/usbip/client_darwin.go | 44 ++++++-- service/usbip/client_linux.go | 10 +- service/usbip/darwin_integration_test.go | 23 +++++ service/usbip/linux_test.go | 124 +++++++++++++++++++++++ service/usbip/server_darwin.go | 3 - service/usbip/server_linux.go | 1 - 8 files changed, 214 insertions(+), 17 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1e400be75..cdb57e6b9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -50,7 +50,7 @@ jobs: if: matrix.os == 'ubuntu-latest' run: | sudo apt-get update - sudo apt-get install -y usbip + sudo apt-get install -y linux-tools-common - name: Test (unix) if: matrix.os != 'windows-latest' run: go test -v -exec sudo -tags "$BUILD_TAGS" -ldflags "$LDFLAGS_SHARED" ./... diff --git a/service/usbip/client_control.go b/service/usbip/client_control.go index a7cd97056..3fb7f9e2b 100644 --- a/service/usbip/client_control.go +++ b/service/usbip/client_control.go @@ -206,3 +206,27 @@ func (c *ClientService) clearControlDeviceState() { c.remoteDevicesV2 = nil c.remoteMu.Unlock() } + +func (c *ClientService) syncRemoteStateAndResetControlState(ctx context.Context) error { + entries, err := c.fetchDevList(ctx) + if err != nil { + return err + } + c.resetControlDeviceStateFromEntries(entries) + c.applyRemoteEntries(entries) + return nil +} + +func (c *ClientService) resetControlDeviceStateFromEntries(entries []DeviceEntry) { + devices := make(map[string]DeviceInfoV2, len(entries)) + for _, entry := range entries { + device := deviceInfoV2FromEntry(entry, "", "", deviceStateAvailable, 0, "available") + if device.BusID == "" { + continue + } + devices[device.BusID] = device + } + c.remoteMu.Lock() + c.remoteDevicesV2 = devices + c.remoteMu.Unlock() +} diff --git a/service/usbip/client_darwin.go b/service/usbip/client_darwin.go index 7b3bcfdd7..f0d5363be 100644 --- a/service/usbip/client_darwin.go +++ b/service/usbip/client_darwin.go @@ -239,7 +239,12 @@ func (c *ClientService) runControlSession() error { return E.Cause(errImmediateReconnect, "control sequence jumped from ", lastSeq, " to ", frame.Sequence) } lastSeq = frame.Sequence - if err := c.syncRemoteState(); err != nil { + if extended { + err = c.syncRemoteStateAndResetControlState(c.ctx) + } else { + err = c.syncRemoteState() + } + if err != nil { return E.Cause(errImmediateReconnect, "devlist sync after change ", frame.Sequence, ": ", err) } case controlFrameDeviceSnapshot: @@ -257,10 +262,9 @@ func (c *ClientService) runControlSession() error { return E.Cause(errImmediateReconnect, "unexpected control frame ", frame.Type) } if frame.Sequence != lastSeq+1 { - if err := c.syncRemoteState(); err != nil { + if err := c.syncRemoteStateAndResetControlState(c.ctx); err != nil { return E.Cause(errImmediateReconnect, "devlist sync after sequence jump ", frame.Sequence, ": ", err) } - c.clearControlDeviceState() lastSeq = frame.Sequence continue } @@ -1054,8 +1058,8 @@ func (c *darwinVirtualController) handleControlDataTransfer(key darwinEndpointKe if err != nil { return -int32(unix.EIO), 0 } - if direction == USBIPDirIn && len(response.Buffer) > 0 { - copyToUnsafe(message.bufferPointer(), response.Buffer) + if direction == USBIPDirIn { + return c.completeSubmitInTransfer(message.bufferPointer(), response, length) } return response.Status, int(response.ActualLength) } @@ -1110,8 +1114,8 @@ func (c *darwinVirtualController) handleNormalTransfer(key darwinEndpointKey, me if err != nil { return -int32(unix.EIO), 0 } - if direction == USBIPDirIn && len(response.Buffer) > 0 { - copyToUnsafe(message.bufferPointer(), response.Buffer) + if direction == USBIPDirIn { + return c.completeSubmitInTransfer(message.bufferPointer(), response, length) } return response.Status, int(response.ActualLength) } @@ -1144,12 +1148,34 @@ func (c *darwinVirtualController) handleIsoTransfer(key darwinEndpointKey, messa if err != nil { return -int32(unix.EIO), 0 } - if direction == USBIPDirIn && len(response.Buffer) > 0 { - copyToUnsafe(message.bufferPointer(), response.Buffer) + if direction == USBIPDirIn { + return c.completeSubmitInTransfer(message.bufferPointer(), response, length) } return response.Status, int(response.ActualLength) } +func (c *darwinVirtualController) completeSubmitInTransfer(ptr unsafe.Pointer, response SubmitResponse, requestLength int) (int32, int) { + if response.ActualLength < 0 { + c.logger.Debug("RET_SUBMIT actual_length is negative: ", response.ActualLength) + c.requestClose() + return -int32(unix.EPROTO), 0 + } + actualLength := int(response.ActualLength) + if actualLength > requestLength || len(response.Buffer) > requestLength { + c.logger.Debug("RET_SUBMIT actual_length ", actualLength, " exceeds request length ", requestLength) + c.requestClose() + return -int32(unix.EOVERFLOW), 0 + } + copyLength := actualLength + if copyLength > len(response.Buffer) { + copyLength = len(response.Buffer) + } + if copyLength > 0 { + copyToUnsafe(ptr, response.Buffer[:copyLength]) + } + return response.Status, actualLength +} + func (c *darwinVirtualController) sendSubmit(command SubmitCommand) (SubmitResponse, error) { seq := c.seq.Add(1) command.Header.SeqNum = seq diff --git a/service/usbip/client_linux.go b/service/usbip/client_linux.go index 5fb7f7174..883bcd8b6 100644 --- a/service/usbip/client_linux.go +++ b/service/usbip/client_linux.go @@ -262,7 +262,12 @@ func (c *ClientService) runControlSession() error { return E.Cause(errImmediateReconnect, "control sequence jumped from ", lastSeq, " to ", frame.Sequence) } lastSeq = frame.Sequence - if err := c.syncRemoteState(); err != nil { + if extended { + err = c.syncRemoteStateAndResetControlState(c.ctx) + } else { + err = c.syncRemoteState() + } + if err != nil { return E.Cause(errImmediateReconnect, "devlist sync after change ", frame.Sequence, ": ", err) } case controlFrameDeviceSnapshot: @@ -280,10 +285,9 @@ func (c *ClientService) runControlSession() error { return E.Cause(errImmediateReconnect, "unexpected control frame ", frame.Type) } if frame.Sequence != lastSeq+1 { - if err := c.syncRemoteState(); err != nil { + if err := c.syncRemoteStateAndResetControlState(c.ctx); err != nil { return E.Cause(errImmediateReconnect, "devlist sync after sequence jump ", frame.Sequence, ": ", err) } - c.clearControlDeviceState() lastSeq = frame.Sequence continue } diff --git a/service/usbip/darwin_integration_test.go b/service/usbip/darwin_integration_test.go index f0b0eedd8..bf4af8b03 100644 --- a/service/usbip/darwin_integration_test.go +++ b/service/usbip/darwin_integration_test.go @@ -21,6 +21,7 @@ import ( M "github.com/sagernet/sing/common/metadata" "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" ) const ( @@ -176,6 +177,28 @@ func TestDarwinVirtualControllerReadsCompliantSubmitResponsePayload(t *testing.T } } +func TestDarwinSubmitInTransferRejectsOversizedPayload(t *testing.T) { + t.Parallel() + + controller := newDarwinVirtualController(context.Background(), newTestLogger(), nil, DeviceInfoTruncated{}) + buffer := []byte{0xaa, 0xbb} + + status, length := controller.completeSubmitInTransfer(unsafe.Pointer(&buffer[0]), SubmitResponse{ + Status: 0, + ActualLength: 3, + Buffer: []byte{1, 2, 3}, + }, len(buffer)) + + require.Equal(t, -int32(unix.EOVERFLOW), status) + require.Zero(t, length) + require.Equal(t, []byte{0xaa, 0xbb}, buffer) + select { + case <-controller.ctx.Done(): + default: + t.Fatal("controller context stayed active after oversized payload") + } +} + func TestWaitDarwinControllerClosesOnContextCancel(t *testing.T) { t.Parallel() diff --git a/service/usbip/linux_test.go b/service/usbip/linux_test.go index 4137da8a4..701913cab 100644 --- a/service/usbip/linux_test.go +++ b/service/usbip/linux_test.go @@ -1327,6 +1327,75 @@ func TestServerReconcileBroadcastsStatusOnlyDeviceDelta(t *testing.T) { require.True(t, netErr.Timeout()) } +func TestServerControlSnapshotPreservesPendingDelta(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + device := newTestDevice("1-1", 0x1d6b, 0x0002, "serial-1", SpeedHigh) + store := newTestDeviceStore(device) + store.setStatus("1-1", usbipStatusUsed) + + serverOps := newTestUSBIPOps(t) + serverOps.listUSBDevices = store.listUSBDevices + serverOps.readUsbipStatus = store.readUsbipStatus + serverOps.readSysfsDevice = store.readSysfsDevice + + server := &ServerService{ + ctx: ctx, + cancel: cancel, + logger: newTestLogger(), + matches: []option.USBIPDeviceMatch{{BusID: "1-1"}}, + exports: map[string]serverExport{"1-1": {busid: "1-1"}}, + controlSubs: make(map[uint64]*serverControlConn), + ops: serverOps, + } + server.refreshControlState() + serverAddr, closeServer := startDispatchServer(t, server) + defer closeServer() + + firstConn, err := net.Dial("tcp", serverAddr.String()) + require.NoError(t, err) + defer firstConn.Close() + setConnDeadline(t, firstConn) + require.NoError(t, WriteControlPreface(firstConn)) + require.NoError(t, WriteControlHello(firstConn)) + _, err = ReadControlFrame(firstConn) + require.NoError(t, err) + firstSnapshot, err := readControlMessage(firstConn) + require.NoError(t, err) + require.Equal(t, controlFrameDeviceSnapshot, firstSnapshot.Frame.Type) + + store.setStatus("1-1", usbipStatusAvailable) + + secondConn, err := net.Dial("tcp", serverAddr.String()) + require.NoError(t, err) + defer secondConn.Close() + setConnDeadline(t, secondConn) + require.NoError(t, WriteControlPreface(secondConn)) + require.NoError(t, WriteControlHello(secondConn)) + _, err = ReadControlFrame(secondConn) + require.NoError(t, err) + secondSnapshotMessage, err := readControlMessage(secondConn) + require.NoError(t, err) + require.Equal(t, controlFrameDeviceSnapshot, secondSnapshotMessage.Frame.Type) + var secondSnapshot controlDeviceSnapshot + require.NoError(t, unmarshalControlPayload(secondSnapshotMessage.Payload, &secondSnapshot)) + require.Len(t, secondSnapshot.Devices, 1) + require.Equal(t, deviceStateAvailable, secondSnapshot.Devices[0].State) + + require.NoError(t, server.reconcileAndBroadcast(true)) + changed, err := readControlMessage(firstConn) + require.NoError(t, err) + require.Equal(t, controlFrameDeviceDelta, changed.Frame.Type) + var delta controlDeviceDelta + require.NoError(t, unmarshalControlPayload(changed.Payload, &delta)) + require.Len(t, delta.Updated, 1) + require.Equal(t, "1-1", delta.Updated[0].BusID) + require.Equal(t, deviceStateAvailable, delta.Updated[0].State) +} + func TestServerControlLeaseEnablesImportExt(t *testing.T) { t.Parallel() @@ -1960,6 +2029,61 @@ func TestClientFetchDevListReturnsOnContextCancelWhileServerStalls(t *testing.T) require.NoError(t, <-serverErr) } +func TestClientSyncRemoteStateAndResetControlStateRebuildsV2Map(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + device := newTestDevice("1-1", 0x1d6b, 0x0002, "serial-1", SpeedHigh) + entry := DeviceEntry{Info: device.toProtocol(), Interfaces: device.Interfaces} + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + serverErr := make(chan error, 1) + go func() { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + serverErr <- acceptErr + return + } + defer conn.Close() + header, readErr := ReadOpHeader(conn) + if readErr != nil { + serverErr <- readErr + return + } + if header.Code != OpReqDevList { + serverErr <- fmt.Errorf("unexpected request code 0x%s", hex16(header.Code)) + return + } + serverErr <- WriteOpRepDevList(conn, []DeviceEntry{entry}) + }() + + client := &ClientService{ + ctx: ctx, + cancel: cancel, + logger: newTestLogger(), + dialer: testDialer{}, + serverAddr: M.SocksaddrFromNet(listener.Addr()), + matches: []option.USBIPDeviceMatch{{BusID: "unused"}}, + ops: newTestUSBIPOps(t), + remoteDevicesV2: map[string]DeviceInfoV2{"stale": {BusID: "stale", State: deviceStateAvailable}}, + } + + require.NoError(t, client.syncRemoteStateAndResetControlState(ctx)) + require.NoError(t, <-serverErr) + + client.remoteMu.Lock() + devices := client.remoteDevicesV2 + client.remoteMu.Unlock() + require.Len(t, devices, 1) + require.Contains(t, devices, "1-1") + require.Equal(t, deviceStateAvailable, devices["1-1"].State) + require.Equal(t, uint16(0x1d6b), devices["1-1"].VendorID) +} + func TestClientAttemptAttachRejectsUnexpectedReplyVersion(t *testing.T) { t.Parallel() diff --git a/service/usbip/server_darwin.go b/service/usbip/server_darwin.go index 83e6d5aab..b7d09aab7 100644 --- a/service/usbip/server_darwin.go +++ b/service/usbip/server_darwin.go @@ -544,9 +544,6 @@ func (s *ServerService) enqueueControlPayload(sub *serverControlConn, frame cont func (s *ServerService) enqueueControlSnapshot(sub *serverControlConn, sequence uint64) { devices := s.buildDeviceStateV2() - s.controlMu.Lock() - s.controlState = deviceInfoV2Map(devices) - s.controlMu.Unlock() s.enqueueControlPayload(sub, controlFrame{ Type: controlFrameDeviceSnapshot, Version: controlProtocolVersion, diff --git a/service/usbip/server_linux.go b/service/usbip/server_linux.go index 3009857af..5ad8a578b 100644 --- a/service/usbip/server_linux.go +++ b/service/usbip/server_linux.go @@ -718,7 +718,6 @@ func (s *ServerService) enqueueControlPayload(sub *serverControlConn, frame cont func (s *ServerService) enqueueControlSnapshot(sub *serverControlConn, sequence uint64) { devices := s.buildDeviceStateV2() - s.setControlState(deviceInfoV2Map(devices)) s.enqueueControlPayload(sub, controlFrame{ Type: controlFrameDeviceSnapshot, Version: controlProtocolVersion,