diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0098a2d4..a8a8c55b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -12,10 +12,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest] - go: ['1.23', '1.22'] - exclude: - - os: macos-latest - go: '1.22' + go: ['1.23'] steps: - uses: actions/checkout@v4 diff --git a/client.go b/client.go index d8c48741..c86225bd 100644 --- a/client.go +++ b/client.go @@ -171,7 +171,14 @@ func (c *clientConn) recvLoop(maxPacket uint32) error { } } -func (c *clientConn) dispatch(req sshfx.PacketMarshaller) (uint32, chan result, error) { +// dispatch will marshal, then dispatch the given request packet. +// Packets are written atomically to the connection. +// It returns the allocated request id (a monotonously incrementing value), +// and either a channel upon which the result will be returned, or an error. +// +// If the cancel channel has been closed before the request is dipatched, +// then dispatch will return an [fs.ErrClosed] error. +func (c *clientConn) dispatch(cancel <-chan struct{}, req sshfx.PacketMarshaller) (uint32, chan result, error) { reqid := c.reqid.Add(1) header, payload, err := req.MarshalPacket(reqid, c.bufPool.Get()) @@ -180,6 +187,9 @@ func (c *clientConn) dispatch(req sshfx.PacketMarshaller) (uint32, chan result, } defer c.bufPool.Put(header) + // payload by design of the API is all but guaranteed to alias a caller-held byte slice, + // so, _do not_ put it into the bufPool. + ch, ok := c.resPool.Get() if !ok { return reqid, nil, sshfx.StatusConnectionLost @@ -188,6 +198,13 @@ func (c *clientConn) dispatch(req sshfx.PacketMarshaller) (uint32, chan result, c.mu.Lock() defer c.mu.Unlock() + select { + case <-cancel: + c.resPool.Put(ch) + return reqid, nil, fs.ErrClosed + default: + } + if c.inflight == nil { c.inflight = make(map[uint32]chan<- result) } @@ -258,8 +275,8 @@ func (c *clientConn) recv(ctx context.Context, reqid uint32, ch chan result) (*s } } -func (c *clientConn) send(ctx context.Context, req sshfx.PacketMarshaller) (*sshfx.RawPacket, error) { - reqid, ch, err := c.dispatch(req) +func (c *clientConn) send(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) (*sshfx.RawPacket, error) { + reqid, ch, err := c.dispatch(cancel, req) if err != nil { return nil, err } @@ -361,17 +378,19 @@ type Client struct { exts map[string]string } -func getPacket[PKT any, P interface { - sshfx.Packet +type respPacket[PKT any] interface { *PKT -}](ctx context.Context, cl *Client, req sshfx.PacketMarshaller) (*PKT, error) { - raw, err := cl.conn.send(ctx, req) + sshfx.Packet +} + +func getPacket[RESP respPacket[PKT], PKT any](ctx context.Context, cancel <-chan struct{}, cl *Client, req sshfx.PacketMarshaller) (*PKT, error) { + raw, err := cl.conn.send(ctx, cancel, req) if err != nil { return nil, err } defer cl.conn.returnRaw(raw) - var resp P + var resp RESP switch raw.PacketType { case resp.Type(): @@ -383,18 +402,50 @@ func getPacket[PKT any, P interface { return resp, nil case sshfx.PacketTypeStatus: - var status sshfx.StatusPacket + status := new(sshfx.StatusPacket) if err := status.UnmarshalPacketBody(&raw.Data); err != nil { return nil, err } - return nil, statusToError(&status, false) + return nil, statusToError(status, false) default: return nil, fmt.Errorf("unexpected packet type: %s", raw.PacketType) } } +func (cl *Client) getHandle(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) (string, error) { + resp, err := getPacket[*sshfx.HandlePacket](ctx, cancel, cl, req) + if err != nil { + return "", err + } + return resp.Handle, nil +} + +func (cl *Client) getNames(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) ([]*sshfx.NameEntry, error) { + resp, err := getPacket[*sshfx.NamePacket](ctx, cancel, cl, req) + if err != nil { + return nil, err + } + return resp.Entries, nil +} + +func (cl *Client) getPath(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) (string, error) { + resp, err := getPacket[*sshfx.PathPseudoPacket](ctx, cancel, cl, req) + if err != nil { + return "", err + } + return resp.Path, nil +} + +func (cl *Client) getAttrs(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) (*sshfx.Attributes, error) { + resp, err := getPacket[*sshfx.AttrsPacket](ctx, cancel, cl, req) + if err != nil { + return nil, err + } + return &resp.Attrs, nil +} + func statusToError(status *sshfx.StatusPacket, okExpected bool) error { switch status.StatusCode { case sshfx.StatusOK: @@ -414,17 +465,16 @@ func statusToError(status *sshfx.StatusPacket, okExpected bool) error { return status } -func (cl *Client) sendPacket(ctx context.Context, req sshfx.PacketMarshaller) error { - reqid, ch, err := cl.conn.dispatch(req) +func (cl *Client) sendPacket(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) error { + reqid, ch, err := cl.conn.dispatch(cancel, req) if err != nil { return err } - var resp sshfx.StatusPacket - return cl.recvStatus(ctx, reqid, ch, &resp) + return cl.recvStatus(ctx, reqid, ch, nil) } -func (cl *Client) recvStatus(ctx context.Context, reqid uint32, ch chan result, resp *sshfx.StatusPacket) error { +func (cl *Client) recvStatus(ctx context.Context, reqid uint32, ch chan result, hint *sshfx.StatusPacket) error { raw, err := cl.conn.recv(ctx, reqid, ch) if err != nil { return err @@ -433,19 +483,23 @@ func (cl *Client) recvStatus(ctx context.Context, reqid uint32, ch chan result, switch raw.PacketType { case sshfx.PacketTypeStatus: - if err := resp.UnmarshalPacketBody(&raw.Data); err != nil { + if hint == nil { + hint = new(sshfx.StatusPacket) + } + + if err := hint.UnmarshalPacketBody(&raw.Data); err != nil { return err } - return statusToError(resp, true) + return statusToError(hint, true) default: return fmt.Errorf("unexpected packet type: %s", raw.PacketType) } } -func (cl *Client) sendRead(ctx context.Context, req *sshfx.ReadPacket, resp *sshfx.DataPacket) (int, error) { - reqid, ch, err := cl.conn.dispatch(req) +func (cl *Client) sendRead(ctx context.Context, cancel <-chan struct{}, req *sshfx.ReadPacket, resp *sshfx.DataPacket) (int, error) { + reqid, ch, err := cl.conn.dispatch(cancel, req) if err != nil { return 0, err } @@ -583,22 +637,54 @@ func (cl *Client) Close() error { return nil } +func wrapPathError(op, path string, err error) error { + if err == nil { + return nil + } + + if errors.Is(err, io.EOF) { + // Numerous odd things break if we don't return bare io.EOF errors. + return io.EOF + } + + return &fs.PathError{Op: op, Path: path, Err: err} +} + +func valOrPathError[T any](op, path string, v T, err error) (T, error) { + if err != nil { + var z T + return z, wrapPathError(op, path, err) + } + + return v, nil +} + +func wrapLinkError(op, oldpath, newpath string, err error) error { + if err == nil { + return nil + } + + if errors.Is(err, io.EOF) { + // Numerous odd things break if we don't return bare io.EOF errors. + return io.EOF + } + + return &os.LinkError{Op: op, Old: oldpath, New: newpath, Err: err} +} + // Mkdir creates the specified directory. // An error will be returned if a file or directory with the specified path already exists, // or if the directory's parent folder does not exist. func (cl *Client) Mkdir(name string, perm fs.FileMode) error { - err := cl.sendPacket(context.Background(), &sshfx.MkdirPacket{ - Path: name, - Attrs: sshfx.Attributes{ - Flags: sshfx.AttrPermissions, - Permissions: sshfx.FileMode(perm.Perm()), - }, - }) - if err != nil { - return &fs.PathError{Op: "mkdir", Path: name, Err: err} - } - - return nil + return wrapPathError("mkdir", name, + cl.sendPacket(context.Background(), nil, &sshfx.MkdirPacket{ + Path: name, + Attrs: sshfx.Attributes{ + Flags: sshfx.AttrPermissions, + Permissions: sshfx.FileMode(perm.Perm()), + }, + }), + ) } // MkdirAll creates a directory named path, along with any necessary parents. @@ -611,7 +697,7 @@ func (cl *Client) MkdirAll(name string, perm fs.FileMode) error { return nil } - return &fs.PathError{Op: "mkdir", Path: name, Err: syscall.ENOTDIR} + return wrapPathError("mkdir", name, syscall.ENOTDIR) } // Slow path: make sure parent exists and then call Mkdir for name. @@ -647,47 +733,48 @@ func (cl *Client) MkdirAll(name string, perm fs.FileMode) error { func (cl *Client) Remove(name string) error { ctx := context.Background() - err := cl.sendPacket(ctx, &sshfx.RemovePacket{ + errF := cl.sendPacket(ctx, nil, &sshfx.RemovePacket{ Path: name, }) - if err == nil { + if errF == nil { return nil } - err1 := cl.sendPacket(ctx, &sshfx.RmdirPacket{ + errD := cl.sendPacket(ctx, nil, &sshfx.RmdirPacket{ Path: name, }) - if err1 == nil { + if errD == nil { return nil } // Both failed: figure out which error to return. - if err != err1 { - attrs, err2 := getPacket[sshfx.AttrsPacket](ctx, cl, &sshfx.StatPacket{ - Path: name, - }) - if err2 != nil { - err = err2 - } else { - if perm, ok := attrs.Attrs.GetPermissions(); ok && perm.IsDir() { - err = err1 - } - } - } - return &fs.PathError{Op: "remove", Path: name, Err: err} -} + if errF == errD { + // If they are the same error, then just return that. + return wrapPathError("remove", name, errF) + } -func (cl *Client) setstat(ctx context.Context, name string, attrs *sshfx.Attributes) error { - err := cl.sendPacket(ctx, &sshfx.SetStatPacket{ - Path: name, - Attrs: *attrs, + attrs, err := cl.getAttrs(ctx, nil, &sshfx.StatPacket{ + Path: name, }) if err != nil { - return &fs.PathError{Op: "setstat", Path: name, Err: err} + return wrapPathError("remove", name, err) } - return nil + if perm, ok := attrs.GetPermissions(); ok && perm.IsDir() { + return wrapPathError("remove", name, errD) + } + + return wrapPathError("remove", name, errF) +} + +func (cl *Client) setstat(ctx context.Context, name string, attrs *sshfx.Attributes) error { + return wrapPathError("setstat", name, + cl.sendPacket(ctx, nil, &sshfx.SetStatPacket{ + Path: name, + Attrs: *attrs, + }), + ) } // Truncate changes the size of the named file. @@ -749,14 +836,10 @@ func (cl *Client) Chtimes(name string, atime, mtime time.Time) error { // This is useful for converting path names containing ".." components, // or relative pathnames without a leading slash into absolute paths. func (cl *Client) RealPath(name string) (string, error) { - pkt, err := getPacket[sshfx.PathPseudoPacket](context.Background(), cl, &sshfx.RealPathPacket{ + path, err := cl.getPath(context.Background(), nil, &sshfx.RealPathPacket{ Path: name, }) - if err != nil { - return "", &fs.PathError{Op: "realpath", Path: name, Err: err} - } - - return pkt.Path, nil + return valOrPathError("realpath", name, path, err) } // ReadLink returns the destination of the named symbolic link. @@ -764,14 +847,10 @@ func (cl *Client) RealPath(name string) (string, error) { // The client cannot guarantee any specific way that a server handles a relative link destination. // That is, you may receive a relative link destination, one that has been converted to an absolute path. func (cl *Client) ReadLink(name string) (string, error) { - pkt, err := getPacket[sshfx.PathPseudoPacket](context.Background(), cl, &sshfx.ReadLinkPacket{ + path, err := cl.getPath(context.Background(), nil, &sshfx.ReadLinkPacket{ Path: name, }) - if err != nil { - return "", &fs.PathError{Op: "readlink", Path: name, Err: err} - } - - return pkt.Path, nil + return valOrPathError("readlink", name, path, err) } // Rename renames (moves) oldpath to newpath. @@ -780,40 +859,31 @@ func (cl *Client) ReadLink(name string) (string, error) { // Even within the same directory, on non-Unix servers Rename is not guaranteed to be an atomic operation. func (cl *Client) Rename(oldpath, newpath string) error { if cl.hasExtension(openssh.ExtensionPOSIXRename()) { - err := cl.sendPacket(context.Background(), &openssh.POSIXRenameExtendedPacket{ - OldPath: oldpath, - NewPath: newpath, - }) - if err != nil { - return &os.LinkError{Op: "rename", Old: oldpath, New: newpath, Err: err} - } - - return nil - } - - err := cl.sendPacket(context.Background(), &sshfx.RenamePacket{ - OldPath: oldpath, - NewPath: newpath, - }) - if err != nil { - return &os.LinkError{Op: "rename", Old: oldpath, New: newpath, Err: err} + return wrapLinkError("rename", oldpath, newpath, + cl.sendPacket(context.Background(), nil, &openssh.POSIXRenameExtendedPacket{ + OldPath: oldpath, + NewPath: newpath, + }), + ) } - return nil + return wrapLinkError("rename", oldpath, newpath, + cl.sendPacket(context.Background(), nil, &sshfx.RenamePacket{ + OldPath: oldpath, + NewPath: newpath, + }), + ) } // Symlink creates newname as a symbolic link to oldname. // There is no guarantee for how a server may handle the request if oldname does not exist. func (cl *Client) Symlink(oldname, newname string) error { - err := cl.sendPacket(context.Background(), &sshfx.SymlinkPacket{ - LinkPath: newname, - TargetPath: oldname, - }) - if err != nil { - return &os.LinkError{Op: "symlink", Old: oldname, New: newname, Err: err} - } - - return nil + return wrapLinkError("symlink", oldname, newname, + cl.sendPacket(context.Background(), nil, &sshfx.SymlinkPacket{ + LinkPath: newname, + TargetPath: oldname, + }), + ) } func (cl *Client) hasExtension(ext *sshfx.ExtensionPair) bool { @@ -827,18 +897,15 @@ func (cl *Client) hasExtension(ext *sshfx.ExtensionPair) bool { // and Link returns an *fs.LinkError wrapping sshfx.StatusOpUnsupported. func (cl *Client) Link(oldname, newname string) error { if !cl.hasExtension(openssh.ExtensionHardlink()) { - return &os.LinkError{Op: "hardlink", Old: oldname, New: newname, Err: sshfx.StatusOpUnsupported} + return wrapLinkError("hardlink", oldname, newname, sshfx.StatusOpUnsupported) } - err := cl.sendPacket(context.Background(), &openssh.HardlinkExtendedPacket{ - NewPath: newname, - OldPath: oldname, - }) - if err != nil { - return &os.LinkError{Op: "hardlink", Old: oldname, New: newname, Err: err} - } - - return nil + return wrapLinkError("hardlink", oldname, newname, + cl.sendPacket(context.Background(), nil, &openssh.HardlinkExtendedPacket{ + OldPath: oldname, + NewPath: newname, + }), + ) } // Readdir reads the named directory, returning all its directory entries as [fs.FileInfo] sorted by filename. @@ -886,45 +953,97 @@ func (cl *Client) ReadDirContext(ctx context.Context, name string) ([]fs.DirEntr return fis, err } -func (cl *Client) stat(name string) (*sshfx.NameEntry, error) { - pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), cl, &sshfx.StatPacket{ +// Stat returns a FileInfo describing the named file. +// If the file is a symbolic link, the returned FileInfo describes the link's target. +func (cl *Client) Stat(name string) (fs.FileInfo, error) { + attrs, err := cl.getAttrs(context.Background(), nil, &sshfx.StatPacket{ Path: name, }) if err != nil { - return nil, &fs.PathError{Op: "stat", Path: name, Err: err} + return nil, wrapPathError("stat", name, err) } return &sshfx.NameEntry{ Filename: name, - Attrs: pkt.Attrs, + Attrs: *attrs, }, nil } -// Stat returns a FileInfo describing the named file. -// If the file is a symbolic link, the returned FileInfo describes the link's target. -func (cl *Client) Stat(name string) (fs.FileInfo, error) { - return cl.stat(name) -} - // LStat returns a FileInfo describing the named file. // If the file is a symbolic link, the returned FileInfo describes the symbolic link // LStat makes no attempte to follow the link. // // The description returned may have server specific caveats and special cases that cannot be covered here. func (cl *Client) LStat(name string) (fs.FileInfo, error) { - pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), cl, &sshfx.LStatPacket{ + attrs, err := cl.getAttrs(context.Background(), nil, &sshfx.LStatPacket{ Path: name, }) if err != nil { - return nil, &fs.PathError{Op: "lstat", Path: name, Err: err} + return nil, wrapPathError("lstat", name, err) } return &sshfx.NameEntry{ Filename: name, - Attrs: pkt.Attrs, + Attrs: *attrs, }, nil } +type handle struct { + value atomic.Pointer[string] + closed chan struct{} +} + +func (h *handle) init(handle string) { + h.value.Store(&handle) + h.closed = make(chan struct{}) +} + +func (h *handle) get() (handle string, cancel <-chan struct{}, err error) { + p := h.value.Load() + if p == nil { + return "", nil, fs.ErrClosed + } + return *p, h.closed, nil +} + +func (h *handle) close(cl *Client) error { + // The design principle here is that when `openssh-portable/sftp-server.c` is doing `handle_close`, + // it will unconditionally mark the handle as unused, + // so we need to also unconditionally mark this handle as invalid. + // By invalidating our local copy of the handle, + // we ensure that there cannot be any new erroneous use-after-close receiver methods started after this swap. + handle := h.value.Swap(nil) + if handle == nil { + return fs.ErrClosed + } + + // The atomic Swap above ensures that only one Close can ever get here. + // We could also use a mutex to guarantee exclusivity here, + // but that would block Close until all synchronized operations have completed, + // some of which could be paused indefinitely. + // + // See: https://github.com/pkg/sftp/issues/603 for more details. + + // So, we have defended now against new receiver methods starting, + // but since an outstanding method could still be holding the handle, we still need a close signal. + // Since this close HAPPENS BEFORE the sendPacket below, + // this ensures that after closing this channel, no further requests will be dispatched. + // Meaning we know that the close request below will be the final request from this handle. + close(h.closed) + + // One might assume we could just simply use the closed channel alone, + // but because close panics if called twice, we need a select to test if the channel is already closed, + // and since there is a window of time between such a test and the close, two goroutines can race. + // So we still need to synchronize the close operation anyways, so either atomic pointer or mutex. + + // It should be obvious, but do not pass h.closed into this sendPacket, or it will never be sent. + // Less obviously, DO NOT pipe a context through this function to the sendPacket. + // We want to ensure that even in a closed-context codepath, that the SSH_FXP_CLOSED packet is still sent. + return cl.sendPacket(context.Background(), nil, &sshfx.ClosePacket{ + Handle: *handle, + }) +} + // Dir represents an open directory handle. // // The methods of Dir are safe for concurrent use. @@ -932,8 +1051,10 @@ type Dir struct { cl *Client name string - mu sync.RWMutex - handle string + handle handle + req sshfx.ReadDirPacket // save on allocations with a scratch request packet. + + mu sync.Mutex entries []*sshfx.NameEntry } @@ -942,51 +1063,38 @@ type Dir struct { // // The semantics of SSH_FX_OPENDIR is such that the associated file handle is in a read-only mode. func (cl *Client) OpenDir(name string) (*Dir, error) { - return cl.openDir(context.Background(), name) -} - -func (cl *Client) openDir(ctx context.Context, name string) (*Dir, error) { - pkt, err := getPacket[sshfx.HandlePacket](ctx, cl, &sshfx.OpenDirPacket{ + handle, err := cl.getHandle(context.Background(), nil, &sshfx.OpenDirPacket{ Path: name, }) if err != nil { - return nil, &fs.PathError{Op: "opendir", Path: name, Err: err} + return nil, wrapPathError("opendir", name, err) } - return &Dir{ - cl: cl, - name: name, - handle: pkt.Handle, - }, nil + d := &Dir{ + cl: cl, + name: name, + req: sshfx.ReadDirPacket{ + Handle: handle, + }, + } + + d.handle.init(handle) + + return d, nil +} + +func (d *Dir) wrapErr(op string, err error) error { + return wrapPathError(op, d.name, err) } // Close closes the Dir, rendering it unusable for I/O. // Close will not send any request, and return an error if it has already been called. func (d *Dir) Close() error { - d.mu.Lock() - defer d.mu.Unlock() - - if d.handle == "" { - return &fs.PathError{Op: "close", Path: d.name, Err: fs.ErrClosed} - } - - // The design principle here is that when `openssh-portable/sftp-server.c` is doing `handle_close`, - // it will unconditionally mark the handle as unused, - // so we need to also unconditionally mark this handle as invalid. - // By invalidating our local copy of the handle, - // we ensure that there cannot be any erroneous use-after-close requests sent after Close. - - handle := d.handle - d.handle = "" - - err := d.cl.sendPacket(context.Background(), &sshfx.ClosePacket{ - Handle: handle, - }) - if err != nil { - return &fs.PathError{Op: "close", Path: d.name, Err: err} + if d == nil { + return os.ErrInvalid } - return nil + return d.wrapErr("close", d.handle.close(d.cl)) } // Name returns the name of the directory as presented to OpenDir. @@ -995,48 +1103,48 @@ func (d *Dir) Name() string { } // rangedir returns an iterator over the directory entries of the directory. +// It will only ever yield either a *sshfx.NameEntry or an error, never both. +// No error will be yielded until all available name entries have been yielded. +// Only one error will be yielded per invocation. +// // We do not expose an iterator, because none has been standardized yet. -// and we do not want to accidentally implement an inconsistent API. -// However, for internal usage, we can definitely make use of this to simplify the common parts of ReadDir and Readdir. +// and we do not want to accidentally implement an API inconsistent with future standards. +// However, for internal usage, we can separate the paginated ReadDir request code from the conversion to Go entries. // -// Callers must guarantee synchronization by either holding the file lock, or holding an exclusive reference. -func (d *Dir) rangedir(ctx context.Context) iter.Seq2[*sshfx.NameEntry, error] { +// Callers must guarantee synchronization by either holding the directory lock, or holding an exclusive reference. +func (d *Dir) rangedir(ctx context.Context, grow func(int)) iter.Seq2[*sshfx.NameEntry, error] { return func(yield func(v *sshfx.NameEntry, err error) bool) { - // Pull from saved entries first. - for i, ent := range d.entries { - if !yield(ent, nil) { - // Early break, delete the entries we have yielded. - d.entries = slices.Delete(d.entries, 0, i+1) - return - } - } - - // We got through all the remaining entries, delete all the entries. - d.entries = slices.Delete(d.entries, 0, len(d.entries)) - for { - pkt, err := getPacket[sshfx.NamePacket](ctx, d.cl, &sshfx.ReadDirPacket{ - Handle: d.handle, - }) - if err != nil { - // There are no remaining entries to save here, - // SFTP can only return either an error or a result, never both. - if errors.Is(err, io.EOF) { - yield(nil, io.EOF) + grow(len(d.entries)) + + // Pull from saved entries first. + for i, ent := range d.entries { + if !yield(ent, nil) { + // This is a break condition. + // We need to remove all entries that have been consumed, + // and that includes the one that we are currently on. + d.entries = slices.Delete(d.entries, 0, i+1) return } + } - yield(nil, &fs.PathError{Op: "readdir", Path: d.name, Err: err}) + // We got through all the remaining entries, delete all the entries. + d.entries = slices.Delete(d.entries, 0, len(d.entries)) + + _, closed, err := d.handle.get() + if err != nil { + yield(nil, err) return } - for i, entry := range pkt.Entries { - if !yield(entry, nil) { - // Early break, save the remaining entries we got for maybe later. - d.entries = append(d.entries, pkt.Entries[i+1:]...) - return - } + entries, err := d.cl.getNames(ctx, closed, &d.req) + if err != nil { + // No need to loop, SFTP can only return either an error or a result, never both. + yield(nil, err) + return } + + d.entries = entries } } } @@ -1059,22 +1167,31 @@ func (d *Dir) Readdir(n int) ([]fs.FileInfo, error) { // If n <= 0, ReaddirContext returns all the FileInfo records remaining in the directory. // When it succeeds, it returns a nil error (not io.EOF). func (d *Dir) ReaddirContext(ctx context.Context, n int) ([]fs.FileInfo, error) { + if d == nil { + return nil, os.ErrInvalid + } + d.mu.Lock() defer d.mu.Unlock() - if d.handle == "" { - return nil, &fs.PathError{Op: "readdir", Path: d.name, Err: fs.ErrClosed} - } - var ret []fs.FileInfo - for ent, err := range d.rangedir(ctx) { + grow := func(more int) { + if n > 0 { + // the lesser of what's coming, and how much remains. + more = min(more, n-len(ret)) + } + + ret = slices.Grow(ret, more) + } + + for ent, err := range d.rangedir(ctx, grow) { if err != nil { if errors.Is(err, io.EOF) && n <= 0 { return ret, nil } - return ret, err + return ret, d.wrapErr("readdir", err) } ret = append(ret, ent) @@ -1105,22 +1222,31 @@ func (d *Dir) ReadDir(n int) ([]fs.DirEntry, error) { // If n <= 0, ReadDirContext returns all the DirEntry records remaining in the directory. // When it succeeds, it returns a nil error (not io.EOF). func (d *Dir) ReadDirContext(ctx context.Context, n int) ([]fs.DirEntry, error) { + if d == nil { + return nil, os.ErrInvalid + } + d.mu.Lock() defer d.mu.Unlock() - if d.handle == "" { - return nil, &fs.PathError{Op: "readdir", Path: d.name, Err: fs.ErrClosed} - } - var ret []fs.DirEntry - for ent, err := range d.rangedir(ctx) { + grow := func(more int) { + if n > 0 { + // the lesser of what's coming, and how much remains. + more = min(more, n-len(ret)) + } + + ret = slices.Grow(ret, more) + } + + for ent, err := range d.rangedir(ctx, grow) { if err != nil { if errors.Is(err, io.EOF) && n <= 0 { return ret, nil } - return ret, err + return ret, d.wrapErr("readdir", err) } ret = append(ret, ent) @@ -1140,8 +1266,9 @@ type File struct { cl *Client name string + handle handle + mu sync.RWMutex - handle string offset int64 // current offset within remote file } @@ -1151,7 +1278,7 @@ const ( OpenFlagReadOnly = os.O_RDONLY OpenFlagWriteOnly = os.O_WRONLY OpenFlagReadWrite = os.O_RDWR - // The remaining values may be or’ed in to control behavior. + // The remaining values may be or'ed in to control behavior. OpenFlagAppend = os.O_APPEND OpenFlagCreate = os.O_CREATE OpenFlagTruncate = os.O_TRUNC @@ -1210,7 +1337,7 @@ func (cl *Client) Create(name string) (*File, error) { // Note well: since all Write operations are down through an offset-specifying operation, // the OpenFlagAppend flag is currently ignored. func (cl *Client) OpenFile(name string, flag int, perm fs.FileMode) (*File, error) { - pkt, err := getPacket[sshfx.HandlePacket](context.Background(), cl, &sshfx.OpenPacket{ + handle, err := cl.getHandle(context.Background(), nil, &sshfx.OpenPacket{ Filename: name, PFlags: toPortableFlags(flag), Attrs: sshfx.Attributes{ @@ -1219,43 +1346,31 @@ func (cl *Client) OpenFile(name string, flag int, perm fs.FileMode) (*File, erro }, }) if err != nil { - return nil, err + return nil, wrapPathError("openfile", name, err) } - return &File{ - cl: cl, - name: name, - handle: pkt.Handle, - }, nil + f := &File{ + cl: cl, + name: name, + } + + f.handle.init(handle) + + return f, nil +} + +func (f *File) wrapErr(op string, err error) error { + return wrapPathError(op, f.name, err) } // Close closes the File, rendering it unusable for I/O. // Close will not send any request, and return an error if it has already been called. func (f *File) Close() error { - f.mu.Lock() - defer f.mu.Unlock() - - if f.handle == "" { - return &fs.PathError{Op: "close", Path: f.name, Err: fs.ErrClosed} - } - - // The design principle here is that when `openssh-portable/sftp-server.c` is doing `handle_close`, - // it will unconditionally mark the handle as unused, - // so we need to also unconditionally mark this handle as invalid. - // By invalidating our local copy of the handle, - // we ensure that there cannot be any erroneous use-after-close requests sent after Close. - - handle := f.handle - f.handle = "" - - err := f.cl.sendPacket(context.Background(), &sshfx.ClosePacket{ - Handle: handle, - }) - if err != nil { - return &fs.PathError{Op: "close", Path: f.name, Err: err} + if f == nil { + return fs.ErrInvalid } - return nil + return f.wrapErr("close", f.handle.close(f.cl)) } // Name returns the name of the file as presented to Open. @@ -1266,22 +1381,21 @@ func (f *File) Name() string { } func (f *File) setstat(ctx context.Context, attrs *sshfx.Attributes) error { - f.mu.Lock() - defer f.mu.Unlock() - - if f.handle == "" { - return &fs.PathError{Op: "fsetstat", Path: f.name, Err: fs.ErrClosed} + if f == nil { + return fs.ErrInvalid } - err := f.cl.sendPacket(ctx, &sshfx.FSetStatPacket{ - Handle: f.handle, - Attrs: *attrs, - }) + handle, closed, err := f.handle.get() if err != nil { - return &fs.PathError{Op: "fsetstat", Path: f.name, Err: err} + return f.wrapErr("fsetstat", err) } - return nil + return f.wrapErr("fsetstat", + f.cl.sendPacket(ctx, closed, &sshfx.FSetStatPacket{ + Handle: handle, + Attrs: *attrs, + }), + ) } // Truncate changes the size of the file. @@ -1327,35 +1441,38 @@ func (f *File) Chtimes(atime, mtime time.Time) error { }) } -func (f *File) stat() (*sshfx.NameEntry, error) { - pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), f.cl, &sshfx.FStatPacket{ - Handle: f.handle, +// Stat returns the FileInfo structure describing file. +func (f *File) Stat() (fs.FileInfo, error) { + if f == nil { + return nil, fs.ErrInvalid + } + + handle, closed, err := f.handle.get() + if err != nil { + return nil, f.wrapErr("fstat", err) + } + + attrs, err := f.cl.getAttrs(context.Background(), closed, &sshfx.FStatPacket{ + Handle: handle, }) if err != nil { - return nil, &fs.PathError{Op: "fstat", Path: f.name, Err: err} + return nil, f.wrapErr("fstat", err) } return &sshfx.NameEntry{ Filename: f.name, - Attrs: pkt.Attrs, + Attrs: *attrs, }, nil } -// Stat returns the FileInfo structure describing file. -func (f *File) Stat() (fs.FileInfo, error) { - f.mu.Lock() - defer f.mu.Unlock() - - if f.handle == "" { - return nil, &fs.PathError{Op: "fstat", Path: f.name, Err: fs.ErrClosed} +func (f *File) writeatSeq(ctx context.Context, b []byte, off int64) (written int, err error) { + handle, closed, err := f.handle.get() + if err != nil { + return 0, err } - return f.stat() -} - -func (f *File) writeAtFull(ctx context.Context, b []byte, off int64) (written int, err error) { req := &sshfx.WritePacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(off), } @@ -1366,9 +1483,9 @@ func (f *File) writeAtFull(ctx context.Context, b []byte, off int64) (written in req.Data, b = b[:n], b[n:] - err = f.cl.sendPacket(ctx, req) + err = f.cl.sendPacket(ctx, closed, req) if err != nil { - return written, &fs.PathError{Op: "writeat", Path: f.name, Err: err} + return written, f.wrapErr("writeat", err) } req.Offset += uint64(n) @@ -1378,15 +1495,16 @@ func (f *File) writeAtFull(ctx context.Context, b []byte, off int64) (written in return written, nil } -func (f *File) writeAt(ctx context.Context, b []byte, off int64) (written int, err error) { - if f.handle == "" { - return 0, &fs.PathError{Op: "writeat", Path: f.name, Err: fs.ErrClosed} - } - +func (f *File) writeat(ctx context.Context, b []byte, off int64) (written int, err error) { if len(b) <= f.cl.maxDataLen { // This should be able to be serviced with just 1 request. - // So, just do it directly. - return f.writeAtFull(ctx, b, off) + // So, just do it sequentially. + return f.writeatSeq(ctx, b, off) + } + + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("writeat", err) } // Split the write into multiple maxPacket sized concurrent writes bounded by maxInflight. @@ -1419,7 +1537,7 @@ func (f *File) writeAt(ctx context.Context, b []byte, off int64) (written int, e chunkSize := f.cl.maxDataLen req := &sshfx.WritePacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(f.offset), } @@ -1428,7 +1546,7 @@ func (f *File) writeAt(ctx context.Context, b []byte, off int64) (written int, e req.Data, b = b[:n], b[n:] - reqid, res, err := f.cl.conn.dispatch(req) + reqid, res, err := f.cl.conn.dispatch(closed, req) if err != nil { errCh <- rwErr{req.Offset, err} return @@ -1456,10 +1574,10 @@ func (f *File) writeAt(ctx context.Context, b []byte, off int64) (written int, e go func() { defer close(errCh) - var status sshfx.StatusPacket + statusHint := new(sshfx.StatusPacket) for work := range workCh { - err := f.cl.recvStatus(ctx, work.reqid, work.res, &status) + err := f.cl.recvStatus(ctx, work.reqid, work.res, statusHint) if err != nil { errCh <- rwErr{work.off, err} @@ -1488,14 +1606,14 @@ func (f *File) writeAt(ctx context.Context, b []byte, off int64) (written int, e // * the offset of the start of the first error received in response to a write packet. // * the offset of the start of the first error received dispatching a write packet offset. // - // Either way, this should be the last successfully write offset. - written := int(int64(firstErr.off) - f.offset) + // Either way, this should be the last successfully written offset. + written := int64(firstErr.off) - f.offset f.offset = int64(firstErr.off) - return written, firstErr.err + return int(written), f.wrapErr("writeat", firstErr.err) } - // We didn’t hit any errors, so we must have written all the bytes in the buffer. + // We didn't hit any errors, so we must have written all the bytes in the buffer. written = len(b) f.offset += int64(written) @@ -1506,20 +1624,25 @@ func (f *File) writeAt(ctx context.Context, b []byte, off int64) (written int, e // It returns the number of bytes written and an error, if any. // WriteAt returns a non-nil error when n != len(b). func (f *File) WriteAt(b []byte, off int64) (n int, err error) { - f.mu.RLock() - defer f.mu.RUnlock() + if f == nil { + return 0, fs.ErrInvalid + } - return f.writeAt(context.Background(), b, off) + return f.writeat(context.Background(), b, off) } // Write writes len(b) bytes from b to the File. // It returns the number of bytes written and an error, if any. // Write returns a non-nil error when n != len(b) func (f *File) Write(b []byte) (int, error) { + if f == nil { + return 0, fs.ErrInvalid + } + f.mu.Lock() defer f.mu.Unlock() - n, err := f.writeAt(context.Background(), b, f.offset) + n, err := f.writeat(context.Background(), b, f.offset) f.offset += int64(n) return n, err @@ -1531,12 +1654,23 @@ func (f *File) WriteString(s string) (n int, err error) { return f.Write(b) } -func (f *File) readFromSequential(r io.Reader) (read int64, err error) { - ctx := context.Background() +func (f *File) readFromSequential(ctx context.Context, r io.Reader) (read int64, err error) { + if f == nil { + return 0, fs.ErrInvalid + } + + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("readfrom", err) + } + + f.mu.Lock() + defer f.mu.Unlock() + b := make([]byte, f.cl.maxDataLen) req := &sshfx.WritePacket{ - Handle: f.handle, + Handle: handle, } for { @@ -1551,7 +1685,7 @@ func (f *File) readFromSequential(r io.Reader) (read int64, err error) { req.Data = b[:n] req.Offset = uint64(f.offset) - err1 := f.cl.sendPacket(ctx, req) + err1 := f.cl.sendPacket(ctx, closed, req) if err1 == nil { // Only increment file offset, if we got a sucess back. f.offset += int64(n) @@ -1562,10 +1696,10 @@ func (f *File) readFromSequential(r io.Reader) (read int64, err error) { if err != nil { if errors.Is(err, io.EOF) { - return read, nil // return nil explicitly + return read, nil // return nil instead of EOF } - return read, err + return read, f.wrapErr("readfrom", err) } } } @@ -1584,13 +1718,18 @@ func (e panicInstead) Error() string { // to maximize throughput when transferring an entire file, // especially over high-latency links. func (f *File) ReadFrom(r io.Reader) (read int64, err error) { - f.mu.Lock() - defer f.mu.Unlock() + if f == nil { + return 0, fs.ErrInvalid + } - if f.handle == "" { - return 0, fs.ErrClosed + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("readfrom", err) } + f.mu.Lock() + defer f.mu.Unlock() + ctx := context.Background() chunkSize := f.cl.maxDataLen @@ -1620,7 +1759,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { defer f.cl.conn.bufPool.Put(b) req := &sshfx.WritePacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(f.offset), } @@ -1635,7 +1774,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { read += int64(n) req.Data = b[:n] - reqid, res, err1 := f.cl.conn.dispatch(req) + reqid, res, err1 := f.cl.conn.dispatch(closed, req) if err1 == nil { // If _NO_ error occurred during dispatch. select { case workCh <- work{reqid, res, req.Offset}: @@ -1670,10 +1809,10 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { go func() { defer close(errCh) - var status sshfx.StatusPacket + statusHint := new(sshfx.StatusPacket) for work := range workCh { - err := f.cl.recvStatus(ctx, work.reqid, work.res, &status) + err := f.cl.recvStatus(ctx, work.reqid, work.res, statusHint) if err != nil { errCh <- rwErr{work.off, err} @@ -1711,21 +1850,24 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { } // ReadFrom is defined to return the read bytes, regardless of any write errors. - return read, firstErr.err + return read, f.wrapErr("readfrom", firstErr.err) } - // We didn’t hit any errors, so we must have written all the bytes that we read until EOF. + // We didn't hit any errors, so we must have written all the bytes that we read until EOF. f.offset += read return read, nil } -// readAtFull attempts to read the whole entire length of the buffer from the file starting at the offset. +// readatSeq attempts to read the whole entire length of the buffer from the file starting at the offset. // It will continue progressively reading into the buffer until it fills the whole buffer, or an error occurs. -// -// This is prefered over io.ReadFull, because it can reuse read and data packet allocations. -func (f *File) readAtFull(ctx context.Context, b []byte, off int64) (read int, err error) { +func (f *File) readatSeq(ctx context.Context, b []byte, off int64) (read int, err error) { + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("readat", err) + } + req := &sshfx.ReadPacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(off), } @@ -1744,7 +1886,7 @@ func (f *File) readAtFull(ctx context.Context, b []byte, off int64) (read int, e // Otherwise, we would need to use unsafe.SliceData to identify a reallocation. resp.Data = slices.Clip(b[:n]) - m, err := f.cl.sendRead(ctx, req, &resp) + m, err := f.cl.sendRead(ctx, closed, req, &resp) if m > n { // OH NO! We received more data than we expected! @@ -1758,26 +1900,23 @@ func (f *File) readAtFull(ctx context.Context, b []byte, off int64) (read int, e read += m if err != nil { - if errors.Is(err, io.EOF) { - return read, io.EOF // io.Copy does not allow this to be wrapped. - } - - return read, &fs.PathError{Op: "readat", Path: f.name, Err: err} + return read, f.wrapErr("readat", err) } } return read, nil } -func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err error) { - if f.handle == "" { - return 0, &fs.PathError{Op: "readat", Path: f.name, Err: fs.ErrClosed} - } - +func (f *File) readat(ctx context.Context, b []byte, off int64) (read int, err error) { if len(b) <= f.cl.maxDataLen { // This should be able to be serviced most times with only 1 request. // So, just do it sequentially. - return f.readAtFull(ctx, b, off) + return f.readatSeq(ctx, b, off) + } + + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("readat", err) } sendCtx, cancel := context.WithCancel(ctx) @@ -1812,7 +1951,7 @@ func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err e chunkSize := f.cl.maxDataLen req := &sshfx.ReadPacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(off), } @@ -1821,7 +1960,7 @@ func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err e req.Length = uint32(n) - reqid, res, err := f.cl.conn.dispatch(req) + reqid, res, err := f.cl.conn.dispatch(closed, req) if err != nil { errCh <- rwErr{req.Offset, err} return @@ -1853,7 +1992,7 @@ func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err e var resp sshfx.DataPacket for work := range workCh { - // See readAtFull for an explanation for why we use slices.Clip here. + // See readatSeq for an explanation for why we use slices.Clip here. resp.Data = slices.Clip(work.b) n, err := f.cl.recvData(ctx, work.reqid, work.res, &resp) @@ -1893,7 +2032,7 @@ func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err e if firstErr.err != nil { // firstErr.err != nil if and only if firstErr.off > our starting offset. - return int(int64(firstErr.off) - off), firstErr.err + return int(int64(firstErr.off) - off), f.wrapErr("readat", firstErr.err) } // As per spec for io.ReaderAt, we return nil error if and only if we read everything. @@ -1905,20 +2044,25 @@ func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err e // ReadAt always returns a non-nil error when n < len(b). // At the end of file, the error is io.EOF. func (f *File) ReadAt(b []byte, off int64) (int, error) { - f.mu.RLock() - defer f.mu.RUnlock() + if f == nil { + return 0, fs.ErrInvalid + } - return f.readAt(context.Background(), b, off) + return f.readat(context.Background(), b, off) } // Read reads up to len(b) bytes from the File and stores them in b. // It returns the number of bytes read and any error encountered. // At end of file, Read returns 0, io.EOF. func (f *File) Read(b []byte) (int, error) { + if f == nil { + return 0, fs.ErrInvalid + } + f.mu.Lock() defer f.mu.Unlock() - n, err := f.readAt(context.Background(), b, f.offset) + n, err := f.readat(context.Background(), b, f.offset) f.offset += int64(n) @@ -1930,11 +2074,23 @@ func (f *File) Read(b []byte) (int, error) { } func (f *File) writeToSequential(w io.Writer) (written int64, err error) { + if f == nil { + return 0, fs.ErrInvalid + } + + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("writeto", err) + } + + f.mu.Lock() + defer f.mu.Unlock() + ctx := context.Background() b := make([]byte, f.cl.maxDataLen) req := &sshfx.ReadPacket{ - Handle: f.handle, + Handle: handle, Length: uint32(len(b)), } @@ -1945,7 +2101,7 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { for { req.Offset = uint64(f.offset) - read, err := f.cl.sendRead(ctx, req, &resp) + read, err := f.cl.sendRead(ctx, closed, req, &resp) if read < 0 { panic("sftp: writeto: sendRead returned negative count") @@ -1964,10 +2120,10 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { if err != nil { if errors.Is(err, io.EOF) { - return written, nil // return nil explicitly. + return written, nil // return nil instead of EOF } - return written, &fs.PathError{Op: "readat", Path: f.name, Err: err} + return written, f.wrapErr("writeto", err) } } } @@ -1980,13 +2136,18 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { // to maximize throughput for transferring the entire file, // especially over high latency links. func (f *File) WriteTo(w io.Writer) (written int64, err error) { - f.mu.Lock() - defer f.mu.Unlock() + if f == nil { + return 0, fs.ErrInvalid + } - if f.handle == "" { - return 0, &fs.PathError{Op: "writeto", Path: f.name, Err: fs.ErrClosed} + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("writeto", err) } + f.mu.Lock() + defer f.mu.Unlock() + ctx := context.Background() chunkSize := f.cl.maxDataLen @@ -2019,13 +2180,13 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { ctx := sendCtx // shadow ctx so we cannot accidentally reference the parent context here. req := &sshfx.ReadPacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(f.offset), Length: uint32(chunkSize), } for { - reqid, res, err := f.cl.conn.dispatch(req) + reqid, res, err := f.cl.conn.dispatch(closed, req) if err != nil { writeErr = err return @@ -2084,14 +2245,14 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { if err := recvErr; err != nil { if errors.Is(err, io.EOF) { - return written, nil + return written, nil // return nil instead of EOF } - return written, &fs.PathError{Op: "readat", Path: f.name, Err: err} + return written, f.wrapErr("writeto", err) } } - return written, writeErr + return written, f.wrapErr("writeto", writeErr) } // WriteFile writes data to the named file, creating it if neccessary. @@ -2113,7 +2274,21 @@ func (cl *Client) WriteFile(name string, data []byte, perm fs.FileMode) error { // ReadFile reads the named file and returns the contents. // A successful call returns err == nil, not err == EOF. // Because ReadFile reads the whole file, it does not treat an EOF from Read as an error to be reported. +// +// Note that ReadFile will call Stat on the file to get the file size, +// in order to avoid unnecessary allocations before reading in all the data. +// Some "read once" servers will delete the file if they recceive a stat call on an open file, +// and then the download will fail. +// +// TODO(puellannivis): Before release, we should resolve this, or have knobs to prevent it. func (cl *Client) ReadFile(name string) ([]byte, error) { + // TODO(puellanivis): we should use path.Split(), OpenDir() the parent, then use the FileInfo from readdir. + // With rangedir, we could even save on collecting all of the name entries to then search through them. + // This approach should work on read-once servers, even if the directory listing would be more expensive. + // Maybe include an UseFstat(false) option again to trigger it? + // There's a chance with case-insensitive servers, that Open(name) would work, but Glob(name) would not... + // so, we might not be able to universally apply it as the default. + f, err := cl.Open(name) if err != nil { return nil, err @@ -2152,13 +2327,13 @@ const ( // In some cases, this may mark a "mailbox"-style file as successfuly read, // and the server will delete the file, and return an error for all later operations. func (f *File) Seek(offset int64, whence int) (int64, error) { + if f == nil { + return 0, fs.ErrInvalid + } + f.mu.Lock() defer f.mu.Unlock() - if f.handle == "" { - return 0, &fs.PathError{Op: "seek", Path: f.name, Err: fs.ErrClosed} - } - var abs int64 switch whence { case SeekStart: @@ -2172,19 +2347,11 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { } abs = fi.Size() + offset default: - return 0, &fs.PathError{ - Op: "seek", - Path: f.name, - Err: fmt.Errorf("%w: invalid whence: %d", fs.ErrInvalid, whence), - } + return 0, f.wrapErr("seek", fmt.Errorf("%w: invalid whence: %d", fs.ErrInvalid, whence)) } if offset < 0 { - return f.offset, &fs.PathError{ - Op: "seek", - Path: f.name, - Err: fmt.Errorf("%w: negative offset: %d", fs.ErrInvalid, offset), - } + return 0, f.wrapErr("seek", fmt.Errorf("%w: negative offset: %d", fs.ErrInvalid, whence)) } f.offset = abs @@ -2198,16 +2365,22 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { // then no request will be sent, // and Sync returns an *fs.PathError wrapping sshfx.StatusOpUnsupported. func (f *File) Sync() error { - if !f.cl.hasExtension(openssh.ExtensionFSync()) { - return &fs.PathError{Op: "fsync", Path: f.name, Err: sshfx.StatusOpUnsupported} + if f == nil { + return fs.ErrInvalid } - err := f.cl.sendPacket(context.Background(), &openssh.FSyncExtendedPacket{ - Handle: f.handle, - }) + handle, closed, err := f.handle.get() if err != nil { - return &fs.PathError{Op: "fsync", Path: f.name, Err: err} + return f.wrapErr("fsync", err) } - return nil + if !f.cl.hasExtension(openssh.ExtensionFSync()) { + return f.wrapErr("fsync", sshfx.StatusOpUnsupported) + } + + return f.wrapErr("fsync", + f.cl.sendPacket(context.Background(), closed, &openssh.FSyncExtendedPacket{ + Handle: handle, + }), + ) } diff --git a/errno_plan9.go b/errno_plan9.go index a5840854..3c530658 100644 --- a/errno_plan9.go +++ b/errno_plan9.go @@ -18,7 +18,7 @@ func translateErrorString(errno syscall.ErrorString) sshfx.Status { case syscall.EACCES, syscall.EPERM: return sshfx.StatusPermissionDenied case syscall.EPLAN9: - return sshfx.StatusOPUnsupported + return sshfx.StatusOpUnsupported } return sshfx.StatusFailure diff --git a/localfs/file.go b/localfs/file.go index 69ce76ef..952845a9 100644 --- a/localfs/file.go +++ b/localfs/file.go @@ -3,6 +3,7 @@ package localfs import ( "cmp" "io/fs" + "iter" "os" "slices" "sync" @@ -21,7 +22,8 @@ type File struct { idLookup sftp.NameLookup mu sync.Mutex - dirErr error + lastErr error + lastEnt *sshfx.NameEntry entries []fs.FileInfo } @@ -43,45 +45,42 @@ func (f *File) Stat() (*sshfx.Attributes, error) { // rangedir returns an iterator over the directory entries of the directory. // It will only ever yield either a [fs.FileInfo] or an error, never both. -// No error will be yielded until all available FileInfos have been yielded, -// and thereafter the same error will be yielded indefinitely, -// however only one error will be yielded per invocation. -// If yield returns false, then the directory entry is considered unconsumed, -// and will be the first yield at the next call to rangedir. +// No error will be yielded until all available FileInfos have been yielded. +// Only one error will be yielded per invocation. // // We do not expose an iterator, because none has been standardized yet, // and we do not want to accidentally implement an API inconsistent with future standards. // However, for internal usage, we can separate the paginated Readdir code from the conversion to SFTP entries. // // Callers must guarantee synchronization by either holding the file lock, or holding an exclusive reference. -func (f *File) rangedir(yield func(fs.FileInfo, error) bool) { - for { - for i, entry := range f.entries { - if !yield(entry, nil) { - // This is break condition. - // As per our semantics, this means this entry has not been consumed. - // So we remove only the entries ahead of this one. - f.entries = slices.Delete(f.entries, 0, i) - return +func (f *File) rangedir(grow func(int)) iter.Seq2[fs.FileInfo, error] { + return func(yield func(fs.FileInfo, error) bool) { + for { + grow(len(f.entries)) + + for i, entry := range f.entries { + if !yield(entry, nil) { + // This is a break condition. + // We need to remove all entries that have been consumed, + // and that includes the one we are currently on. + f.entries = slices.Delete(f.entries, 0, i+1) + return + } } - } - // We have consumed all of the saved entries, so we remove everything. - f.entries = slices.Delete(f.entries, 0, len(f.entries)) + // We have consumed all of the saved entries, so we remove everything. + f.entries = slices.Delete(f.entries, 0, len(f.entries)) - if f.dirErr != nil { - // No need to try acquiring more entries, - // we’re already in the error state. - yield(nil, f.dirErr) - return - } + if f.lastErr != nil { + yield(nil, f.lastErr) + f.lastErr = nil + return + } - ents, err := f.Readdir(128) - if err != nil { - f.dirErr = err + // We cannot guarantee we only get entries, or an error, never both. + // So we need to just save these, and loop. + f.entries, f.lastErr = f.Readdir(128) } - - f.entries = ents } } @@ -91,8 +90,18 @@ func (f *File) ReadDir(maxDataLen uint32) (entries []*sshfx.NameEntry, err error f.mu.Lock() defer f.mu.Unlock() + if f.lastEnt != nil { + // Last ReadDir left an entry for us to include in this call. + entries = append(entries, f.lastEnt) + f.lastEnt = nil + } + + grow := func(more int) { + entries = slices.Grow(entries, more) + } + var size int - for fi, err := range f.rangedir { + for fi, err := range f.rangedir(grow) { if err != nil { if len(entries) != 0 { return entries, nil @@ -112,7 +121,10 @@ func (f *File) ReadDir(maxDataLen uint32) (entries []*sshfx.NameEntry, err error size += entry.Len() if size > int(maxDataLen) { - // rangedir will take care of starting the next range with this entry. + // This would exceed the packet data length, + // so save this one for the next call, + // and return. + f.lastEnt = entry break } diff --git a/localfs/localfs_integration_test.go b/localfs/localfs_integration_test.go index dc47cb7a..77538f32 100644 --- a/localfs/localfs_integration_test.go +++ b/localfs/localfs_integration_test.go @@ -761,12 +761,12 @@ func benchHelperWriteTo(b *testing.B, length int) { target := filepath.Join(dir, "bench-writeto") remote := toRemotePath(target) - if err := os.WriteFile(remote, nil, 0644); err != nil { + if err := os.WriteFile(target, nil, 0644); err != nil { b.Fatal(err) } defer os.Remove(remote) - if err := os.Truncate(remote, int64(length)); err != nil { + if err := os.Truncate(target, int64(length)); err != nil { b.Fatal(err) } @@ -834,7 +834,7 @@ func benchHelperReadFrom(b *testing.B, length int) { target := filepath.Join(dir, "bench-readfrom") remote := toRemotePath(target) - if err := os.WriteFile(remote, nil, 0644); err != nil { + if err := os.WriteFile(target, nil, 0644); err != nil { b.Fatal(err) } defer os.Remove(remote) diff --git a/localfs/statvfs/statvfs_plan9.go b/localfs/statvfs/statvfs_plan9.go index da85aed7..bdc04f07 100644 --- a/localfs/statvfs/statvfs_plan9.go +++ b/localfs/statvfs/statvfs_plan9.go @@ -7,10 +7,10 @@ import ( "github.com/pkg/sftp/v2/encoding/ssh/filexfer/openssh" ) -// StatVFS stubs the OpenSSH StatVFS with an sshfx.StatusOPUnsupported Status. +// StatVFS stubs the OpenSSH StatVFS with an sshfx.StatusOpUnsupported Status. func StatVFS(name string) (*openssh.StatVFSExtendedReplyPacket, error) { return nil, &sshfx.StatusPacket{ - StatusCode: sshfx.StatusOPUnsupported, + StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: syscall.EPLAN9.Error(), } } diff --git a/localfs/statvfs/statvfs_stubs.go b/localfs/statvfs/statvfs_stubs.go index 73da61bc..7d66aafe 100644 --- a/localfs/statvfs/statvfs_stubs.go +++ b/localfs/statvfs/statvfs_stubs.go @@ -10,10 +10,10 @@ import ( "github.com/pkg/sftp/v2/encoding/ssh/filexfer/openssh" ) -// StatVFS stubs the OpenSSH StatVFS with an sshfx.StatusOPUnsupported Status. +// StatVFS stubs the OpenSSH StatVFS with an sshfx.StatusOpUnsupported Status. func StatVFS(name string) (*openssh.StatVFSExtendedReplyPacket, error) { return nil, &sshfx.StatusPacket{ - StatusCode: sshfx.StatusOPUnsupported, + StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: "not supported by " + runtime.GOOS, } } diff --git a/server.go b/server.go index d6ad6204..cd9a4c5d 100644 --- a/server.go +++ b/server.go @@ -192,8 +192,8 @@ type DirHandler interface { type wrapHandler func(ctx context.Context, req sshfx.Packet) (sshfx.Packet, error) -// handle is the intersection of FileHandler and DirHandler -type handle interface { +// commonHandle is the intersection of FileHandler and DirHandler +type commonHandle interface { io.Closer Name() string @@ -212,7 +212,7 @@ type Server struct { Debug io.Writer wg sync.WaitGroup - handles sync.Map[string, handle] + handles sync.Map[string, commonHandle] hijacks map[sshfx.PacketType]wrapHandler dataPktPool *sync.Pool[sshfx.DataPacket]