-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add support for Zero3 FSDP #3753
Description
In #3740, we added support for FullyShardedDataParallel, but limited implementation to that of Zero2, not Zero3. Zero3 results in substantial decreases of memory usage compared with Zero2 while bringing speed back in line with vanilla DDP.
We have already added support for this (via manual calls to wrap) within the Transformer modules, but we still cannot support Zero3. The main issue is that Zero3 assumes that every worker calls forward the exact same number of times, and performs a parameter-transfer during this forward (moving the sharded parameters to each worker just in time). ParlAI cannot provide this guarantee though because:
- During validation, each worker sees a variable number of examples. This is okay in itself, but it is problematic (hang) if it results in any worker having extra batches.
- During generation, workers will have a variable number of forwards due to the variable sequence length. While everything stays happy for a while, if one worker ends the run with needing more generations than the others, we will get hangs.
It seems far too difficult (and ugly) to try to force this equality in worlds.py or in our dataset sharding. So our best future bet is to implement something like .join() in vanilla DDP. It would work roughly as follows:
- Every worker in forward tries to synchronize a
Trueboolean saying "Am I doing a true forward?" - Upon
__exit__of the context, workers enter an infinite loop where they sync aFalseboolean. As long as any worker is providing aTruevalue, they participate in a dummy batch forward. - When all workers agree on the
Falseboolean, we can end the infinite loop.
This feature makes the most sense to implement upstream in Fairscale, and then integrate into ParlAI.