diff --git a/Makefile b/Makefile index c744193..307d1cd 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ OCI_IMAGE_URI ?= docker://valkey/valkey:latest OCI_IMAGE_ARCHITECTURE ?= amd64 OCI_IMAGE_HOSTNAME ?= drafterguest -OS_URL ?= https://buildroot.org/downloads/buildroot-2024.08-rc1.tar.gz +OS_URL ?= https://buildroot.org/downloads/buildroot-2024.08.tar.gz OS_DEFCONFIG ?= drafteros-firecracker-x86_64_defconfig OS_BR2_EXTERNAL ?= ../../os diff --git a/cmd/drafter-agent/main.go b/cmd/drafter-agent/main.go index 4a58c1b..fd64453 100644 --- a/cmd/drafter-agent/main.go +++ b/cmd/drafter-agent/main.go @@ -55,7 +55,9 @@ func main() { cancel() }() - agentClient := ipc.NewAgentClient( + agentClient := ipc.NewAgentClient[struct{}]( + struct{}{}, + func(ctx context.Context) error { log.Println("Running pre-suspend command") @@ -95,7 +97,7 @@ func main() { dialCtx, cancelDialCtx := context.WithTimeout(goroutineManager.Context(), *vsockTimeout) defer cancelDialCtx() - connectedAgentClient, err := ipc.StartAgentClient( + connectedAgentClient, err := ipc.StartAgentClient[*ipc.AgentClientLocal[struct{}], struct{}]( dialCtx, goroutineManager.Context(), @@ -103,6 +105,7 @@ func main() { uint32(*vsockPort), agentClient, + ipc.StartAgentClientHooks[struct{}]{}, ) if err != nil { return err diff --git a/cmd/drafter-peer/main.go b/cmd/drafter-peer/main.go index 1e74389..075085d 100644 --- a/cmd/drafter-peer/main.go +++ b/cmd/drafter-peer/main.go @@ -16,6 +16,7 @@ import ( "sync" "time" + "github.com/loopholelabs/drafter/pkg/ipc" "github.com/loopholelabs/drafter/pkg/mounter" "github.com/loopholelabs/drafter/pkg/packager" "github.com/loopholelabs/drafter/pkg/peer" @@ -277,7 +278,7 @@ func main() { writers = []io.Writer{conn} } - p, err := peer.StartPeer( + p, err := peer.StartPeer[struct{}, ipc.AgentServerRemote[struct{}]]( goroutineManager.Context(), context.Background(), // Never give up on rescue operations @@ -328,9 +329,9 @@ func main() { } }) - migrateFromDevices := []peer.MigrateFromDevice{} + migrateFromDevices := []peer.MigrateFromDevice[struct{}, ipc.AgentServerRemote[struct{}], struct{}]{} for _, device := range devices { - migrateFromDevices = append(migrateFromDevices, peer.MigrateFromDevice{ + migrateFromDevices = append(migrateFromDevices, peer.MigrateFromDevice[struct{}, ipc.AgentServerRemote[struct{}], struct{}]{ Name: device.Name, Base: device.Base, @@ -419,6 +420,9 @@ func main() { *resumeTimeout, *rescueTimeout, + struct{}{}, + ipc.AgentServerAcceptHooks[ipc.AgentServerRemote[struct{}], struct{}]{}, + runner.SnapshotLoadConfiguration{ ExperimentalMapPrivate: *experimentalMapPrivate, diff --git a/cmd/drafter-runner/main.go b/cmd/drafter-runner/main.go index d91014a..cf1db04 100644 --- a/cmd/drafter-runner/main.go +++ b/cmd/drafter-runner/main.go @@ -13,6 +13,7 @@ import ( "syscall" "time" + "github.com/loopholelabs/drafter/pkg/ipc" "github.com/loopholelabs/drafter/pkg/packager" "github.com/loopholelabs/drafter/pkg/peer" "github.com/loopholelabs/drafter/pkg/runner" @@ -174,7 +175,7 @@ func main() { cancel() }() - r, err := runner.StartRunner( + r, err := runner.StartRunner[struct{}, ipc.AgentServerRemote[struct{}]]( goroutineManager.Context(), context.Background(), // Never give up on rescue operations @@ -303,6 +304,9 @@ func main() { *rescueTimeout, packageConfig.AgentVSockPort, + struct{}{}, + ipc.AgentServerAcceptHooks[ipc.AgentServerRemote[struct{}], struct{}]{}, + runner.SnapshotLoadConfiguration{ ExperimentalMapPrivate: *experimentalMapPrivate, diff --git a/go.mod b/go.mod index fdb6757..842faac 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/loopholelabs/goroutine-manager v0.1.1 github.com/loopholelabs/silo v0.0.8 github.com/metal-stack/go-ipam v1.14.0 - github.com/pojntfx/panrpc/go v0.0.0-20240816011753-7169be8c89fb + github.com/pojntfx/panrpc/go v0.0.0-20240913062914-ea5ef6b07692 github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netns v0.0.4 golang.org/x/sys v0.24.0 diff --git a/go.sum b/go.sum index 24e7fa0..8eccc07 100644 --- a/go.sum +++ b/go.sum @@ -169,8 +169,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pojntfx/panrpc/go v0.0.0-20240816011753-7169be8c89fb h1:2vU/ZbsJ1uZhR2BCjc3mIfzR9GVuAWHSIUrw8jnLK9g= -github.com/pojntfx/panrpc/go v0.0.0-20240816011753-7169be8c89fb/go.mod h1:G9YawT9jiXDf6z7WEv7KMIca+A41mjm8KL8AfBxN37U= +github.com/pojntfx/panrpc/go v0.0.0-20240913062914-ea5ef6b07692 h1:kiSksMNOL9fQQLen6hrOIC/Mgxcts6mAxcZBunOfCuA= +github.com/pojntfx/panrpc/go v0.0.0-20240913062914-ea5ef6b07692/go.mod h1:G9YawT9jiXDf6z7WEv7KMIca+A41mjm8KL8AfBxN37U= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/redis/go-redis/v9 v9.6.1 h1:HHDteefn6ZkTtY5fGUE8tj8uy85AHk6zP7CpzIAM0y4= diff --git a/internal/remotes/agent.go b/internal/remotes/agent.go deleted file mode 100644 index 6f0a84a..0000000 --- a/internal/remotes/agent.go +++ /dev/null @@ -1,8 +0,0 @@ -package remotes - -import "context" - -type AgentRemote struct { - BeforeSuspend func(ctx context.Context) error - AfterResume func(ctx context.Context) error -} diff --git a/internal/vsock/dialer.go b/internal/vsock/dialer.go index 29c8dee..bcf1ea8 100644 --- a/internal/vsock/dialer.go +++ b/internal/vsock/dialer.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "sync" "github.com/loopholelabs/goroutine-manager/pkg/manager" "golang.org/x/sys/unix" @@ -35,10 +36,14 @@ func DialContext( panic(errors.Join(ErrVSockSocketCreationFailed, err)) } + var cLock sync.Mutex goroutineManager.StartForegroundGoroutine(func(_ context.Context) { <-goroutineManager.Context().Done() // Non-happy path; context was cancelled before `connect()` completed + cLock.Lock() + defer cLock.Unlock() + if c == nil { if err := unix.Shutdown(fd, unix.SHUT_RDWR); err != nil { // Always close the file descriptor even if shutdown fails @@ -70,6 +75,9 @@ func DialContext( panic(errors.Join(ErrVSockConnectFailed, err)) } + cLock.Lock() + defer cLock.Unlock() + c = &conn{fd} return diff --git a/pkg/ipc/agent_client.go b/pkg/ipc/agent_client.go index d663182..5017f94 100644 --- a/pkg/ipc/agent_client.go +++ b/pkg/ipc/agent_client.go @@ -19,44 +19,63 @@ var ( ErrAgentContextCancelled = errors.New("agent context cancelled") ) -type AgentClient struct { +// The RPCs the agent server can call on this client +// See https://github.com/pojntfx/panrpc/tree/main?tab=readme-ov-file#5-calling-the-clients-rpcs-from-the-server +type AgentClientLocal[G any] struct { + GuestService G + beforeSuspend func(ctx context.Context) error afterResume func(ctx context.Context) error } -func NewAgentClient( +// The RPCs this client can call on the agent server +// See https://github.com/pojntfx/panrpc/tree/main?tab=readme-ov-file#4-calling-the-servers-rpcs-from-the-client +type AgentClientRemote any + +func NewAgentClient[G any]( + guestService G, + beforeSuspend func(ctx context.Context) error, afterResume func(ctx context.Context) error, -) *AgentClient { - return &AgentClient{ +) *AgentClientLocal[G] { + return &AgentClientLocal[G]{ + GuestService: guestService, + beforeSuspend: beforeSuspend, afterResume: afterResume, } } -func (l *AgentClient) BeforeSuspend(ctx context.Context) error { +func (l *AgentClientLocal[G]) BeforeSuspend(ctx context.Context) error { return l.beforeSuspend(ctx) } -func (l *AgentClient) AfterResume(ctx context.Context) error { +func (l *AgentClientLocal[G]) AfterResume(ctx context.Context) error { return l.afterResume(ctx) } -type ConnectedAgentClient struct { +type ConnectedAgentClient[L *AgentClientLocal[G], R AgentClientRemote, G any] struct { + Remote R + Wait func() error Close func() } -func StartAgentClient( +type StartAgentClientHooks[R AgentClientRemote] struct { + OnAfterRegistrySetup func(forRemotes func(cb func(remoteID string, remote R) error) error) error +} + +func StartAgentClient[L *AgentClientLocal[G], R AgentClientRemote, G any]( dialCtx context.Context, remoteCtx context.Context, vsockCID uint32, vsockPort uint32, - agentClient *AgentClient, -) (connectedAgentClient *ConnectedAgentClient, errs error) { - connectedAgentClient = &ConnectedAgentClient{ + agentClientLocal L, + hooks StartAgentClientHooks[R], +) (connectedAgentClient *ConnectedAgentClient[L, R, G], errs error) { + connectedAgentClient = &ConnectedAgentClient[L, R, G]{ Wait: func() error { return nil }, @@ -120,8 +139,8 @@ func StartAgentClient( } }) - registry := rpc.NewRegistry[struct{}, json.RawMessage]( - agentClient, + registry := rpc.NewRegistry[R, json.RawMessage]( + agentClientLocal, &rpc.RegistryHooks{ OnClientConnect: func(remoteID string) { @@ -130,8 +149,12 @@ func StartAgentClient( }, ) + if hook := hooks.OnAfterRegistrySetup; hook != nil { + hook(registry.ForRemotes) + } + connectedAgentClient.Wait = sync.OnceValue(func() error { - defer conn.Close() // We ignore errors here since we might interrupt a network connection + // We don't `defer conn.Close` here since Firecracker handles resetting active VSock connections for us defer cancelLinkCtx(nil) encoder := json.NewEncoder(conn) @@ -199,5 +222,19 @@ func StartAgentClient( break } + found := false + if err := registry.ForRemotes(func(remoteID string, r R) error { + connectedAgentClient.Remote = r + found = true + + return nil + }); err != nil { + panic(err) + } + + if !found { + panic(ErrNoRemoteFound) + } + return } diff --git a/pkg/ipc/agent_server.go b/pkg/ipc/agent_server.go index 5f94362..9289151 100644 --- a/pkg/ipc/agent_server.go +++ b/pkg/ipc/agent_server.go @@ -9,7 +9,6 @@ import ( "os" "sync" - "github.com/loopholelabs/drafter/internal/remotes" "github.com/loopholelabs/goroutine-manager/pkg/manager" "github.com/pojntfx/panrpc/go/pkg/rpc" ) @@ -23,7 +22,20 @@ var ( ErrCouldNotLinkRegistry = errors.New("could not link registry") ) -type AgentServer struct { +// The RPCs the agent client can call on this server +// See https://github.com/pojntfx/panrpc/tree/main?tab=readme-ov-file#5-calling-the-clients-rpcs-from-the-server +type AgentServerLocal any + +// The RPCs this server can call on the agent client +// See https://github.com/pojntfx/panrpc/tree/main?tab=readme-ov-file#4-calling-the-servers-rpcs-from-the-client +type AgentServerRemote[G any] struct { + GuestService G + + BeforeSuspend func(ctx context.Context) error + AfterResume func(ctx context.Context) error +} + +type AgentServer[L AgentServerLocal, R AgentServerRemote[G], G any] struct { VSockPath string Close func() @@ -32,18 +44,24 @@ type AgentServer struct { closed bool closeLock sync.Mutex + + agentServerLocal L } -func StartAgentServer( +func StartAgentServer[L AgentServerLocal, R AgentServerRemote[G], G any]( vsockPath string, vsockPort uint32, + + agentServerLocal L, ) ( - agentServer *AgentServer, + agentServer *AgentServer[L, R, G], err error, ) { - agentServer = &AgentServer{ + agentServer = &AgentServer[L, R, G]{ Close: func() {}, + + agentServerLocal: agentServerLocal, } agentServer.VSockPath = fmt.Sprintf("%s_%d", vsockPath, vsockPort) @@ -68,15 +86,24 @@ func StartAgentServer( return } -type AcceptingAgentServer struct { - Remote remotes.AgentRemote +type AgentServerAcceptHooks[R AgentServerRemote[G], G any] struct { + OnAfterRegistrySetup func(forRemotes func(cb func(remoteID string, remote R) error) error) error +} + +type AcceptingAgentServer[L AgentServerLocal, R AgentServerRemote[G], G any] struct { + Remote R Wait func() error Close func() error } -func (agentServer *AgentServer) Accept(acceptCtx context.Context, remoteCtx context.Context) (acceptingAgentServer *AcceptingAgentServer, errs error) { - acceptingAgentServer = &AcceptingAgentServer{ +func (agentServer *AgentServer[L, R, G]) Accept( + acceptCtx context.Context, + remoteCtx context.Context, + + hooks AgentServerAcceptHooks[R, G], +) (acceptingAgentServer *AcceptingAgentServer[L, R, G], errs error) { + acceptingAgentServer = &AcceptingAgentServer[L, R, G]{ Wait: func() error { return nil }, @@ -170,8 +197,8 @@ func (agentServer *AgentServer) Accept(acceptCtx context.Context, remoteCtx cont } }) - registry := rpc.NewRegistry[remotes.AgentRemote, json.RawMessage]( - &struct{}{}, + registry := rpc.NewRegistry[R, json.RawMessage]( + agentServer.agentServerLocal, &rpc.RegistryHooks{ OnClientConnect: func(remoteID string) { @@ -180,8 +207,12 @@ func (agentServer *AgentServer) Accept(acceptCtx context.Context, remoteCtx cont }, ) + if hook := hooks.OnAfterRegistrySetup; hook != nil { + hook(registry.ForRemotes) + } + acceptingAgentServer.Wait = sync.OnceValue(func() error { - defer conn.Close() // We ignore errors here since we might interrupt a network connection + // We don't `defer conn.Close` here since Firecracker handles resetting active VSock connections for us defer cancelLinkCtx(nil) encoder := json.NewEncoder(conn) @@ -246,7 +277,7 @@ func (agentServer *AgentServer) Accept(acceptCtx context.Context, remoteCtx cont } found := false - if err := registry.ForRemotes(func(remoteID string, r remotes.AgentRemote) error { + if err := registry.ForRemotes(func(remoteID string, r R) error { acceptingAgentServer.Remote = r found = true diff --git a/pkg/ipc/liveness_client.go b/pkg/ipc/liveness_client.go index ce00809..ab42aeb 100644 --- a/pkg/ipc/liveness_client.go +++ b/pkg/ipc/liveness_client.go @@ -17,14 +17,11 @@ func SendLivenessPing( vsockCID uint32, vsockPort uint32, ) error { - conn, err := vsock.DialContext(ctx, vsockCID, vsockPort) - if err != nil { + if _, err := vsock.DialContext(ctx, vsockCID, vsockPort); err != nil { return errors.Join(ErrCouldNotDialLivenessVSockConnection, err) } - if err := conn.Close(); err != nil { - return errors.Join(ErrCouldNotCloseLivenessVSockConnection, err) - } + // We don't `conn.Close` here since Firecracker handles resetting active VSock connections for us return nil } diff --git a/pkg/peer/make_migratable.go b/pkg/peer/make_migratable.go index 2655b57..d1a6714 100644 --- a/pkg/peer/make_migratable.go +++ b/pkg/peer/make_migratable.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/loopholelabs/drafter/internal/utils" + "github.com/loopholelabs/drafter/pkg/ipc" "github.com/loopholelabs/drafter/pkg/mounter" "github.com/loopholelabs/drafter/pkg/runner" "github.com/loopholelabs/goroutine-manager/pkg/manager" @@ -14,21 +15,23 @@ import ( "github.com/loopholelabs/silo/pkg/storage/volatilitymonitor" ) -type ResumedPeer struct { +type ResumedPeer[L ipc.AgentServerLocal, R ipc.AgentServerRemote[G], G any] struct { + Remote R + Wait func() error Close func() error - resumedRunner *runner.ResumedRunner + resumedRunner *runner.ResumedRunner[L, R, G] stage2Inputs []migrateFromStage } -func (resumedPeer *ResumedPeer) MakeMigratable( +func (resumedPeer *ResumedPeer[L, R, G]) MakeMigratable( ctx context.Context, devices []mounter.MakeMigratableDevice, -) (migratablePeer *MigratablePeer, errs error) { - migratablePeer = &MigratablePeer{ +) (migratablePeer *MigratablePeer[L, R, G], errs error) { + migratablePeer = &MigratablePeer[L, R, G]{ Close: func() {}, resumedPeer: resumedPeer, diff --git a/pkg/peer/migrate_from.go b/pkg/peer/migrate_from.go index de476bb..0dc71d5 100644 --- a/pkg/peer/migrate_from.go +++ b/pkg/peer/migrate_from.go @@ -14,6 +14,7 @@ import ( "syscall" "github.com/loopholelabs/drafter/internal/utils" + "github.com/loopholelabs/drafter/pkg/ipc" "github.com/loopholelabs/drafter/pkg/mounter" "github.com/loopholelabs/drafter/pkg/registry" "github.com/loopholelabs/drafter/pkg/snapshotter" @@ -28,7 +29,7 @@ import ( "golang.org/x/sys/unix" ) -type MigrateFromDevice struct { +type MigrateFromDevice[L ipc.AgentServerLocal, R ipc.AgentServerRemote[G], G any] struct { Name string `json:"name"` Base string `json:"base"` @@ -40,21 +41,21 @@ type MigrateFromDevice struct { Shared bool `json:"shared"` } -func (peer *Peer) MigrateFrom( +func (peer *Peer[L, R, G]) MigrateFrom( ctx context.Context, - devices []MigrateFromDevice, + devices []MigrateFromDevice[L, R, G], readers []io.Reader, writers []io.Writer, hooks mounter.MigrateFromHooks, ) ( - migratedPeer *MigratedPeer, + migratedPeer *MigratedPeer[L, R, G], errs error, ) { - migratedPeer = &MigratedPeer{ + migratedPeer = &MigratedPeer[L, R, G]{ Wait: func() error { return nil }, @@ -410,7 +411,7 @@ func (peer *Peer) MigrateFrom( break } - stage1Inputs := []MigrateFromDevice{} + stage1Inputs := []MigrateFromDevice[L, R, G]{} for _, input := range devices { if slices.ContainsFunc( migratedPeer.stage2Inputs, @@ -430,7 +431,7 @@ func (peer *Peer) MigrateFrom( _, deferFuncs, err := utils.ConcurrentMap( stage1Inputs, - func(index int, input MigrateFromDevice, _ *struct{}, addDefer func(deferFunc func() error)) error { + func(index int, input MigrateFromDevice[L, R, G], _ *struct{}, addDefer func(deferFunc func() error)) error { if hook := hooks.OnLocalDeviceRequested; hook != nil { hook(uint32(index), input.Name) } diff --git a/pkg/peer/migrate_to.go b/pkg/peer/migrate_to.go index 561df47..27d6dd6 100644 --- a/pkg/peer/migrate_to.go +++ b/pkg/peer/migrate_to.go @@ -9,6 +9,7 @@ import ( "time" "github.com/loopholelabs/drafter/internal/utils" + "github.com/loopholelabs/drafter/pkg/ipc" "github.com/loopholelabs/drafter/pkg/mounter" "github.com/loopholelabs/drafter/pkg/packager" "github.com/loopholelabs/drafter/pkg/registry" @@ -37,15 +38,15 @@ type MigrateToHooks struct { OnAllMigrationsCompleted func() } -type MigratablePeer struct { +type MigratablePeer[L ipc.AgentServerLocal, R ipc.AgentServerRemote[G], G any] struct { Close func() - resumedPeer *ResumedPeer + resumedPeer *ResumedPeer[L, R, G] stage4Inputs []makeMigratableDeviceStage - resumedRunner *runner.ResumedRunner + resumedRunner *runner.ResumedRunner[L, R, G] } -func (migratablePeer *MigratablePeer) MigrateTo( +func (migratablePeer *MigratablePeer[L, R, G]) MigrateTo( ctx context.Context, devices []mounter.MigrateToDevice, diff --git a/pkg/peer/resume.go b/pkg/peer/resume.go index 8e799f3..573df21 100644 --- a/pkg/peer/resume.go +++ b/pkg/peer/resume.go @@ -8,30 +8,34 @@ import ( "strings" "time" + "github.com/loopholelabs/drafter/pkg/ipc" "github.com/loopholelabs/drafter/pkg/packager" "github.com/loopholelabs/drafter/pkg/runner" "github.com/loopholelabs/drafter/pkg/snapshotter" ) -type MigratedPeer struct { +type MigratedPeer[L ipc.AgentServerLocal, R ipc.AgentServerRemote[G], G any] struct { Wait func() error Close func() error - devices []MigrateFromDevice - runner *runner.Runner + devices []MigrateFromDevice[L, R, G] + runner *runner.Runner[L, R, G] stage2Inputs []migrateFromStage } -func (migratedPeer *MigratedPeer) Resume( +func (migratedPeer *MigratedPeer[L, R, G]) Resume( ctx context.Context, resumeTimeout, rescueTimeout time.Duration, + agentServerLocal L, + agentServerHooks ipc.AgentServerAcceptHooks[R, G], + snapshotLoadConfiguration runner.SnapshotLoadConfiguration, -) (resumedPeer *ResumedPeer, errs error) { - resumedPeer = &ResumedPeer{ +) (resumedPeer *ResumedPeer[L, R, G], errs error) { + resumedPeer = &ResumedPeer[L, R, G]{ Wait: func() error { return nil }, @@ -73,11 +77,15 @@ func (migratedPeer *MigratedPeer) Resume( rescueTimeout, packageConfig.AgentVSockPort, + agentServerLocal, + agentServerHooks, + snapshotLoadConfiguration, ) if err != nil { return nil, errors.Join(ErrCouldNotResumeRunner, err) } + resumedPeer.Remote = resumedPeer.resumedRunner.Remote resumedPeer.Wait = resumedPeer.resumedRunner.Wait resumedPeer.Close = resumedPeer.resumedRunner.Close diff --git a/pkg/peer/start.go b/pkg/peer/start.go index fdcfc58..e410739 100644 --- a/pkg/peer/start.go +++ b/pkg/peer/start.go @@ -4,12 +4,13 @@ import ( "context" "errors" + "github.com/loopholelabs/drafter/pkg/ipc" "github.com/loopholelabs/drafter/pkg/runner" "github.com/loopholelabs/drafter/pkg/snapshotter" "github.com/loopholelabs/goroutine-manager/pkg/manager" ) -type Peer struct { +type Peer[L ipc.AgentServerLocal, R ipc.AgentServerRemote[G], G any] struct { VMPath string VMPid int @@ -18,10 +19,10 @@ type Peer struct { hypervisorCtx context.Context - runner *runner.Runner + runner *runner.Runner[L, R, G] } -func StartPeer( +func StartPeer[L ipc.AgentServerLocal, R ipc.AgentServerRemote[G], G any]( hypervisorCtx context.Context, rescueCtx context.Context, @@ -30,11 +31,11 @@ func StartPeer( stateName string, memoryName string, ) ( - peer *Peer, + peer *Peer[L, R, G], errs error, ) { - peer = &Peer{ + peer = &Peer[L, R, G]{ hypervisorCtx: hypervisorCtx, Wait: func() error { @@ -55,7 +56,7 @@ func StartPeer( defer goroutineManager.CreateBackgroundPanicCollector()() var err error - peer.runner, err = runner.StartRunner( + peer.runner, err = runner.StartRunner[L, R]( hypervisorCtx, rescueCtx, diff --git a/pkg/peer/suspend.go b/pkg/peer/suspend.go index 322bdb0..de69ba4 100644 --- a/pkg/peer/suspend.go +++ b/pkg/peer/suspend.go @@ -5,7 +5,7 @@ import ( "time" ) -func (resumedPeer *ResumedPeer) SuspendAndCloseAgentServer(ctx context.Context, resumeTimeout time.Duration) error { +func (resumedPeer *ResumedPeer[L, R, G]) SuspendAndCloseAgentServer(ctx context.Context, resumeTimeout time.Duration) error { return resumedPeer.resumedRunner.SuspendAndCloseAgentServer( ctx, diff --git a/pkg/runner/msync.go b/pkg/runner/msync.go index 92755a5..beef62a 100644 --- a/pkg/runner/msync.go +++ b/pkg/runner/msync.go @@ -8,7 +8,7 @@ import ( "github.com/loopholelabs/drafter/pkg/snapshotter" ) -func (resumedRunner *ResumedRunner) Msync(ctx context.Context) error { +func (resumedRunner *ResumedRunner[L, R, G]) Msync(ctx context.Context) error { if !resumedRunner.snapshotLoadConfiguration.ExperimentalMapPrivate { if err := firecracker.CreateSnapshot( ctx, diff --git a/pkg/runner/resume.go b/pkg/runner/resume.go index 907d712..8ac738b 100644 --- a/pkg/runner/resume.go +++ b/pkg/runner/resume.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "time" + "unsafe" "github.com/lithammer/shortuuid/v4" "github.com/loopholelabs/drafter/internal/firecracker" @@ -17,34 +18,39 @@ import ( "github.com/loopholelabs/goroutine-manager/pkg/manager" ) -type ResumedRunner struct { +type ResumedRunner[L ipc.AgentServerLocal, R ipc.AgentServerRemote[G], G any] struct { + Remote R + Wait func() error Close func() error snapshotLoadConfiguration SnapshotLoadConfiguration - runner *Runner + runner *Runner[L, R, G] - agent *ipc.AgentServer - acceptingAgent *ipc.AcceptingAgentServer + agent *ipc.AgentServer[L, R, G] + acceptingAgent *ipc.AcceptingAgentServer[L, R, G] createSnapshot func(ctx context.Context) error } -func (runner *Runner) Resume( +func (runner *Runner[L, R, G]) Resume( ctx context.Context, resumeTimeout time.Duration, rescueTimeout time.Duration, agentVSockPort uint32, + agentServerLocal L, + agentServerHooks ipc.AgentServerAcceptHooks[R, G], + snapshotLoadConfiguration SnapshotLoadConfiguration, ) ( - resumedRunner *ResumedRunner, + resumedRunner *ResumedRunner[L, R, G], errs error, ) { - resumedRunner = &ResumedRunner{ + resumedRunner = &ResumedRunner[L, R, G]{ Wait: func() error { return nil }, Close: func() error { return nil }, @@ -191,9 +197,11 @@ func (runner *Runner) Resume( }) var err error - resumedRunner.agent, err = ipc.StartAgentServer( + resumedRunner.agent, err = ipc.StartAgentServer[L, R]( filepath.Join(runner.server.VMPath, snapshotter.VSockName), uint32(agentVSockPort), + + agentServerLocal, ) if err != nil { panic(errors.Join(snapshotter.ErrCouldNotStartAgentServer, err)) @@ -228,10 +236,16 @@ func (runner *Runner) Resume( suspendOnPanicWithError = true - resumedRunner.acceptingAgent, err = resumedRunner.agent.Accept(resumeSnapshotAndAcceptCtx, ctx) + resumedRunner.acceptingAgent, err = resumedRunner.agent.Accept( + resumeSnapshotAndAcceptCtx, + ctx, + + agentServerHooks, + ) if err != nil { panic(errors.Join(ErrCouldNotAcceptAgent, err)) } + resumedRunner.Remote = resumedRunner.acceptingAgent.Remote } // We intentionally don't call `wg.Add` and `wg.Done` here since we return the process's wait method @@ -261,7 +275,11 @@ func (runner *Runner) Resume( afterResumeCtx, cancelAfterResumeCtx := context.WithTimeout(goroutineManager.Context(), resumeTimeout) defer cancelAfterResumeCtx() - if err := resumedRunner.acceptingAgent.Remote.AfterResume(afterResumeCtx); err != nil { + // This is a safe type cast because R is constrained by ipc.AgentServerRemote, so this specific AfterResume field + // must be defined or there will be a compile-time error. + // The Go Generics system can't catch this here however, it can only catch it once the type is concrete, so we need to manually cast. + remote := *(*ipc.AgentServerRemote[G])(unsafe.Pointer(&resumedRunner.acceptingAgent.Remote)) + if err := remote.AfterResume(afterResumeCtx); err != nil { panic(errors.Join(ErrCouldNotCallAfterResumeRPC, err)) } } diff --git a/pkg/runner/start.go b/pkg/runner/start.go index e014abe..23ddf08 100644 --- a/pkg/runner/start.go +++ b/pkg/runner/start.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/loopholelabs/drafter/internal/firecracker" + "github.com/loopholelabs/drafter/pkg/ipc" "github.com/loopholelabs/drafter/pkg/snapshotter" "github.com/loopholelabs/goroutine-manager/pkg/manager" ) @@ -21,7 +22,7 @@ type SnapshotLoadConfiguration struct { ExperimentalMapPrivateMemoryOutput string } -type Runner struct { +type Runner[L ipc.AgentServerLocal, R ipc.AgentServerRemote[G], G any] struct { VMPath string VMPid int @@ -41,7 +42,7 @@ type Runner struct { rescueCtx context.Context } -func StartRunner( +func StartRunner[L ipc.AgentServerLocal, R ipc.AgentServerRemote[G], G any]( hypervisorCtx context.Context, rescueCtx context.Context, @@ -50,11 +51,11 @@ func StartRunner( stateName string, memoryName string, ) ( - runner *Runner, + runner *Runner[L, R, G], errs error, ) { - runner = &Runner{ + runner = &Runner[L, R, G]{ Wait: func() error { return nil }, Close: func() error { return nil }, diff --git a/pkg/runner/suspend.go b/pkg/runner/suspend.go index 36d1df0..2fd00d3 100644 --- a/pkg/runner/suspend.go +++ b/pkg/runner/suspend.go @@ -4,15 +4,21 @@ import ( "context" "errors" "time" + "unsafe" + "github.com/loopholelabs/drafter/pkg/ipc" "github.com/loopholelabs/drafter/pkg/snapshotter" ) -func (resumedRunner *ResumedRunner) SuspendAndCloseAgentServer(ctx context.Context, suspendTimeout time.Duration) error { +func (resumedRunner *ResumedRunner[L, R, G]) SuspendAndCloseAgentServer(ctx context.Context, suspendTimeout time.Duration) error { suspendCtx, cancelSuspendCtx := context.WithTimeout(ctx, suspendTimeout) defer cancelSuspendCtx() - if err := resumedRunner.acceptingAgent.Remote.BeforeSuspend(suspendCtx); err != nil { + // This is a safe type cast because R is constrained by ipc.AgentServerRemote, so this specific BeforeSuspend field + // must be defined or there will be a compile-time error. + // The Go Generics system can't catch this here however, it can only catch it once the type is concrete, so we need to manually cast. + remote := *(*ipc.AgentServerRemote[G])(unsafe.Pointer(&resumedRunner.acceptingAgent.Remote)) + if err := remote.BeforeSuspend(suspendCtx); err != nil { return errors.Join(ErrCouldNotCallBeforeSuspendRPC, err) } diff --git a/pkg/snapshotter/create.go b/pkg/snapshotter/create.go index edf46bf..e59ae71 100644 --- a/pkg/snapshotter/create.go +++ b/pkg/snapshotter/create.go @@ -141,9 +141,11 @@ func CreateSnapshot( panic(errors.Join(ErrCouldNotChownLivenessServerVSock, err)) } - agent, err := ipc.StartAgentServer( + agent, err := ipc.StartAgentServer[struct{}, ipc.AgentServerRemote[struct{}]]( filepath.Join(server.VMPath, VSockName), uint32(agentConfiguration.AgentVSockPort), + + struct{}{}, ) if err != nil { panic(errors.Join(ErrCouldNotStartAgentServer, err)) @@ -247,12 +249,17 @@ func CreateSnapshot( } } - var acceptingAgent *ipc.AcceptingAgentServer + var acceptingAgent *ipc.AcceptingAgentServer[struct{}, ipc.AgentServerRemote[struct{}], struct{}] { acceptCtx, cancel := context.WithTimeout(goroutineManager.Context(), agentConfiguration.ResumeTimeout) defer cancel() - acceptingAgent, err = agent.Accept(acceptCtx, goroutineManager.Context()) + acceptingAgent, err = agent.Accept( + acceptCtx, + goroutineManager.Context(), + + ipc.AgentServerAcceptHooks[ipc.AgentServerRemote[struct{}], struct{}]{}, + ) if err != nil { panic(errors.Join(ErrCouldNotAcceptAgentConnection, err)) }