Skip to content

Commit

Permalink
(fix): Add guards
Browse files Browse the repository at this point in the history
  • Loading branch information
justinkaseman committed Jan 7, 2025
1 parent db7919d commit dce3c30
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 15 deletions.
21 changes: 14 additions & 7 deletions pkg/workflows/wasm/host/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"io"
"math"
"regexp"

"strings"
"sync"
Expand Down Expand Up @@ -383,6 +384,9 @@ func (m *Module) Run(ctx context.Context, request *wasmpb.Request) (*wasmpb.Resp
}

func containsCode(err error, code int) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), fmt.Sprintf("exit status %d", code))
}

Expand Down Expand Up @@ -596,21 +600,24 @@ func createLogFn(logger logger.Logger) func(caller *wasmtime.Caller, ptr int32,
args = append(args, k, v)
}

reg, _ := regexp.Compile(`[^\w]`)
sanitizedMsg := reg.ReplaceAllString(msg, " ")

switch level {
case "debug":
logger.Debugw(msg, args...)
logger.Debugw(sanitizedMsg, args...)
case "info":
logger.Infow(msg, args...)
logger.Infow(sanitizedMsg, args...)
case "warn":
logger.Warnw(msg, args...)
logger.Warnw(sanitizedMsg, args...)
case "error":
logger.Errorw(msg, args...)
logger.Errorw(sanitizedMsg, args...)
case "panic":
logger.Panicw(msg, args...)
logger.Panicw(sanitizedMsg, args...)
case "fatal":
logger.Fatalw(msg, args...)
logger.Fatalw(sanitizedMsg, args...)
default:
logger.Infow(msg, args...)
logger.Infow(sanitizedMsg, args...)
}
}
}
Expand Down
44 changes: 36 additions & 8 deletions pkg/workflows/wasm/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ func createEmitFn(
return NewEmissionError(fmt.Errorf("metadata is required to emit"))
}

if labels == nil {
return NewEmissionError(fmt.Errorf("labels must be provided"))
}
labels, err := toEmitLabels(sdkConfig.Metadata, labels)
if err != nil {
return NewEmissionError(err)
Expand All @@ -120,14 +123,23 @@ func createEmitFn(
// Prepare the request to be sent to the host memory by allocating space for the
// response and response length buffers.
respBuffer := make([]byte, sdkConfig.MaxFetchResponseSizeBytes)
respptr, _ := bufferToPointerLen(respBuffer)
respptr, _, err := bufferToPointerLen(respBuffer)
if err != nil {
return err
}

resplenBuffer := make([]byte, uint32Size)
resplenptr, _ := bufferToPointerLen(resplenBuffer)
resplenptr, _, err := bufferToPointerLen(resplenBuffer)
if err != nil {
return err
}

// The request buffer is the wasm memory, get a pointer to the first element and the length
// of the protobuf message.
reqptr, reqptrlen := bufferToPointerLen(b)
reqptr, reqptrlen, err := bufferToPointerLen(b)
if err != nil {
return err
}

// Emit the message via the method imported from the host
errno := emit(respptr, resplenptr, reqptr, reqptrlen)
Expand Down Expand Up @@ -188,13 +200,22 @@ func createFetchFn(
if err != nil {
return sdk.FetchResponse{}, fmt.Errorf("failed to marshal fetch request: %w", err)
}
reqptr, reqptrlen := bufferToPointerLen(b)
reqptr, reqptrlen, err := bufferToPointerLen(b)
if err != nil {
return sdk.FetchResponse{}, err
}

respBuffer := make([]byte, sdkConfig.MaxFetchResponseSizeBytes)
respptr, _ := bufferToPointerLen(respBuffer)
respptr, _, err := bufferToPointerLen(respBuffer)
if err != nil {
return sdk.FetchResponse{}, err
}

resplenBuffer := make([]byte, uint32Size)
resplenptr, _ := bufferToPointerLen(resplenBuffer)
resplenptr, _, err := bufferToPointerLen(resplenBuffer)
if err != nil {
return sdk.FetchResponse{}, err
}

errno := fetch(respptr, resplenptr, reqptr, reqptrlen)
if errno != 0 {
Expand Down Expand Up @@ -229,8 +250,11 @@ func createFetchFn(
}

// bufferToPointerLen returns a pointer to the first element of the buffer and the length of the buffer.
func bufferToPointerLen(buf []byte) (unsafe.Pointer, int32) {
return unsafe.Pointer(&buf[0]), int32(len(buf))
func bufferToPointerLen(buf []byte) (unsafe.Pointer, int32, error) {
if len(buf) == 0 {
return nil, 0, fmt.Errorf("buffer cannot be empty")
}
return unsafe.Pointer(&buf[0]), int32(len(buf)), nil
}

// toEmitLabels ensures that the required metadata is present in the labels map
Expand All @@ -247,6 +271,10 @@ func toEmitLabels(md *capabilities.RequestMetadata, labels map[string]string) (m
return nil, fmt.Errorf("must provide workflow owner to emit event")
}

if md.WorkflowExecutionID == "" {
return nil, fmt.Errorf("must provide workflow owner to emit event")
}

labels[events.LabelWorkflowExecutionID] = md.WorkflowExecutionID
labels[events.LabelWorkflowOwner] = md.WorkflowOwner
labels[events.LabelWorkflowID] = md.WorkflowID
Expand Down

0 comments on commit dce3c30

Please sign in to comment.