diff --git a/packet.go b/packet.go index d89ad997..7a0d0027 100644 --- a/packet.go +++ b/packet.go @@ -667,12 +667,13 @@ type sshFxpOpenPacket struct { ID uint32 Path string Pflags uint32 - Flags uint32 // ignored + Flags uint32 + Attrs interface{} } func (p *sshFxpOpenPacket) id() uint32 { return p.ID } -func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { +func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) { l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Path) + 4 + 4 @@ -684,7 +685,14 @@ func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { b = marshalUint32(b, p.Pflags) b = marshalUint32(b, p.Flags) - return b, nil + payload := marshal(nil, p.Attrs) + + return b, payload, nil +} + +func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err } func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { @@ -695,9 +703,10 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { return err } else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil { return err - } else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil { + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { return err } + p.Attrs = b return nil } @@ -869,13 +878,14 @@ func (p *sshFxpWritePacket) UnmarshalBinary(b []byte) error { type sshFxpMkdirPacket struct { ID uint32 - Flags uint32 // ignored Path string + Flags uint32 + Attrs interface{} } func (p *sshFxpMkdirPacket) id() uint32 { return p.ID } -func (p *sshFxpMkdirPacket) MarshalBinary() ([]byte, error) { +func (p *sshFxpMkdirPacket) marshalPacket() ([]byte, []byte, error) { l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Path) + 4 // uint32 @@ -886,7 +896,14 @@ func (p *sshFxpMkdirPacket) MarshalBinary() ([]byte, error) { b = marshalString(b, p.Path) b = marshalUint32(b, p.Flags) - return b, nil + payload := marshal(nil, p.Attrs) + + return b, payload, nil +} + +func (p *sshFxpMkdirPacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err } func (p *sshFxpMkdirPacket) UnmarshalBinary(b []byte) error { @@ -895,9 +912,10 @@ func (p *sshFxpMkdirPacket) UnmarshalBinary(b []byte) error { return err } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { return err - } else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil { + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { return err } + p.Attrs = b return nil } diff --git a/server.go b/server.go index 503454e6..f85f1d4a 100644 --- a/server.go +++ b/server.go @@ -19,6 +19,9 @@ import ( const ( // SftpServerWorkerCount defines the number of workers for the SFTP server SftpServerWorkerCount = 8 + + defaultFileMode = 0o644 + defaultDirMode = 0o755 ) // Server is an SSH File Transfer Protocol (sftp) server. @@ -218,8 +221,15 @@ func handlePacket(s *Server, p orderedRequest) error { rpkt = statusFromError(p.ID, err) } case *sshFxpMkdirPacket: - // TODO FIXME: ignore flags field - err := os.Mkdir(s.toLocalPath(p.Path), 0o755) + var mode os.FileMode = defaultDirMode + if p.Attrs != nil { + attrs, _ := unmarshalFileStat(p.Flags, p.Attrs.([]byte)) + if p.Flags&sshFileXferAttrPermissions != 0 { + mode = toFileMode(attrs.Mode) + } + } + + err := os.Mkdir(s.toLocalPath(p.Path), mode) rpkt = statusFromError(p.ID, err) case *sshFxpRmdirPacket: err := os.Remove(s.toLocalPath(p.Path)) @@ -458,7 +468,15 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { osFlags |= os.O_EXCL } - f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644) + var mode os.FileMode = defaultFileMode + if p.Attrs != nil { + attrs, _ := unmarshalFileStat(p.Flags, p.Attrs.([]byte)) + if p.Flags&sshFileXferAttrPermissions != 0 { + mode = toFileMode(attrs.Mode) + } + } + + f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode) if err != nil { return statusFromError(p.ID, err) }