diff --git a/request-interfaces.go b/request-interfaces.go index da56ed40..956b7a19 100644 --- a/request-interfaces.go +++ b/request-interfaces.go @@ -149,6 +149,11 @@ type ListerAt interface { ListAt([]os.FileInfo, int64) (int, error) } +// CloserListerAt is a ListerAt that implements the Close method. +type CloserListerAt interface { + Close() error +} + // TransferError is an optional interface that readerAt and writerAt // can implement to be notified about the error causing Serve() to exit // with the request still open diff --git a/request-server_test.go b/request-server_test.go index 6825bf43..ca911b89 100644 --- a/request-server_test.go +++ b/request-server_test.go @@ -793,6 +793,36 @@ func TestRequestReaddir(t *testing.T) { checkRequestServerAllocator(t, p) } +type testListerAtCloser struct { + isClosed bool +} + +func (l *testListerAtCloser) ListAt([]os.FileInfo, int64) (int, error) { + return 0, io.EOF +} + +func (l *testListerAtCloser) Close() error { + l.isClosed = true + return nil +} + +func TestRequestServerListerAtCloser(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + handle, err := p.cli.opendir(context.Background(), "/") + require.NoError(t, err) + require.Len(t, p.svr.openRequests, 1) + req, ok := p.svr.getRequest(handle) + require.True(t, ok) + listerAt := &testListerAtCloser{} + req.setListerAt(listerAt) + assert.NotNil(t, req.state.getListerAt()) + p.cli.close(handle) + require.Len(t, p.svr.openRequests, 0) + assert.True(t, listerAt.isClosed) +} + func TestRequestStatVFS(t *testing.T) { if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { t.Skip("StatVFS is implemented on linux and darwin") diff --git a/request.go b/request.go index 57d788df..49d7c5f6 100644 --- a/request.go +++ b/request.go @@ -111,6 +111,12 @@ func (s *state) setListerAt(la ListerAt) { s.mu.Lock() defer s.mu.Unlock() + if s.listerAt != nil { + if closer, ok := s.listerAt.(CloserListerAt); ok { + closer.Close() + } + } + s.listerAt = la } @@ -121,6 +127,18 @@ func (s *state) getListerAt() ListerAt { return s.listerAt } +func (s *state) closeListerAt() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.listerAt != nil { + if closer, ok := s.listerAt.(CloserListerAt); ok { + closer.Close() + } + s.listerAt = nil + } +} + // Request contains the data and state for the incoming service request. type Request struct { // Get, Put, Setstat, Stat, Rename, Remove @@ -229,6 +247,8 @@ func (r *Request) close() error { } }() + r.state.closeListerAt() + rd, wr, rw := r.getAllReaderWriters() var err error