Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Needs thorough testing] async model file listing #4968

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 64 additions & 26 deletions folder_paths.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import asyncio, aiofiles, threading
import os
import time
import mimetypes
import logging
from typing import Set, List, Dict, Tuple, Literal
from collections.abc import Collection

import aiofiles.os

supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}

folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
Expand Down Expand Up @@ -194,23 +197,36 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")

logging.debug("recursive file list on directory {}".format(directory))
dirpath: str
subdirs: list[str]
filenames: list[str]

for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)

for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue

async def proc_subdir(path: str):
dirs[path] = await aiofiles.os.path.getmtime(path)

def proc_thread():
asyncio.set_event_loop(asyncio.new_event_loop())
calls = []

async def handle(file):
if not await aiofiles.os.path.isdir(file):
relative_path = os.path.relpath(file, directory)
result.append(relative_path)
return
calls.append(proc_subdir(file))
for subdir in await aiofiles.os.listdir(file):
path = os.path.join(file, subdir)
if subdir not in excluded_dir_names:
calls.append(handle(path))
calls.append(handle(directory))

while len(calls) > 0:
future = asyncio.gather(*calls)
calls = []
asyncio.get_event_loop().run_until_complete(future)
asyncio.get_event_loop().close()

thread = threading.Thread(target=proc_thread)
thread.start()
thread.join()

logging.debug("found {} files".format(len(result)))
return result, dirs

Expand Down Expand Up @@ -263,19 +279,41 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
must_invalidate = [False]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only the first element of must_invalidate is accessed. Is there any reason why it needs to be an array?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python lets you read any variable anywhere very freely, but writing has a lot of edge cases and oddities, which hit you when you're doing threading stuff especially. But if you have an array, then it's just a read you're doing in the thread, and the value inside the array is somewhere in the heap down yonder, so it lets you do it. So the array is being used as the equivalent to a pointer/reference essentially

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems really hacky. Please try use threading.Event to approach this problem.

import asyncio
import aiofiles
import threading
from concurrent.futures import ThreadPoolExecutor

async def check_folder_mtime(folder: str, time_modified: float) -> bool:
    return await aiofiles.os.path.getmtime(folder) != time_modified

async def check_new_dirs(x: str, known_dirs: set) -> bool:
    return await aiofiles.os.path.isdir(x) and x not in known_dirs

async def check_invalidation(folder_names_and_paths, folder_name):
    folders = folder_names_and_paths[folder_name]
    out = folders[1]
    invalidation_event = threading.Event()

    async def process_checks():
        tasks = []
        for x, time_modified in out[1].items():
            tasks.append(check_folder_mtime(x, time_modified))
        for x in folders[0]:
            tasks.append(check_new_dirs(x, set(out[1])))
        
        results = await asyncio.gather(*tasks)
        if any(results):
            invalidation_event.set()

    with ThreadPoolExecutor() as executor:
        future = executor.submit(lambda: asyncio.run(process_checks()))
        future.result()  # Wait for the async operations to complete

    return None if invalidation_event.is_set() else out  # Return None if invalidation is needed

# Usage
result = asyncio.run(check_invalidation(folder_names_and_paths, folder_name))

It replaces the must_invalidate list hack with a threading.Event object, which is designed for inter-thread communication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough, swapped to Event

folders = folder_names_and_paths[folder_name]

for x in out[1]:
time_modified = out[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
async def check_folder_mtime(folder: str, time_modified: float):
if await aiofiles.os.path.getmtime(folder) != time_modified:
must_invalidate[0] = True

folders = folder_names_and_paths[folder_name]
for x in folders[0]:
if os.path.isdir(x):
async def check_new_dirs(x: str):
if await aiofiles.os.path.isdir(x):
if x not in out[1]:
return None
must_invalidate[0] = True

def proc_thread():
asyncio.set_event_loop(asyncio.new_event_loop())
calls = []

for x in out[1]:
time_modified = out[1][x]
call = check_folder_mtime(x, time_modified)
calls.append(call)

for x in folders[0]:
call = check_new_dirs(x)
calls.append(call)

future = asyncio.gather(*calls)
asyncio.get_event_loop().run_until_complete(future)
asyncio.get_event_loop().close()

thread = threading.Thread(target=proc_thread)
thread.start()
thread.join()

if must_invalidate[0]:
return None
return out

def get_filename_list(folder_name: str) -> list[str]:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ transformers>=4.28.1
tokenizers>=0.13.3
sentencepiece
safetensors>=0.4.2
aiofiles
aiohttp
pyyaml
Pillow
Expand Down
Loading