From dce3c303614e63aa0a778c9cb8ece8a76cfe924f Mon Sep 17 00:00:00 2001 From: Justin Kaseman Date: Tue, 7 Jan 2025 15:55:27 -0800 Subject: [PATCH] (fix): Add guards --- pkg/workflows/wasm/host/module.go | 21 ++++++++++----- pkg/workflows/wasm/sdk.go | 44 +++++++++++++++++++++++++------ 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 601b69632..84cfffeba 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "math" + "regexp" "strings" "sync" @@ -383,6 +384,9 @@ func (m *Module) Run(ctx context.Context, request *wasmpb.Request) (*wasmpb.Resp } func containsCode(err error, code int) bool { + if err == nil { + return false + } return strings.Contains(err.Error(), fmt.Sprintf("exit status %d", code)) } @@ -596,21 +600,24 @@ func createLogFn(logger logger.Logger) func(caller *wasmtime.Caller, ptr int32, args = append(args, k, v) } + reg, _ := regexp.Compile(`[^\w]`) + sanitizedMsg := reg.ReplaceAllString(msg, " ") + switch level { case "debug": - logger.Debugw(msg, args...) + logger.Debugw(sanitizedMsg, args...) case "info": - logger.Infow(msg, args...) + logger.Infow(sanitizedMsg, args...) case "warn": - logger.Warnw(msg, args...) + logger.Warnw(sanitizedMsg, args...) case "error": - logger.Errorw(msg, args...) + logger.Errorw(sanitizedMsg, args...) case "panic": - logger.Panicw(msg, args...) + logger.Panicw(sanitizedMsg, args...) case "fatal": - logger.Fatalw(msg, args...) + logger.Fatalw(sanitizedMsg, args...) default: - logger.Infow(msg, args...) + logger.Infow(sanitizedMsg, args...) } } } diff --git a/pkg/workflows/wasm/sdk.go b/pkg/workflows/wasm/sdk.go index 5c33a0fb1..611dd7660 100644 --- a/pkg/workflows/wasm/sdk.go +++ b/pkg/workflows/wasm/sdk.go @@ -97,6 +97,9 @@ func createEmitFn( return NewEmissionError(fmt.Errorf("metadata is required to emit")) } + if labels == nil { + return NewEmissionError(fmt.Errorf("labels must be provided")) + } labels, err := toEmitLabels(sdkConfig.Metadata, labels) if err != nil { return NewEmissionError(err) @@ -120,14 +123,23 @@ func createEmitFn( // Prepare the request to be sent to the host memory by allocating space for the // response and response length buffers. respBuffer := make([]byte, sdkConfig.MaxFetchResponseSizeBytes) - respptr, _ := bufferToPointerLen(respBuffer) + respptr, _, err := bufferToPointerLen(respBuffer) + if err != nil { + return err + } resplenBuffer := make([]byte, uint32Size) - resplenptr, _ := bufferToPointerLen(resplenBuffer) + resplenptr, _, err := bufferToPointerLen(resplenBuffer) + if err != nil { + return err + } // The request buffer is the wasm memory, get a pointer to the first element and the length // of the protobuf message. - reqptr, reqptrlen := bufferToPointerLen(b) + reqptr, reqptrlen, err := bufferToPointerLen(b) + if err != nil { + return err + } // Emit the message via the method imported from the host errno := emit(respptr, resplenptr, reqptr, reqptrlen) @@ -188,13 +200,22 @@ func createFetchFn( if err != nil { return sdk.FetchResponse{}, fmt.Errorf("failed to marshal fetch request: %w", err) } - reqptr, reqptrlen := bufferToPointerLen(b) + reqptr, reqptrlen, err := bufferToPointerLen(b) + if err != nil { + return sdk.FetchResponse{}, err + } respBuffer := make([]byte, sdkConfig.MaxFetchResponseSizeBytes) - respptr, _ := bufferToPointerLen(respBuffer) + respptr, _, err := bufferToPointerLen(respBuffer) + if err != nil { + return sdk.FetchResponse{}, err + } resplenBuffer := make([]byte, uint32Size) - resplenptr, _ := bufferToPointerLen(resplenBuffer) + resplenptr, _, err := bufferToPointerLen(resplenBuffer) + if err != nil { + return sdk.FetchResponse{}, err + } errno := fetch(respptr, resplenptr, reqptr, reqptrlen) if errno != 0 { @@ -229,8 +250,11 @@ func createFetchFn( } // bufferToPointerLen returns a pointer to the first element of the buffer and the length of the buffer. -func bufferToPointerLen(buf []byte) (unsafe.Pointer, int32) { - return unsafe.Pointer(&buf[0]), int32(len(buf)) +func bufferToPointerLen(buf []byte) (unsafe.Pointer, int32, error) { + if len(buf) == 0 { + return nil, 0, fmt.Errorf("buffer cannot be empty") + } + return unsafe.Pointer(&buf[0]), int32(len(buf)), nil } // toEmitLabels ensures that the required metadata is present in the labels map @@ -247,6 +271,10 @@ func toEmitLabels(md *capabilities.RequestMetadata, labels map[string]string) (m return nil, fmt.Errorf("must provide workflow owner to emit event") } + if md.WorkflowExecutionID == "" { + return nil, fmt.Errorf("must provide workflow owner to emit event") + } + labels[events.LabelWorkflowExecutionID] = md.WorkflowExecutionID labels[events.LabelWorkflowOwner] = md.WorkflowOwner labels[events.LabelWorkflowID] = md.WorkflowID