Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class OrchestrationStatus(Enum):
PENDING = pb.ORCHESTRATION_STATUS_PENDING
SUSPENDED = pb.ORCHESTRATION_STATUS_SUSPENDED
CANCELED = pb.ORCHESTRATION_STATUS_CANCELED
STALLED = pb.ORCHESTRATION_STATUS_STALLED

def __str__(self):
return helpers.get_orchestration_status_str(self.value)
Expand Down
2 changes: 1 addition & 1 deletion durabletask/internal/PROTO_SOURCE_COMMIT_HASH
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4b86756497d875b97f9a91051781b5711c1e4fa6
889781bbe90e6ec84ebe169978c4f2fd0df74ff0
9 changes: 9 additions & 0 deletions durabletask/internal/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ def new_complete_orchestration_action(
)


def new_orchestrator_version_not_available_action(
id: int,
) -> pb.OrchestratorAction:
return pb.OrchestratorAction(
id=id,
orchestratorVersionNotAvailable=pb.OrchestratorVersionNotAvailableAction(),
)


def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction:
timestamp = timestamp_pb2.Timestamp()
timestamp.FromDatetime(fire_at)
Expand Down
472 changes: 245 additions & 227 deletions durabletask/internal/orchestrator_service_pb2.py

Large diffs are not rendered by default.

117 changes: 99 additions & 18 deletions durabletask/internal/orchestrator_service_pb2.pyi

Large diffs are not rendered by default.

86 changes: 86 additions & 0 deletions durabletask/internal/orchestrator_service_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ def __init__(self, channel):
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RerunWorkflowFromEventRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RerunWorkflowFromEventResponse.FromString,
_registered_method=True)
self.ListInstanceIDs = channel.unary_unary(
'/TaskHubSidecarService/ListInstanceIDs',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsResponse.FromString,
_registered_method=True)
self.GetInstanceHistory = channel.unary_unary(
'/TaskHubSidecarService/GetInstanceHistory',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryResponse.FromString,
_registered_method=True)


