diff --git a/client.go b/client.go index d8c48741..208b4b00 100644 --- a/client.go +++ b/client.go @@ -171,7 +171,14 @@ func (c *clientConn) recvLoop(maxPacket uint32) error { } } -func (c *clientConn) dispatch(req sshfx.PacketMarshaller) (uint32, chan result, error) { +// dispatch will marshal, then dispatch the given request packet. +// Packets are written atomically to the connection. +// It returns the allocated request id (a monotonously incrementing value), +// and either a channel upon which the result will be returned, or an error. +// +// If the cancel channel has been closed before the request is dipatched, +// then dispatch will return an [fs.ErrClosed] error. +func (c *clientConn) dispatch(cancel <-chan struct{}, req sshfx.PacketMarshaller) (uint32, chan result, error) { reqid := c.reqid.Add(1) header, payload, err := req.MarshalPacket(reqid, c.bufPool.Get()) @@ -180,6 +187,9 @@ func (c *clientConn) dispatch(req sshfx.PacketMarshaller) (uint32, chan result, } defer c.bufPool.Put(header) + // payload by design of the API is all but guaranteed to alias a caller-held byte slice, + // so, _do not_ put it into the bufPool. + ch, ok := c.resPool.Get() if !ok { return reqid, nil, sshfx.StatusConnectionLost @@ -188,6 +198,13 @@ func (c *clientConn) dispatch(req sshfx.PacketMarshaller) (uint32, chan result, c.mu.Lock() defer c.mu.Unlock() + select { + case <-cancel: + c.resPool.Put(ch) + return reqid, nil, fs.ErrClosed + default: + } + if c.inflight == nil { c.inflight = make(map[uint32]chan<- result) } @@ -258,8 +275,8 @@ func (c *clientConn) recv(ctx context.Context, reqid uint32, ch chan result) (*s } } -func (c *clientConn) send(ctx context.Context, req sshfx.PacketMarshaller) (*sshfx.RawPacket, error) { - reqid, ch, err := c.dispatch(req) +func (c *clientConn) send(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) (*sshfx.RawPacket, error) { + reqid, ch, err := c.dispatch(cancel, req) if err != nil { return nil, err } @@ -361,11 +378,13 @@ type Client struct { exts map[string]string } -func getPacket[PKT any, P interface { - sshfx.Packet +type respPacket[PKT any] interface { *PKT -}](ctx context.Context, cl *Client, req sshfx.PacketMarshaller) (*PKT, error) { - raw, err := cl.conn.send(ctx, req) + sshfx.Packet +} + +func getPacket[PKT any, P respPacket[PKT]](ctx context.Context, cancel <-chan struct{}, cl *Client, req sshfx.PacketMarshaller) (*PKT, error) { + raw, err := cl.conn.send(ctx, cancel, req) if err != nil { return nil, err } @@ -414,8 +433,8 @@ func statusToError(status *sshfx.StatusPacket, okExpected bool) error { return status } -func (cl *Client) sendPacket(ctx context.Context, req sshfx.PacketMarshaller) error { - reqid, ch, err := cl.conn.dispatch(req) +func (cl *Client) sendPacket(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) error { + reqid, ch, err := cl.conn.dispatch(cancel, req) if err != nil { return err } @@ -444,8 +463,8 @@ func (cl *Client) recvStatus(ctx context.Context, reqid uint32, ch chan result, } } -func (cl *Client) sendRead(ctx context.Context, req *sshfx.ReadPacket, resp *sshfx.DataPacket) (int, error) { - reqid, ch, err := cl.conn.dispatch(req) +func (cl *Client) sendRead(ctx context.Context, cancel <-chan struct{}, req *sshfx.ReadPacket, resp *sshfx.DataPacket) (int, error) { + reqid, ch, err := cl.conn.dispatch(cancel, req) if err != nil { return 0, err } @@ -583,22 +602,45 @@ func (cl *Client) Close() error { return nil } +func wrapPathError(op, path string, err error) error { + if err == nil { + return nil + } + + if errors.Is(err, io.EOF) { + // Numerous odd things break if we don't return bare io.EOF errors. + return io.EOF + } + + return &fs.PathError{Op: op, Path: path, Err: err} +} + +func wrapLinkError(op, oldpath, newpath string, err error) error { + if err == nil { + return nil + } + + if errors.Is(err, io.EOF) { + // Numerous odd things break if we don't return bare io.EOF errors. + return io.EOF + } + + return &os.LinkError{Op: op, Old: oldpath, New: newpath, Err: err} +} + // Mkdir creates the specified directory. // An error will be returned if a file or directory with the specified path already exists, // or if the directory's parent folder does not exist. func (cl *Client) Mkdir(name string, perm fs.FileMode) error { - err := cl.sendPacket(context.Background(), &sshfx.MkdirPacket{ - Path: name, - Attrs: sshfx.Attributes{ - Flags: sshfx.AttrPermissions, - Permissions: sshfx.FileMode(perm.Perm()), - }, - }) - if err != nil { - return &fs.PathError{Op: "mkdir", Path: name, Err: err} - } - - return nil + return wrapPathError("mkdir", name, + cl.sendPacket(context.Background(), nil, &sshfx.MkdirPacket{ + Path: name, + Attrs: sshfx.Attributes{ + Flags: sshfx.AttrPermissions, + Permissions: sshfx.FileMode(perm.Perm()), + }, + }), + ) } // MkdirAll creates a directory named path, along with any necessary parents. @@ -611,7 +653,7 @@ func (cl *Client) MkdirAll(name string, perm fs.FileMode) error { return nil } - return &fs.PathError{Op: "mkdir", Path: name, Err: syscall.ENOTDIR} + return wrapPathError("mkdir", name, syscall.ENOTDIR) } // Slow path: make sure parent exists and then call Mkdir for name. @@ -647,14 +689,14 @@ func (cl *Client) MkdirAll(name string, perm fs.FileMode) error { func (cl *Client) Remove(name string) error { ctx := context.Background() - err := cl.sendPacket(ctx, &sshfx.RemovePacket{ + err := cl.sendPacket(ctx, nil, &sshfx.RemovePacket{ Path: name, }) if err == nil { return nil } - err1 := cl.sendPacket(ctx, &sshfx.RmdirPacket{ + err1 := cl.sendPacket(ctx, nil, &sshfx.RmdirPacket{ Path: name, }) if err1 == nil { @@ -663,7 +705,7 @@ func (cl *Client) Remove(name string) error { // Both failed: figure out which error to return. if err != err1 { - attrs, err2 := getPacket[sshfx.AttrsPacket](ctx, cl, &sshfx.StatPacket{ + attrs, err2 := getPacket[sshfx.AttrsPacket](ctx, nil, cl, &sshfx.StatPacket{ Path: name, }) if err2 != nil { @@ -675,19 +717,16 @@ func (cl *Client) Remove(name string) error { } } - return &fs.PathError{Op: "remove", Path: name, Err: err} + return wrapPathError("remove", name, err) } func (cl *Client) setstat(ctx context.Context, name string, attrs *sshfx.Attributes) error { - err := cl.sendPacket(ctx, &sshfx.SetStatPacket{ - Path: name, - Attrs: *attrs, - }) - if err != nil { - return &fs.PathError{Op: "setstat", Path: name, Err: err} - } - - return nil + return wrapPathError("setstat", name, + cl.sendPacket(ctx, nil, &sshfx.SetStatPacket{ + Path: name, + Attrs: *attrs, + }), + ) } // Truncate changes the size of the named file. @@ -749,11 +788,11 @@ func (cl *Client) Chtimes(name string, atime, mtime time.Time) error { // This is useful for converting path names containing ".." components, // or relative pathnames without a leading slash into absolute paths. func (cl *Client) RealPath(name string) (string, error) { - pkt, err := getPacket[sshfx.PathPseudoPacket](context.Background(), cl, &sshfx.RealPathPacket{ + pkt, err := getPacket[sshfx.PathPseudoPacket](context.Background(), nil, cl, &sshfx.RealPathPacket{ Path: name, }) if err != nil { - return "", &fs.PathError{Op: "realpath", Path: name, Err: err} + return "", wrapPathError("realpath", name, err) } return pkt.Path, nil @@ -764,11 +803,11 @@ func (cl *Client) RealPath(name string) (string, error) { // The client cannot guarantee any specific way that a server handles a relative link destination. // That is, you may receive a relative link destination, one that has been converted to an absolute path. func (cl *Client) ReadLink(name string) (string, error) { - pkt, err := getPacket[sshfx.PathPseudoPacket](context.Background(), cl, &sshfx.ReadLinkPacket{ + pkt, err := getPacket[sshfx.PathPseudoPacket](context.Background(), nil, cl, &sshfx.ReadLinkPacket{ Path: name, }) if err != nil { - return "", &fs.PathError{Op: "readlink", Path: name, Err: err} + return "", wrapPathError("readlink", name, err) } return pkt.Path, nil @@ -780,40 +819,31 @@ func (cl *Client) ReadLink(name string) (string, error) { // Even within the same directory, on non-Unix servers Rename is not guaranteed to be an atomic operation. func (cl *Client) Rename(oldpath, newpath string) error { if cl.hasExtension(openssh.ExtensionPOSIXRename()) { - err := cl.sendPacket(context.Background(), &openssh.POSIXRenameExtendedPacket{ - OldPath: oldpath, - NewPath: newpath, - }) - if err != nil { - return &os.LinkError{Op: "rename", Old: oldpath, New: newpath, Err: err} - } - - return nil + return wrapLinkError("rename", oldpath, newpath, + cl.sendPacket(context.Background(), nil, &openssh.POSIXRenameExtendedPacket{ + OldPath: oldpath, + NewPath: newpath, + }), + ) } - err := cl.sendPacket(context.Background(), &sshfx.RenamePacket{ - OldPath: oldpath, - NewPath: newpath, - }) - if err != nil { - return &os.LinkError{Op: "rename", Old: oldpath, New: newpath, Err: err} - } - - return nil + return wrapLinkError("rename", oldpath, newpath, + cl.sendPacket(context.Background(), nil, &sshfx.RenamePacket{ + OldPath: oldpath, + NewPath: newpath, + }), + ) } // Symlink creates newname as a symbolic link to oldname. // There is no guarantee for how a server may handle the request if oldname does not exist. func (cl *Client) Symlink(oldname, newname string) error { - err := cl.sendPacket(context.Background(), &sshfx.SymlinkPacket{ - LinkPath: newname, - TargetPath: oldname, - }) - if err != nil { - return &os.LinkError{Op: "symlink", Old: oldname, New: newname, Err: err} - } - - return nil + return wrapLinkError("symlink", oldname, newname, + cl.sendPacket(context.Background(), nil, &sshfx.SymlinkPacket{ + LinkPath: newname, + TargetPath: oldname, + }), + ) } func (cl *Client) hasExtension(ext *sshfx.ExtensionPair) bool { @@ -827,18 +857,15 @@ func (cl *Client) hasExtension(ext *sshfx.ExtensionPair) bool { // and Link returns an *fs.LinkError wrapping sshfx.StatusOpUnsupported. func (cl *Client) Link(oldname, newname string) error { if !cl.hasExtension(openssh.ExtensionHardlink()) { - return &os.LinkError{Op: "hardlink", Old: oldname, New: newname, Err: sshfx.StatusOpUnsupported} - } - - err := cl.sendPacket(context.Background(), &openssh.HardlinkExtendedPacket{ - NewPath: newname, - OldPath: oldname, - }) - if err != nil { - return &os.LinkError{Op: "hardlink", Old: oldname, New: newname, Err: err} + return wrapLinkError("hardlink", oldname, newname, sshfx.StatusOpUnsupported) } - return nil + return wrapLinkError("hardlink", oldname, newname, + cl.sendPacket(context.Background(), nil, &openssh.HardlinkExtendedPacket{ + OldPath: oldname, + NewPath: newname, + }), + ) } // Readdir reads the named directory, returning all its directory entries as [fs.FileInfo] sorted by filename. @@ -886,12 +913,14 @@ func (cl *Client) ReadDirContext(ctx context.Context, name string) ([]fs.DirEntr return fis, err } -func (cl *Client) stat(name string) (*sshfx.NameEntry, error) { - pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), cl, &sshfx.StatPacket{ +// Stat returns a FileInfo describing the named file. +// If the file is a symbolic link, the returned FileInfo describes the link's target. +func (cl *Client) Stat(name string) (fs.FileInfo, error) { + pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), nil, cl, &sshfx.StatPacket{ Path: name, }) if err != nil { - return nil, &fs.PathError{Op: "stat", Path: name, Err: err} + return nil, wrapPathError("stat", name, err) } return &sshfx.NameEntry{ @@ -900,23 +929,17 @@ func (cl *Client) stat(name string) (*sshfx.NameEntry, error) { }, nil } -// Stat returns a FileInfo describing the named file. -// If the file is a symbolic link, the returned FileInfo describes the link's target. -func (cl *Client) Stat(name string) (fs.FileInfo, error) { - return cl.stat(name) -} - // LStat returns a FileInfo describing the named file. // If the file is a symbolic link, the returned FileInfo describes the symbolic link // LStat makes no attempte to follow the link. // // The description returned may have server specific caveats and special cases that cannot be covered here. func (cl *Client) LStat(name string) (fs.FileInfo, error) { - pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), cl, &sshfx.LStatPacket{ + pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), nil, cl, &sshfx.LStatPacket{ Path: name, }) if err != nil { - return nil, &fs.PathError{Op: "lstat", Path: name, Err: err} + return nil, wrapPathError("lstat", name, err) } return &sshfx.NameEntry{ @@ -925,6 +948,62 @@ func (cl *Client) LStat(name string) (fs.FileInfo, error) { }, nil } +type clHandle struct { + value atomic.Pointer[string] + closed chan struct{} +} + +func (h *clHandle) init(handle string) { + h.value.Store(&handle) + h.closed = make(chan struct{}) +} + +func (h *clHandle) get() (handle string, cancel <-chan struct{}, err error) { + p := h.value.Load() + if p == nil { + return "", nil, fs.ErrClosed + } + return *p, h.closed, nil +} + +func (h *clHandle) close(cl *Client) error { + // The design principle here is that when `openssh-portable/sftp-server.c` is doing `handle_close`, + // it will unconditionally mark the handle as unused, + // so we need to also unconditionally mark this handle as invalid. + // By invalidating our local copy of the handle, + // we ensure that there cannot be any new erroneous use-after-close receiver methods started after this swap. + handle := h.value.Swap(nil) + if handle == nil { + return fs.ErrClosed + } + + // The atomic Swap above ensures that only one Close can ever get here. + // We could also use a mutex to guarantee exclusivity here, + // but that would block Close until all synchronized operations have completed, + // some of which could be paused indefinitely. + // + // See: https://github.com/pkg/sftp/issues/603 for more details. + + // So, we have defended now against new receiver methods starting, + // but since an outstanding method could still be holding the handle, we still need a close signal. + // Since this close HAPPENS BEFORE the sendPacket below, + // this ensures that after closing this channel, no further requests will be dispatched. + // Meaning we know that the close request below will be the final request from this clHandle. + close(h.closed) + + // One might assume we could just simply use the closed channel alone, + // but because close panics if called twice, we need a select to test if the channel is already closed, + // and since there is a window of time between such a test and the close, two goroutines can race. + // So we still need to synchronize the close operation anyways, so either atomic pointer or mutex. + + // It should be obvious, but do not pass h.closed into this sendPacket, or it will never be sent. + // Less obviously, DO NOT pipe a context through this function to the sendPacket. + // We want to ensure that even in a closed-context codepath, that the SSH_FXP_CLOSED packet is still sent. + return cl.sendPacket(context.Background(), nil, &sshfx.ClosePacket{ + Handle: *handle, + }) +} + // Dir represents an open directory handle. // // The methods of Dir are safe for concurrent use. @@ -932,8 +1011,9 @@ type Dir struct { cl *Client name string + handle clHandle + mu sync.RWMutex - handle string entries []*sshfx.NameEntry } @@ -942,51 +1022,35 @@ type Dir struct { // // The semantics of SSH_FX_OPENDIR is such that the associated file handle is in a read-only mode. func (cl *Client) OpenDir(name string) (*Dir, error) { - return cl.openDir(context.Background(), name) -} - -func (cl *Client) openDir(ctx context.Context, name string) (*Dir, error) { - pkt, err := getPacket[sshfx.HandlePacket](ctx, cl, &sshfx.OpenDirPacket{ + pkt, err := getPacket[sshfx.HandlePacket](context.Background(), nil, cl, &sshfx.OpenDirPacket{ Path: name, }) if err != nil { - return nil, &fs.PathError{Op: "opendir", Path: name, Err: err} + return nil, wrapPathError("opendir", name, err) } - return &Dir{ - cl: cl, - name: name, - handle: pkt.Handle, - }, nil + d := &Dir{ + cl: cl, + name: name, + } + + d.handle.init(pkt.Handle) + + return d, nil +} + +func (d *Dir) wrapErr(op string, err error) error { + return wrapPathError(op, d.name, err) } // Close closes the Dir, rendering it unusable for I/O. // Close will not send any request, and return an error if it has already been called. func (d *Dir) Close() error { - d.mu.Lock() - defer d.mu.Unlock() - - if d.handle == "" { - return &fs.PathError{Op: "close", Path: d.name, Err: fs.ErrClosed} + if d == nil { + return os.ErrInvalid } - // The design principle here is that when `openssh-portable/sftp-server.c` is doing `handle_close`, - // it will unconditionally mark the handle as unused, - // so we need to also unconditionally mark this handle as invalid. - // By invalidating our local copy of the handle, - // we ensure that there cannot be any erroneous use-after-close requests sent after Close. - - handle := d.handle - d.handle = "" - - err := d.cl.sendPacket(context.Background(), &sshfx.ClosePacket{ - Handle: handle, - }) - if err != nil { - return &fs.PathError{Op: "close", Path: d.name, Err: err} - } - - return nil + return d.wrapErr("close", d.handle.close(d.cl)) } // Name returns the name of the directory as presented to OpenDir. @@ -1015,18 +1079,19 @@ func (d *Dir) rangedir(ctx context.Context) iter.Seq2[*sshfx.NameEntry, error] { d.entries = slices.Delete(d.entries, 0, len(d.entries)) for { - pkt, err := getPacket[sshfx.NamePacket](ctx, d.cl, &sshfx.ReadDirPacket{ - Handle: d.handle, + handle, closed, err := d.handle.get() + if err != nil { + yield(nil, err) + return + } + + pkt, err := getPacket[sshfx.NamePacket](ctx, closed, d.cl, &sshfx.ReadDirPacket{ + Handle: handle, }) if err != nil { // There are no remaining entries to save here, // SFTP can only return either an error or a result, never both. - if errors.Is(err, io.EOF) { - yield(nil, io.EOF) - return - } - - yield(nil, &fs.PathError{Op: "readdir", Path: d.name, Err: err}) + yield(nil, err) return } @@ -1059,13 +1124,13 @@ func (d *Dir) Readdir(n int) ([]fs.FileInfo, error) { // If n <= 0, ReaddirContext returns all the FileInfo records remaining in the directory. // When it succeeds, it returns a nil error (not io.EOF). func (d *Dir) ReaddirContext(ctx context.Context, n int) ([]fs.FileInfo, error) { + if d == nil { + return nil, os.ErrInvalid + } + d.mu.Lock() defer d.mu.Unlock() - if d.handle == "" { - return nil, &fs.PathError{Op: "readdir", Path: d.name, Err: fs.ErrClosed} - } - var ret []fs.FileInfo for ent, err := range d.rangedir(ctx) { @@ -1074,7 +1139,7 @@ func (d *Dir) ReaddirContext(ctx context.Context, n int) ([]fs.FileInfo, error) return ret, nil } - return ret, err + return ret, d.wrapErr("readdir", err) } ret = append(ret, ent) @@ -1105,13 +1170,13 @@ func (d *Dir) ReadDir(n int) ([]fs.DirEntry, error) { // If n <= 0, ReadDirContext returns all the DirEntry records remaining in the directory. // When it succeeds, it returns a nil error (not io.EOF). func (d *Dir) ReadDirContext(ctx context.Context, n int) ([]fs.DirEntry, error) { + if d == nil { + return nil, os.ErrInvalid + } + d.mu.Lock() defer d.mu.Unlock() - if d.handle == "" { - return nil, &fs.PathError{Op: "readdir", Path: d.name, Err: fs.ErrClosed} - } - var ret []fs.DirEntry for ent, err := range d.rangedir(ctx) { @@ -1120,7 +1185,7 @@ func (d *Dir) ReadDirContext(ctx context.Context, n int) ([]fs.DirEntry, error) return ret, nil } - return ret, err + return ret, d.wrapErr("readdir", err) } ret = append(ret, ent) @@ -1140,8 +1205,9 @@ type File struct { cl *Client name string + handle clHandle + mu sync.RWMutex - handle string offset int64 // current offset within remote file } @@ -1210,7 +1276,7 @@ func (cl *Client) Create(name string) (*File, error) { // Note well: since all Write operations are down through an offset-specifying operation, // the OpenFlagAppend flag is currently ignored. func (cl *Client) OpenFile(name string, flag int, perm fs.FileMode) (*File, error) { - pkt, err := getPacket[sshfx.HandlePacket](context.Background(), cl, &sshfx.OpenPacket{ + pkt, err := getPacket[sshfx.HandlePacket](context.Background(), nil, cl, &sshfx.OpenPacket{ Filename: name, PFlags: toPortableFlags(flag), Attrs: sshfx.Attributes{ @@ -1219,43 +1285,31 @@ func (cl *Client) OpenFile(name string, flag int, perm fs.FileMode) (*File, erro }, }) if err != nil { - return nil, err + return nil, wrapPathError("openfile", name, err) } - return &File{ - cl: cl, - name: name, - handle: pkt.Handle, - }, nil + f := &File{ + cl: cl, + name: name, + } + + f.handle.init(pkt.Handle) + + return f, nil +} + +func (f *File) wrapErr(op string, err error) error { + return wrapPathError(op, f.name, err) } // Close closes the File, rendering it unusable for I/O. // Close will not send any request, and return an error if it has already been called. func (f *File) Close() error { - f.mu.Lock() - defer f.mu.Unlock() - - if f.handle == "" { - return &fs.PathError{Op: "close", Path: f.name, Err: fs.ErrClosed} + if f == nil { + return fs.ErrInvalid } - // The design principle here is that when `openssh-portable/sftp-server.c` is doing `handle_close`, - // it will unconditionally mark the handle as unused, - // so we need to also unconditionally mark this handle as invalid. - // By invalidating our local copy of the handle, - // we ensure that there cannot be any erroneous use-after-close requests sent after Close. - - handle := f.handle - f.handle = "" - - err := f.cl.sendPacket(context.Background(), &sshfx.ClosePacket{ - Handle: handle, - }) - if err != nil { - return &fs.PathError{Op: "close", Path: f.name, Err: err} - } - - return nil + return f.wrapErr("close", f.handle.close(f.cl)) } // Name returns the name of the file as presented to Open. @@ -1266,22 +1320,21 @@ func (f *File) Name() string { } func (f *File) setstat(ctx context.Context, attrs *sshfx.Attributes) error { - f.mu.Lock() - defer f.mu.Unlock() - - if f.handle == "" { - return &fs.PathError{Op: "fsetstat", Path: f.name, Err: fs.ErrClosed} + if f == nil { + return fs.ErrInvalid } - err := f.cl.sendPacket(ctx, &sshfx.FSetStatPacket{ - Handle: f.handle, - Attrs: *attrs, - }) + handle, closed, err := f.handle.get() if err != nil { - return &fs.PathError{Op: "fsetstat", Path: f.name, Err: err} + return f.wrapErr("fsetstat", err) } - return nil + return f.wrapErr("fsetstat", + f.cl.sendPacket(ctx, closed, &sshfx.FSetStatPacket{ + Handle: handle, + Attrs: *attrs, + }), + ) } // Truncate changes the size of the file. @@ -1327,12 +1380,22 @@ func (f *File) Chtimes(atime, mtime time.Time) error { }) } -func (f *File) stat() (*sshfx.NameEntry, error) { - pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), f.cl, &sshfx.FStatPacket{ - Handle: f.handle, +// Stat returns the FileInfo structure describing file. +func (f *File) Stat() (fs.FileInfo, error) { + if f == nil { + return nil, fs.ErrInvalid + } + + handle, closed, err := f.handle.get() + if err != nil { + return nil, f.wrapErr("fstat", err) + } + + pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), closed, f.cl, &sshfx.FStatPacket{ + Handle: handle, }) if err != nil { - return nil, &fs.PathError{Op: "fstat", Path: f.name, Err: err} + return nil, f.wrapErr("fstat", err) } return &sshfx.NameEntry{ @@ -1341,21 +1404,14 @@ func (f *File) stat() (*sshfx.NameEntry, error) { }, nil } -// Stat returns the FileInfo structure describing file. -func (f *File) Stat() (fs.FileInfo, error) { - f.mu.Lock() - defer f.mu.Unlock() - - if f.handle == "" { - return nil, &fs.PathError{Op: "fstat", Path: f.name, Err: fs.ErrClosed} +func (f *File) writeatFull(ctx context.Context, b []byte, off int64) (written int, err error) { + handle, closed, err := f.handle.get() + if err != nil { + return 0, err } - return f.stat() -} - -func (f *File) writeAtFull(ctx context.Context, b []byte, off int64) (written int, err error) { req := &sshfx.WritePacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(off), } @@ -1366,9 +1422,9 @@ func (f *File) writeAtFull(ctx context.Context, b []byte, off int64) (written in req.Data, b = b[:n], b[n:] - err = f.cl.sendPacket(ctx, req) + err = f.cl.sendPacket(ctx, closed, req) if err != nil { - return written, &fs.PathError{Op: "writeat", Path: f.name, Err: err} + return written, f.wrapErr("writeat", err) } req.Offset += uint64(n) @@ -1378,15 +1434,16 @@ func (f *File) writeAtFull(ctx context.Context, b []byte, off int64) (written in return written, nil } -func (f *File) writeAt(ctx context.Context, b []byte, off int64) (written int, err error) { - if f.handle == "" { - return 0, &fs.PathError{Op: "writeat", Path: f.name, Err: fs.ErrClosed} - } - +func (f *File) writeat(ctx context.Context, b []byte, off int64) (written int, err error) { if len(b) <= f.cl.maxDataLen { // This should be able to be serviced with just 1 request. // So, just do it directly. - return f.writeAtFull(ctx, b, off) + return f.writeatFull(ctx, b, off) + } + + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("writeat", err) } // Split the write into multiple maxPacket sized concurrent writes bounded by maxInflight. @@ -1419,7 +1476,7 @@ func (f *File) writeAt(ctx context.Context, b []byte, off int64) (written int, e chunkSize := f.cl.maxDataLen req := &sshfx.WritePacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(f.offset), } @@ -1428,7 +1485,7 @@ func (f *File) writeAt(ctx context.Context, b []byte, off int64) (written int, e req.Data, b = b[:n], b[n:] - reqid, res, err := f.cl.conn.dispatch(req) + reqid, res, err := f.cl.conn.dispatch(closed, req) if err != nil { errCh <- rwErr{req.Offset, err} return @@ -1488,11 +1545,11 @@ func (f *File) writeAt(ctx context.Context, b []byte, off int64) (written int, e // * the offset of the start of the first error received in response to a write packet. // * the offset of the start of the first error received dispatching a write packet offset. // - // Either way, this should be the last successfully write offset. - written := int(int64(firstErr.off) - f.offset) + // Either way, this should be the last successfully written offset. + written := int64(firstErr.off) - f.offset f.offset = int64(firstErr.off) - return written, firstErr.err + return int(written), f.wrapErr("writeat", firstErr.err) } // We didn’t hit any errors, so we must have written all the bytes in the buffer. @@ -1506,20 +1563,25 @@ func (f *File) writeAt(ctx context.Context, b []byte, off int64) (written int, e // It returns the number of bytes written and an error, if any. // WriteAt returns a non-nil error when n != len(b). func (f *File) WriteAt(b []byte, off int64) (n int, err error) { - f.mu.RLock() - defer f.mu.RUnlock() + if f == nil { + return 0, fs.ErrInvalid + } - return f.writeAt(context.Background(), b, off) + return f.writeat(context.Background(), b, off) } // Write writes len(b) bytes from b to the File. // It returns the number of bytes written and an error, if any. // Write returns a non-nil error when n != len(b) func (f *File) Write(b []byte) (int, error) { + if f == nil { + return 0, fs.ErrInvalid + } + f.mu.Lock() defer f.mu.Unlock() - n, err := f.writeAt(context.Background(), b, f.offset) + n, err := f.writeat(context.Background(), b, f.offset) f.offset += int64(n) return n, err @@ -1531,12 +1593,23 @@ func (f *File) WriteString(s string) (n int, err error) { return f.Write(b) } -func (f *File) readFromSequential(r io.Reader) (read int64, err error) { - ctx := context.Background() +func (f *File) readFromSequential(ctx context.Context, r io.Reader) (read int64, err error) { + if f == nil { + return 0, fs.ErrInvalid + } + + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("readfrom", err) + } + + f.mu.Lock() + defer f.mu.Unlock() + b := make([]byte, f.cl.maxDataLen) req := &sshfx.WritePacket{ - Handle: f.handle, + Handle: handle, } for { @@ -1551,7 +1624,7 @@ func (f *File) readFromSequential(r io.Reader) (read int64, err error) { req.Data = b[:n] req.Offset = uint64(f.offset) - err1 := f.cl.sendPacket(ctx, req) + err1 := f.cl.sendPacket(ctx, closed, req) if err1 == nil { // Only increment file offset, if we got a sucess back. f.offset += int64(n) @@ -1562,10 +1635,10 @@ func (f *File) readFromSequential(r io.Reader) (read int64, err error) { if err != nil { if errors.Is(err, io.EOF) { - return read, nil // return nil explicitly + return read, nil // return nil instead of EOF } - return read, err + return read, f.wrapErr("readfrom", err) } } } @@ -1584,13 +1657,18 @@ func (e panicInstead) Error() string { // to maximize throughput when transferring an entire file, // especially over high-latency links. func (f *File) ReadFrom(r io.Reader) (read int64, err error) { - f.mu.Lock() - defer f.mu.Unlock() + if f == nil { + return 0, fs.ErrInvalid + } - if f.handle == "" { - return 0, fs.ErrClosed + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("readfrom", err) } + f.mu.Lock() + defer f.mu.Unlock() + ctx := context.Background() chunkSize := f.cl.maxDataLen @@ -1620,7 +1698,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { defer f.cl.conn.bufPool.Put(b) req := &sshfx.WritePacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(f.offset), } @@ -1635,7 +1713,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { read += int64(n) req.Data = b[:n] - reqid, res, err1 := f.cl.conn.dispatch(req) + reqid, res, err1 := f.cl.conn.dispatch(closed, req) if err1 == nil { // If _NO_ error occurred during dispatch. select { case workCh <- work{reqid, res, req.Offset}: @@ -1711,7 +1789,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { } // ReadFrom is defined to return the read bytes, regardless of any write errors. - return read, firstErr.err + return read, f.wrapErr("readfrom", firstErr.err) } // We didn’t hit any errors, so we must have written all the bytes that we read until EOF. @@ -1719,13 +1797,18 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { return read, nil } -// readAtFull attempts to read the whole entire length of the buffer from the file starting at the offset. +// readatFull attempts to read the whole entire length of the buffer from the file starting at the offset. // It will continue progressively reading into the buffer until it fills the whole buffer, or an error occurs. // // This is prefered over io.ReadFull, because it can reuse read and data packet allocations. -func (f *File) readAtFull(ctx context.Context, b []byte, off int64) (read int, err error) { +func (f *File) readatFull(ctx context.Context, b []byte, off int64) (read int, err error) { + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("readat", err) + } + req := &sshfx.ReadPacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(off), } @@ -1744,7 +1827,7 @@ func (f *File) readAtFull(ctx context.Context, b []byte, off int64) (read int, e // Otherwise, we would need to use unsafe.SliceData to identify a reallocation. resp.Data = slices.Clip(b[:n]) - m, err := f.cl.sendRead(ctx, req, &resp) + m, err := f.cl.sendRead(ctx, closed, req, &resp) if m > n { // OH NO! We received more data than we expected! @@ -1758,26 +1841,23 @@ func (f *File) readAtFull(ctx context.Context, b []byte, off int64) (read int, e read += m if err != nil { - if errors.Is(err, io.EOF) { - return read, io.EOF // io.Copy does not allow this to be wrapped. - } - - return read, &fs.PathError{Op: "readat", Path: f.name, Err: err} + return read, f.wrapErr("readat", err) } } return read, nil } -func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err error) { - if f.handle == "" { - return 0, &fs.PathError{Op: "readat", Path: f.name, Err: fs.ErrClosed} - } - +func (f *File) readat(ctx context.Context, b []byte, off int64) (read int, err error) { if len(b) <= f.cl.maxDataLen { // This should be able to be serviced most times with only 1 request. // So, just do it sequentially. - return f.readAtFull(ctx, b, off) + return f.readatFull(ctx, b, off) + } + + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("readat", err) } sendCtx, cancel := context.WithCancel(ctx) @@ -1812,7 +1892,7 @@ func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err e chunkSize := f.cl.maxDataLen req := &sshfx.ReadPacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(off), } @@ -1821,7 +1901,7 @@ func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err e req.Length = uint32(n) - reqid, res, err := f.cl.conn.dispatch(req) + reqid, res, err := f.cl.conn.dispatch(closed, req) if err != nil { errCh <- rwErr{req.Offset, err} return @@ -1853,7 +1933,7 @@ func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err e var resp sshfx.DataPacket for work := range workCh { - // See readAtFull for an explanation for why we use slices.Clip here. + // See readatFull for an explanation for why we use slices.Clip here. resp.Data = slices.Clip(work.b) n, err := f.cl.recvData(ctx, work.reqid, work.res, &resp) @@ -1893,7 +1973,7 @@ func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err e if firstErr.err != nil { // firstErr.err != nil if and only if firstErr.off > our starting offset. - return int(int64(firstErr.off) - off), firstErr.err + return int(int64(firstErr.off) - off), f.wrapErr("readat", firstErr.err) } // As per spec for io.ReaderAt, we return nil error if and only if we read everything. @@ -1905,20 +1985,25 @@ func (f *File) readAt(ctx context.Context, b []byte, off int64) (read int, err e // ReadAt always returns a non-nil error when n < len(b). // At the end of file, the error is io.EOF. func (f *File) ReadAt(b []byte, off int64) (int, error) { - f.mu.RLock() - defer f.mu.RUnlock() + if f == nil { + return 0, fs.ErrInvalid + } - return f.readAt(context.Background(), b, off) + return f.readat(context.Background(), b, off) } // Read reads up to len(b) bytes from the File and stores them in b. // It returns the number of bytes read and any error encountered. // At end of file, Read returns 0, io.EOF. func (f *File) Read(b []byte) (int, error) { + if f == nil { + return 0, fs.ErrInvalid + } + f.mu.Lock() defer f.mu.Unlock() - n, err := f.readAt(context.Background(), b, f.offset) + n, err := f.readat(context.Background(), b, f.offset) f.offset += int64(n) @@ -1930,11 +2015,23 @@ func (f *File) Read(b []byte) (int, error) { } func (f *File) writeToSequential(w io.Writer) (written int64, err error) { + if f == nil { + return 0, fs.ErrInvalid + } + + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("writeto", err) + } + + f.mu.Lock() + defer f.mu.Unlock() + ctx := context.Background() b := make([]byte, f.cl.maxDataLen) req := &sshfx.ReadPacket{ - Handle: f.handle, + Handle: handle, Length: uint32(len(b)), } @@ -1945,7 +2042,7 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { for { req.Offset = uint64(f.offset) - read, err := f.cl.sendRead(ctx, req, &resp) + read, err := f.cl.sendRead(ctx, closed, req, &resp) if read < 0 { panic("sftp: writeto: sendRead returned negative count") @@ -1964,10 +2061,10 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { if err != nil { if errors.Is(err, io.EOF) { - return written, nil // return nil explicitly. + return written, nil // return nil instead of EOF } - return written, &fs.PathError{Op: "readat", Path: f.name, Err: err} + return written, f.wrapErr("writeto", err) } } } @@ -1980,13 +2077,18 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { // to maximize throughput for transferring the entire file, // especially over high latency links. func (f *File) WriteTo(w io.Writer) (written int64, err error) { - f.mu.Lock() - defer f.mu.Unlock() + if f == nil { + return 0, fs.ErrInvalid + } - if f.handle == "" { - return 0, &fs.PathError{Op: "writeto", Path: f.name, Err: fs.ErrClosed} + handle, closed, err := f.handle.get() + if err != nil { + return 0, f.wrapErr("writeto", err) } + f.mu.Lock() + defer f.mu.Unlock() + ctx := context.Background() chunkSize := f.cl.maxDataLen @@ -2019,13 +2121,13 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { ctx := sendCtx // shadow ctx so we cannot accidentally reference the parent context here. req := &sshfx.ReadPacket{ - Handle: f.handle, + Handle: handle, Offset: uint64(f.offset), Length: uint32(chunkSize), } for { - reqid, res, err := f.cl.conn.dispatch(req) + reqid, res, err := f.cl.conn.dispatch(closed, req) if err != nil { writeErr = err return @@ -2084,14 +2186,14 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { if err := recvErr; err != nil { if errors.Is(err, io.EOF) { - return written, nil + return written, nil // return nil instead of EOF } - return written, &fs.PathError{Op: "readat", Path: f.name, Err: err} + return written, f.wrapErr("writeto", err) } } - return written, writeErr + return written, f.wrapErr("writeto", writeErr) } // WriteFile writes data to the named file, creating it if neccessary. @@ -2113,7 +2215,21 @@ func (cl *Client) WriteFile(name string, data []byte, perm fs.FileMode) error { // ReadFile reads the named file and returns the contents. // A successful call returns err == nil, not err == EOF. // Because ReadFile reads the whole file, it does not treat an EOF from Read as an error to be reported. +// +// Note that ReadFile will call Stat on the file to get the file size, +// in order to avoid unnecessary allocations before reading in all the data. +// Some "read once" servers will delete the file if they recceive a stat call on an open file, +// and then the download will fail. +// +// TODO(puellannivis): Before release, we should resolve this, or have knobs to prevent it. func (cl *Client) ReadFile(name string) ([]byte, error) { + // TODO(puellanivis): we should use path.Split(), OpenDir() the parent, then use the FileInfo from readdir. + // With rangedir, we could even save on collecting all of the name entries to then search through them. + // This approach should work on read-once servers, even if the directory listing would be more expensive. + // Maybe include an UseFstat(false) option again to trigger it? + // There's a chance with case-insensitive servers, that Open(name) would work, but Glob(name) would not... + // so, we might not be able to universally apply it as the default. + f, err := cl.Open(name) if err != nil { return nil, err @@ -2152,13 +2268,13 @@ const ( // In some cases, this may mark a "mailbox"-style file as successfuly read, // and the server will delete the file, and return an error for all later operations. func (f *File) Seek(offset int64, whence int) (int64, error) { + if f == nil { + return 0, fs.ErrInvalid + } + f.mu.Lock() defer f.mu.Unlock() - if f.handle == "" { - return 0, &fs.PathError{Op: "seek", Path: f.name, Err: fs.ErrClosed} - } - var abs int64 switch whence { case SeekStart: @@ -2172,19 +2288,11 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { } abs = fi.Size() + offset default: - return 0, &fs.PathError{ - Op: "seek", - Path: f.name, - Err: fmt.Errorf("%w: invalid whence: %d", fs.ErrInvalid, whence), - } + return 0, f.wrapErr("seek", fmt.Errorf("%w: invalid whence: %d", fs.ErrInvalid, whence)) } if offset < 0 { - return f.offset, &fs.PathError{ - Op: "seek", - Path: f.name, - Err: fmt.Errorf("%w: negative offset: %d", fs.ErrInvalid, offset), - } + return 0, f.wrapErr("seek", fmt.Errorf("%w: negative offset: %d", fs.ErrInvalid, whence)) } f.offset = abs @@ -2198,16 +2306,22 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { // then no request will be sent, // and Sync returns an *fs.PathError wrapping sshfx.StatusOpUnsupported. func (f *File) Sync() error { - if !f.cl.hasExtension(openssh.ExtensionFSync()) { - return &fs.PathError{Op: "fsync", Path: f.name, Err: sshfx.StatusOpUnsupported} + if f == nil { + return fs.ErrInvalid } - err := f.cl.sendPacket(context.Background(), &openssh.FSyncExtendedPacket{ - Handle: f.handle, - }) + handle, closed, err := f.handle.get() if err != nil { - return &fs.PathError{Op: "fsync", Path: f.name, Err: err} + return f.wrapErr("fsync", err) } - return nil + if !f.cl.hasExtension(openssh.ExtensionFSync()) { + return f.wrapErr("fsync", sshfx.StatusOpUnsupported) + } + + return f.wrapErr("fsync", + f.cl.sendPacket(context.Background(), closed, &openssh.FSyncExtendedPacket{ + Handle: handle, + }), + ) } diff --git a/server.go b/server.go index d6ad6204..d63157a1 100644 --- a/server.go +++ b/server.go @@ -192,8 +192,8 @@ type DirHandler interface { type wrapHandler func(ctx context.Context, req sshfx.Packet) (sshfx.Packet, error) -// handle is the intersection of FileHandler and DirHandler -type handle interface { +// srvHandle is the intersection of FileHandler and DirHandler +type srvHandle interface { io.Closer Name() string @@ -212,7 +212,7 @@ type Server struct { Debug io.Writer wg sync.WaitGroup - handles sync.Map[string, handle] + handles sync.Map[string, srvHandle] hijacks map[sshfx.PacketType]wrapHandler dataPktPool *sync.Pool[sshfx.DataPacket]