Skip to content

Commit e6690be

Browse files
danyi1212orweis
andauthored
Added hierarchical lock for applying data updates (#745)
* Added hierarchical lock for applying data updates * Added docs and inline comments * Added _fetch_data to handle fetching errors * Removed line reference * Fixed test * Fixed imports * Fixed types for Python 3.9 support * Removed TaskGroup use (python compat) * Fixed import * Fixed types * Update packages/opal-client/opal_client/data/updater.py Co-authored-by: Or Weis <[email protected]> * Update packages/opal-client/opal_client/data/updater.py Co-authored-by: Or Weis <[email protected]> * Fixed pre-commit --------- Co-authored-by: Or Weis <[email protected]>
1 parent 61dc17a commit e6690be

File tree

8 files changed

+755
-258
lines changed

8 files changed

+755
-258
lines changed

packages/opal-client/opal_client/callbacks/reporter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,5 @@ async def report_update_results(
8585
status=result.status,
8686
error=error_content,
8787
)
88-
except:
89-
logger.exception("Failed to execute report_update_results")
88+
except Exception as e:
89+
logger.exception(f"Failed to execute report_update_results: {e}")

packages/opal-client/opal_client/data/fetcher.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import Any, Dict, List, Optional, Tuple
2+
from typing import Any, List, Optional, Tuple
33

44
from opal_client.config import opal_client_config
55
from opal_client.policy_store.base_policy_store_client import JsonableValue
@@ -58,8 +58,8 @@ async def stop(self):
5858
await self._engine.terminate_workers()
5959

6060
async def handle_url(
61-
self, url: str, config: FetcherConfig, data: Optional[JsonableValue]
62-
):
61+
self, url: str, config: dict, data: Optional[JsonableValue]
62+
) -> Optional[JsonableValue]:
6363
"""Helper function wrapping self._engine.handle_url."""
6464
if data is not None:
6565
logger.info("Data provided inline for url: {url}", url=url)
@@ -107,7 +107,7 @@ async def handle_urls(
107107
results_with_url_and_config = [
108108
(url, config, result)
109109
for (url, config, data), result in zip(urls, results)
110-
if result is not None
110+
if result is not None # FIXME ignores None results
111111
]
112112

113113
# return results

packages/opal-client/opal_client/data/updater.py

Lines changed: 399 additions & 235 deletions
Large diffs are not rendered by default.

packages/opal-client/opal_client/tests/data_updater_test.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -196,17 +196,18 @@ async def test_data_updater(server):
196196
proc.terminate()
197197

198198
# test PATCH update event via API
199-
entries = [
200-
DataSourceEntry(
201-
url="",
202-
data=PATCH_DATA_UPDATE,
203-
dst_path="/",
204-
topics=DATA_TOPICS,
205-
save_method="PATCH",
206-
)
207-
]
208199
update = DataUpdate(
209-
reason="Test_Patch", entries=entries, callback=UpdateCallback(callbacks=[])
200+
reason="Test_Patch",
201+
entries=[
202+
DataSourceEntry(
203+
url="",
204+
data=PATCH_DATA_UPDATE,
205+
dst_path="/",
206+
topics=DATA_TOPICS,
207+
save_method="PATCH",
208+
)
209+
],
210+
callback=UpdateCallback(callbacks=[]),
210211
)
211212

212213
headers = {"content-type": "application/json"}
@@ -218,13 +219,26 @@ async def test_data_updater(server):
218219
)
219220
assert res.status_code == 200
220221
# value field is not specified for add operation should fail
221-
entries[0].data = [{"op": "add", "path": "/"}]
222222
res = requests.post(
223223
DATA_UPDATE_ROUTE,
224-
data=json.dumps(update, default=pydantic_encoder),
224+
data=json.dumps(
225+
{
226+
"reason": "Test_Patch",
227+
"entries": [
228+
{
229+
"url": "",
230+
"data": [{"op": "add", "path": "/"}],
231+
"dst_path": "/",
232+
"topics": DATA_TOPICS,
233+
"save_method": "PATCH",
234+
}
235+
],
236+
},
237+
default=pydantic_encoder,
238+
),
225239
headers=headers,
226240
)
227-
assert res.status_code == 422
241+
assert res.status_code == 422, res.text
228242

229243

230244
@pytest.mark.asyncio

packages/opal-common/opal_common/async_utils.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import asyncio
44
import sys
55
from functools import partial
6-
from typing import Any, Callable, Coroutine, List, Optional, Tuple, TypeVar
6+
from typing import Any, Callable, Coroutine, Optional, Set, Tuple, TypeVar
77

88
import loguru
9+
from loguru import logger
910

1011
if sys.version_info < (3, 10):
1112
from typing_extensions import ParamSpec
@@ -94,16 +95,40 @@ async def stop_queue_handling(self):
9495