class TaskHubSidecarServiceServicer(object):
Expand Down Expand Up @@ -360,6 +370,18 @@ def RerunWorkflowFromEvent(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def ListInstanceIDs(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetInstanceHistory(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -498,6 +520,16 @@ def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RerunWorkflowFromEventRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RerunWorkflowFromEventResponse.SerializeToString,
),
'ListInstanceIDs': grpc.unary_unary_rpc_method_handler(
servicer.ListInstanceIDs,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsResponse.SerializeToString,
),
'GetInstanceHistory': grpc.unary_unary_rpc_method_handler(
servicer.GetInstanceHistory,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'TaskHubSidecarService', rpc_method_handlers)
Expand Down Expand Up @@ -1237,3 +1269,57 @@ def RerunWorkflowFromEvent(request,
timeout,
metadata,
_registered_method=True)

@staticmethod
def ListInstanceIDs(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/TaskHubSidecarService/ListInstanceIDs',
durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsRequest.SerializeToString,
durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIDsResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def GetInstanceHistory(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/TaskHubSidecarService/GetInstanceHistory',
durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryRequest.SerializeToString,
durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceHistoryResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
16 changes: 16 additions & 0 deletions durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,22 @@ def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None:
"""
pass

@abstractmethod
def is_patched(self, patch_name: str) -> bool:
"""Check if the given patch name can be applied to the orchestration.

Parameters
----------
patch_name : str
The name of the patch to check.

Returns
-------
bool
True if the given patch name can be applied to the orchestration, False otherwise.
"""
pass


class FailureDetails:
def __init__(self, message: str, error_type: str, stack_trace: Optional[str]):
Expand Down
122 changes: 112 additions & 10 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
TInput = TypeVar("TInput")
TOutput = TypeVar("TOutput")

class VersionNotRegisteredException(Exception):
pass

class ConcurrencyOptions:
"""Configuration options for controlling concurrency of different work item types and the thread pool size.
Expand Down Expand Up @@ -74,30 +76,58 @@ def __init__(

class _Registry:
orchestrators: dict[str, task.Orchestrator]
versioned_orchestrators: dict[str, dict[str, task.Orchestrator]]
latest_versioned_orchestrators_version_name: dict[str, str]
activities: dict[str, task.Activity]

def __init__(self):
self.orchestrators = {}
self.versioned_orchestrators = {}
self.latest_versioned_orchestrators_version_name = {}
self.activities = {}

def add_orchestrator(self, fn: task.Orchestrator) -> str:
def add_orchestrator(self, fn: task.Orchestrator, version_name: Optional[str] = None, is_latest: bool = False) -> str:
if fn is None:
raise ValueError("An orchestrator function argument is required.")

name = task.get_name(fn)
self.add_named_orchestrator(name, fn)
self.add_named_orchestrator(name, fn, version_name, is_latest)
return name

def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None:
def add_named_orchestrator(self, name: str, fn: task.Orchestrator, version_name: Optional[str] = None, is_latest: bool = False) -> None:
if not name:
raise ValueError("A non-empty orchestrator name is required.")

if version_name is None:
if name in self.orchestrators:
raise ValueError(f"A '{name}' orchestrator already exists.")
self.orchestrators[name] = fn
else:
if name not in self.versioned_orchestrators:
self.versioned_orchestrators[name] = {}
if version_name in self.versioned_orchestrators[name]:
raise ValueError(f"The version '{version_name}' of '{name}' orchestrator already exists.")
self.versioned_orchestrators[name][version_name] = fn
if is_latest:
self.latest_versioned_orchestrators_version_name[name] = version_name

def get_orchestrator(self, name: str, version_name: Optional[str] = None) -> Optional[tuple[task.Orchestrator, str]]:
if name in self.orchestrators:
raise ValueError(f"A '{name}' orchestrator already exists.")
return self.orchestrators.get(name), None

self.orchestrators[name] = fn
if name in self.versioned_orchestrators:
if version_name:
version_to_use = version_name
elif name in self.latest_versioned_orchestrators_version_name:
version_to_use = self.latest_versioned_orchestrators_version_name[name]
else:
return None, None

if version_to_use not in self.versioned_orchestrators[name]:
raise VersionNotRegisteredException
return self.versioned_orchestrators[name].get(version_to_use), version_to_use

def get_orchestrator(self, name: str) -> Optional[task.Orchestrator]:
return self.orchestrators.get(name)
return None, None

def add_activity(self, fn: task.Activity) -> str:
if fn is None:
Expand Down Expand Up @@ -540,11 +570,22 @@ def _execute_orchestrator(
try:
executor = _OrchestrationExecutor(self._registry, self._logger)
result = executor.execute(req.instanceId, req.pastEvents, req.newEvents)

version = None
if result.version_name:
version = version or pb.OrchestrationVersion()
version.name = result.version_name
if result.patches:
version = version or pb.OrchestrationVersion()
version.patches.extend(result.patches)


res = pb.OrchestratorResponse(
instanceId=req.instanceId,
actions=result.actions,
customStatus=ph.get_string_value(result.encoded_custom_status),
completionToken=completionToken,
version=version,
)
except Exception as ex:
self._logger.exception(
Expand Down Expand Up @@ -629,6 +670,11 @@ def __init__(self, instance_id: str):
self._new_input: Optional[Any] = None
self._save_events = False
self._encoded_custom_status: Optional[str] = None
self._orchestrator_started_version: Optional[pb.OrchestrationVersion] = None
self._version_name: Optional[str] = None
self._history_patches: dict[str, bool] = {}
self._applied_patches: dict[str, bool] = {}
self._encountered_patches: list[str] = []

def run(self, generator: Generator[task.Task, Any, Any]):
self._generator = generator
Expand Down Expand Up @@ -705,6 +751,14 @@ def set_failed(self, ex: Exception):
)
self._pending_actions[action.id] = action


def set_version_not_registered(self):
self._pending_actions.clear()
self._completion_status = pb.ORCHESTRATION_STATUS_STALLED
action = ph.new_orchestrator_version_not_available_action(self.next_sequence_number())
self._pending_actions[action.id] = action


def set_continued_as_new(self, new_input: Any, save_events: bool):
if self._is_complete:
return
Expand Down Expand Up @@ -916,13 +970,38 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
self.set_continued_as_new(new_input, save_events)


def is_patched(self, patch_name: str) -> bool:
is_patched = self._is_patched(patch_name)
if is_patched:
self._encountered_patches.append(patch_name)
return is_patched

def _is_patched(self, patch_name: str) -> bool:
if patch_name in self._applied_patches:
return self._applied_patches[patch_name]
if patch_name in self._history_patches:
self._applied_patches[patch_name] = True
return True

if self._is_replaying:
self._applied_patches[patch_name] = False
return False

self._applied_patches[patch_name] = True
return True


class ExecutionResults:
actions: list[pb.OrchestratorAction]
encoded_custom_status: Optional[str]
version_name: Optional[str]
patches: Optional[list[str]]

def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]):
def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str], version_name: Optional[str] = None, patches: Optional[list[str]] = None):
self.actions = actions
self.encoded_custom_status = encoded_custom_status
self.version_name = version_name
self.patches = patches


class _OrchestrationExecutor:
Expand Down Expand Up @@ -965,6 +1044,8 @@ def execute(
for new_event in new_events:
self.process_event(ctx, new_event)

except VersionNotRegisteredException:
ctx.set_version_not_registered()
except Exception as ex:
# Unhandled exceptions fail the orchestration
ctx.set_failed(ex)
Expand All @@ -989,7 +1070,12 @@ def execute(
self._logger.debug(
f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}"
)
return ExecutionResults(actions=actions, encoded_custom_status=ctx._encoded_custom_status)
return ExecutionResults(
actions=actions,
encoded_custom_status=ctx._encoded_custom_status,
version_name=getattr(ctx, '_version_name', None),
patches=ctx._encountered_patches
)

def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None:
if self._is_suspended and _is_suspendable(event):
Expand All @@ -1001,19 +1087,32 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
try:
if event.HasField("orchestratorStarted"):
ctx.current_utc_datetime = event.timestamp.ToDatetime()
ctx._orchestrator_started_version = event.orchestratorStarted.version
elif event.HasField("executionStarted"):
if event.router.targetAppID:
ctx._app_id = event.router.targetAppID
else:
ctx._app_id = event.router.sourceAppID

if ctx._orchestrator_started_version and ctx._orchestrator_started_version.patches:
ctx._history_patches = {patch: True for patch in ctx._orchestrator_started_version.patches}

version_name = None
if ctx._orchestrator_started_version and ctx._orchestrator_started_version.name:
version_name = ctx._orchestrator_started_version.name


# TODO: Check if we already started the orchestration
fn = self._registry.get_orchestrator(event.executionStarted.name)
fn, version_used = self._registry.get_orchestrator(event.executionStarted.name, version_name=version_name)

if fn is None:
raise OrchestratorNotRegisteredError(
f"A '{event.executionStarted.name}' orchestrator was not registered."
)

if version_used is not None:
ctx._version_name = version_used

# deserialize the input, if any
input = None
if (
Expand Down Expand Up @@ -1280,6 +1379,9 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
pb.ORCHESTRATION_STATUS_TERMINATED,
is_result_encoded=True,
)
elif event.HasField("executionStalled"):
# Nothing to do
pass
else:
eventType = event.WhichOneof("eventType")
raise task.OrchestrationStateError(
Expand Down
Loading