Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add job_id_position Parameter to launch_slurm_job Method #282

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/datatrove/executor/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class SlurmPipelineExecutor(PipelineExecutor):
depends: another SlurmPipelineExecutor that should run
before this one
depends_job_id: alternatively to the above, you can pass the job id of a dependency
job_id_position: position of job ID in custom Sbatch outputs.
default: -1
logging_dir: where to save logs, stats, etc. Should be parsable into a datatrove.io.DataFolder
skip_completed: whether to skip tasks that were completed in
previous runs. default: True
Expand Down Expand Up @@ -99,6 +101,7 @@ def __init__(
max_array_size: int = 1001,
depends: SlurmPipelineExecutor | None = None,
depends_job_id: str | None = None,
job_id_position: int = -1,
logging_dir: DataFolderLike = None,
skip_completed: bool = True,
slurm_logs_folder: str = None,
Expand Down Expand Up @@ -128,6 +131,7 @@ def __init__(
self.venv_path = venv_path
self.depends = depends
self.depends_job_id = depends_job_id
self.job_id_position = job_id_position
self._sbatch_args = sbatch_args if sbatch_args else {}
self.max_array_size = max_array_size
self.max_array_launch_parallel = max_array_launch_parallel
Expand Down Expand Up @@ -198,7 +202,8 @@ def launch_merge_stats(self):
},
f'merge_stats {self.logging_dir.resolve_paths("stats")} '
f'-o {self.logging_dir.resolve_paths("stats.json")}',
)
),
self.job_id_position,
)

@property
Expand Down Expand Up @@ -275,7 +280,7 @@ def launch_job(self):
args = [f"--export=ALL,RUN_OFFSET={launched_jobs}"]
if self.dependency:
args.append(f"--dependency={self.dependency}")
self.job_id = launch_slurm_job(launch_file_contents, *args)
self.job_id = launch_slurm_job(launch_file_contents, self.job_id_position, *args)
launched_jobs += 1
logger.info(f"Slurm job launched successfully with (last) id={self.job_id}.")
self.launch_merge_stats()
Expand Down Expand Up @@ -353,11 +358,12 @@ def world_size(self) -> int:
return self.tasks


def launch_slurm_job(launch_file_contents, *args):
def launch_slurm_job(launch_file_contents, job_id_position, *args):
"""
Small helper function to save a sbatch script and call it.
Args:
launch_file_contents: Contents of the sbatch script
job_id_position: Index of dependecy job ID.
*args: any other arguments to pass to the sbatch command

Returns: the id of the launched slurm job
Expand All @@ -366,4 +372,4 @@ def launch_slurm_job(launch_file_contents, *args):
with tempfile.NamedTemporaryFile("w") as f:
f.write(launch_file_contents)
f.flush()
return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[-1]
return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[job_id_position]