@@ -125,25 +125,19 @@ def __init__(
125125 self ,
126126 async_work : Optional [dist .Work ],
127127 async_event : Optional [torch .cuda .Event ],
128- async_stream : Optional [torch .cuda .Stream ],
129- unsharded_param : torch .Tensor ,
130128 shard_buf : torch .Tensor ,
131129 resize_callback : Callable [[], None ],
132130 ) -> None :
133131 """
134132 Args:
135133 async_work: The async reduce scatter work handle
136134 async_event: CUDA event to synchronize streams
137- async_stream: The communication stream
138- unsharded_param: The original unsharded parameter tensor
139135 shard_buf: The buffer containing the sharded result
140136 resize_callback: Callback to perform resize operation (called on wait())
141137 """
142138 super ().__init__ ()
143139 self ._async_work = async_work
144140 self ._async_event = async_event
145- self ._async_stream = async_stream
146- self ._unsharded_param = unsharded_param
147141 self ._shard_buf = shard_buf
148142 self ._resize_callback = resize_callback
149143 self ._completed = False
@@ -2689,8 +2683,6 @@ def resize_callback() -> None:
26892683 return ReduceScatterResizeAwaitable (
26902684 async_work = self ._async_work ,
26912685 async_event = self ._async_event ,
2692- async_stream = self ._async_stream ,
2693- unsharded_param = self ._unsharded_param ,
26942686 shard_buf = self ._shard_buf ,
26952687 resize_callback = resize_callback ,
26962688 )
@@ -3748,8 +3740,6 @@ def resize_callback() -> None:
37483740 return ReduceScatterResizeAwaitable (
37493741 async_work = self ._async_work ,
37503742 async_event = self ._async_event ,
3751- async_stream = self ._async_stream ,
3752- unsharded_param = self ._unsharded_param ,
37533743 shard_buf = self ._shard_buf ,
37543744 resize_callback = resize_callback ,
37553745 )
0 commit comments