diff --git a/src/flyte/_internal/controllers/remote/_core.py b/src/flyte/_internal/controllers/remote/_core.py index 3c06c55f8..c3dfb3899 100644 --- a/src/flyte/_internal/controllers/remote/_core.py +++ b/src/flyte/_internal/controllers/remote/_core.py @@ -213,9 +213,10 @@ def _bg_thread_target(self): """Target function for the controller thread that creates and manages its own event loop""" try: # Create a new event loop for this thread - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - self._loop.set_exception_handler(flyte.errors.silence_grpc_polling_error) + with self._thread_com_lock: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._loop.set_exception_handler(flyte.errors.silence_grpc_polling_error) logger.debug(f"Controller thread started with new event loop: {threading.current_thread().name}") # Create an event to signal the errors were observed in the thread's loop diff --git a/src/flyte/_task.py b/src/flyte/_task.py index 521b75b93..31c2ed8c6 100644 --- a/src/flyte/_task.py +++ b/src/flyte/_task.py @@ -24,6 +24,7 @@ ) from flyte._pod import PodTemplate +from flyte._utils.asyncify import run_sync_with_loop from flyte.errors import RuntimeSystemError, RuntimeUserError from ._cache import Cache, CacheRequest @@ -493,7 +494,8 @@ async def execute(self, *args: P.args, **kwargs: P.kwargs) -> R: if iscoroutinefunction(self.func): v = await self.func(*args, **kwargs) else: - v = self.func(*args, **kwargs) + v = await run_sync_with_loop(self.func, *args, **kwargs) + await self.post(v) return v diff --git a/src/flyte/_utils/asyncify.py b/src/flyte/_utils/asyncify.py new file mode 100644 index 000000000..1d850ef57 --- /dev/null +++ b/src/flyte/_utils/asyncify.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import asyncio +import contextvars +import inspect +import random +import threading +from typing import Callable, TypeVar + +from typing_extensions import ParamSpec + +from flyte._logging import logger + +T = TypeVar("T") +P = ParamSpec("P") + + +async def run_sync_with_loop( + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> T: + """ + Run a synchronous function from an async context with its own event loop. + + This function: + - Copies the current context variables and preserves them in the sync function + - Creates a new event loop in a separate thread for the sync function + - Allows the sync function to potentially use asyncio operations + - Returns the result without blocking the calling async event loop + + Args: + func: The synchronous function to run (must not be an async function) + *args: Positional arguments to pass to the function + **kwargs: Keyword arguments to pass to the function + + Returns: + The result of calling func(*args, **kwargs) + + Raises: + TypeError: If func is an async function (coroutine function) + + Example: + async def my_async_function(): + result = await run_sync_with_loop(some_sync_function, arg1, arg2) + return result + """ + # Check if func is an async function + if inspect.iscoroutinefunction(func): + raise TypeError( + f"Cannot call run_sync_with_loop with async function '{func.__name__}'. " + "This utility is for running sync functions from async contexts." + ) + + copied_ctx = contextvars.copy_context() + execute_loop = None + execute_loop_created = threading.Event() + + # Build thread name with random suffix for uniqueness + func_name = getattr(func, "__name__", "unknown") + current_thread = threading.current_thread().name + random_suffix = f"{random.getrandbits(32):08x}" + full_thread_name = f"sync-executor-{random_suffix}_from_{current_thread}" + + def _sync_thread_loop_runner() -> None: + """This method runs the event loop and should be invoked in a separate thread.""" + nonlocal execute_loop + try: + execute_loop = asyncio.new_event_loop() + asyncio.set_event_loop(execute_loop) + logger.debug(f"Created event loop for function '{func_name}' in thread '{full_thread_name}'") + execute_loop_created.set() + execute_loop.run_forever() + except Exception as e: + logger.error(f"Exception in thread '{full_thread_name}' running '{func_name}': {e}", exc_info=True) + raise + finally: + if execute_loop: + logger.debug(f"Stopping event loop for function '{func_name}' in thread '{full_thread_name}'") + execute_loop.stop() + execute_loop.close() + logger.debug(f"Cleaned up event loop for function '{func_name}' in thread '{full_thread_name}'") + + executor_thread = threading.Thread( + name=full_thread_name, + daemon=True, + target=_sync_thread_loop_runner, + ) + logger.debug(f"Starting executor thread '{full_thread_name}' for function '{func_name}'") + executor_thread.start() + + async def async_wrapper(): + res = copied_ctx.run(func, *args, **kwargs) + return res + + # Wait for the loop to be created in a thread to avoid blocking the current thread + await asyncio.get_event_loop().run_in_executor(None, execute_loop_created.wait) + assert execute_loop is not None + fut = asyncio.run_coroutine_threadsafe(async_wrapper(), loop=execute_loop) + async_fut = asyncio.wrap_future(fut) + result = await async_fut + + return result diff --git a/tests/flyte/utils/test_asyncify.py b/tests/flyte/utils/test_asyncify.py new file mode 100644 index 000000000..ce969f4b4 --- /dev/null +++ b/tests/flyte/utils/test_asyncify.py @@ -0,0 +1,124 @@ +import asyncio +import contextvars +import threading + +import pytest + +from flyte._utils.asyncify import run_sync_with_loop + +# Context variable for testing context preservation +test_context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context_var") + + +@pytest.mark.asyncio +async def test_basic_sync_function(): + """Test that a basic sync function can be called and returns the correct result.""" + + def sync_add(a: int, b: int) -> int: + return a + b + + result = await run_sync_with_loop(sync_add, 5, 7) + assert result == 12 + + +@pytest.mark.asyncio +async def test_sync_function_with_kwargs(): + """Test that kwargs are properly passed to the sync function.""" + + def sync_multiply(x: int, y: int, multiplier: int = 1) -> int: + return x * y * multiplier + + result = await run_sync_with_loop(sync_multiply, 3, 4, multiplier=2) + assert result == 24 + + +@pytest.mark.asyncio +async def test_context_variable_preservation(): + """Test that context variables are preserved when calling the sync function.""" + test_context_var.set("test_value") + + def get_context_value() -> str: + return test_context_var.get() + + result = await run_sync_with_loop(get_context_value) + assert result == "test_value" + + +@pytest.mark.asyncio +async def test_raises_error_on_async_function(): + """Test that TypeError is raised when trying to run an async function.""" + + async def async_function(): + return 42 + + with pytest.raises(TypeError) as exc_info: + await run_sync_with_loop(async_function) + + assert "Cannot call run_sync_with_loop with async function" in str(exc_info.value) + assert "async_function" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_sync_function_has_own_event_loop(): + """Test that the sync function runs with its own event loop.""" + main_loop_id = id(asyncio.get_event_loop()) + + def get_loop_info() -> tuple: + # Get the loop that the sync function is running in + loop = asyncio.get_event_loop() + loop_id = id(loop) + thread_name = threading.current_thread().name + return loop_id, thread_name + + loop_id, thread_name = await run_sync_with_loop(get_loop_info) + + # The sync function should have a different event loop than the main async function + assert loop_id != main_loop_id + # And it should be running in a different thread + assert "sync-executor" in thread_name + + +@pytest.mark.asyncio +async def test_thread_name_uniqueness(): + """Test that different invocations create threads with unique names.""" + thread_names = [] + + def capture_thread_name() -> str: + name = threading.current_thread().name + thread_names.append(name) + return name + + # Run multiple times + name1 = await run_sync_with_loop(capture_thread_name) + name2 = await run_sync_with_loop(capture_thread_name) + + # Thread names should be different due to random suffix + assert name1 != name2 + assert "sync-executor" in name1 + assert "sync-executor" in name2 + assert "_from_" in name1 + assert "_from_" in name2 + + +@pytest.mark.asyncio +async def test_exception_propagation(): + """Test that exceptions raised in sync functions are properly propagated.""" + + def sync_function_that_raises(): + raise ValueError("Test error message") + + with pytest.raises(ValueError) as exc_info: + await run_sync_with_loop(sync_function_that_raises) + + assert "Test error message" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_return_complex_types(): + """Test that complex return types are properly returned.""" + + def sync_function_returning_dict() -> dict: + return {"key1": "value1", "key2": [1, 2, 3], "key3": {"nested": True}} + + result = await run_sync_with_loop(sync_function_returning_dict) + assert result == {"key1": "value1", "key2": [1, 2, 3], "key3": {"nested": True}}