Skip to content

Commit

Permalink
Extend isResharding to allow DID loop split. (#3421)
Browse files Browse the repository at this point in the history
For #2563 

Host latency benchmarks are neutral: 

```
$ pytest-benchmark compare 0003 0004 --group-by=name

------------------------------------------------------------------------------- benchmark "test_adaptive_layernorm_fwd_benchmark[host_bench_mode='compile']": 2 tests --------------------------------------------------------------------------------
Name (time in ms)                                                                        Min                 Max                Mean            StdDev              Median               IQR            Outliers     OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='compile'] (0003_96a8efb)     170.7932 (1.0)      175.0024 (1.0)      173.0403 (1.0)      1.2632 (1.0)      173.0455 (1.0)      1.7885 (1.0)           4;0  5.7790 (1.0)          10           1
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='compile'] (0004_24e1b90)     170.9341 (1.00)     175.2137 (1.00)     173.4435 (1.00)     1.4717 (1.17)     173.6018 (1.00)     1.9277 (1.08)          3;0  5.7656 (1.00)         10           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------------- benchmark "test_adaptive_layernorm_fwd_benchmark[host_bench_mode='dynamic']": 2 tests -----------------------------------------------------------------------------------
Name (time in us)                                                                        Min                 Max                Mean             StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='dynamic'] (0003_96a8efb)     114.1470 (1.0)      157.7900 (1.27)     124.2067 (1.05)     17.6455 (5.83)     115.8945 (1.0)      3.3760 (1.0)           2;2        8.0511 (0.96)         10           1
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='dynamic'] (0004_24e1b90)     114.8980 (1.01)     123.9660 (1.0)      118.6434 (1.0)       3.0273 (1.0)      118.6260 (1.02)     5.5200 (1.64)          3;0        8.4286 (1.0)          10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

--------------------------------------------------------------------------------- benchmark "test_adaptive_layernorm_fwd_benchmark[host_bench_mode='steady']": 2 tests --------------------------------------------------------------------------------
Name (time in us)                                                                      Min                Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='steady'] (0003_96a8efb)     39.5550 (1.0)      50.2060 (1.13)     42.4156 (1.02)     3.4449 (2.72)     40.8575 (1.0)      4.3180 (4.40)          2;0       23.5762 (0.98)         10           1
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='steady'] (0004_24e1b90)     39.9650 (1.01)     44.5350 (1.0)      41.3924 (1.0)      1.2683 (1.0)      41.1580 (1.01)     0.9820 (1.0)           2;1       24.1590 (1.0)          10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------- benchmark "test_many_segment_benchmark[host_bench_mode='compile']": 2 tests --------------------------------------------------------------------------------
Name (time in ms)                                                              Min                 Max                Mean            StdDev              Median               IQR            Outliers     OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_many_segment_benchmark[host_bench_mode='compile'] (0003_96a8efb)     208.3781 (1.0)      213.6772 (1.0)      211.2823 (1.0)      1.3894 (1.0)      211.2441 (1.0)      0.9576 (1.0)           2;2  4.7330 (1.0)          10           1
test_many_segment_benchmark[host_bench_mode='compile'] (0004_24e1b90)     212.2643 (1.02)     224.4728 (1.05)     216.4410 (1.02)     3.3363 (2.40)     216.1582 (1.02)     2.8121 (2.94)          2;1  4.6202 (0.98)         10           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------------- benchmark "test_many_segment_benchmark[host_bench_mode='dynamic']": 2 tests ------------------------------------------------------------------------------------
Name (time in us)                                                              Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_many_segment_benchmark[host_bench_mode='dynamic'] (0003_96a8efb)     500.2300 (1.0)      626.6390 (1.11)     548.1104 (1.03)     54.3243 (2.19)     510.5340 (1.0)      97.5850 (1.95)          3;0        1.8244 (0.97)         10           1
test_many_segment_benchmark[host_bench_mode='dynamic'] (0004_24e1b90)     500.3700 (1.00)     563.3790 (1.0)      530.3696 (1.0)      24.7773 (1.0)      530.2765 (1.04)     50.0760 (1.0)           5;0        1.8855 (1.0)          10           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------------- benchmark "test_many_segment_benchmark[host_bench_mode='steady']": 2 tests ----------------------------------------------------------------------------------
Name (time in us)                                                             Min                 Max                Mean            StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_many_segment_benchmark[host_bench_mode='steady'] (0003_96a8efb)     174.8920 (1.0)      193.9480 (1.00)     179.0197 (1.0)      5.6300 (1.08)     176.5750 (1.0)      3.8570 (1.0)           1;1        5.5860 (1.0)          10           1
test_many_segment_benchmark[host_bench_mode='steady'] (0004_24e1b90)     175.0420 (1.00)     193.0060 (1.0)      181.0727 (1.01)     5.2312 (1.0)      180.9635 (1.02)     5.4300 (1.41)          2;1        5.5226 (0.99)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=16]": 2 tests --------------------------------------------------------------------------------
Name (time in ms)                                                                            Min                 Max                Mean            StdDev              Median               IQR            Outliers     OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=16] (0003_96a8efb)     274.7113 (1.0)      278.1071 (1.0)      276.5300 (1.0)      1.1733 (1.0)      276.5168 (1.0)      2.3361 (1.73)          5;0  3.6162 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=16] (0004_24e1b90)     274.8315 (1.00)     280.0985 (1.01)     277.2161 (1.00)     1.3838 (1.18)     276.9688 (1.00)     1.3489 (1.0)           2;1  3.6073 (1.00)         10           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------ benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=2]": 2 tests ------------------------------------------------------------------------------
Name (time in ms)                                                                          Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=2] (0004_24e1b90)     68.5291 (1.0)      69.5794 (1.0)      69.0360 (1.0)      0.4227 (1.0)      69.0748 (1.0)      0.7587 (1.0)           3;0  14.4852 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=2] (0003_96a8efb)     70.4157 (1.03)     72.0312 (1.04)     71.4535 (1.04)     0.5711 (1.35)     71.5552 (1.04)     0.9860 (1.30)          2;0  13.9951 (0.97)         10           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------ benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=4]": 2 tests ------------------------------------------------------------------------------
Name (time in ms)                                                                          Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=4] (0004_24e1b90)     91.7454 (1.0)      93.9708 (1.0)      93.2835 (1.0)      0.7899 (1.0)      93.5256 (1.0)      0.8339 (1.0)           2;1  10.7200 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=4] (0003_96a8efb)     94.9294 (1.03)     97.1430 (1.03)     95.9882 (1.03)     0.8550 (1.08)     95.9445 (1.03)     1.5375 (1.84)          4;0  10.4179 (0.97)         10           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=8]": 2 tests -------------------------------------------------------------------------------
Name (time in ms)                                                                           Min                 Max                Mean            StdDev              Median               IQR            Outliers     OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=8] (0004_24e1b90)     147.1260 (1.0)      148.6477 (1.0)      147.7856 (1.0)      0.5382 (1.0)      147.7964 (1.0)      0.9269 (1.0)           3;0  6.7666 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=8] (0003_96a8efb)     148.2082 (1.01)     151.0851 (1.02)     149.7864 (1.01)     0.8852 (1.64)     149.8612 (1.01)     0.9669 (1.04)          3;0  6.6762 (0.99)         10           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16]": 2 tests ----------------------------------------------------------------------------------
Name (time in us)                                                                           Min                 Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16] (0003_96a8efb)     73.6900 (1.0)      114.9290 (1.39)     86.4882 (1.13)     13.3314 (4.96)     81.0935 (1.07)     18.4040 (9.83)          2;0       11.5623 (0.89)         10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16] (0004_24e1b90)     73.8800 (1.00)      82.7570 (1.0)      76.5963 (1.0)       2.6876 (1.0)      75.8490 (1.0)       1.8720 (1.0)           3;1       13.0555 (1.0)          10           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                          Min                Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2] (0004_24e1b90)     63.3200 (1.0)      72.0470 (1.0)      68.0081 (1.0)       2.8500 (1.0)      67.4285 (1.0)       4.7200 (1.0)           4;0       14.7041 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2] (0003_96a8efb)     65.8250 (1.04)     92.5460 (1.28)     74.4402 (1.09)     10.3170 (3.62)     69.5170 (1.03)     18.0350 (3.82)          3;0       13.4336 (0.91)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                          Min                Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4] (0004_24e1b90)     64.0220 (1.0)      78.1180 (1.0)      68.6502 (1.0)       3.6925 (1.0)      68.1590 (1.0)       0.9820 (1.0)           2;3       14.5666 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4] (0003_96a8efb)     67.1480 (1.05)     96.8140 (1.24)     77.4761 (1.13)     10.4931 (2.84)     73.0485 (1.07)     18.8960 (19.24)         3;0       12.9072 (0.89)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                          Min                Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8] (0004_24e1b90)     68.3790 (1.0)      76.8860 (1.0)      71.7751 (1.0)       2.7013 (1.0)      71.4705 (1.0)       2.7150 (1.0)           4;1       13.9324 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8] (0003_96a8efb)     70.9040 (1.04)     98.3670 (1.28)     80.6828 (1.12)     11.6885 (4.33)     75.1170 (1.05)     23.9160 (8.81)          3;0       12.3942 (0.89)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

--------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=16]": 2 tests --------------------------------------------------------------------------------
Name (time in us)                                                                          Min                Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=16] (0003_96a8efb)     23.9860 (1.0)      33.7140 (1.11)     26.1107 (1.01)     2.8895 (1.61)     24.8575 (1.0)      1.8430 (2.42)          1;1       38.2985 (0.99)         10           1
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=16] (0004_24e1b90)     24.7370 (1.03)     30.4480 (1.0)      25.8039 (1.0)      1.7968 (1.0)      25.0880 (1.01)     0.7610 (1.0)           1;2       38.7538 (1.0)          10           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=2]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                         Min                Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=2] (0003_96a8efb)     23.9560 (1.0)      32.4720 (1.09)     26.1016 (1.01)     2.9284 (1.72)     24.9875 (1.0)      1.4230 (1.48)          2;2       38.3118 (0.99)         10           1
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=2] (0004_24e1b90)     24.4160 (1.02)     29.8470 (1.0)      25.7449 (1.0)      1.7051 (1.0)      25.1225 (1.01)     0.9620 (1.0)           2;2       38.8426 (1.0)          10           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=4]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                         Min                Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=4] (0003_96a8efb)     23.3650 (1.0)      30.4380 (1.08)     25.1518 (1.0)      1.9443 (1.76)     24.6120 (1.0)      0.7120 (1.13)          1;1       39.7586 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=4] (0004_24e1b90)     24.2360 (1.04)     28.1530 (1.0)      25.2177 (1.00)     1.1044 (1.0)      24.8370 (1.01)     0.6320 (1.0)           1;1       39.6547 (1.00)         10           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=8]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                         Min                Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=8] (0003_96a8efb)     24.1860 (1.0)      33.0630 (1.17)     25.4774 (1.0)      2.7047 (2.00)     24.5365 (1.0)      1.0620 (1.0)           1;1       39.2505 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=8] (0004_24e1b90)     24.2860 (1.00)     28.3440 (1.0)      25.6967 (1.01)     1.3499 (1.0)      25.1020 (1.02)     2.0040 (1.89)          3;0       38.9155 (0.99)         10           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
```
  • Loading branch information
wujingyue authored Nov 25, 2024
1 parent 96a8efb commit c4a0335
Show file tree
Hide file tree
Showing 9 changed files with 366 additions and 66 deletions.
2 changes: 1 addition & 1 deletion csrc/id_model/id_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ StatefulInliningInfo buildStatefulInliningInfo(
// Map all iteration domains
// Always contain root mappings (otherwise they could have been forwarded in
// broadcast)
// IdMappingMode::AlmostExact
// IdMappingMode::ALMOSTEXACT
// Forward through broadcast axes, but not through to a non-broadcast axis
// i.e. id{b1*i0}, id{i0} are mapped
// id{i1*i0}, id{i0} are not mapped (this part is the difference from
Expand Down
4 changes: 4 additions & 0 deletions csrc/multidevice/device_mesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class DeviceMesh final {
return vector_ == other.vector();
}

bool operator!=(const DeviceMesh& other) const {
return vector_ != other.vector();
}

private:
void setDevices(std::vector<DeviceIdxType> devices);

Expand Down
1 change: 0 additions & 1 deletion csrc/multidevice/lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ std::vector<Communication*> lowerCommunication(Expr* c) {
c);
auto* input_tv = c->input(0)->as<TensorView>();
auto* output_tv = c->output(0)->as<TensorView>();
at::Tensor dummy;

const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
Expand Down
185 changes: 158 additions & 27 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges

bool isSharded(const TensorView* tv) {
bool is_sharded = false;
const auto& logical_ids = TensorDomain::noReductions(tv->getLogicalDomain());
const auto& loop_ids = TensorDomain::noReductions(tv->getLoopDomain());
for (auto i : c10::irange(loop_ids.size())) {
if (!loop_ids[i]->isDeviceDim()) {
for (IterDomain* id : TensorDomain::noReductions(tv->getLoopDomain())) {
if (!id->isDeviceDim()) {
continue;
}

Expand All @@ -118,14 +116,6 @@ bool isSharded(const TensorView* tv) {
!is_sharded,
"Multiple IterDomains parallelized on DIDx in TensorView ",
tv);

// Currently do not support split/merge on a device dimension.
NVF_ERROR(
std::find(logical_ids.begin(), logical_ids.end(), loop_ids[i]) !=
logical_ids.end(),
"Cannot parallelize DIDx on a split/merge axis ",
loop_ids[i]);

is_sharded = true;
}
return is_sharded;
Expand All @@ -138,42 +128,179 @@ int64_t numDeviceDims(const TensorView* tv) {
[](IterDomain* id) { return id->isDeviceDim(); });
}

namespace {
// Collect device-parallel IterDomains in `loop_domain` and return them as a
// ParallelType-to-IterDomain map.
std::unordered_map<ParallelType, IterDomain*> mapParallelTypeToId(
const std::vector<IterDomain*>& loop_domain) {
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id;
parallel_type_to_id.reserve(kParallelTypeDIDs.size());
for (IterDomain* loop_id : loop_domain) {
const ParallelType parallel_type = loop_id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

NVF_ERROR(
parallel_type_to_id.try_emplace(parallel_type, loop_id).second,
"Found multiple loop IterDomains with the same parallel type (",
parallel_type,
"): ",
toDelimitedString(loop_domain));
}
return parallel_type_to_id;
}

std::vector<IterDomain*> getInputsInTargetDomain(
IterDomain* loop_id,
const std::vector<IterDomain*>& target_domain) {
const std::vector<Val*> inputs_as_vals = IterVisitor::getInputsTo(
{loop_id}, {target_domain.begin(), target_domain.end()});

std::vector<IterDomain*> inputs_as_iter_domains;
inputs_as_iter_domains.reserve(inputs_as_vals.size());
std::transform(
inputs_as_vals.begin(),
inputs_as_vals.end(),
std::back_inserter(inputs_as_iter_domains),
[](Val* val) { return val->as<IterDomain>(); });
return inputs_as_iter_domains;
}

bool overlaps(
const std::vector<IterDomain*>& a,
const std::unordered_set<IterDomain*>& b) {
return std::any_of(
a.begin(), a.end(), [&](IterDomain* id) { return b.count(id); });
}

} // namespace

bool haveDifferentShardings(
const TensorView* producer,
const TensorView* consumer) {
const TensorView* consumer,
const IdModel& id_model) {
// cpu scalars are not required to have a mesh
if (producer->isCpuScalar() || consumer->isCpuScalar()) {
return false;
}

// exit early in the unsharded case for performance
if (!producer->hasDeviceMesh() && !consumer->hasDeviceMesh()) {
return false;
}

// If device mesh are different, the Expr is resharding
if (!(producer->getDeviceMesh() == consumer->getDeviceMesh())) {
if (producer->getDeviceMesh() != consumer->getDeviceMesh()) {
return true;
}
// Create a map between producer's and consumer's IterDomains. We iterate
// over producer's iterdomain and compare sharding type with consumer's
// iterdomain

// The rest of this function tries to do the following: for each pair of
// logical-domain-mapped IterDomains (i.e. those mapped by
// PairwiseLogicalDomainMap), check if they are sharded consistently. If not,
// returns true. For example,
//
// a: iDIDx{M}, iK
// b: iK, iDIDy{N}
// c = matmul(a, b): iDIDx{M}, iDIDy{N}
//
// haveDifferentShardings(a, c) only cares about iM, which is
// logical-domain-mapped, but not iK or iN, which are not
// logical-domain-mapped.
//
// One challenge is that DID parallelization doesn't always
// happen on the root/logical IterDomains. For example, a root/logical
// IterDomain may be outer-split by the number of devices, and only the outer
// split gets parallelized on DID.
//
// logical: iM
// loop: iDIDx{D}, iM/D
//
// Therefore, we collect all the loop IterDomains that depend on the
// logical-domain-mapped IterDomains, and check if they are DID-parallelized
// consistently.
const std::unordered_map<IterDomain*, IterDomain*>& p2c =
PairwiseLogicalDomainMap(producer, consumer).mapProducerToConsumer();
for (auto p_id : producer->getLogicalDomain()) {
const auto i = p2c.find(p_id);
std::unordered_set<IterDomain*> mapped_p_logical_ids;
mapped_p_logical_ids.reserve(p2c.size());
std::unordered_set<IterDomain*> mapped_c_root_ids;
mapped_c_root_ids.reserve(p2c.size());
for (IterDomain* p_logical_id : producer->getLogicalDomain()) {
const auto i = p2c.find(p_logical_id);
if (i == p2c.end()) {
// This happens e.g. when `p_id` is squeezed or is a product of a
// reduction. Even if `p_id` is parallelized on DID, the dimension is
// size-1 and doesn't trigger resharding.
// This happens e.g. when `p_logical_id` is squeezed or is a product of a
// reduction. Even if `p_logical_id` is parallelized on DID, the
// dimension is size-1 and doesn't trigger resharding.
continue;
}
mapped_p_logical_ids.insert(p_logical_id);
mapped_c_root_ids.insert(i->second);
}

// In practice, only loop IterDomains can be parallelized, and no two loop
// IterDomains in a TensorView can have the same parallel type. Therefore, we
// do the check in reverse order for efficiency and simplicity:
// 1. For each DID parallel type, find the loop IterDomain in producer and the
// one in consumer that have the type.
// 2. Find what IterDomains they come from in producer's logical or
// consumer's root domain. If that input IterDomain is not
// logical-domain-mapped, treat the loop IterDomain as not existing -- it is
// parallelized but just not a concern for this producer-consumer pair.
// 3. Check if the two loop IterDomains are almost-exactly mapped in the
// IdModel.
std::unordered_map<ParallelType, IterDomain*> p_parallel_type_to_id =
mapParallelTypeToId(producer->getLoopDomain());
std::unordered_map<ParallelType, IterDomain*> c_parallel_type_to_id =
mapParallelTypeToId(consumer->getLoopDomain());

for (const auto parallel_type : kParallelTypeDIDs) {
IterDomain* p_loop_id = getOrDefault(p_parallel_type_to_id, parallel_type);
if (p_loop_id != nullptr) {
auto p_inputs =
getInputsInTargetDomain(p_loop_id, producer->getLogicalDomain());
if (!overlaps(p_inputs, mapped_p_logical_ids)) {
p_loop_id = nullptr;
}
}

IterDomain* c_loop_id = getOrDefault(c_parallel_type_to_id, parallel_type);
if (c_loop_id != nullptr) {
auto c_inputs =
getInputsInTargetDomain(c_loop_id, consumer->getMaybeRootDomain());
if (!overlaps(c_inputs, mapped_c_root_ids)) {
c_loop_id = nullptr;
}
}

auto is_mapped_in_id_model =
[](IterDomain* a, IterDomain* b, const IdModel& id_model) -> bool {
if (a == nullptr && b == nullptr) {
return true;
}

auto c_id = i->second;
if (p_id->getParallelType() != c_id->getParallelType() &&
(p_id->isDeviceDim() || c_id->isDeviceDim())) {
// Mismatch found
if (a == nullptr || b == nullptr) {
return false;
}

// Going between bDIDx{1} and iDIDx{N} doesn't trigger resharding, but
// would be flagged by ALMOSTEXACT as a false positive.
if (id_model.idGraph(IdMappingMode::BROADCAST)
.disjointValSets()
.strictAreMapped(a, b)) {
return true;
}

// Check ALMOSTEXACT so iDIDx{N}*b{1} and iDIDx{N} are mapped.
return id_model.idGraph(IdMappingMode::ALMOSTEXACT)
.disjointValSets()
.strictAreMapped(a, b);
};

if (!is_mapped_in_id_model(p_loop_id, c_loop_id, id_model)) {
return true;
}
}

return false;
}

Expand All @@ -184,12 +311,16 @@ bool isResharding(const Expr* expr) {
return false;
}

// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
IdModel id_model({const_cast<Expr*>(expr)}, {}, false, false);
id_model.buildAlmostExactGraph();
id_model.buildBroadcastGraph();
// We don't use getTvsWithDifferentSharding because it creates a computeAtMap,
// which is too costly
for (auto* input : ir_utils::filterByType<TensorView>(expr->inputs())) {
for (auto* output : ir_utils::filterByType<TensorView>(expr->outputs())) {
// exit early in the unsharded case for performance
if (haveDifferentShardings(input, output)) {
if (haveDifferentShardings(input, output, id_model)) {
return true;
}
}
Expand Down
4 changes: 3 additions & 1 deletion csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <compute_at_map.h>
#include <fusion.h>
#include <id_model/id_model.h>
#include <ir/interface_nodes.h>
#include <multidevice/multidevice.h>
#include <visibility.h>
Expand Down Expand Up @@ -81,7 +82,8 @@ bool isResharding(const Expr* expr);
// producer/consumer relationship between the arguments.
bool haveDifferentShardings(
const TensorView* producer,
const TensorView* consumer);
const TensorView* consumer,
const IdModel& id_model);

// Returns whether a resharding expr reshards an inner axis
bool isInnerResharding(Expr* expr);
Expand Down
23 changes: 13 additions & 10 deletions csrc/preseg_passes/insert_reshardings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <fusion.h>
#include <ir/base_nodes.h>
#include <ir/interface_nodes.h>
#include <ir/iostream.h>
#include <ir/utils.h>
#include <multidevice/lower_communication.h>
#include <multidevice/utils.h>
Expand Down Expand Up @@ -38,14 +39,10 @@ void insertReshardingsBefore(Fusion* fusion) {

// Verify that multi-output expression requires no resharding.
if (expr->outputs().size() > 1) {
for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) {
NVF_CHECK(
!haveDifferentShardings(input, output),
"Cannot handle resharding a multi-output expression ",
expr->toString());
}
}
NVF_CHECK(
!isResharding(expr),
"Cannot handle resharding a multi-output expression: ",
expr);
continue;
}

Expand All @@ -55,8 +52,11 @@ void insertReshardingsBefore(Fusion* fusion) {
auto output = expr->output(0)->as<TensorView>();

std::unordered_set<TensorView*> inputs;
IdModel id_model({expr}, {}, false, false);
id_model.buildAlmostExactGraph();
id_model.buildBroadcastGraph();
for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) {
if (haveDifferentShardings(input, output)) {
if (haveDifferentShardings(input, output, id_model)) {
inputs.insert(input);
}
}
Expand Down Expand Up @@ -95,8 +95,11 @@ void insertReshardingsAfter(Fusion* fusion) {
auto output = expr->output(0)->as<TensorView>();

std::unordered_set<TensorView*> inputs;
IdModel id_model({expr}, {}, false, false);
id_model.buildAlmostExactGraph();
id_model.buildBroadcastGraph();
for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) {
if (haveDifferentShardings(input, output)) {
if (haveDifferentShardings(input, output, id_model)) {
inputs.insert(input);
}
}
Expand Down
3 changes: 3 additions & 0 deletions csrc/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,9 @@ static constexpr std::array<ParallelType, 3> kParallelTypeTIDs = {
ParallelType::TIDy,
ParallelType::TIDz};

static constexpr std::array<ParallelType, 1> kParallelTypeDIDs = {
ParallelType::DIDx};

enum class MemoryType { Local, Shared, Global };

// Symbolic: Undetermined between Iteration or Broadcast
Expand Down
Loading

0 comments on commit c4a0335

Please sign in to comment.