Skip to content

Commit

Permalink
Merge pull request #35 from loopholelabs/improve-context-signaling
Browse files Browse the repository at this point in the history
Improve Signaling Logic
  • Loading branch information
pojntfx authored Aug 23, 2024
2 parents ab3faed + bf110eb commit e712a91
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 41 deletions.
24 changes: 16 additions & 8 deletions cmd/drafter-mounter/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,22 +355,30 @@ func main() {
if err != nil {
panic(err)
}
defer lis.Close()
defer func() {
defer goroutineManager.CreateForegroundPanicCollector()()

closeLock.Lock()
defer closeLock.Unlock()

closed = true

closeLock.Unlock()

if err := lis.Close(); err != nil {
panic(err)
}
}()

log.Println("Serving on", lis.Addr())

l:
for {
// We use `context.Background` here because we want to distinguish between a cancellation and a successful accept
// We select between `acceptedCtx` and `ctx` on all code paths so we don't leak the context
acceptedCtx, cancelAcceptedCtx := context.WithCancel(context.Background())
defer cancelAcceptedCtx()
var (
ready = make(chan struct{})
signalReady = sync.OnceFunc(func() {
close(ready) // We can safely close() this channel since the caller only runs once/is `sync.OnceFunc`d
})
)

var conn net.Conn
goroutineManager.StartForegroundGoroutine(func(_ context.Context) {
Expand All @@ -390,7 +398,7 @@ l:
panic(err)
}

cancelAcceptedCtx()
signalReady()
})

bubbleSignals = true
Expand All @@ -403,7 +411,7 @@ l:
case <-done:
break l

case <-acceptedCtx.Done():
case <-ready:
break s
}

Expand Down
24 changes: 16 additions & 8 deletions cmd/drafter-peer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -481,20 +481,28 @@ func main() {
if err != nil {
panic(err)
}
defer lis.Close()
defer func() {
defer goroutineManager.CreateForegroundPanicCollector()()

closeLock.Lock()
defer closeLock.Unlock()

closed = true

closeLock.Unlock()

if err := lis.Close(); err != nil {
panic(err)
}
}()

log.Println("Serving on", lis.Addr())

// We use `context.Background` here because we want to distinguish between a cancellation and a successful accept
// We select between `acceptedCtx` and `ctx` on all code paths so we don't leak the context
acceptedCtx, cancelAcceptedCtx := context.WithCancel(context.Background())
defer cancelAcceptedCtx()
var (
ready = make(chan struct{})
signalReady = sync.OnceFunc(func() {
close(ready) // We can safely close() this channel since the caller only runs once/is `sync.OnceFunc`d
})
)

var conn net.Conn
goroutineManager.StartForegroundGoroutine(func(_ context.Context) {
Expand All @@ -514,7 +522,7 @@ func main() {
panic(err)
}

cancelAcceptedCtx()
signalReady()
})

bubbleSignals = true
Expand All @@ -536,7 +544,7 @@ func main() {

return

case <-acceptedCtx.Done():
case <-ready:
break
}

Expand Down
3 changes: 2 additions & 1 deletion os/configs/drafteros-firecracker-x86_64_pvm_defconfig
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
BR2_x86_64=y
BR2_KERNEL_HEADERS_VERSION=y
BR2_DEFAULT_KERNEL_VERSION="6.8.12"
BR2_DEFAULT_KERNEL_VERSION="6.7.0-rc6"
BR2_KERNEL_HEADERS_6_6=y
BR2_GNU_MIRROR="https://mirrors.kernel.org/gnu"
BR2_CCACHE=y
BR2_TARGET_GENERIC_HOSTNAME="drafterhost"
Expand Down
10 changes: 8 additions & 2 deletions pkg/forwarder/forward_ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"runtime"
"sync"

"github.com/coreos/go-iptables/iptables"
"github.com/loopholelabs/drafter/internal/utils"
Expand Down Expand Up @@ -252,9 +253,14 @@ func ForwardPorts(
},
)

