Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CRE-39] (fix): Add more guards & nil checks to WASM compute #984

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions pkg/workflows/wasm/host/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"io"
"math"
"regexp"

"strings"
"sync"
Expand Down Expand Up @@ -383,6 +384,9 @@ func (m *Module) Run(ctx context.Context, request *wasmpb.Request) (*wasmpb.Resp
}

func containsCode(err error, code int) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), fmt.Sprintf("exit status %d", code))
}

Expand Down Expand Up @@ -596,21 +600,24 @@ func createLogFn(logger logger.Logger) func(caller *wasmtime.Caller, ptr int32,
args = append(args, k, v)
}

reg, _ := regexp.Compile(`[^\w]`)
sanitizedMsg := reg.ReplaceAllString(msg, " ")

switch level {
case "debug":
logger.Debugw(msg, args...)
logger.Debugw(sanitizedMsg, args...)
case "info":
logger.Infow(msg, args...)
logger.Infow(sanitizedMsg, args...)
case "warn":
logger.Warnw(msg, args...)
logger.Warnw(sanitizedMsg, args...)
case "error":
logger.Errorw(msg, args...)
logger.Errorw(sanitizedMsg, args...)
case "panic":
logger.Panicw(msg, args...)
logger.Panicw(sanitizedMsg, args...)
case "fatal":
logger.Fatalw(msg, args...)
logger.Fatalw(sanitizedMsg, args...)
default:
logger.Infow(msg, args...)
logger.Infow(sanitizedMsg, args...)
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions pkg/workflows/wasm/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,15 @@ func Test_createEmitFn(t *testing.T) {
assert.NoError(t, err)
})

t.Run("success if no labels are given", func(t *testing.T) {
hostEmit := func(respptr, resplenptr, reqptr unsafe.Pointer, reqptrlen int32) int32 {
return 0
}
runtimeEmit := createEmitFn(sdkConfig, l, hostEmit)
err := runtimeEmit(giveMsg, nil)
assert.NoError(t, err)
})

