diff --git a/tests/hooks/test_ray_hooks.py b/tests/hooks/test_ray_hooks.py index 5b97346..95c787d 100644 --- a/tests/hooks/test_ray_hooks.py +++ b/tests/hooks/test_ray_hooks.py @@ -676,3 +676,99 @@ def test_delete_ray_cluster_success( mock_delete_daemon_set.assert_called_once() mock_delete_custom_object.assert_called_once() mock_uninstall_kuberay_operator.assert_called_once() + + @patch("ray_provider.hooks.ray.JobSubmissionClient") + def test_ray_client_exception(self, mock_job_client, ray_hook): + mock_job_client.side_effect = Exception("Connection failed") + with pytest.raises(AirflowException) as exc_info: + ray_hook.ray_client() + assert str(exc_info.value) == "Failed to create Ray JobSubmissionClient: Connection failed" + + @patch("ray_provider.hooks.ray.RayHook.get_custom_object") + @patch("ray_provider.hooks.ray.RayHook.create_custom_object") + def test_create_or_update_cluster_exception(self, mock_create, mock_get, ray_hook): + mock_get.side_effect = client.exceptions.ApiException(status=500, reason="Internal Server Error") + with pytest.raises(AirflowException) as exc_info: + ray_hook._create_or_update_cluster( + update_if_exists=False, + group="ray.io", + version="v1", + plural="rayclusters", + name="test-cluster", + namespace="default", + cluster_spec={}, + ) + assert "Error accessing Ray cluster 'test-cluster'" in str(exc_info.value) + + @patch("ray_provider.hooks.ray.RayHook.get_custom_object") + @patch("ray_provider.hooks.ray.RayHook.custom_object_client") + def test_create_or_update_cluster_update(self, mock_client, mock_get, ray_hook): + mock_get.return_value = {"metadata": {"name": "test-cluster"}} + ray_hook._create_or_update_cluster( + update_if_exists=True, + group="ray.io", + version="v1", + plural="rayclusters", + name="test-cluster", + namespace="default", + cluster_spec={"spec": {"some": "config"}}, + ) + mock_client.patch_namespaced_custom_object.assert_called_once_with( + group="ray.io", + version="v1", + namespace="default", + plural="rayclusters", + name="test-cluster", + body={"spec": {"some": "config"}}, + ) + + @patch("ray_provider.hooks.ray.RayHook._validate_yaml_file") + @patch("ray_provider.hooks.ray.RayHook.install_kuberay_operator") + @patch("ray_provider.hooks.ray.RayHook.load_yaml_content") + @patch("ray_provider.hooks.ray.RayHook._create_or_update_cluster") + @patch("ray_provider.hooks.ray.RayHook._setup_gpu_driver") + @patch("ray_provider.hooks.ray.RayHook._setup_load_balancer") + def test_setup_ray_cluster_exception( + self, + mock_setup_lb, + mock_setup_gpu, + mock_create_or_update, + mock_load_yaml, + mock_install_operator, + mock_validate_yaml, + ray_hook, + ): + mock_create_or_update.side_effect = Exception("Cluster creation failed") + context = {"task_instance": MagicMock()} + with pytest.raises(AirflowException) as exc_info: + ray_hook.setup_ray_cluster( + context=context, + ray_cluster_yaml="test.yaml", + kuberay_version="1.0.0", + gpu_device_plugin_yaml="gpu.yaml", + update_if_exists=False, + ) + assert "Failed to set up Ray cluster: Cluster creation failed" in str(exc_info.value) + + @patch("ray_provider.hooks.ray.RayHook._validate_yaml_file") + @patch("ray_provider.hooks.ray.RayHook.load_yaml_content") + @patch("ray_provider.hooks.ray.RayHook.get_custom_object") + @patch("ray_provider.hooks.ray.RayHook.delete_custom_object") + @patch("ray_provider.hooks.ray.RayHook.get_daemon_set") + @patch("ray_provider.hooks.ray.RayHook.delete_daemon_set") + @patch("ray_provider.hooks.ray.RayHook.uninstall_kuberay_operator") + def test_delete_ray_cluster_exception( + self, + mock_uninstall_operator, + mock_delete_daemon_set, + mock_get_daemon_set, + mock_delete_custom_object, + mock_get_custom_object, + mock_load_yaml, + mock_validate_yaml, + ray_hook, + ): + mock_delete_custom_object.side_effect = Exception("Cluster deletion failed") + with pytest.raises(AirflowException) as exc_info: + ray_hook.delete_ray_cluster(ray_cluster_yaml="test.yaml", gpu_device_plugin_yaml="gpu.yaml") + assert "Failed to delete Ray cluster: Cluster deletion failed" in str(exc_info.value) diff --git a/tests/operators/test_ray_operators.py b/tests/operators/test_ray_operators.py index 0303cd3..a22e62e 100644 --- a/tests/operators/test_ray_operators.py +++ b/tests/operators/test_ray_operators.py @@ -346,3 +346,45 @@ def test_template_fields(self): "ray_cluster_yaml", "job_timeout_seconds", ) + + @patch("ray_provider.operators.ray.RayHook") + def test_setup_cluster_exception(self, mock_ray_hook, context): + operator = SubmitRayJob( + task_id="test_task", + conn_id="test_conn", + entrypoint="python script.py", + runtime_env={}, + ray_cluster_yaml="cluster.yaml", + ) + + mock_hook = mock_ray_hook.return_value + operator.hook = mock_hook + + mock_hook.setup_ray_cluster.side_effect = Exception("Cluster setup failed") + + with pytest.raises(Exception) as exc_info: + operator._setup_cluster(context) + + assert str(exc_info.value) == "Cluster setup failed" + mock_hook.setup_ray_cluster.assert_called_once() + + @patch("ray_provider.operators.ray.RayHook") + def test_delete_cluster_exception(self, mock_ray_hook): + operator = SubmitRayJob( + task_id="test_task", + conn_id="test_conn", + entrypoint="python script.py", + runtime_env={}, + ray_cluster_yaml="cluster.yaml", + ) + + mock_hook = mock_ray_hook.return_value + operator.hook = mock_hook + + mock_hook.delete_ray_cluster.side_effect = Exception("Cluster deletion failed") + + with pytest.raises(Exception) as exc_info: + operator._delete_cluster() + + assert str(exc_info.value) == "Cluster deletion failed" + mock_hook.delete_ray_cluster.assert_called_once()