Skip to content

Commit 6c029a8

Browse files
author
pytorchbot
committed
2025-12-04 nightly release (83288ce)
1 parent 449dd75 commit 6c029a8

File tree

8 files changed

+432
-30
lines changed

8 files changed

+432
-30
lines changed

.github/workflows/unittest_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ on:
3232
- test
3333

3434
jobs:
35-
build_test:
35+
unittest_ci_gpu:
3636
strategy:
3737
fail-fast: false
3838
matrix:

.github/workflows/unittest_ci_cpu.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ on:
3232
- test
3333

3434
jobs:
35-
build_test:
35+
unittest_ci_cpu:
3636
strategy:
3737
fail-fast: false
3838
matrix:
@@ -65,7 +65,7 @@ jobs:
6565
contents: read
6666
with:
6767
runner: ${{ matrix.os }}
68-
timeout: 15
68+
timeout: 20
6969
script: |
7070
ldd --version
7171
conda create -y --name build_binary python=${{ matrix.python.version }}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# this is a very basic sparse data dist config
2+
# runs on 2 ranks, showing traces with reasonable workloads
3+
RunOptions:
4+
world_size: 2
5+
batch_size: 16384
6+
num_batches: 10
7+
num_benchmarks: 1
8+
num_profiles: 1
9+
sharding_type: table_wise
10+
profile_dir: "."
11+
name: "sparse_data_dist_base"
12+
# export_stacks: True # enable this to export stack traces
13+
PipelineConfig:
14+
pipeline: "sparse"
15+
ModelInputConfig:
16+
feature_pooling_avg: 30
17+
use_variable_batch: True
18+
EmbeddingTablesConfig:
19+
num_unweighted_features: 90
20+
num_weighted_features: 80
21+
embedding_feature_dim: 256
22+
additional_tables:
23+
- - name: FP16_table
24+
embedding_dim: 512
25+
num_embeddings: 100_000
26+
feature_names: ["additional_0_0"]
27+
data_type: FP16
28+
- name: large_table
29+
embedding_dim: 2048
30+
num_embeddings: 1_000_000
31+
feature_names: ["additional_0_1"]
32+
- []
33+
- - name: skipped_table
34+
embedding_dim: 128
35+
num_embeddings: 100_000
36+
feature_names: ["additional_2_1"]
37+
PlannerConfig:
38+
additional_constraints:
39+
large_table:
40+
sharding_types: [column_wise]

torchrec/distributed/test_utils/input_config.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414
from torchrec.modules.embedding_configs import EmbeddingBagConfig
1515

16-
from .model_input import ModelInput
16+
from .model_input import ModelInput, VariableBatchModelInput
1717

1818

1919
@dataclass
@@ -30,6 +30,7 @@ class ModelInputConfig:
3030
long_kjt_offsets: bool = True
3131
long_kjt_lengths: bool = True
3232
pin_memory: bool = True
33+
use_variable_batch: bool = False
3334

3435
def generate_batches(
3536
self,
@@ -47,6 +48,29 @@ def generate_batches(
4748
"""
4849
device = torch.device(self.device) if self.device is not None else None
4950

51+
if self.use_variable_batch:
52+
return [
53+
VariableBatchModelInput.generate(
54+
batch_size=self.batch_size,
55+
num_float_features=self.num_float_features,
56+
tables=tables,
57+
weighted_tables=weighted_tables,
58+
use_offsets=self.use_offsets,
59+
indices_dtype=(
60+
torch.int64 if self.long_kjt_indices else torch.int32
61+
),
62+
offsets_dtype=(
63+
torch.int64 if self.long_kjt_offsets else torch.int32
64+
),
65+
lengths_dtype=(
66+
torch.int64 if self.long_kjt_lengths else torch.int32
67+
),
68+
device=device,
69+
pin_memory=self.pin_memory,
70+
)
71+
for _ in range(self.num_batches)
72+
]
73+
5074
return [
5175
ModelInput.generate(
5276
batch_size=self.batch_size,
@@ -61,5 +85,5 @@ def generate_batches(
6185
lengths_dtype=(torch.int64 if self.long_kjt_lengths else torch.int32),
6286
pin_memory=self.pin_memory,
6387
)
64-
for batch_size in range(self.num_batches)
88+
for _ in range(self.num_batches)
6589
]

0 commit comments

Comments
 (0)