From 98fad757bf40bb3f75b34e968af5326da70815ec Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 15 Nov 2024 15:18:57 +0000 Subject: [PATCH] refactor getPacket to separate per-response-packet-types that even unpack the relevant field that we're actually interested in --- client.go | 106 ++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 71 insertions(+), 35 deletions(-) diff --git a/client.go b/client.go index 72a96911..9e435001 100644 --- a/client.go +++ b/client.go @@ -383,14 +383,14 @@ type respPacket[PKT any] interface { sshfx.Packet } -func getPacket[PKT any, P respPacket[PKT]](ctx context.Context, cancel <-chan struct{}, cl *Client, req sshfx.PacketMarshaller) (*PKT, error) { +func getPacket[RESP respPacket[PKT], PKT any](ctx context.Context, cancel <-chan struct{}, cl *Client, req sshfx.PacketMarshaller) (*PKT, error) { raw, err := cl.conn.send(ctx, cancel, req) if err != nil { return nil, err } defer cl.conn.returnRaw(raw) - var resp P + var resp RESP switch raw.PacketType { case resp.Type(): @@ -414,6 +414,38 @@ func getPacket[PKT any, P respPacket[PKT]](ctx context.Context, cancel <-chan st } } +func (cl *Client) getPath(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) (string, error) { + resp, err := getPacket[*sshfx.PathPseudoPacket](ctx, cancel, cl, req) + if err != nil { + return "", err + } + return resp.Path, nil +} + +func (cl *Client) getHandle(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) (string, error) { + resp, err := getPacket[*sshfx.HandlePacket](ctx, cancel, cl, req) + if err != nil { + return "", err + } + return resp.Handle, nil +} + +func (cl *Client) getNames(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) ([]*sshfx.NameEntry, error) { + resp, err := getPacket[*sshfx.NamePacket](ctx, cancel, cl, req) + if err != nil { + return nil, err + } + return resp.Entries, nil +} + +func (cl *Client) getAttrs(ctx context.Context, cancel <-chan struct{}, req sshfx.PacketMarshaller) (*sshfx.Attributes, error) { + resp, err := getPacket[*sshfx.AttrsPacket](ctx, cancel, cl, req) + if err != nil { + return nil, err + } + return &resp.Attrs, nil +} + func statusToError(status *sshfx.StatusPacket, okExpected bool) error { switch status.StatusCode { case sshfx.StatusOK: @@ -689,35 +721,39 @@ func (cl *Client) MkdirAll(name string, perm fs.FileMode) error { func (cl *Client) Remove(name string) error { ctx := context.Background() - err := cl.sendPacket(ctx, nil, &sshfx.RemovePacket{ + errFile := cl.sendPacket(ctx, nil, &sshfx.RemovePacket{ Path: name, }) - if err == nil { + if errFile == nil { return nil } - err1 := cl.sendPacket(ctx, nil, &sshfx.RmdirPacket{ + errDir := cl.sendPacket(ctx, nil, &sshfx.RmdirPacket{ Path: name, }) - if err1 == nil { + if errDir == nil { return nil } // Both failed: figure out which error to return. - if err != err1 { - attrs, err2 := getPacket[sshfx.AttrsPacket](ctx, nil, cl, &sshfx.StatPacket{ - Path: name, - }) - if err2 != nil { - err = err2 - } else { - if perm, ok := attrs.Attrs.GetPermissions(); ok && perm.IsDir() { - err = err1 - } - } + + if errFile == errDir { + // If they are the same error, then just return that. + return wrapPathError("remove", name, errFile) + } + + attrs, err := cl.getAttrs(ctx, nil, &sshfx.StatPacket{ + Path: name, + }) + if err != nil { + return wrapPathError("remove", name, err) + } + + if perm, ok := attrs.GetPermissions(); ok && perm.IsDir() { + return wrapPathError("remove", name, errDir) } - return wrapPathError("remove", name, err) + return wrapPathError("remove", name, errFile) } func (cl *Client) setstat(ctx context.Context, name string, attrs *sshfx.Attributes) error { @@ -788,14 +824,14 @@ func (cl *Client) Chtimes(name string, atime, mtime time.Time) error { // This is useful for converting path names containing ".." components, // or relative pathnames without a leading slash into absolute paths. func (cl *Client) RealPath(name string) (string, error) { - pkt, err := getPacket[sshfx.PathPseudoPacket](context.Background(), nil, cl, &sshfx.RealPathPacket{ + path, err := cl.getPath(context.Background(), nil, &sshfx.RealPathPacket{ Path: name, }) if err != nil { return "", wrapPathError("realpath", name, err) } - return pkt.Path, nil + return path, nil } // ReadLink returns the destination of the named symbolic link. @@ -803,14 +839,14 @@ func (cl *Client) RealPath(name string) (string, error) { // The client cannot guarantee any specific way that a server handles a relative link destination. // That is, you may receive a relative link destination, one that has been converted to an absolute path. func (cl *Client) ReadLink(name string) (string, error) { - pkt, err := getPacket[sshfx.PathPseudoPacket](context.Background(), nil, cl, &sshfx.ReadLinkPacket{ + path, err := cl.getPath(context.Background(), nil, &sshfx.ReadLinkPacket{ Path: name, }) if err != nil { return "", wrapPathError("readlink", name, err) } - return pkt.Path, nil + return path, nil } // Rename renames (moves) oldpath to newpath. @@ -916,7 +952,7 @@ func (cl *Client) ReadDirContext(ctx context.Context, name string) ([]fs.DirEntr // Stat returns a FileInfo describing the named file. // If the file is a symbolic link, the returned FileInfo describes the link's target. func (cl *Client) Stat(name string) (fs.FileInfo, error) { - pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), nil, cl, &sshfx.StatPacket{ + attrs, err := cl.getAttrs(context.Background(), nil, &sshfx.StatPacket{ Path: name, }) if err != nil { @@ -925,7 +961,7 @@ func (cl *Client) Stat(name string) (fs.FileInfo, error) { return &sshfx.NameEntry{ Filename: name, - Attrs: pkt.Attrs, + Attrs: *attrs, }, nil } @@ -935,7 +971,7 @@ func (cl *Client) Stat(name string) (fs.FileInfo, error) { // // The description returned may have server specific caveats and special cases that cannot be covered here. func (cl *Client) LStat(name string) (fs.FileInfo, error) { - pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), nil, cl, &sshfx.LStatPacket{ + attrs, err := cl.getAttrs(context.Background(), nil, &sshfx.LStatPacket{ Path: name, }) if err != nil { @@ -944,7 +980,7 @@ func (cl *Client) LStat(name string) (fs.FileInfo, error) { return &sshfx.NameEntry{ Filename: name, - Attrs: pkt.Attrs, + Attrs: *attrs, }, nil } @@ -1022,7 +1058,7 @@ type Dir struct { // // The semantics of SSH_FX_OPENDIR is such that the associated file handle is in a read-only mode. func (cl *Client) OpenDir(name string) (*Dir, error) { - pkt, err := getPacket[sshfx.HandlePacket](context.Background(), nil, cl, &sshfx.OpenDirPacket{ + handle, err := cl.getHandle(context.Background(), nil, &sshfx.OpenDirPacket{ Path: name, }) if err != nil { @@ -1034,7 +1070,7 @@ func (cl *Client) OpenDir(name string) (*Dir, error) { name: name, } - d.handle.init(pkt.Handle) + d.handle.init(handle) return d, nil } @@ -1085,7 +1121,7 @@ func (d *Dir) rangedir(ctx context.Context) iter.Seq2[*sshfx.NameEntry, error] { return } - pkt, err := getPacket[sshfx.NamePacket](ctx, closed, d.cl, &sshfx.ReadDirPacket{ + entries, err := d.cl.getNames(ctx, closed, &sshfx.ReadDirPacket{ Handle: handle, }) if err != nil { @@ -1095,10 +1131,10 @@ func (d *Dir) rangedir(ctx context.Context) iter.Seq2[*sshfx.NameEntry, error] { return } - for i, entry := range pkt.Entries { + for i, entry := range entries { if !yield(entry, nil) { // Early break, save the remaining entries we got for maybe later. - d.entries = append(d.entries, pkt.Entries[i+1:]...) + d.entries = append(d.entries, entries[i+1:]...) return } } @@ -1276,7 +1312,7 @@ func (cl *Client) Create(name string) (*File, error) { // Note well: since all Write operations are down through an offset-specifying operation, // the OpenFlagAppend flag is currently ignored. func (cl *Client) OpenFile(name string, flag int, perm fs.FileMode) (*File, error) { - pkt, err := getPacket[sshfx.HandlePacket](context.Background(), nil, cl, &sshfx.OpenPacket{ + handle, err := cl.getHandle(context.Background(), nil, &sshfx.OpenPacket{ Filename: name, PFlags: toPortableFlags(flag), Attrs: sshfx.Attributes{ @@ -1293,7 +1329,7 @@ func (cl *Client) OpenFile(name string, flag int, perm fs.FileMode) (*File, erro name: name, } - f.handle.init(pkt.Handle) + f.handle.init(handle) return f, nil } @@ -1391,7 +1427,7 @@ func (f *File) Stat() (fs.FileInfo, error) { return nil, f.wrapErr("fstat", err) } - pkt, err := getPacket[sshfx.AttrsPacket](context.Background(), closed, f.cl, &sshfx.FStatPacket{ + attrs, err := f.cl.getAttrs(context.Background(), closed, &sshfx.FStatPacket{ Handle: handle, }) if err != nil { @@ -1400,7 +1436,7 @@ func (f *File) Stat() (fs.FileInfo, error) { return &sshfx.NameEntry{ Filename: f.name, - Attrs: pkt.Attrs, + Attrs: *attrs, }, nil }