t.Run("successfully read error message when emit fails", func(t *testing.T) {
hostEmit := func(respptr, resplenptr, reqptr unsafe.Pointer, reqptrlen int32) int32 {
// marshall the protobufs
Expand Down
10 changes: 8 additions & 2 deletions pkg/workflows/wasm/runner_wasip1.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ func sendResponseFn(response *wasmpb.Response) {
os.Exit(CodeInvalidRequest)
}

ptr, ptrlen := bufferToPointerLen(pb)
ptr, ptrlen, err := bufferToPointerLen(pb)
if err != nil {
os.Exit(CodeInvalidResponse)
}
errno := sendResponse(ptr, ptrlen)
if errno != 0 {
os.Exit(CodeHostErr)
Expand All @@ -76,7 +79,10 @@ type wasmWriteSyncer struct{}

// Write is used to proxy log requests from the WASM binary back to the host
func (wws *wasmWriteSyncer) Write(p []byte) (n int, err error) {
ptr, ptrlen := bufferToPointerLen(p)
ptr, ptrlen, err := bufferToPointerLen(p)
if err != nil {
return int(ptrlen), err
}
log(ptr, ptrlen)
return int(ptrlen), nil
}
44 changes: 36 additions & 8 deletions pkg/workflows/wasm/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ func createEmitFn(
return NewEmissionError(fmt.Errorf("metadata is required to emit"))
}

if labels == nil {
labels = map[string]string{}
}
labels, err := toEmitLabels(sdkConfig.Metadata, labels)
if err != nil {
return NewEmissionError(err)
Expand All @@ -120,14 +123,23 @@ func createEmitFn(
// Prepare the request to be sent to the host memory by allocating space for the
// response and response length buffers.
respBuffer := make([]byte, sdkConfig.MaxFetchResponseSizeBytes)
respptr, _ := bufferToPointerLen(respBuffer)
respptr, _, err := bufferToPointerLen(respBuffer)
if err != nil {
return err
}

resplenBuffer := make([]byte, uint32Size)
resplenptr, _ := bufferToPointerLen(resplenBuffer)
resplenptr, _, err := bufferToPointerLen(resplenBuffer)
if err != nil {
return err
}

// The request buffer is the wasm memory, get a pointer to the first element and the length
// of the protobuf message.
reqptr, reqptrlen := bufferToPointerLen(b)
reqptr, reqptrlen, err := bufferToPointerLen(b)
if err != nil {
return err
}

// Emit the message via the method imported from the host
errno := emit(respptr, resplenptr, reqptr, reqptrlen)
Expand Down Expand Up @@ -189,13 +201,22 @@ func createFetchFn(
if err != nil {
return sdk.FetchResponse{}, fmt.Errorf("failed to marshal fetch request: %w", err)
}
reqptr, reqptrlen := bufferToPointerLen(b)
reqptr, reqptrlen, err := bufferToPointerLen(b)
if err != nil {
return sdk.FetchResponse{}, err
}

respBuffer := make([]byte, sdkConfig.MaxFetchResponseSizeBytes)
respptr, _ := bufferToPointerLen(respBuffer)
respptr, _, err := bufferToPointerLen(respBuffer)
if err != nil {
return sdk.FetchResponse{}, err
}

resplenBuffer := make([]byte, uint32Size)
resplenptr, _ := bufferToPointerLen(resplenBuffer)
resplenptr, _, err := bufferToPointerLen(resplenBuffer)
if err != nil {
return sdk.FetchResponse{}, err
}

errno := fetch(respptr, resplenptr, reqptr, reqptrlen)
if errno != 0 {
Expand Down Expand Up @@ -230,8 +251,11 @@ func createFetchFn(
}

// bufferToPointerLen returns a pointer to the first element of the buffer and the length of the buffer.
func bufferToPointerLen(buf []byte) (unsafe.Pointer, int32) {
return unsafe.Pointer(&buf[0]), int32(len(buf))
func bufferToPointerLen(buf []byte) (unsafe.Pointer, int32, error) {
if len(buf) == 0 {
return nil, 0, fmt.Errorf("buffer cannot be empty")
}
return unsafe.Pointer(&buf[0]), int32(len(buf)), nil
}

// toEmitLabels ensures that the required metadata is present in the labels map
Expand All @@ -248,6 +272,10 @@ func toEmitLabels(md *capabilities.RequestMetadata, labels map[string]string) (m
return nil, fmt.Errorf("must provide workflow owner to emit event")
}

if md.WorkflowExecutionID == "" {
return nil, fmt.Errorf("must provide workflow execution id to emit event")
}

labels[events.LabelWorkflowExecutionID] = md.WorkflowExecutionID
labels[events.LabelWorkflowOwner] = md.WorkflowOwner
labels[events.LabelWorkflowID] = md.WorkflowID
Expand Down
30 changes: 26 additions & 4 deletions pkg/workflows/wasm/sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import (
func Test_toEmitLabels(t *testing.T) {
t.Run("successfully transforms metadata", func(t *testing.T) {
md := &capabilities.RequestMetadata{
WorkflowID: "workflow-id",
WorkflowName: "workflow-name",
WorkflowOwner: "workflow-owner",
WorkflowID: "workflow-id",
WorkflowName: "workflow-name",
WorkflowOwner: "workflow-owner",
WorkflowExecutionID: "6e2a46e3b6ae611bdb9bcc36ed3f46bb9a30babc3aabdd4eae7f35dd9af0f244",
}
empty := make(map[string]string, 0)

Expand All @@ -24,7 +25,7 @@ func Test_toEmitLabels(t *testing.T) {
"workflow_id": "workflow-id",
"workflow_name": "workflow-name",
"workflow_owner_address": "workflow-owner",
"workflow_execution_id": "",
"workflow_execution_id": "6e2a46e3b6ae611bdb9bcc36ed3f46bb9a30babc3aabdd4eae7f35dd9af0f244",
}, gotLabels)
})

Expand Down Expand Up @@ -63,4 +64,25 @@ func Test_toEmitLabels(t *testing.T) {
assert.Error(t, err)
assert.ErrorContains(t, err, "workflow owner")
})

t.Run("fails on missing workflow execution id", func(t *testing.T) {
md := &capabilities.RequestMetadata{
WorkflowID: "workflow-id",
WorkflowName: "workflow-name",
WorkflowOwner: "workflow-owner",
}
empty := make(map[string]string, 0)

_, err := toEmitLabels(md, empty)
assert.Error(t, err)
assert.ErrorContains(t, err, "workflow execution id")
})
}

func Test_bufferToPointerLen(t *testing.T) {
t.Run("fails when no buffer", func(t *testing.T) {
_, _, err := bufferToPointerLen([]byte{})
assert.Error(t, err)
assert.ErrorContains(t, err, "buffer cannot be empty")
})
}
Loading