9596
class TasksPool:
9697
def __init__(self):
97-
self._tasks: List[asyncio.Task] = []
98+
self._tasks: Set[asyncio.Task] = set()
99+
self._running = True
98100

99101
def _cleanup_task(self, done_task):
100102
self._tasks.remove(done_task)
101103

102104
def add_task(self, f):
105+
if not self._running:
106+
raise RuntimeError("TasksPool is already shutdown")
103107
t = asyncio.create_task(f)
104-
self._tasks.append(t)
108+
self._tasks.add(t)
105109
t.add_done_callback(self._cleanup_task)
106110

111+
async def shutdown(self, force: bool = False):
112+
"""Wait for them to finish.
113+
114+
:param force: If True, cancel all tasks immediately.
115+
"""
116+
self._running = False
117+
if force:
118+
for t in self._tasks:
119+
t.cancel()
120+
121+
results = await asyncio.gather(
122+
*self._tasks,
123+
return_exceptions=True,
124+
)
125+
for result in results:
126+
if isinstance(result, Exception):
127+
logger.exception(
128+
"Error on task during shutdown of TasksPool: {result}",
129+
result=result,
130+
)
131+
107132

108133
async def repeated_call(
109134
func: Coroutine,

packages/opal-common/opal_common/fetcher/engine/fetching_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ async def queue_url(
124124
self,
125125
url: str,
126126
callback: Coroutine,
127-
config: Union[FetcherConfig, dict] = None,
127+
config: Union[FetcherConfig, dict, None] = None,
128128
fetcher="HttpFetchProvider",
129129
) -> FetchEvent:
130130
"""Simplified default fetching handler for queuing a fetch task.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import asyncio
2+
from contextlib import asynccontextmanager
3+
from typing import Set
4+
5+
6+
class HierarchicalLock:
7+
"""A hierarchical lock for asyncio.
8+
9+
- If a path is locked, no ancestor or descendant path can be locked.
10+
- Conversely, if a child path is locked, the parent path cannot be locked
11+
until all child paths are released.
12+
"""
13+
14+
def __init__(self):
15+
# locked_paths: set of currently locked string paths
16+
self._locked_paths: Set[str] = set()
17+
# Map of tasks to their acquired locks for re-entrant protection
18+
self._task_locks: dict[asyncio.Task, Set[str]] = {}
19+
# Internal lock for synchronizing access to locked_paths
20+
self._lock = asyncio.Lock()
21+
# Condition to wake up tasks when a path is released
22+
self._cond = asyncio.Condition(self._lock)
23+
24+
@staticmethod
25+
def _is_conflicting(p1: str, p2: str) -> bool:
26+
"""Check if two paths conflict with each other."""
27+
return p1 == p2 or p1.startswith(p2) or p2.startswith(p1)
28+
29+
async def acquire(self, path: str):
30+
"""Acquire the lock for the given hierarchical path.
31+
32+
If an ancestor or descendant path is locked, this will wait
33+
until it is released.
34+
"""
35+
task = asyncio.current_task()
36+
if task is None:
37+
raise RuntimeError("acquire() must be called from within a task.")
38+
39+
async with self._lock:
40+
# Prevent re-entrant locking by the same task
41+
if path in self._task_locks.get(task, set()):
42+
raise RuntimeError(f"Task {task} cannot re-acquire lock on '{path}'.")
43+
44+
# Wait until there is no conflict with existing locked paths
45+
while any(self._is_conflicting(path, lp) for lp in self._locked_paths):
46+
await self._cond.wait()
47+
48+
# Acquire the path
49+
self._locked_paths.add(path)
50+
if task not in self._task_locks:
51+
self._task_locks[task] = set()
52+
self._task_locks[task].add(path)
53+
54+
async def release(self, path: str):
55+
"""Release the lock for the given path and notify waiting tasks."""
56+
task = asyncio.current_task()
57+
if task is None:
58+
raise RuntimeError("release() must be called from within a task.")
59+
60+
async with self._lock:
61+
if path not in self._locked_paths:
62+
raise RuntimeError(f"Cannot release path '{path}' that is not locked.")
63+
64+
if path not in self._task_locks.get(task, set()):
65+
raise RuntimeError(
66+
f"Task {task} cannot release lock on '{path}' it does not hold."
67+
)
68+
69+
# Remove the path from locked paths and task locks
70+
self._locked_paths.remove(path)
71+
self._task_locks[task].remove(path)
72+
if not self._task_locks[task]:
73+
del self._task_locks[task]
74+
75+
# Notify all tasks that something was released
76+
self._cond.notify_all()
77+
78+
@asynccontextmanager
79+
async def lock(self, path: str) -> "HierarchicalLock":
80+
"""Acquire the lock for the given path and return a context manager."""
81+
await self.acquire(path)
82+
try:
83+
yield self
84+
finally:
85+
await self.release(path)

0 commit comments

Comments
 (0)