diff --git a/internal/proxy/fuse_test.go b/internal/proxy/fuse_test.go index 2068e4618..9d552c9ad 100644 --- a/internal/proxy/fuse_test.go +++ b/internal/proxy/fuse_test.go @@ -43,7 +43,7 @@ func randTmpDir(t interface { // newTestClient is a convenience function for testing that creates a // proxy.Client and starts it. The returned cleanup function is also a // convenience. Callers may choose to ignore it and manually close the client. -func newTestClient(t *testing.T, d cloudsql.Dialer, fuseDir, fuseTempDir string) (*proxy.Client, func()) { +func newTestClient(t *testing.T, d cloudsql.Dialer, fuseDir, fuseTempDir string) (*proxy.Client, chan error, func()) { conf := &proxy.Config{FUSEDir: fuseDir, FUSETempDir: fuseTempDir} c, err := proxy.NewClient(context.Background(), d, testLogger, conf) if err != nil { @@ -51,13 +51,21 @@ func newTestClient(t *testing.T, d cloudsql.Dialer, fuseDir, fuseTempDir string) } ready := make(chan struct{}) - go c.Serve(context.Background(), func() { close(ready) }) + servErrCh := make(chan error) + go func() { + servErr := c.Serve(context.Background(), func() { close(ready) }) + select { + case servErrCh <- servErr: + default: + // exit background thread + } + }() select { case <-ready: case <-time.Tick(5 * time.Second): t.Fatal("failed to Serve") } - return c, func() { + return c, servErrCh, func() { if cErr := c.Close(); cErr != nil { t.Logf("failed to close client: %v", cErr) } @@ -70,7 +78,7 @@ func TestFUSEREADME(t *testing.T) { } dir := randTmpDir(t) d := &fakeDialer{} - _, cleanup := newTestClient(t, d, dir, randTmpDir(t)) + _, _, cleanup := newTestClient(t, d, dir, randTmpDir(t)) fi, err := os.Stat(dir) if err != nil { @@ -161,7 +169,7 @@ func TestFUSEDialInstance(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { d := &fakeDialer{} - _, cleanup := newTestClient(t, d, fuseDir, tc.fuseTempDir) + _, _, cleanup := newTestClient(t, d, fuseDir, tc.fuseTempDir) defer cleanup() conn := tryDialUnix(t, tc.socketPath) @@ -185,13 +193,66 @@ func TestFUSEDialInstance(t *testing.T) { }) } } +func TestFUSEAcceptErrorReturnedFromServe(t *testing.T) { + if testing.Short() { + t.Skip("skipping fuse tests in short mode.") + } + + fuseDir := randTmpDir(t) + fuseTempDir := randTmpDir(t) + socketPath := filepath.Join(fuseDir, "proj:region:mysql") + + // Create a new client + d := &fakeDialer{} + c, servErrCh, cleanup := newTestClient(t, d, fuseDir, fuseTempDir) + defer cleanup() + + // Attempt a successful connection to the client + conn := tryDialUnix(t, socketPath) + defer conn.Close() + + // Ensure that the client actually fully connected. + // This solves a race condition in the test that is only present on + // the Ubuntu-Latest platform. + var got []string + for i := 0; i < 10; i++ { + got = d.dialedInstances() + if len(got) == 1 { + break + } + time.Sleep(100 * time.Millisecond) + } + if len(got) != 1 { + t.Fatalf("dialed instances len: want = 1, got = %v", got) + } + + // Explicitly close the dialer. This will close all the unix sockets, forcing + // the unix socket accept goroutine to exit with an error + c.Close() + + // Check that Client.Serve() returned a non-nil error + for i := 0; i < 10; i++ { + select { + case servErr := <-servErrCh: + if servErr == nil { + t.Fatal("got nil, want non-nil error returned by Client.Serve()") + } + return + default: + time.Sleep(100 * time.Millisecond) + continue + } + } + t.Fatal("No error thrown by Client.Serve()") + +} func TestFUSEReadDir(t *testing.T) { if testing.Short() { t.Skip("skipping fuse tests in short mode.") } fuseDir := randTmpDir(t) - _, cleanup := newTestClient(t, &fakeDialer{}, fuseDir, randTmpDir(t)) + _, _, cleanup := newTestClient(t, &fakeDialer{}, fuseDir, randTmpDir(t)) defer cleanup() // Initiate a connection so the FUSE server will list it in the dir entries. @@ -221,7 +282,7 @@ func TestFUSEErrors(t *testing.T) { } ctx := context.Background() d := &fakeDialer{} - c, _ := newTestClient(t, d, randTmpDir(t), randTmpDir(t)) + c, _, _ := newTestClient(t, d, randTmpDir(t), randTmpDir(t)) // Simulate FUSE file access by invoking Lookup directly to control // how the socket cache is populated. @@ -261,7 +322,7 @@ func TestFUSEWithBadInstanceName(t *testing.T) { } fuseDir := randTmpDir(t) d := &fakeDialer{} - _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) + _, _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) defer cleanup() _, dialErr := net.Dial("unix", filepath.Join(fuseDir, "notvalid")) @@ -280,7 +341,7 @@ func TestFUSECheckConnections(t *testing.T) { } fuseDir := randTmpDir(t) d := &fakeDialer{} - c, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) + c, _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) defer cleanup() // first establish a connection to "register" it with the proxy @@ -315,7 +376,7 @@ func TestFUSEClose(t *testing.T) { } fuseDir := randTmpDir(t) d := &fakeDialer{} - c, _ := newTestClient(t, d, fuseDir, randTmpDir(t)) + c, _, _ := newTestClient(t, d, fuseDir, randTmpDir(t)) // first establish a connection to "register" it with the proxy conn := tryDialUnix(t, filepath.Join(fuseDir, "proj:reg:mysql")) diff --git a/internal/proxy/proxy_other.go b/internal/proxy/proxy_other.go index bac097ca4..2569e5a27 100644 --- a/internal/proxy/proxy_other.go +++ b/internal/proxy/proxy_other.go @@ -68,6 +68,7 @@ type fuseMount struct { fuseServerMu *sync.Mutex fuseServer *fuse.Server fuseWg *sync.WaitGroup + fuseExitCh chan error // Inode adds support for FUSE operations. fs.Inode @@ -131,10 +132,20 @@ func (c *Client) Lookup(ctx context.Context, instance string, _ *fuse.EntryOut) defer c.fuseWg.Done() sErr := c.serveSocketMount(ctx, s) if sErr != nil { - c.fuseMu.Lock() c.logger.Debugf("could not serve socket for instance %q: %v", instance, sErr) + c.fuseMu.Lock() + defer c.fuseMu.Unlock() delete(c.fuseSockets, instance) - c.fuseMu.Unlock() + select { + // Best effort attempt to send error. + // If this send fails, it means the reading goroutine has + // already pulled a value out of the channel and is no longer + // reading any more values. In other words, we report only the + // first error. + case c.fuseExitCh <- sErr: + default: + return + } } }() @@ -165,10 +176,27 @@ func (c *Client) serveFuse(ctx context.Context, notify func()) error { } c.fuseServerMu.Lock() c.fuseServer = srv + c.fuseExitCh = make(chan error) + + // When the context is canceled, put the context cancel error into the + // exit chanel + go func() { + <-ctx.Done() + select { + case c.fuseExitCh <- ctx.Err(): + default: + } + }() + c.fuseServerMu.Unlock() notify() - <-ctx.Done() - return ctx.Err() + select { + case err = <-c.fuseExitCh: + return err + default: + } + return nil + } func (c *Client) fuseMounts() []*socketMount {