Skip to content

Commit

Permalink
Merge pull request #11291 from xianzhe-databricks:fix-uds-auth
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712981221
  • Loading branch information
gvisor-bot committed Jan 7, 2025
2 parents cd31d05 + c4f686f commit 7aa4c49
Show file tree
Hide file tree
Showing 12 changed files with 283 additions and 55 deletions.
5 changes: 4 additions & 1 deletion images/basic/integrationtest/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ RUN gcc -O2 -o tcp_server tcp_server.c

# Add nonprivileged regular user named "nonroot".
RUN groupadd --gid 1337 nonroot && \
useradd --uid 1337 --gid 1337 \
useradd --uid 1338 --gid 1337 \
--create-home \
--shell $(which bash) \
--password '' \
nonroot

# Copy host_connect to /home/nonroot so that "nonroot" can execute it.
RUN cp host_connect /home/nonroot/host_connect
33 changes: 26 additions & 7 deletions pkg/lisafs/client_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,32 @@ func (f *ClientFD) BindAt(ctx context.Context, sockType linux.SockType, name str
}

// Connect makes the Connect RPC.
func (f *ClientFD) Connect(ctx context.Context, sockType linux.SockType) (int, error) {
req := ConnectReq{FD: f.fd, SockType: uint32(sockType)}
var resp ConnectResp
var sockFD [1]int
ctx.UninterruptibleSleepStart(false)
err := f.client.SndRcvMessage(Connect, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.CheckedUnmarshal, sockFD[:], req.String, resp.String)
ctx.UninterruptibleSleepFinish(false)
func (f *ClientFD) Connect(ctx context.Context, sockType linux.SockType, euid UID, egid GID) (int, error) {
credsAvailable := euid != NoUID && egid != NoGID
var (
err error
sockFD [1]int
resp ConnectResp
req = ConnectReq{
FD: f.fd,
SockType: uint32(sockType),
}
)
if credsAvailable && f.client.IsSupported(ConnectWithCreds) {
reqWithCreds := ConnectWithCredsReq{
ConnectReq: req,
UID: euid,
GID: egid,
}
ctx.UninterruptibleSleepStart(false)
err = f.client.SndRcvMessage(ConnectWithCreds, uint32(reqWithCreds.SizeBytes()), reqWithCreds.MarshalUnsafe, resp.CheckedUnmarshal, sockFD[:], reqWithCreds.String, resp.String)
ctx.UninterruptibleSleepFinish(false)
} else {
ctx.UninterruptibleSleepStart(false)
err = f.client.SndRcvMessage(Connect, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.CheckedUnmarshal, sockFD[:], req.String, resp.String)
ctx.UninterruptibleSleepFinish(false)
}

if err == nil && sockFD[0] < 0 {
err = unix.EBADF
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/lisafs/fd.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,13 @@ type ControlFDImpl interface {
// On the server, Connect has a read concurrency guarantee.
Connect(sockType uint32) (int, error)

// ConnectWithCreds is a wrapper around Connect but first changes the gofer's
// euid and egid to the given uid and gid before calling Connect. It restores
// the euid and egid after Connect.
//
// On the server, ConnectWithCreds has a read concurrency guarantee.
ConnectWithCreds(sockType uint32, uid UID, gid GID) (int, error)

// BindAt creates a host unix domain socket of type sockType, bound to
// the given namt of type sockType, bound to the given name. It returns
// a ControlFD that can be used for path operations on the socket, a
Expand Down
95 changes: 63 additions & 32 deletions pkg/lisafs/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,38 +46,39 @@ const (
type RPCHandler func(c *Connection, comm Communicator, payloadLen uint32) (uint32, error)

var handlers = [...]RPCHandler{
Error: ErrorHandler,
Mount: MountHandler,
Channel: ChannelHandler,
FStat: FStatHandler,
SetStat: SetStatHandler,
Walk: WalkHandler,
WalkStat: WalkStatHandler,
OpenAt: OpenAtHandler,
OpenCreateAt: OpenCreateAtHandler,
Close: CloseHandler,
FSync: FSyncHandler,
PWrite: PWriteHandler,
PRead: PReadHandler,
MkdirAt: MkdirAtHandler,
MknodAt: MknodAtHandler,
SymlinkAt: SymlinkAtHandler,
LinkAt: LinkAtHandler,
FStatFS: FStatFSHandler,
FAllocate: FAllocateHandler,
ReadLinkAt: ReadLinkAtHandler,
Flush: FlushHandler,
UnlinkAt: UnlinkAtHandler,
RenameAt: RenameAtHandler,
Getdents64: Getdents64Handler,
FGetXattr: FGetXattrHandler,
FSetXattr: FSetXattrHandler,
FListXattr: FListXattrHandler,
FRemoveXattr: FRemoveXattrHandler,
Connect: ConnectHandler,
BindAt: BindAtHandler,
Listen: ListenHandler,
Accept: AcceptHandler,
Error: ErrorHandler,
Mount: MountHandler,
Channel: ChannelHandler,
FStat: FStatHandler,
SetStat: SetStatHandler,
Walk: WalkHandler,
WalkStat: WalkStatHandler,
OpenAt: OpenAtHandler,
OpenCreateAt: OpenCreateAtHandler,
Close: CloseHandler,
FSync: FSyncHandler,
PWrite: PWriteHandler,
PRead: PReadHandler,
MkdirAt: MkdirAtHandler,
MknodAt: MknodAtHandler,
SymlinkAt: SymlinkAtHandler,
LinkAt: LinkAtHandler,
FStatFS: FStatFSHandler,
FAllocate: FAllocateHandler,
ReadLinkAt: ReadLinkAtHandler,
Flush: FlushHandler,
UnlinkAt: UnlinkAtHandler,
RenameAt: RenameAtHandler,
Getdents64: Getdents64Handler,
FGetXattr: FGetXattrHandler,
FSetXattr: FSetXattrHandler,
FListXattr: FListXattrHandler,
FRemoveXattr: FRemoveXattrHandler,
Connect: ConnectHandler,
BindAt: BindAtHandler,
Listen: ListenHandler,
Accept: AcceptHandler,
ConnectWithCreds: ConnectWithCredsHandler,
}

// ErrorHandler handles Error message.
Expand Down Expand Up @@ -1069,6 +1070,36 @@ func ConnectHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32
return 0, nil
}

// ConnectWithCredsHandler handles the ConnectWithCreds RPC.
func ConnectWithCredsHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
var req ConnectWithCredsReq
if _, ok := req.CheckedUnmarshal(comm.PayloadBuf(payloadLen)); !ok {
return 0, unix.EIO
}

fd, err := c.lookupControlFD(req.FD)
if err != nil {
return 0, err
}
defer fd.DecRef(nil)
if !fd.IsSocket() {
return 0, unix.ENOTSOCK
}
var sock int
if err := fd.safelyRead(func() error {
if fd.node.isDeleted() {
return unix.EINVAL
}
sock, err = fd.impl.ConnectWithCreds(req.SockType, req.UID, req.GID)
return err
}); err != nil {
return 0, err
}

comm.DonateFD(sock)
return 0, nil
}

// BindAtHandler handles the BindAt RPC.
func BindAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
var req BindAtReq
Expand Down
19 changes: 19 additions & 0 deletions pkg/lisafs/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ const (

// Accept is analogous to accept4(2).
Accept MID = 31

// ConnectWithCreds is analogous to connect(2) but it asks the server
// to connect with the provided effective uid/gid.
ConnectWithCreds MID = 32
)

const (
Expand Down Expand Up @@ -1318,6 +1322,21 @@ func (*ConnectResp) String() string {
return "ConnectResp{}"
}

// ConnectWithCredsReq is used to make a ConnectWithCreds request. The response is also ConnectResp.
//
// +marshal boundCheck
type ConnectWithCredsReq struct {
ConnectReq
// UID and GID are used to specify the credentials to connect with.
UID UID
GID GID
}

// String implements fmt.Stringer.String.
func (c *ConnectWithCredsReq) String() string {
return fmt.Sprintf("ConnectWithCredsReq{FD: %d, SockType: %d, UID: %d, GID: %d}", c.FD, c.SockType, c.UID, c.GID)
}

// BindAtReq is used to make BindAt requests.
type BindAtReq struct {
createCommon
Expand Down
11 changes: 9 additions & 2 deletions pkg/sentry/fsimpl/gofer/dentry_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,18 @@ func (d *dentry) allocate(ctx context.Context, mode, offset, length uint64) erro
// - !d.isSynthetic().
// - fs.renameMu is locked.
func (d *dentry) connect(ctx context.Context, sockType linux.SockType) (int, error) {
creds := auth.CredentialsOrNilFromContext(ctx)
euid := lisafs.NoUID
egid := lisafs.NoGID
if creds != nil {
euid = lisafs.UID(creds.EffectiveKUID)
egid = lisafs.GID(creds.EffectiveKGID)
}
switch dt := d.impl.(type) {
case *lisafsDentry:
return dt.controlFD.Connect(ctx, sockType)
return dt.controlFD.Connect(ctx, sockType, euid, egid)
case *directfsDentry:
return dt.connect(ctx, sockType)
return dt.connect(ctx, sockType, euid, egid)
default:
panic("unknown dentry implementation")
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sentry/fsimpl/gofer/directfs_dentry.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,13 +603,13 @@ func (d *directfsDentry) getDirentsLocked(recordDirent func(name string, key ino
}

// Precondition: fs.renameMu is locked.
func (d *directfsDentry) connect(ctx context.Context, sockType linux.SockType) (int, error) {
func (d *directfsDentry) connect(ctx context.Context, sockType linux.SockType, euid lisafs.UID, egid lisafs.GID) (int, error) {
// There are no filesystems mounted in the sandbox process's mount namespace.
// So we can't perform absolute path traversals. So fallback to using lisafs.
if err := d.ensureLisafsControlFD(ctx); err != nil {
return -1, err
}
return d.controlFDLisa.Connect(ctx, sockType)
return d.controlFDLisa.Connect(ctx, sockType, euid, egid)
}

func (d *directfsDentry) readlink() (string, error) {
Expand Down
9 changes: 9 additions & 0 deletions pkg/sentry/kernel/auth/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ func CredentialsFromContext(ctx context.Context) *Credentials {
return NewAnonymousCredentials()
}

// CredentialsOrNilFromContext returns a copy of the Credentials used by ctx,
// or nil if ctx does not have Credentials.
func CredentialsOrNilFromContext(ctx context.Context) *Credentials {
if v := ctx.Value(CtxCredentials); v != nil {
return v.(*Credentials)
}
return nil
}

// ThreadGroupIDFromContext returns the current thread group ID when ctx
// represents a task context.
func ThreadGroupIDFromContext(ctx context.Context) (tgid int32, ok bool) {
Expand Down
31 changes: 28 additions & 3 deletions runsc/cmd/gofer.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ var caps = []string{
"CAP_SYS_CHROOT",
}

var udsOpenCaps = []string{
"CAP_SETUID",
"CAP_SETGID",
}

// goferCaps is the minimal set of capabilities needed by the Gofer to operate
// on files.
var goferCaps = &specs.LinuxCapabilities{
Expand All @@ -61,6 +66,12 @@ var goferCaps = &specs.LinuxCapabilities{
Permitted: caps,
}

var goferUdsOpenCaps = &specs.LinuxCapabilities{
Bounding: udsOpenCaps,
Effective: udsOpenCaps,
Permitted: udsOpenCaps,
}

// goferSyncFDs contains file descriptors that are used for synchronization
// of the Gofer startup process against other processes.
type goferSyncFDs struct {
Expand Down Expand Up @@ -181,7 +192,11 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...any) subcomm
overrides["apply-caps"] = "false"
overrides["setup-root"] = "false"
args := prepareArgs(g.Name(), f, overrides)
util.Fatalf("setCapsAndCallSelf(%v, %v): %v", args, goferCaps, setCapsAndCallSelf(args, goferCaps))
capsToApply := goferCaps
if conf.GetHostUDS().AllowOpen() {
capsToApply = specutils.MergeCapabilities(capsToApply, goferUdsOpenCaps)
}
util.Fatalf("setCapsAndCallSelf(%v, %v): %v", args, capsToApply, setCapsAndCallSelf(args, capsToApply))
panic("unreachable")
}

Expand Down Expand Up @@ -253,6 +268,12 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...any) subcomm
}
log.Infof("Process chroot'd to %q", root)

ruid := unix.Getuid()
euid := unix.Geteuid()
rgid := unix.Getgid()
egid := unix.Getegid()
log.Debugf("Process running as uid=%d euid=%d gid=%d egid=%d", ruid, euid, rgid, egid)

// Initialize filters.
opts := filter.Options{
UDSOpenEnabled: conf.GetHostUDS().AllowOpen(),
Expand All @@ -265,7 +286,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...any) subcomm
util.Fatalf("installing seccomp filters: %v", err)
}

return g.serve(spec, conf, root)
return g.serve(spec, conf, root, ruid, euid, rgid, egid)
}

func newSocket(ioFD int) *unet.Socket {
Expand All @@ -276,7 +297,7 @@ func newSocket(ioFD int) *unet.Socket {
return socket
}

func (g *Gofer) serve(spec *specs.Spec, conf *config.Config, root string) subcommands.ExitStatus {
func (g *Gofer) serve(spec *specs.Spec, conf *config.Config, root string, ruid int, euid int, rgid int, egid int) subcommands.ExitStatus {
type connectionConfig struct {
sock *unet.Socket
mountPath string
Expand All @@ -289,6 +310,10 @@ func (g *Gofer) serve(spec *specs.Spec, conf *config.Config, root string) subcom
HostUDS: conf.GetHostUDS(),
HostFifo: conf.HostFifo,
DonateMountPointFD: conf.DirectFS,
RUID: ruid,
EUID: euid,
RGID: rgid,
EGID: egid,
})

ioFDs := g.ioFDs
Expand Down
4 changes: 3 additions & 1 deletion runsc/fsgofer/filter/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ var udsCommonSyscalls = seccomp.MakeSyscallRules(map[uintptr]seccomp.SyscallRule
})

var udsOpenSyscalls = seccomp.MakeSyscallRules(map[uintptr]seccomp.SyscallRule{
unix.SYS_CONNECT: seccomp.MatchAll{},
unix.SYS_CONNECT: seccomp.MatchAll{},
unix.SYS_SETREUID: seccomp.MatchAll{},
unix.SYS_SETREGID: seccomp.MatchAll{},
})

var udsCreateSyscalls = seccomp.MakeSyscallRules(map[uintptr]seccomp.SyscallRule{
Expand Down
Loading

0 comments on commit 7aa4c49

Please sign in to comment.