Skip to content

Commit

Permalink
Add test cases, update to support torch==2.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
acherstyx committed May 18, 2024
1 parent 9adeef1 commit 0c2116c
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 3 deletions.
18 changes: 16 additions & 2 deletions hydra_plugins/hydra_torchrun_launcher/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,31 @@
# @Author : Yaojie Shen
# @Project : hydra-torchrun-launcher
# @File : config.py

from dataclasses import dataclass, field
from hydra.core.config_store import ConfigStore
from typing import Optional, List

from packaging.version import Version
from importlib.metadata import version

from hydra.core.config_store import ConfigStore


@dataclass
class TorchDistributedLauncherConf:
_target_: str = "hydra_plugins.hydra_torchrun_launcher.distributed_launcher.TorchDistributedLauncher"

# Worker/node size related arguments.
nnodes: str = "1:1"
nproc_per_node: str = "1"

# Rendezvous related arguments
rdzv_backend: str = "static"
rdzv_endpoint: str = ""
rdzv_id: str = "none"
rdzv_conf: str = ""
standalone: bool = False

# User-code launch related arguments.
max_restarts: int = 0
monitor_interval: int = 5
start_method: str = "spawn"
Expand All @@ -33,16 +39,24 @@ class TorchDistributedLauncherConf:
redirects: str = "0"
tee: str = "0"

# Backwards compatible parameters with caffe2.distributed.launch.
node_rank: int = 0
master_addr: str = "127.0.0.1"
master_port: int = 29500
local_addr: Optional[str] = None

# Positional arguments.
# Should not set
training_script: str = ""
training_script_args: List[str] = field(default_factory=list)


if Version(version("torch")) >= Version("2.3"):
@dataclass
class TorchDistributedLauncherConf(TorchDistributedLauncherConf):
logs_specs: Optional[str] = None
local_ranks_filter: str = ""

ConfigStore.instance().store(
group="hydra/launcher", name="torchrun", node=TorchDistributedLauncherConf
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "hydra-torchrun-launcher"
version = "0.1.1"
version = "0.1.2"
description = "Hydra torchrun launcher"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
5 changes: 5 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-
# @Time : 5/15/24
# @Author : Yaojie Shen
# @Project : hydra-torchrun-launcher
# @File : __init__.py
64 changes: 64 additions & 0 deletions test/test_hydra_torchrun_launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from hydra.core.plugins import Plugins
from hydra.plugins.launcher import Launcher
from hydra.test_utils.launcher_common_tests import (
IntegrationTestSuite,
LauncherTestSuite,
)
from pytest import mark

from hydra_plugins.hydra_torchrun_launcher.distributed_launcher import (
TorchDistributedLauncher,
)


def test_discovery() -> None:
# Tests that this plugin can be discovered via the plugins subsystem when looking for Launchers
assert TorchDistributedLauncher.__name__ in [
x.__name__ for x in Plugins.instance().discover(Launcher)
]


@mark.parametrize("launcher_name, overrides", [("torchrun", [])])
class TestTorechDistributedLauncher(LauncherTestSuite):
"""
Run the Launcher test suite on this launcher.
"""

pass


@mark.parametrize(
"task_launcher_cfg, extra_flags",
[
(
{},
[
"-m",
"hydra/job_logging=hydra_debug",
"hydra/job_logging=disabled",
"hydra/launcher=torchrun",
],
)
],
)
class TestTorechDistributedLauncherIntegration(IntegrationTestSuite):
"""
Run this launcher through the integration test suite.
"""

pass

0 comments on commit 0c2116c

Please sign in to comment.