closeInProgress := make(chan any)
var (
closeInProgress = make(chan struct{})
signalCloseInProgress = sync.OnceFunc(func() {
close(closeInProgress) // We can safely close() this channel since the caller only runs once/is `sync.OnceFunc`d
})
)
forwardedPorts.Close = func() (errs error) {
defer close(closeInProgress)
defer signalCloseInProgress()

for _, closeFuncs := range deferFuncs {
for _, closeFunc := range closeFuncs {
Expand Down
9 changes: 7 additions & 2 deletions pkg/ipc/agent_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ func StartAgentClient(
cancelLinkCtx(goroutineManager.GetErrGoroutineStopped())
}

ready := make(chan any)
var (
ready = make(chan struct{})
signalReady = sync.OnceFunc(func() {
close(ready) // We can safely close() this channel since the caller only runs once/is `sync.OnceFunc`d
})
)
// This goroutine will not leak on function return because it selects on `goroutineManager.Context().Done()`
// internally and we return a wait function
goroutineManager.StartBackgroundGoroutine(func(ctx context.Context) {
Expand All @@ -120,7 +125,7 @@ func StartAgentClient(

&rpc.RegistryHooks{
OnClientConnect: func(remoteID string) {
close(ready)
signalReady()
},
},
)
Expand Down
9 changes: 7 additions & 2 deletions pkg/ipc/agent_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ func (agentServer *AgentServer) Accept(acceptCtx context.Context, remoteCtx cont
defer goroutineManager.StopAllGoroutines()
defer goroutineManager.CreateBackgroundPanicCollector()()

ready := make(chan any)
var (
ready = make(chan struct{})
signalReady = sync.OnceFunc(func() {
close(ready) // We can safely close() this channel since the caller only runs once/is `sync.OnceFunc`d
})
)
// This goroutine will not leak on function return because it selects on `goroutineManager.Context().Done()` internally
goroutineManager.StartBackgroundGoroutine(func(ctx context.Context) {
select {
Expand Down Expand Up @@ -170,7 +175,7 @@ func (agentServer *AgentServer) Accept(acceptCtx context.Context, remoteCtx cont

&rpc.RegistryHooks{
OnClientConnect: func(remoteID string) {
close(ready)
signalReady()
},
},
)
Expand Down
19 changes: 13 additions & 6 deletions pkg/mounter/migrate_from.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,15 @@ func MigrateFromAndMount(
}

var (
allRemoteDevicesReceived = make(chan any)
allRemoteDevicesReady = make(chan any)
allRemoteDevicesReceived = make(chan struct{})
signalAllRemoteDevicesReceived = sync.OnceFunc(func() {
close(allRemoteDevicesReceived) // We can safely close() this channel since the caller only runs once/is `sync.OnceFunc`d
})

allRemoteDevicesReady = make(chan struct{})
signalAllRemoteDevicesReady = sync.OnceFunc(func() {
close(allRemoteDevicesReady) // We can safely close() this channel since the caller only runs once/is `sync.OnceFunc`d
})
)

// We don't `defer cancelProtocolCtx()` this because we cancel in the wait function
Expand Down Expand Up @@ -244,15 +251,15 @@ func MigrateFromAndMount(
case packets.EventCustom:
switch e.CustomType {
case byte(registry.EventCustomAllDevicesSent):
close(allRemoteDevicesReceived)
signalAllRemoteDevicesReceived()

if hook := hooks.OnRemoteAllDevicesReceived; hook != nil {
hook()
}

case byte(registry.EventCustomTransferAuthority):
if receivedButNotReadyRemoteDevices.Add(-1) <= 0 {
close(allRemoteDevicesReady)
signalAllRemoteDevicesReady()
}

if hook := hooks.OnRemoteDeviceAuthorityReceived; hook != nil {
Expand Down Expand Up @@ -298,15 +305,15 @@ func MigrateFromAndMount(
select {
case <-allRemoteDevicesReceived:
default:
close(allRemoteDevicesReceived)
signalAllRemoteDevicesReceived()

// We need to call the hook manually too since we would otherwise only call if we received at least one device
if hook := hooks.OnRemoteAllDevicesReceived; hook != nil {
hook()
}
}

close(allRemoteDevicesReady)
signalAllRemoteDevicesReady()

if hook := hooks.OnRemoteAllMigrationsCompleted; hook != nil {
hook()
Expand Down
4 changes: 2 additions & 2 deletions pkg/mounter/migrate_to.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ func (migratableMounter *MigratableMounter) MigrateTo(
suspendedVM bool
)

suspendedVMCh := make(chan any)
suspendedVMCh := make(chan struct{})

suspendAndMsyncVM := sync.OnceValue(func() error {
suspendedVMLock.Lock()
suspendedVM = true
suspendedVMLock.Unlock()

close(suspendedVMCh)
close(suspendedVMCh) // We can safely close() this channel since the caller only runs once/is `sync.OnceFunc`d

return nil
})
Expand Down
4 changes: 2 additions & 2 deletions pkg/nat/nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func CreateNAT(
return nil
}

ready := make(chan any)
ready := make(chan struct{})
// This goroutine will not leak on function return because it selects on `goroutineManager.Context().Done()` internally
goroutineManager.StartBackgroundGoroutine(func(internalCtx context.Context) {
select {
Expand Down Expand Up @@ -299,7 +299,7 @@ func CreateNAT(
}
}

close(ready)
close(ready) // We can safely close() this channel since this code path will only run once

return
}
19 changes: 13 additions & 6 deletions pkg/peer/migrate_from.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,15 @@ func (peer *Peer) MigrateFrom(
}

var (
allRemoteDevicesReceived = make(chan any)
allRemoteDevicesReady = make(chan any)
allRemoteDevicesReceived = make(chan struct{})
signalAllRemoteDevicesReceived = sync.OnceFunc(func() {
close(allRemoteDevicesReceived) // We can safely close() this channel since the caller only runs once/is `sync.OnceFunc`d
})

allRemoteDevicesReady = make(chan struct{})
signalAllRemoteDevicesReady = sync.OnceFunc(func() {
close(allRemoteDevicesReady) // We can safely close() this channel since the caller only runs once/is `sync.OnceFunc`d
})
)

// We don't `defer cancelProtocolCtx()` this because we cancel in the wait function
Expand Down Expand Up @@ -268,15 +275,15 @@ func (peer *Peer) MigrateFrom(
case packets.EventCustom:
switch e.CustomType {
case byte(registry.EventCustomAllDevicesSent):
close(allRemoteDevicesReceived)
signalAllRemoteDevicesReceived()

if hook := hooks.OnRemoteAllDevicesReceived; hook != nil {
hook()
}

case byte(registry.EventCustomTransferAuthority):
if receivedButNotReadyRemoteDevices.Add(-1) <= 0 {
close(allRemoteDevicesReady)
signalAllRemoteDevicesReady()
}

if hook := hooks.OnRemoteDeviceAuthorityReceived; hook != nil {
Expand Down Expand Up @@ -322,15 +329,15 @@ func (peer *Peer) MigrateFrom(
select {
case <-allRemoteDevicesReceived:
default:
close(allRemoteDevicesReceived)
signalAllRemoteDevicesReceived()

// We need to call the hook manually too since we would otherwise only call if we received at least one device
if hook := hooks.OnRemoteAllDevicesReceived; hook != nil {
hook()
}
}

close(allRemoteDevicesReady)
signalAllRemoteDevicesReady()

if hook := hooks.OnRemoteAllMigrationsCompleted; hook != nil {
hook()
Expand Down
4 changes: 2 additions & 2 deletions pkg/peer/migrate_to.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (migratablePeer *MigratablePeer) MigrateTo(
suspendedVM bool
)

suspendedVMCh := make(chan any)
suspendedVMCh := make(chan struct{})

suspendAndMsyncVM := sync.OnceValue(func() error {
if hook := hooks.OnBeforeSuspend; hook != nil {
Expand All @@ -111,7 +111,7 @@ func (migratablePeer *MigratablePeer) MigrateTo(
suspendedVM = true
suspendedVMLock.Unlock()

close(suspendedVMCh)
close(suspendedVMCh) // We can safely close() this channel since the caller only runs once/is `sync.OnceValue`d

return nil
})
Expand Down

0 comments on commit e712a91

Please sign in to comment.