diff --git a/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingWorkflowOutboundCallsInterceptor.java b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingWorkflowOutboundCallsInterceptor.java index 5f7f2ccbc..70a458f97 100644 --- a/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingWorkflowOutboundCallsInterceptor.java +++ b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingWorkflowOutboundCallsInterceptor.java @@ -30,6 +30,7 @@ import io.temporal.workflow.Workflow; import io.temporal.workflow.WorkflowInfo; import io.temporal.workflow.unsafe.WorkflowUnsafe; +import java.util.concurrent.Executor; public class OpenTracingWorkflowOutboundCallsInterceptor extends WorkflowOutboundCallsInterceptorBase { @@ -37,6 +38,28 @@ public class OpenTracingWorkflowOutboundCallsInterceptor private final Tracer tracer; private final ContextAccessor contextAccessor; + private class SpanTransferringExecutor implements Executor { + private final Executor passthrough; + private final Span capturedSpan; + + public SpanTransferringExecutor(Executor passthrough) { + this.passthrough = passthrough; + capturedSpan = tracer.scopeManager().activeSpan(); + } + + public void execute(Runnable r) { + if (capturedSpan != null) { + // if we captured a span during construction, we should transfer it to the calling context + // as the new activespan + try (Scope ignored = tracer.scopeManager().activate(capturedSpan)) { + passthrough.execute(r); + } + } else { + passthrough.execute(r); + } + } + } + public OpenTracingWorkflowOutboundCallsInterceptor( WorkflowOutboundCallsInterceptor next, OpenTracingOptions options, @@ -178,6 +201,12 @@ public Object newChildThread(Runnable runnable, boolean detached, String name) { return super.newChildThread(wrappedRunnable, detached, name); } + @Override + public Executor newCallbackExecutor() { + Executor passthrough = super.newCallbackExecutor(); + return new SpanTransferringExecutor(passthrough); + } + private Tracer.SpanBuilder createActivityStartSpanBuilder(String activityName) { WorkflowInfo workflowInfo = Workflow.getInfo(); return spanFactory.createActivityStartSpan( diff --git a/temporal-opentracing/src/test/java/io/temporal/opentracing/CallbackContextTest.java b/temporal-opentracing/src/test/java/io/temporal/opentracing/CallbackContextTest.java new file mode 100644 index 000000000..05f8d6d21 --- /dev/null +++ b/temporal-opentracing/src/test/java/io/temporal/opentracing/CallbackContextTest.java @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2022 Temporal Technologies, Inc. All Rights Reserved. + * + * Copyright (C) 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this material except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.temporal.opentracing; + +import static org.junit.Assert.assertEquals; + +import io.opentracing.Scope; +import io.opentracing.Span; +import io.opentracing.mock.MockSpan; +import io.opentracing.mock.MockTracer; +import io.opentracing.util.ThreadLocalScopeManager; +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.activity.ActivityOptions; +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowClientOptions; +import io.temporal.client.WorkflowOptions; +import io.temporal.testing.internal.SDKTestWorkflowRule; +import io.temporal.worker.WorkerFactoryOptions; +import io.temporal.workflow.*; +import java.time.Duration; +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; + +public class CallbackContextTest { + + private static final MockTracer mockTracer = + new MockTracer(new ThreadLocalScopeManager(), MockTracer.Propagator.TEXT_MAP); + + private final OpenTracingOptions OT_OPTIONS = + OpenTracingOptions.newBuilder().setTracer(mockTracer).build(); + + @Rule + public SDKTestWorkflowRule testWorkflowRule = + SDKTestWorkflowRule.newBuilder() + .setWorkflowClientOptions( + WorkflowClientOptions.newBuilder() + .setInterceptors(new OpenTracingClientInterceptor(OT_OPTIONS)) + .validateAndBuildWithDefaults()) + .setWorkerFactoryOptions( + WorkerFactoryOptions.newBuilder() + .setWorkerInterceptors(new OpenTracingWorkerInterceptor(OT_OPTIONS)) + .validateAndBuildWithDefaults()) + .setWorkflowTypes(WorkflowImpl.class) + .setActivityImplementations(new ActivityImpl()) + .build(); + + @After + public void tearDown() { + mockTracer.reset(); + } + + @ActivityInterface + public interface TestActivity { + @ActivityMethod + boolean activity(); + } + + @WorkflowInterface + public interface TestWorkflow { + @WorkflowMethod + String workflow(String input); + } + + public static class ActivityImpl implements TestActivity { + @Override + public boolean activity() { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return true; + } + } + + public static class WorkflowImpl implements TestWorkflow { + private final TestActivity activity = + Workflow.newActivityStub( + TestActivity.class, + ActivityOptions.newBuilder() + .setStartToCloseTimeout(Duration.ofMinutes(1)) + .validateAndBuildWithDefaults()); + + @Override + public String workflow(String input) { + return Async.function(activity::activity) + .thenCompose( + (r) -> { + Span activeSpan = mockTracer.activeSpan(); + return Workflow.newPromise( + activeSpan != null ? activeSpan.context().toTraceId() : "not-found"); + }) + .get(); + } + } + + @Test + public void testCallbackContext() { + MockSpan span = mockTracer.buildSpan("ClientFunction").start(); + + WorkflowClient client = testWorkflowRule.getWorkflowClient(); + try (Scope scope = mockTracer.scopeManager().activate(span)) { + TestWorkflow workflow = + client.newWorkflowStub( + TestWorkflow.class, + WorkflowOptions.newBuilder() + .setTaskQueue(testWorkflowRule.getTaskQueue()) + .validateBuildWithDefaults()); + assertEquals(span.context().toTraceId(), workflow.workflow("input")); + } finally { + span.finish(); + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowOutboundCallsInterceptor.java b/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowOutboundCallsInterceptor.java index 5645eeebb..3c0803bee 100644 --- a/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowOutboundCallsInterceptor.java +++ b/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowOutboundCallsInterceptor.java @@ -31,6 +31,7 @@ import java.lang.reflect.Type; import java.time.Duration; import java.util.*; +import java.util.concurrent.Executor; import java.util.function.BiPredicate; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -798,5 +799,17 @@ R mutableSideEffect( */ Object newChildThread(Runnable runnable, boolean detached, String name); + /** + * Intercepts the point where a new callback is being prepared for deferment and allows + * interceptors to provide an wrapped execution environment for running the callback at a later + * time. + * + *

The executor's execute() function _must_ fully execute the provided Runnable within the + * caller's thread or determinism guarantees could be violated. + * + * @return created Executor + */ + Executor newCallbackExecutor(); + long currentTimeMillis(); } diff --git a/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowOutboundCallsInterceptorBase.java b/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowOutboundCallsInterceptorBase.java index e6d6e9db6..5ef6f108d 100644 --- a/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowOutboundCallsInterceptorBase.java +++ b/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowOutboundCallsInterceptorBase.java @@ -30,6 +30,7 @@ import java.util.Map; import java.util.Random; import java.util.UUID; +import java.util.concurrent.Executor; import java.util.function.BiPredicate; import java.util.function.Supplier; @@ -184,6 +185,11 @@ public Object newChildThread(Runnable runnable, boolean detached, String name) { return next.newChildThread(runnable, detached, name); } + @Override + public Executor newCallbackExecutor() { + return next.newCallbackExecutor(); + } + @Override public long currentTimeMillis() { return next.currentTimeMillis(); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/CompletablePromiseImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/CompletablePromiseImpl.java index c12c9c125..1b7e5bc06 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/sync/CompletablePromiseImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/CompletablePromiseImpl.java @@ -29,6 +29,7 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -40,6 +41,7 @@ class CompletablePromiseImpl implements CompletablePromise { private final List handlers = new ArrayList<>(); private final DeterministicRunnerImpl runner; private boolean registeredWithRunner; + private Executor callbackExecutor; @SuppressWarnings("unchecked") static Promise promiseAnyOf(Promise[] promises) { @@ -62,7 +64,10 @@ static Promise promiseAnyOf(Iterable> promises) { } CompletablePromiseImpl() { - runner = DeterministicRunnerImpl.currentThreadInternal().getRunner(); + WorkflowThread workflowThread = DeterministicRunnerImpl.currentThreadInternal(); + runner = workflowThread.getRunner(); + callbackExecutor = + workflowThread.getWorkflowContext().getWorkflowOutboundInterceptor().newCallbackExecutor(); } @Override @@ -275,9 +280,14 @@ private Promise then(Functions.Proc1> proc) { * @return true if there were any handlers invoked */ private boolean invokeHandlers() { - for (Functions.Proc handler : handlers) { - handler.apply(); - } + // execute synchronously to this thread, but under the context established in the constructor + callbackExecutor.execute( + () -> { + for (Functions.Proc handler : handlers) { + handler.apply(); + } + }); + return !handlers.isEmpty(); } } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java index e95deb3d0..94a82f163 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/SyncWorkflowContext.java @@ -78,6 +78,7 @@ import java.time.Duration; import java.time.Instant; import java.util.*; +import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiPredicate; @@ -412,6 +413,12 @@ public WorkflowMetadata getWorkflowMetadata() { return workflowMetadata.build(); } + private class DirectExecutor implements Executor { + public void execute(Runnable r) { + r.run(); + } + } + private class ActivityCallback { private final CompletablePromise> result = Workflow.newPromise(); @@ -1419,6 +1426,11 @@ public Object newChildThread(Runnable runnable, boolean detached, String name) { return runner.newWorkflowThread(runnable, detached, name); } + @Override + public Executor newCallbackExecutor() { + return new DirectExecutor(); + } + @Override public long currentTimeMillis() { return replayContext.currentTimeMillis(); diff --git a/temporal-testing/src/main/java/io/temporal/testing/TestActivityEnvironmentInternal.java b/temporal-testing/src/main/java/io/temporal/testing/TestActivityEnvironmentInternal.java index 1f13e16e8..0a1a0387f 100644 --- a/temporal-testing/src/main/java/io/temporal/testing/TestActivityEnvironmentInternal.java +++ b/temporal-testing/src/main/java/io/temporal/testing/TestActivityEnvironmentInternal.java @@ -497,6 +497,11 @@ public Object newChildThread(Runnable runnable, boolean detached, String name) { throw new UnsupportedOperationException("not implemented"); } + @Override + public Executor newCallbackExecutor() { + throw new UnsupportedOperationException("not implemented"); + } + @Override public long currentTimeMillis() { throw new UnsupportedOperationException("not implemented"); diff --git a/temporal-testing/src/main/java/io/temporal/testing/internal/TracingWorkerInterceptor.java b/temporal-testing/src/main/java/io/temporal/testing/internal/TracingWorkerInterceptor.java index 93eb971dd..03cf7445d 100644 --- a/temporal-testing/src/main/java/io/temporal/testing/internal/TracingWorkerInterceptor.java +++ b/temporal-testing/src/main/java/io/temporal/testing/internal/TracingWorkerInterceptor.java @@ -36,6 +36,7 @@ import java.lang.reflect.Type; import java.time.Duration; import java.util.*; +import java.util.concurrent.Executor; import java.util.function.BiPredicate; import java.util.function.Supplier; import javax.annotation.Nonnull; @@ -433,6 +434,15 @@ public Object newChildThread(Runnable runnable, boolean detached, String name) { return next.newChildThread(runnable, detached, name); } + @Override + public Executor newCallbackExecutor() { + if (!WorkflowUnsafe.isReplaying()) { + trace.add("newCallbackExecutor "); + } + + return next.newCallbackExecutor(); + } + @Override public long currentTimeMillis() { if (!WorkflowUnsafe.isReplaying()) {