Skip to content

Commit

Permalink
Prevent race condition with location reload and backups list (#5602)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdegat01 authored Feb 5, 2025
1 parent 01382e7 commit 129a37a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 17 deletions.
49 changes: 32 additions & 17 deletions supervisor/backups/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,15 @@ def load(self) -> Awaitable[None]:
async def reload(self, location: str | None | type[DEFAULT] = DEFAULT) -> bool:
"""Load exists backups."""

backups: dict[str, Backup] = {}

async def _load_backup(location_name: str | None, tar_file: Path) -> bool:
"""Load the backup."""
backup = Backup(self.coresys, tar_file, "temp", location_name)
if await backup.load():
if backup.slug in self._backups:
if backup.slug in backups:
try:
self._backups[backup.slug].consolidate(backup)
backups[backup.slug].consolidate(backup)
except BackupInvalidError as err:
_LOGGER.error(
"Ignoring backup %s in %s due to: %s",
Expand All @@ -247,23 +249,18 @@ async def _load_backup(location_name: str | None, tar_file: Path) -> bool:
return False

else:
self._backups[backup.slug] = Backup(
backups[backup.slug] = Backup(
self.coresys, tar_file, backup.slug, location_name, backup.data
)
return True

return False

# Single location refresh clears out just that part of the cache and rebuilds it
if location != DEFAULT:
locations = {location: self.backup_locations[location]}
for backup in self.list_backups:
if location in backup.all_locations:
del backup.all_locations[location]
else:
locations = self.backup_locations
self._backups = {}

locations = (
self.backup_locations
if location == DEFAULT
else {location: self.backup_locations[location]}
)
tasks = [
self.sys_create_task(_load_backup(_location, tar_file))
for _location, path in locations.items()
Expand All @@ -274,10 +271,28 @@ async def _load_backup(location_name: str | None, tar_file: Path) -> bool:
if tasks:
await asyncio.wait(tasks)

# Remove any backups with no locations from cache (only occurs in single location refresh)
if location != DEFAULT:
for backup in list(self.list_backups):
if not backup.all_locations:
# For a full reload, replace our cache with new one
if location == DEFAULT:
self._backups = backups
return True

# For a location reload, merge new cache in with existing
for backup in list(self.list_backups):
if backup.slug in backups:
try:
backup.consolidate(backups[backup.slug])
except BackupInvalidError as err:
_LOGGER.error(
"Ignoring backup %s in %s due to: %s",
backup.slug,
location,
err,
)

elif location in backup.all_locations:
if len(backup.all_locations) > 1:
del backup.all_locations[location]
else:
del self._backups[backup.slug]

return True
Expand Down
49 changes: 49 additions & 0 deletions tests/api/test_backups.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,3 +1300,52 @@ async def test_missing_file_removes_backup_from_cache(
# Wait for reload task to complete and confirm backup is removed
await asyncio.sleep(0)
assert not coresys.backups.list_backups


@pytest.mark.usefixtures("tmp_supervisor_data")
async def test_immediate_list_after_missing_file_restore(
api_client: TestClient, coresys: CoreSys
):
"""Test race with reload for missing file on restore does not error."""
coresys.core.state = CoreState.RUNNING
coresys.hardware.disk.get_disk_free_space = lambda x: 5000

backup_file = get_fixture_path("backup_example.tar")
bad_location = Path(copy(backup_file, coresys.config.path_backup))
# Copy a second backup in so there's something to reload later
copy(get_fixture_path("backup_example_enc.tar"), coresys.config.path_backup)
await coresys.backups.reload()

# After reload, remove one of the file and confirm we have an out of date cache
bad_location.unlink()
assert coresys.backups.get("7fed74c8").all_locations.keys() == {None}

event = asyncio.Event()
orig_wait = asyncio.wait

async def mock_wait(tasks: list[asyncio.Task], *args, **kwargs):
"""Mock for asyncio wait that allows force of race condition."""
if tasks[0].get_coro().__qualname__.startswith("BackupManager.reload"):
await event.wait()
return await orig_wait(tasks, *args, **kwargs)

with patch("supervisor.backups.manager.asyncio.wait", new=mock_wait):
resp = await api_client.post(
"/backups/7fed74c8/restore/partial",
json={"location": ".local", "folders": ["ssl"]},
)
assert resp.status == 404

await asyncio.sleep(0)
resp = await api_client.get("/backups")
assert resp.status == 200
result = await resp.json()
assert len(result["data"]["backups"]) == 2

event.set()
await asyncio.sleep(0.1)
resp = await api_client.get("/backups")
assert resp.status == 200
result = await resp.json()
assert len(result["data"]["backups"]) == 1
assert result["data"]["backups"][0]["slug"] == "93b462f8"

0 comments on commit 129a37a

Please sign in to comment.