Skip to content

Commit f642711

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
remove unnecessary params from reduce scatter awaitable (#3594)
Summary: tsia, cleaning up awaitable API Reviewed By: spmex Differential Revision: D88206952
1 parent 3a1d5f3 commit f642711

File tree

1 file changed

+0
-10
lines changed

1 file changed

+0
-10
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,25 +125,19 @@ def __init__(
125125
self,
126126
async_work: Optional[dist.Work],
127127
async_event: Optional[torch.cuda.Event],
128-
async_stream: Optional[torch.cuda.Stream],
129-
unsharded_param: torch.Tensor,
130128
shard_buf: torch.Tensor,
131129
resize_callback: Callable[[], None],
132130
) -> None:
133131
"""
134132
Args:
135133
async_work: The async reduce scatter work handle
136134
async_event: CUDA event to synchronize streams
137-
async_stream: The communication stream
138-
unsharded_param: The original unsharded parameter tensor
139135
shard_buf: The buffer containing the sharded result
140136
resize_callback: Callback to perform resize operation (called on wait())
141137
"""
142138
super().__init__()
143139
self._async_work = async_work
144140
self._async_event = async_event
145-
self._async_stream = async_stream
146-
self._unsharded_param = unsharded_param
147141
self._shard_buf = shard_buf
148142
self._resize_callback = resize_callback
149143
self._completed = False
@@ -2689,8 +2683,6 @@ def resize_callback() -> None:
26892683
return ReduceScatterResizeAwaitable(
26902684
async_work=self._async_work,
26912685
async_event=self._async_event,
2692-
async_stream=self._async_stream,
2693-
unsharded_param=self._unsharded_param,
26942686
shard_buf=self._shard_buf,
26952687
resize_callback=resize_callback,
26962688
)
@@ -3748,8 +3740,6 @@ def resize_callback() -> None:
37483740
return ReduceScatterResizeAwaitable(
37493741
async_work=self._async_work,
37503742
async_event=self._async_event,
3751-
async_stream=self._async_stream,
3752-
unsharded_param=self._unsharded_param,
37533743
shard_buf=self._shard_buf,
37543744
resize_callback=resize_callback,
37553745
)

0 commit comments

Comments
 (0)