Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CRE-39] (fix): Add more guards & nil checks to WASM compute #984

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions pkg/workflows/wasm/host/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"io"
"math"
"regexp"

"strings"
"sync"
Expand Down Expand Up @@ -383,6 +384,9 @@ func (m *Module) Run(ctx context.Context, request *wasmpb.Request) (*wasmpb.Resp
}

func containsCode(err error, code int) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), fmt.Sprintf("exit status %d", code))
}

Expand Down Expand Up @@ -596,21 +600,24 @@ func createLogFn(logger logger.Logger) func(caller *wasmtime.Caller, ptr int32,
args = append(args, k, v)
}

reg, _ := regexp.Compile(`[^\w]`)
sanitizedMsg := reg.ReplaceAllString(msg, " ")

switch level {
case "debug":
logger.Debugw(msg, args...)
logger.Debugw(sanitizedMsg, args...)
case "info":
logger.Infow(msg, args...)
logger.Infow(sanitizedMsg, args...)
case "warn":
logger.Warnw(msg, args...)
logger.Warnw(sanitizedMsg, args...)
case "error":
logger.Errorw(msg, args...)
logger.Errorw(sanitizedMsg, args...)
case "panic":
logger.Panicw(msg, args...)
logger.Panicw(sanitizedMsg, args...)
case "fatal":
logger.Fatalw(msg, args...)
logger.Fatalw(sanitizedMsg, args...)
default:
logger.Infow(msg, args...)
logger.Infow(sanitizedMsg, args...)
}
}
}
Expand Down
44 changes: 36 additions & 8 deletions pkg/workflows/wasm/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ func createEmitFn(
return NewEmissionError(fmt.Errorf("metadata is required to emit"))
}

if labels == nil {
return NewEmissionError(fmt.Errorf("labels must be provided"))
}
labels, err := toEmitLabels(sdkConfig.Metadata, labels)
if err != nil {
return NewEmissionError(err)
Expand All @@ -120,14 +123,23 @@ func createEmitFn(
// Prepare the request to be sent to the host memory by allocating space for the
// response and response length buffers.
respBuffer := make([]byte, sdkConfig.MaxFetchResponseSizeBytes)
respptr, _ := bufferToPointerLen(respBuffer)
respptr, _, err := bufferToPointerLen(respBuffer)
if err != nil {
return err
}

resplenBuffer := make([]byte, uint32Size)
resplenptr, _ := bufferToPointerLen(resplenBuffer)
resplenptr, _, err := bufferToPointerLen(resplenBuffer)
if err != nil {
return err
}

// The request buffer is the wasm memory, get a pointer to the first element and the length
// of the protobuf message.
reqptr, reqptrlen := bufferToPointerLen(b)
reqptr, reqptrlen, err := bufferToPointerLen(b)
if err != nil {
return err
}

// Emit the message via the method imported from the host
errno := emit(respptr, resplenptr, reqptr, reqptrlen)
Expand Down Expand Up @@ -188,13 +200,22 @@ func createFetchFn(
if err != nil {
return sdk.FetchResponse{}, fmt.Errorf("failed to marshal fetch request: %w", err)
}
reqptr, reqptrlen := bufferToPointerLen(b)
reqptr, reqptrlen, err := bufferToPointerLen(b)
if err != nil {
return sdk.FetchResponse{}, err
}

respBuffer := make([]byte, sdkConfig.MaxFetchResponseSizeBytes)
respptr, _ := bufferToPointerLen(respBuffer)
respptr, _, err := bufferToPointerLen(respBuffer)
if err != nil {
return sdk.FetchResponse{}, err
}

resplenBuffer := make([]byte, uint32Size)
resplenptr, _ := bufferToPointerLen(resplenBuffer)
resplenptr, _, err := bufferToPointerLen(resplenBuffer)
if err != nil {
return sdk.FetchResponse{}, err
}

errno := fetch(respptr, resplenptr, reqptr, reqptrlen)
if errno != 0 {
Expand Down Expand Up @@ -229,8 +250,11 @@ func createFetchFn(
}

// bufferToPointerLen returns a pointer to the first element of the buffer and the length of the buffer.
func bufferToPointerLen(buf []byte) (unsafe.Pointer, int32) {
return unsafe.Pointer(&buf[0]), int32(len(buf))
func bufferToPointerLen(buf []byte) (unsafe.Pointer, int32, error) {
if len(buf) == 0 {
return nil, 0, fmt.Errorf("buffer cannot be empty")
}
return unsafe.Pointer(&buf[0]), int32(len(buf)), nil
}

// toEmitLabels ensures that the required metadata is present in the labels map
Expand All @@ -247,6 +271,10 @@ func toEmitLabels(md *capabilities.RequestMetadata, labels map[string]string) (m
return nil, fmt.Errorf("must provide workflow owner to emit event")
}

if md.WorkflowExecutionID == "" {
return nil, fmt.Errorf("must provide workflow owner to emit event")
}

labels[events.LabelWorkflowExecutionID] = md.WorkflowExecutionID
labels[events.LabelWorkflowOwner] = md.WorkflowOwner
labels[events.LabelWorkflowID] = md.WorkflowID
Expand Down
Loading