Skip to content
Merged
7 changes: 4 additions & 3 deletions src/flyte/_internal/controllers/remote/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,10 @@ def _bg_thread_target(self):
"""Target function for the controller thread that creates and manages its own event loop"""
try:
# Create a new event loop for this thread
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._loop.set_exception_handler(flyte.errors.silence_grpc_polling_error)
with self._thread_com_lock:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._loop.set_exception_handler(flyte.errors.silence_grpc_polling_error)
logger.debug(f"Controller thread started with new event loop: {threading.current_thread().name}")

# Create an event to signal the errors were observed in the thread's loop
Expand Down
4 changes: 3 additions & 1 deletion src/flyte/_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)

from flyte._pod import PodTemplate
from flyte._utils.asyncify import run_sync_with_loop
from flyte.errors import RuntimeSystemError, RuntimeUserError

from ._cache import Cache, CacheRequest
Expand Down Expand Up @@ -493,7 +494,8 @@ async def execute(self, *args: P.args, **kwargs: P.kwargs) -> R:
if iscoroutinefunction(self.func):
v = await self.func(*args, **kwargs)
else:
v = self.func(*args, **kwargs)
v = await run_sync_with_loop(self.func, *args, **kwargs)

await self.post(v)
return v

Expand Down
103 changes: 103 additions & 0 deletions src/flyte/_utils/asyncify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

import asyncio
import contextvars
import inspect
import random
import threading
from typing import Callable, TypeVar

from typing_extensions import ParamSpec

from flyte._logging import logger

T = TypeVar("T")
P = ParamSpec("P")


