Skip to content

Commit

Permalink
fix: seperated _update_cluster_fn from _process_operation_fn to displ…
Browse files Browse the repository at this point in the history
…ay deployment operations
  • Loading branch information
SteBaum committed Aug 2, 2024
1 parent de8296e commit 9584116
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 25 deletions.
19 changes: 15 additions & 4 deletions tdp/cli/commands/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def deploy(
).run(planned_deployment, force_stale_update=force_stale_update)

if dry:
for operation_rec, process_operation_fn in deployment_iterator:
for (
operation_rec,
process_operation_fn,
update_cluster_fn,
) in deployment_iterator:
if process_operation_fn:
process_operation_fn()
click.echo(
Expand All @@ -96,13 +100,20 @@ def deploy(
# deployment and operations records are mutated by the iterator so we need to
# commit them before iterating and at each iteration
dao.session.commit() # Update operation status to RUNNING
for operation_rec, process_operation_fn in deployment_iterator:
for (
operation_rec,
process_operation_fn,
update_cluster_fn,
) in deployment_iterator:
dao.session.commit() # Update deployment and current operation status to RUNNING and next operations to PENDING
if process_operation_fn and (cluster_status_logs := process_operation_fn()):
if process_operation_fn:
process_operation_fn()
click.echo(
f"Operation {operation_rec.operation} is {operation_rec.state} {'for hosts: ' + operation_rec.host if operation_rec.host is not None else ''}"
)
dao.session.add_all(cluster_status_logs)
if update_cluster_fn:
if cluster_status_logs := update_cluster_fn():
dao.session.add_all(cluster_status_logs)
dao.session.commit() # Update operation status to SUCCESS, FAILURE or HELD

if deployment_iterator.deployment.state != DeploymentStateEnum.SUCCESS:
Expand Down
35 changes: 25 additions & 10 deletions tdp/core/deployment/deployment_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

logger = logging.getLogger(__name__)

ProcessOperationFn = Callable[[], Optional[list[SCHStatusLogModel]]]
ProcessOperationFn = Callable[[], None]
UpdateClusterFn = Callable[[], Optional[list[SCHStatusLogModel]]]


def _group_hosts_by_operation(
Expand Down Expand Up @@ -58,7 +59,11 @@ def _group_hosts_by_operation(
return operation_to_hosts_set


class DeploymentIterator(Iterator[tuple[OperationModel, Optional[ProcessOperationFn]]]):
class DeploymentIterator(
Iterator[
tuple[OperationModel, Optional[ProcessOperationFn], Optional[UpdateClusterFn]]
]
):
"""Iterator that runs an operation at each iteration.
Attributes:
Expand Down Expand Up @@ -113,19 +118,27 @@ def __init__(

def __next__(
self,
) -> tuple[OperationModel, Optional[ProcessOperationFn]]:
) -> tuple[OperationModel, Optional[ProcessOperationFn], Optional[UpdateClusterFn]]:
try:
while True:
operation_rec = next(self._iter)

# Return early if deployment failed
if self.deployment.state == DeploymentStateEnum.FAILURE:
operation_rec.state = OperationStateEnum.HELD
return operation_rec, None
return operation_rec, None, None

operation_rec.state = OperationStateEnum.RUNNING

return operation_rec, partial(self._process_operation_fn, operation_rec)
# Get service version number
operation = self._collections.operations[operation_rec.operation]
version = self._cluster_variables[operation.service_name].version

return (
operation_rec,
partial(self._process_operation_fn, operation_rec),
partial(self._update_cluster_fn, operation_rec, version),
)
# StopIteration is a "normal" exception raised when the iteration has stopped
except StopIteration as e:
self.deployment.end_time = datetime.utcnow()
Expand All @@ -138,9 +151,7 @@ def __next__(
self.deployment.state = DeploymentStateEnum.FAILURE
raise e

def _process_operation_fn(
self, operation_rec: OperationModel
) -> Optional[list[SCHStatusLogModel]]:
def _process_operation_fn(self, operation_rec: OperationModel) -> None:

operation = self._collections.operations[operation_rec.operation]

Expand All @@ -158,7 +169,11 @@ def _process_operation_fn(
# Return early as status is not updated
return

# ===== Update the cluster status if success =====
def _update_cluster_fn(
self, operation_rec: OperationModel, version: str
) -> Optional[list[SCHStatusLogModel]]:

operation = self._collections.operations[operation_rec.operation]

# Skip sleep operation
if operation.name == OPERATION_SLEEP_NAME:
Expand Down Expand Up @@ -206,7 +221,7 @@ def _process_operation_fn(
sch_status_log = self._cluster_status.update_hosted_entity(
create_hosted_entity(entity_name, host),
action_name=operation.action_name,
version=self._cluster_variables[operation.service_name].version,
version=version,
can_update_stale=can_update_stale,
)
if sch_status_log:
Expand Down
50 changes: 39 additions & 11 deletions tests/unit/core/deployment/test_deployment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@ def test_deployment_plan_is_success(
"""Nominal case, runs a deployment with full DAG."""
deployment_iterator = mock_deployment_runner.run(DeploymentModel.from_dag(mock_dag))

for op, process_operation_fn in deployment_iterator:
for op, process_operation_fn, update_cluster_fn in deployment_iterator:
if process_operation_fn:
process_operation_fn()
if update_cluster_fn:
update_cluster_fn()

assert deployment_iterator.deployment.state == DeploymentStateEnum.SUCCESS
assert len(deployment_iterator.deployment.operations) == 8
Expand All @@ -112,10 +114,15 @@ def test_deployment_plan_with_filter_is_success(
)
deployment_iterator = mock_deployment_runner.run(deployment)

for i, (op, process_operation_fn) in enumerate(deployment_iterator):
for i, (op, process_operation_fn, update_cluster_fn) in enumerate(
deployment_iterator
):
if process_operation_fn:
process_operation_fn()

if update_cluster_fn:
update_cluster_fn()

assert deployment.operations[i].state == OperationStateEnum.SUCCESS

assert deployment_iterator.deployment.state == DeploymentStateEnum.SUCCESS
Expand All @@ -133,6 +140,9 @@ def test_noop_deployment_plan_is_success(
process_operation_fn = _[1]
if process_operation_fn:
process_operation_fn()
update_cluster_fn = _[2]
if update_cluster_fn:
update_cluster_fn()
assert deployment.operations[i].state == OperationStateEnum.SUCCESS

assert deployment_iterator.deployment.state == DeploymentStateEnum.SUCCESS
Expand All @@ -146,9 +156,11 @@ def test_failed_operation_stops(
deployment = DeploymentModel.from_dag(mock_dag, targets=["serv_init"])
deployment_iterator = mock_deployment_runner_failing.run(deployment)

for op, process_operation_fn in deployment_iterator:
for op, process_operation_fn, update_cluster_fn in deployment_iterator:
if process_operation_fn:
process_operation_fn()
if update_cluster_fn:
update_cluster_fn()
assert deployment_iterator.deployment.state == DeploymentStateEnum.FAILURE
assert len(deployment_iterator.deployment.operations) == 8

Expand All @@ -160,9 +172,11 @@ def test_service_log_is_emitted(
deployment = DeploymentModel.from_dag(mock_dag, targets=["serv_init"])
deployment_iterator = mock_deployment_runner.run(deployment)

for op, process_operation_fn in deployment_iterator:
for op, process_operation_fn, update_cluster_fn in deployment_iterator:
if process_operation_fn:
process_operation_fn()
if update_cluster_fn:
update_cluster_fn()

assert deployment_iterator.deployment.state == DeploymentStateEnum.SUCCESS

Expand All @@ -176,9 +190,11 @@ def test_service_log_is_not_emitted(
)
deployment_iterator = mock_deployment_runner.run(deployment)

for op, process_operation_fn in deployment_iterator:
for op, process_operation_fn, update_cluster_fn in deployment_iterator:
if process_operation_fn:
process_operation_fn()
if update_cluster_fn:
update_cluster_fn()

assert deployment_iterator.deployment.state == DeploymentStateEnum.SUCCESS

Expand All @@ -192,9 +208,11 @@ def test_service_log_only_noop_is_emitted(
)
deployment_iterator = mock_deployment_runner.run(deployment)

for op, process_operation_fn in deployment_iterator:
for op, process_operation_fn, update_cluster_fn in deployment_iterator:
if process_operation_fn:
process_operation_fn()
if update_cluster_fn:
update_cluster_fn()

assert deployment_iterator.deployment.state == DeploymentStateEnum.SUCCESS

Expand All @@ -208,9 +226,11 @@ def test_service_log_not_emitted_when_config_start_wrong_order(
)
deployment_iterator = mock_deployment_runner.run(deployment)

for op, process_operation_fn in deployment_iterator:
for op, process_operation_fn, update_cluster_fn in deployment_iterator:
if process_operation_fn:
process_operation_fn()
if update_cluster_fn:
update_cluster_fn()

assert deployment_iterator.deployment.state == DeploymentStateEnum.SUCCESS

Expand All @@ -224,9 +244,11 @@ def test_service_log_emitted_once_with_start_and_restart(
)
deployment_iterator = mock_deployment_runner.run(deployment)

for op, process_operation_fn in deployment_iterator:
for op, process_operation_fn, update_cluster_fn in deployment_iterator:
if process_operation_fn:
process_operation_fn()
if update_cluster_fn:
update_cluster_fn()

assert deployment_iterator.deployment.state == DeploymentStateEnum.SUCCESS

Expand All @@ -246,9 +268,11 @@ def test_service_log_emitted_once_with_multiple_config_and_start_on_same_compone
)
deployment_iterator = mock_deployment_runner.run(deployment)

for op, process_operation_fn in deployment_iterator:
for op, process_operation_fn, update_cluster_fn in deployment_iterator:
if process_operation_fn:
process_operation_fn()
if update_cluster_fn:
update_cluster_fn()

assert deployment_iterator.deployment.state == DeploymentStateEnum.SUCCESS

Expand All @@ -261,19 +285,23 @@ def test_deployment_dag_is_resumed(
):
deployment = DeploymentModel.from_dag(mock_dag, targets=["serv_init"])
deployment_iterator = mock_deployment_runner_failing.run(deployment)
for op, process_operation_fn in deployment_iterator:
for op, process_operation_fn, update_cluster_fn in deployment_iterator:
if process_operation_fn:
process_operation_fn()
if update_cluster_fn:
update_cluster_fn()

assert deployment_iterator.deployment.state == DeploymentStateEnum.FAILURE

resume_log = DeploymentModel.from_failed_deployment(
mock_collections, deployment_iterator.deployment
)
resume_deployment_iterator = mock_deployment_runner.run(resume_log)
for op, process_operation_fn in resume_deployment_iterator:
for op, process_operation_fn, update_cluster_fn in resume_deployment_iterator:
if process_operation_fn:
process_operation_fn()
if update_cluster_fn:
update_cluster_fn()

assert (
resume_deployment_iterator.deployment.deployment_type
Expand Down

0 comments on commit 9584116

Please sign in to comment.