Skip to content

Commit c495ef3

Browse files
author
pytorchbot
committed
2025-12-05 nightly release (0a2cebd)
1 parent 6c029a8 commit c495ef3

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

.github/workflows/unittest_ci.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
strategy:
3737
fail-fast: false
3838
matrix:
39-
cuda-tag: ["cu126", "cu128"]
39+
cuda-tag: ["cu126", "cu128", "cu129"]
4040
os:
4141
- linux.g5.12xlarge.nvidia.gpu
4242
python:
@@ -55,18 +55,20 @@ jobs:
5555
cuda-tag: "cu126"
5656
- is_pr: true
5757
cuda-tag: "cu128"
58+
- is_pr: true
59+
cuda-tag: "cu129"
5860
python:
5961
version: "3.9"
6062
- is_pr: true
61-
cuda-tag: "cu128"
63+
cuda-tag: "cu129"
6264
python:
6365
version: "3.10"
6466
- is_pr: true
65-
cuda-tag: "cu128"
67+
cuda-tag: "cu129"
6668
python:
6769
version: "3.11"
6870
- is_pr: true
69-
cuda-tag: "cu128"
71+
cuda-tag: "cu129"
7072
python:
7173
version: "3.12"
7274
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

torchrec/distributed/test_utils/test_model_parallel_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def test_sharding_ebc_input_validation_enabled(self, mock_jk: Mock) -> None:
253253

254254
with self.assertRaisesRegex(ValueError, "keys must be unique"):
255255
model(kjt)
256-
mock_jk.assert_called_once_with("pytorch/torchrec:enable_kjt_validation")
256+
mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation")
257257

258258
@patch("torch._utils_internal.justknobs_check")
259259
def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
@@ -271,7 +271,7 @@ def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
271271
model(kjt)
272272
model(kjt)
273273

274-
mock_jk.assert_called_once_with("pytorch/torchrec:enable_kjt_validation")
274+
mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation")
275275
matched_logs = list(
276276
filter(lambda s: "Validating input features..." in s, logs.output)
277277
)
@@ -294,7 +294,7 @@ def test_sharding_ebc_input_validation_disabled(self, mock_jk: Mock) -> None:
294294
except ValueError:
295295
self.fail("Input validation should not be enabled.")
296296

297-
mock_jk.assert_called_once_with("pytorch/torchrec:enable_kjt_validation")
297+
mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation")
298298

299299
def _create_sharded_model(
300300
self, embedding_dim: int = 128, num_embeddings: int = 256

0 commit comments

Comments
 (0)