Skip to content

Commit

Permalink
sendPacket with context
Browse files Browse the repository at this point in the history
  • Loading branch information
ungerik committed Nov 13, 2023
1 parent 273341d commit 3aa53a5
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 31 deletions.
51 changes: 24 additions & 27 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,19 +333,16 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
// The passed context can be used to cancel the operation
// returning all entries listed up to the cancellation.
func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, error) {
handle, err := c.opendir(p)
handle, err := c.opendir(ctx, p)
if err != nil {
return nil, err
}
defer c.close(handle) // this has to defer earlier than the lock below
var entries []os.FileInfo
var done = false
for !done {
if err = ctx.Err(); err != nil {
return entries, err
}
id := c.nextID()
typ, data, err1 := c.sendPacket(nil, &sshFxpReaddirPacket{
typ, data, err1 := c.sendPacket(ctx, nil, &sshFxpReaddirPacket{
ID: id,
Handle: handle,
})
Expand Down Expand Up @@ -386,9 +383,9 @@ func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, e
return entries, err
}

func (c *Client) opendir(path string) (string, error) {
func (c *Client) opendir(ctx context.Context, path string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpOpendirPacket{
typ, data, err := c.sendPacket(ctx, nil, &sshFxpOpendirPacket{
ID: id,
Path: path,
})
Expand Down Expand Up @@ -424,7 +421,7 @@ func (c *Client) Stat(p string) (os.FileInfo, error) {
// If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link.
func (c *Client) Lstat(p string) (os.FileInfo, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpLstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpLstatPacket{
ID: id,
Path: p,
})
Expand All @@ -449,7 +446,7 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) {
// ReadLink reads the target of a symbolic link.
func (c *Client) ReadLink(p string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpReadlinkPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpReadlinkPacket{
ID: id,
Path: p,
})
Expand Down Expand Up @@ -478,7 +475,7 @@ func (c *Client) ReadLink(p string) (string, error) {
// Link creates a hard link at 'newname', pointing at the same inode as 'oldname'
func (c *Client) Link(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpHardlinkPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpHardlinkPacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
Expand All @@ -497,7 +494,7 @@ func (c *Client) Link(oldname, newname string) error {
// Symlink creates a symbolic link at 'newname', pointing at target 'oldname'
func (c *Client) Symlink(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpSymlinkPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSymlinkPacket{
ID: id,
Linkpath: newname,
Targetpath: oldname,
Expand All @@ -515,7 +512,7 @@ func (c *Client) Symlink(oldname, newname string) error {

func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpFsetstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{
ID: id,
Handle: handle,
Flags: flags,
Expand All @@ -535,7 +532,7 @@ func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error
// setstat is a convience wrapper to allow for changing of various parts of the file descriptor.
func (c *Client) setstat(path string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpSetstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSetstatPacket{
ID: id,
Path: path,
Flags: flags,
Expand Down Expand Up @@ -605,7 +602,7 @@ func (c *Client) OpenFile(path string, f int) (*File, error) {

func (c *Client) open(path string, pflags uint32) (*File, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpOpenPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpOpenPacket{
ID: id,
Path: path,
Pflags: pflags,
Expand Down Expand Up @@ -633,7 +630,7 @@ func (c *Client) open(path string, pflags uint32) (*File, error) {
// immediately after this request has been sent.
func (c *Client) close(handle string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpClosePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpClosePacket{
ID: id,
Handle: handle,
})
Expand All @@ -650,7 +647,7 @@ func (c *Client) close(handle string) error {

func (c *Client) stat(path string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatPacket{
ID: id,
Path: path,
})
Expand All @@ -674,7 +671,7 @@ func (c *Client) stat(path string) (*FileStat, error) {

func (c *Client) fstat(handle string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpFstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFstatPacket{
ID: id,
Handle: handle,
})
Expand Down Expand Up @@ -703,7 +700,7 @@ func (c *Client) fstat(handle string) (*FileStat, error) {
func (c *Client) StatVFS(path string) (*StatVFS, error) {
// send the StatVFS packet to the server
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpStatvfsPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatvfsPacket{
ID: id,
Path: path,
})
Expand Down Expand Up @@ -758,7 +755,7 @@ func (c *Client) Remove(path string) error {

func (c *Client) removeFile(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRemovePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRemovePacket{
ID: id,
Filename: path,
})
Expand All @@ -776,7 +773,7 @@ func (c *Client) removeFile(path string) error {
// RemoveDirectory removes a directory path.
func (c *Client) RemoveDirectory(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRmdirPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRmdirPacket{
ID: id,
Path: path,
})
Expand All @@ -794,7 +791,7 @@ func (c *Client) RemoveDirectory(path string) error {
// Rename renames a file.
func (c *Client) Rename(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRenamePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRenamePacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
Expand All @@ -814,7 +811,7 @@ func (c *Client) Rename(oldname, newname string) error {
// which will replace newname if it already exists.
func (c *Client) PosixRename(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpPosixRenamePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpPosixRenamePacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
Expand All @@ -836,7 +833,7 @@ func (c *Client) PosixRename(oldname, newname string) error {
// or relative pathnames without a leading slash into absolute paths.
func (c *Client) RealPath(path string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRealpathPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRealpathPacket{
ID: id,
Path: path,
})
Expand Down Expand Up @@ -873,7 +870,7 @@ func (c *Client) Getwd() (string, error) {
// parent folder does not exist (the method cannot create complete paths).
func (c *Client) Mkdir(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpMkdirPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpMkdirPacket{
ID: id,
Path: path,
})
Expand Down Expand Up @@ -1019,7 +1016,7 @@ func (f *File) Read(b []byte) (int, error) {
func (f *File) readChunkAt(ch chan result, b []byte, off int64) (n int, err error) {
for err == nil && n < len(b) {
id := f.c.nextID()
typ, data, err := f.c.sendPacket(ch, &sshFxpReadPacket{
typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpReadPacket{
ID: id,
Handle: f.handle,
Offset: uint64(off) + uint64(n),
Expand Down Expand Up @@ -1487,7 +1484,7 @@ func (f *File) Write(b []byte) (int, error) {
}

func (f *File) writeChunkAt(ch chan result, b []byte, off int64) (int, error) {
typ, data, err := f.c.sendPacket(ch, &sshFxpWritePacket{
typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpWritePacket{
ID: f.c.nextID(),
Handle: f.handle,
Offset: uint64(off),
Expand Down Expand Up @@ -1945,7 +1942,7 @@ func (f *File) Chmod(mode os.FileMode) error {
// Sync requires the server to support the [email protected] extension.
func (f *File) Sync() error {
id := f.c.nextID()
typ, data, err := f.c.sendPacket(nil, &sshFxpFsyncPacket{
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
ID: id,
Handle: f.handle,
})
Expand Down
12 changes: 9 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sftp

import (
"context"
"encoding"
"fmt"
"io"
Expand Down Expand Up @@ -128,14 +129,19 @@ type idmarshaler interface {
encoding.BinaryMarshaler
}

func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) {
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) {
if cap(ch) < 1 {
ch = make(chan result, 1)
}

c.dispatchRequest(ch, p)
s := <-ch
return s.typ, s.data, s.err

select {
case <-ctx.Done():
return 0, nil, ctx.Err()
case s := <-ch:
return s.typ, s.data, s.err
}
}

// dispatchRequest should ideally only be called by race-detection tests outside of this file,
Expand Down
3 changes: 2 additions & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sftp

import (
"bytes"
"context"
"errors"
"io"
"os"
Expand Down Expand Up @@ -76,7 +77,7 @@ func TestInvalidExtendedPacket(t *testing.T) {
defer server.Close()

badPacket := sshFxpTestBadExtendedPacket{client.nextID(), "thisDoesn'tExist", "foobar"}
typ, data, err := client.clientConn.sendPacket(nil, badPacket)
typ, data, err := client.clientConn.sendPacket(context.Background(), nil, badPacket)
if err != nil {
t.Fatalf("unexpected error from sendPacket: %s", err)
}
Expand Down

0 comments on commit 3aa53a5

Please sign in to comment.