@@ -122,6 +122,9 @@ type VisitFailuresOptions struct {
122122 // Context is the same for every call of a visit, callers should not store it.
123123 // Visitor is free to mutate the passed failure struct.
124124 Visitor func(*VisitFailuresContext, *failure.Failure) (error)
125+ // Will be called for each Any encountered. If not set, the default is to recurse into the Any
126+ // object, unmarshal it, visit, and re-marshal it always (even if there are no changes).
127+ WellKnownAnyVisitor func(*VisitFailuresContext, *anypb.Any) error
125128}
126129
127130// VisitFailures calls the options.Visitor function for every Failure proto within msg.
@@ -162,6 +165,25 @@ func NewFailureVisitorInterceptor(options FailureVisitorInterceptorOptions) (grp
162165 }, nil
163166}
164167
168+ func (o *VisitFailuresOptions) defaultWellKnownAnyVisitor(ctx *VisitFailuresContext, p *anypb.Any) error {
169+ child, err := p.UnmarshalNew()
170+ if err != nil {
171+ return fmt.Errorf("failed to unmarshal any: %w", err)
172+ }
173+ // We choose to visit and re-marshal always instead of cloning, visiting,
174+ // and checking if anything changed before re-marshaling. It is assumed the
175+ // clone + equality check is not much cheaper than re-marshal.
176+ if err := visitFailures(ctx, o, child); err != nil {
177+ return err
178+ }
179+ // Confirmed this replaces both Any fields on non-error, there is nothing
180+ // left over
181+ if err := p.MarshalFrom(child); err != nil {
182+ return fmt.Errorf("failed to marshal any: %w", err)
183+ }
184+ return nil
185+ }
186+
165187func (o *VisitPayloadsOptions) defaultWellKnownAnyVisitor(ctx *VisitPayloadsContext, p *anypb.Any) error {
166188 child, err := p.UnmarshalNew()
167189 if err != nil {
@@ -299,6 +321,20 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj
299321 if o == nil { continue }
300322 if err := options.Visitor(ctx, o); err != nil { return err }
301323 if err := visitFailures(ctx, options, o.GetCause()); err != nil { return err }
324+ case *anypb.Any:
325+ if o == nil {
326+ continue
327+ }
328+ visitor := options.WellKnownAnyVisitor
329+ if visitor == nil {
330+ visitor = options.defaultWellKnownAnyVisitor
331+ }
332+ ctx.Parent = o
333+ err := visitor(ctx, o)
334+ ctx.Parent = nil
335+ if err != nil {
336+ return err
337+ }
302338{{range $type, $record := .FailureTypes}}
303339 {{if $record.Slice}}
304340 case []{{$type}}:
@@ -508,17 +544,19 @@ func generateInterceptor(cfg config) error {
508544 if err != nil {
509545 return err
510546 }
511- // For the purposes of payloads, we also consider the Any well known type as
547+
548+ failureTypes , err := lookupTypes ("go.temporal.io/api/failure/v1" , []string {"Failure" })
549+ if err != nil {
550+ return err
551+ }
552+
553+ // For the purposes of payloads and failures, we also consider the Any well known type as
512554 // possible
513555 if anyTypes , err := lookupTypes ("google.golang.org/protobuf/types/known/anypb" , []string {"Any" }); err != nil {
514556 return err
515557 } else {
516558 payloadTypes = append (payloadTypes , anyTypes ... )
517- }
518-
519- failureTypes , err := lookupTypes ("go.temporal.io/api/failure/v1" , []string {"Failure" })
520- if err != nil {
521- return err
559+ failureTypes = append (failureTypes , anyTypes ... )
522560 }
523561
524562 // UnimplementedWorkflowServiceServer is auto-generated via our API package
@@ -542,6 +580,11 @@ func generateInterceptor(cfg config) error {
542580 }
543581 workflowExecutions := types .NewPointer (exportTypes [0 ])
544582
583+ updateTypes , err := lookupTypes ("go.temporal.io/api/update/v1" , []string {"Acceptance" , "Rejection" , "Response" })
584+ if err != nil {
585+ return err
586+ }
587+
545588 payloadRecords := map [string ]* TypeRecord {}
546589 failureRecords := map [string ]* TypeRecord {}
547590
@@ -572,6 +615,11 @@ func generateInterceptor(cfg config) error {
572615 walk (payloadTypes , workflowExecutions , & payloadRecords , true )
573616 walk (failureTypes , workflowExecutions , & failureRecords , false )
574617
618+ for _ , ut := range updateTypes {
619+ walk (payloadTypes , types .NewPointer (ut ), & payloadRecords , true )
620+ walk (failureTypes , types .NewPointer (ut ), & failureRecords , false )
621+ }
622+
575623 payloadRecords = pruneRecords (payloadRecords )
576624 failureRecords = pruneRecords (failureRecords )
577625
0 commit comments