async def run_sync_with_loop(
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
"""
Run a synchronous function from an async context with its own event loop.

This function:
- Copies the current context variables and preserves them in the sync function
- Creates a new event loop in a separate thread for the sync function
- Allows the sync function to potentially use asyncio operations
- Returns the result without blocking the calling async event loop

Args:
func: The synchronous function to run (must not be an async function)
*args: Positional arguments to pass to the function
**kwargs: Keyword arguments to pass to the function

Returns:
The result of calling func(*args, **kwargs)

Raises:
TypeError: If func is an async function (coroutine function)

Example:
async def my_async_function():
result = await run_sync_with_loop(some_sync_function, arg1, arg2)
return result
"""
# Check if func is an async function
if inspect.iscoroutinefunction(func):
raise TypeError(
f"Cannot call run_sync_with_loop with async function '{func.__name__}'. "
"This utility is for running sync functions from async contexts."
)

copied_ctx = contextvars.copy_context()
execute_loop = None
execute_loop_created = threading.Event()

# Build thread name with random suffix for uniqueness
func_name = getattr(func, "__name__", "unknown")
current_thread = threading.current_thread().name
random_suffix = f"{random.getrandbits(32):08x}"
full_thread_name = f"sync-executor-{random_suffix}_from_{current_thread}"

def _sync_thread_loop_runner() -> None:
"""This method runs the event loop and should be invoked in a separate thread."""
nonlocal execute_loop
try:
execute_loop = asyncio.new_event_loop()
asyncio.set_event_loop(execute_loop)
logger.debug(f"Created event loop for function '{func_name}' in thread '{full_thread_name}'")
execute_loop_created.set()
execute_loop.run_forever()
except Exception as e:
logger.error(f"Exception in thread '{full_thread_name}' running '{func_name}': {e}", exc_info=True)
raise
finally:
if execute_loop:
logger.debug(f"Stopping event loop for function '{func_name}' in thread '{full_thread_name}'")
execute_loop.stop()
execute_loop.close()
logger.debug(f"Cleaned up event loop for function '{func_name}' in thread '{full_thread_name}'")

executor_thread = threading.Thread(
name=full_thread_name,
daemon=True,
target=_sync_thread_loop_runner,
)
logger.debug(f"Starting executor thread '{full_thread_name}' for function '{func_name}'")
executor_thread.start()

async def async_wrapper():
res = copied_ctx.run(func, *args, **kwargs)
return res

# Wait for the loop to be created in a thread to avoid blocking the current thread
await asyncio.get_event_loop().run_in_executor(None, execute_loop_created.wait)
assert execute_loop is not None
fut = asyncio.run_coroutine_threadsafe(async_wrapper(), loop=execute_loop)
async_fut = asyncio.wrap_future(fut)
result = await async_fut

return result
124 changes: 124 additions & 0 deletions tests/flyte/utils/test_asyncify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import asyncio
import contextvars
import threading

import pytest

from flyte._utils.asyncify import run_sync_with_loop

# Context variable for testing context preservation
test_context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context_var")


@pytest.mark.asyncio
async def test_basic_sync_function():
"""Test that a basic sync function can be called and returns the correct result."""

def sync_add(a: int, b: int) -> int:
return a + b

result = await run_sync_with_loop(sync_add, 5, 7)
assert result == 12


@pytest.mark.asyncio
async def test_sync_function_with_kwargs():
"""Test that kwargs are properly passed to the sync function."""

def sync_multiply(x: int, y: int, multiplier: int = 1) -> int:
return x * y * multiplier

result = await run_sync_with_loop(sync_multiply, 3, 4, multiplier=2)
assert result == 24


@pytest.mark.asyncio
async def test_context_variable_preservation():
"""Test that context variables are preserved when calling the sync function."""
test_context_var.set("test_value")

def get_context_value() -> str:
return test_context_var.get()

result = await run_sync_with_loop(get_context_value)
assert result == "test_value"


@pytest.mark.asyncio
async def test_raises_error_on_async_function():
"""Test that TypeError is raised when trying to run an async function."""

async def async_function():
return 42

with pytest.raises(TypeError) as exc_info:
await run_sync_with_loop(async_function)

assert "Cannot call run_sync_with_loop with async function" in str(exc_info.value)
assert "async_function" in str(exc_info.value)


@pytest.mark.asyncio
async def test_sync_function_has_own_event_loop():
"""Test that the sync function runs with its own event loop."""
main_loop_id = id(asyncio.get_event_loop())

def get_loop_info() -> tuple:
# Get the loop that the sync function is running in
loop = asyncio.get_event_loop()
loop_id = id(loop)
thread_name = threading.current_thread().name
return loop_id, thread_name

loop_id, thread_name = await run_sync_with_loop(get_loop_info)

# The sync function should have a different event loop than the main async function
assert loop_id != main_loop_id
# And it should be running in a different thread
assert "sync-executor" in thread_name


@pytest.mark.asyncio
async def test_thread_name_uniqueness():
"""Test that different invocations create threads with unique names."""
thread_names = []

def capture_thread_name() -> str:
name = threading.current_thread().name
thread_names.append(name)
return name

# Run multiple times
name1 = await run_sync_with_loop(capture_thread_name)
name2 = await run_sync_with_loop(capture_thread_name)

# Thread names should be different due to random suffix
assert name1 != name2
assert "sync-executor" in name1
assert "sync-executor" in name2
assert "_from_" in name1
assert "_from_" in name2


@pytest.mark.asyncio
async def test_exception_propagation():
"""Test that exceptions raised in sync functions are properly propagated."""

def sync_function_that_raises():
raise ValueError("Test error message")

with pytest.raises(ValueError) as exc_info:
await run_sync_with_loop(sync_function_that_raises)

assert "Test error message" in str(exc_info.value)


@pytest.mark.asyncio
async def test_return_complex_types():
"""Test that complex return types are properly returned."""

def sync_function_returning_dict() -> dict:
return {"key1": "value1", "key2": [1, 2, 3], "key3": {"nested": True}}

result = await run_sync_with_loop(sync_function_returning_dict)
assert result == {"key1": "value1", "key2": [1, 2, 3], "key3": {"nested": True}}
Loading