From 3aa53a572f6b9edd84a438457cdb1f5a31700196 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Mon, 13 Nov 2023 09:13:33 +0100 Subject: [PATCH] sendPacket with context --- client.go | 51 ++++++++++++++++++++++++-------------------------- conn.go | 12 +++++++++--- server_test.go | 3 ++- 3 files changed, 35 insertions(+), 31 deletions(-) diff --git a/client.go b/client.go index 8532ade0..a3b8e22b 100644 --- a/client.go +++ b/client.go @@ -333,7 +333,7 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { // The passed context can be used to cancel the operation // returning all entries listed up to the cancellation. func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, error) { - handle, err := c.opendir(p) + handle, err := c.opendir(ctx, p) if err != nil { return nil, err } @@ -341,11 +341,8 @@ func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, e var entries []os.FileInfo var done = false for !done { - if err = ctx.Err(); err != nil { - return entries, err - } id := c.nextID() - typ, data, err1 := c.sendPacket(nil, &sshFxpReaddirPacket{ + typ, data, err1 := c.sendPacket(ctx, nil, &sshFxpReaddirPacket{ ID: id, Handle: handle, }) @@ -386,9 +383,9 @@ func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, e return entries, err } -func (c *Client) opendir(path string) (string, error) { +func (c *Client) opendir(ctx context.Context, path string) (string, error) { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpOpendirPacket{ + typ, data, err := c.sendPacket(ctx, nil, &sshFxpOpendirPacket{ ID: id, Path: path, }) @@ -424,7 +421,7 @@ func (c *Client) Stat(p string) (os.FileInfo, error) { // If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link. func (c *Client) Lstat(p string) (os.FileInfo, error) { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpLstatPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpLstatPacket{ ID: id, Path: p, }) @@ -449,7 +446,7 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) { // ReadLink reads the target of a symbolic link. func (c *Client) ReadLink(p string) (string, error) { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpReadlinkPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpReadlinkPacket{ ID: id, Path: p, }) @@ -478,7 +475,7 @@ func (c *Client) ReadLink(p string) (string, error) { // Link creates a hard link at 'newname', pointing at the same inode as 'oldname' func (c *Client) Link(oldname, newname string) error { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpHardlinkPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpHardlinkPacket{ ID: id, Oldpath: oldname, Newpath: newname, @@ -497,7 +494,7 @@ func (c *Client) Link(oldname, newname string) error { // Symlink creates a symbolic link at 'newname', pointing at target 'oldname' func (c *Client) Symlink(oldname, newname string) error { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpSymlinkPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSymlinkPacket{ ID: id, Linkpath: newname, Targetpath: oldname, @@ -515,7 +512,7 @@ func (c *Client) Symlink(oldname, newname string) error { func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpFsetstatPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{ ID: id, Handle: handle, Flags: flags, @@ -535,7 +532,7 @@ func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error // setstat is a convience wrapper to allow for changing of various parts of the file descriptor. func (c *Client) setstat(path string, flags uint32, attrs interface{}) error { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpSetstatPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSetstatPacket{ ID: id, Path: path, Flags: flags, @@ -605,7 +602,7 @@ func (c *Client) OpenFile(path string, f int) (*File, error) { func (c *Client) open(path string, pflags uint32) (*File, error) { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpOpenPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpOpenPacket{ ID: id, Path: path, Pflags: pflags, @@ -633,7 +630,7 @@ func (c *Client) open(path string, pflags uint32) (*File, error) { // immediately after this request has been sent. func (c *Client) close(handle string) error { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpClosePacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpClosePacket{ ID: id, Handle: handle, }) @@ -650,7 +647,7 @@ func (c *Client) close(handle string) error { func (c *Client) stat(path string) (*FileStat, error) { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatPacket{ ID: id, Path: path, }) @@ -674,7 +671,7 @@ func (c *Client) stat(path string) (*FileStat, error) { func (c *Client) fstat(handle string) (*FileStat, error) { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpFstatPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFstatPacket{ ID: id, Handle: handle, }) @@ -703,7 +700,7 @@ func (c *Client) fstat(handle string) (*FileStat, error) { func (c *Client) StatVFS(path string) (*StatVFS, error) { // send the StatVFS packet to the server id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpStatvfsPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatvfsPacket{ ID: id, Path: path, }) @@ -758,7 +755,7 @@ func (c *Client) Remove(path string) error { func (c *Client) removeFile(path string) error { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpRemovePacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRemovePacket{ ID: id, Filename: path, }) @@ -776,7 +773,7 @@ func (c *Client) removeFile(path string) error { // RemoveDirectory removes a directory path. func (c *Client) RemoveDirectory(path string) error { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpRmdirPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRmdirPacket{ ID: id, Path: path, }) @@ -794,7 +791,7 @@ func (c *Client) RemoveDirectory(path string) error { // Rename renames a file. func (c *Client) Rename(oldname, newname string) error { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpRenamePacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRenamePacket{ ID: id, Oldpath: oldname, Newpath: newname, @@ -814,7 +811,7 @@ func (c *Client) Rename(oldname, newname string) error { // which will replace newname if it already exists. func (c *Client) PosixRename(oldname, newname string) error { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpPosixRenamePacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpPosixRenamePacket{ ID: id, Oldpath: oldname, Newpath: newname, @@ -836,7 +833,7 @@ func (c *Client) PosixRename(oldname, newname string) error { // or relative pathnames without a leading slash into absolute paths. func (c *Client) RealPath(path string) (string, error) { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpRealpathPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRealpathPacket{ ID: id, Path: path, }) @@ -873,7 +870,7 @@ func (c *Client) Getwd() (string, error) { // parent folder does not exist (the method cannot create complete paths). func (c *Client) Mkdir(path string) error { id := c.nextID() - typ, data, err := c.sendPacket(nil, &sshFxpMkdirPacket{ + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpMkdirPacket{ ID: id, Path: path, }) @@ -1019,7 +1016,7 @@ func (f *File) Read(b []byte) (int, error) { func (f *File) readChunkAt(ch chan result, b []byte, off int64) (n int, err error) { for err == nil && n < len(b) { id := f.c.nextID() - typ, data, err := f.c.sendPacket(ch, &sshFxpReadPacket{ + typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpReadPacket{ ID: id, Handle: f.handle, Offset: uint64(off) + uint64(n), @@ -1487,7 +1484,7 @@ func (f *File) Write(b []byte) (int, error) { } func (f *File) writeChunkAt(ch chan result, b []byte, off int64) (int, error) { - typ, data, err := f.c.sendPacket(ch, &sshFxpWritePacket{ + typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpWritePacket{ ID: f.c.nextID(), Handle: f.handle, Offset: uint64(off), @@ -1945,7 +1942,7 @@ func (f *File) Chmod(mode os.FileMode) error { // Sync requires the server to support the fsync@openssh.com extension. func (f *File) Sync() error { id := f.c.nextID() - typ, data, err := f.c.sendPacket(nil, &sshFxpFsyncPacket{ + typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{ ID: id, Handle: f.handle, }) diff --git a/conn.go b/conn.go index 3bb2ba15..93bc37bf 100644 --- a/conn.go +++ b/conn.go @@ -1,6 +1,7 @@ package sftp import ( + "context" "encoding" "fmt" "io" @@ -128,14 +129,19 @@ type idmarshaler interface { encoding.BinaryMarshaler } -func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) { +func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) { if cap(ch) < 1 { ch = make(chan result, 1) } c.dispatchRequest(ch, p) - s := <-ch - return s.typ, s.data, s.err + + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + case s := <-ch: + return s.typ, s.data, s.err + } } // dispatchRequest should ideally only be called by race-detection tests outside of this file, diff --git a/server_test.go b/server_test.go index ae61eec6..87beece5 100644 --- a/server_test.go +++ b/server_test.go @@ -2,6 +2,7 @@ package sftp import ( "bytes" + "context" "errors" "io" "os" @@ -76,7 +77,7 @@ func TestInvalidExtendedPacket(t *testing.T) { defer server.Close() badPacket := sshFxpTestBadExtendedPacket{client.nextID(), "thisDoesn'tExist", "foobar"} - typ, data, err := client.clientConn.sendPacket(nil, badPacket) + typ, data, err := client.clientConn.sendPacket(context.Background(), nil, badPacket) if err != nil { t.Fatalf("unexpected error from sendPacket: %s", err) }