From 34108eec3b03c0ce8b0dfb904e876a61988d8cec Mon Sep 17 00:00:00 2001 From: fernandoalonso Date: Wed, 16 Oct 2024 15:48:13 -0600 Subject: [PATCH] Fix: add validation --- go/genkit/flow.go | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/go/genkit/flow.go b/go/genkit/flow.go index e07b7e76e..b0f20bee8 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -477,6 +477,20 @@ func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb streamin return state, nil } +func isInputMissing(input any) bool { + if input == nil { + return true + } + v := reflect.ValueOf(input) + switch v.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface, reflect.Chan, reflect.Func: + return v.IsNil() + default: + // For other types like structs, zero value might be a valid input. + return false + } +} + // execute performs one flow execution. // Using its flowState argument as a starting point, it runs the flow function until // it finishes or is interrupted. @@ -511,13 +525,21 @@ func (f *Flow[In, Out, Stream]) execute(ctx context.Context, state *flowState[In traceID := rootSpanContext.TraceID().String() exec.TraceIDs = append(exec.TraceIDs, traceID) // TODO: Save rootSpanContext in the state. - if reflect.ValueOf(input).IsZero() { - if state == nil || reflect.ValueOf(state.Input).IsZero() { - return base.Zero[Out](), fmt.Errorf("input is missing and cannot be retrieved from state") + if isInputMissing(input) { + if state == nil { + return base.Zero[Out](), errors.New("input is missing and state is nil") + } + if isInputMissing(state.Input) { + return base.Zero[Out](), errors.New("input is missing and state.Input is also empty") } input = state.Input - // TODO: convert input to string - //tracing.SetCustomMetadataAttr(ctx, "flow:input", string(input)) + + // Convert input to JSON string for tracing metadata + bytes, err := json.Marshal(input) + if err != nil { + return base.Zero[Out](), fmt.Errorf("failed to marshal input for tracing: %w", err) + } + tracing.SetCustomMetadataAttr(ctx, "input", string(bytes)) } start := time.Now() var err error