diff --git a/cli-plugins/socket/socket_test.go b/cli-plugins/socket/socket_test.go index 7d786abfc6..7938766b88 100644 --- a/cli-plugins/socket/socket_test.go +++ b/cli-plugins/socket/socket_test.go @@ -178,24 +178,39 @@ func TestConnectAndWait(t *testing.T) { // TODO: this test cannot be executed with `t.Parallel()`, due to // relying on goroutine numbers to ensure correct behaviour t.Run("connect goroutine exits after EOF", func(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() srv, err := NewPluginServer(nil) assert.NilError(t, err, "failed to setup server") defer srv.Close() t.Setenv(EnvKey, srv.Addr().String()) + + runtime.Gosched() numGoroutines := runtime.NumGoroutine() ConnectAndWait(func() {}) - assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1) + + runtime.Gosched() + poll.WaitOn(t, func(t poll.LogT) poll.Result { + // +1 goroutine for the poll.WaitOn + // +1 goroutine for the connect goroutine + if runtime.NumGoroutine() < numGoroutines+1+1 { + return poll.Continue("waiting for connect goroutine to spawn") + } + return poll.Success() + }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(500*time.Millisecond)) srv.Close() + runtime.Gosched() poll.WaitOn(t, func(t poll.LogT) poll.Result { + // +1 goroutine for the poll.WaitOn if runtime.NumGoroutine() > numGoroutines+1 { return poll.Continue("waiting for connect goroutine to exit") } return poll.Success() - }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) + }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(500*time.Millisecond)) }) }