|
13 | 13 | #include "type_traits.h" |
14 | 14 |
|
15 | 15 | namespace zuku { |
| 16 | +namespace { |
| 17 | + |
| 18 | +template <int N, typename T> |
| 19 | +RealmShape MakeTileBoxN(const std::vector<int64_t>& index, |
| 20 | + const Sharding& sharding) { |
| 21 | + Realm::Rect<N, T> bounds; |
| 22 | + int64_t box_dim = 0; |
| 23 | + for (int64_t s = 0; s < sharding.dims.size(); ++s) { |
| 24 | + const int64_t tile_size = sharding.dims[s].size / sharding.dims[s].sharding; |
| 25 | + // skip replicated dimension |
| 26 | + if (tile_size > 0) { |
| 27 | + bounds.lo[box_dim] = index[box_dim] * tile_size; |
| 28 | + bounds.hi[box_dim] = bounds.lo[box_dim] + tile_size - 1; |
| 29 | + ++box_dim; |
| 30 | + } |
| 31 | + } |
| 32 | + return bounds; |
| 33 | +} |
| 34 | + |
| 35 | +} // namespace |
| 36 | + |
| 37 | +RealmShape MakeTileBox(const std::vector<int64_t>& index, |
| 38 | + const Sharding& sharding) { |
| 39 | + const int64_t num_box_dims = [&] { |
| 40 | + if (sharding.IsPartiallyReplicated()) { |
| 41 | + return sharding.dims.size() - 1; |
| 42 | + } |
| 43 | + return sharding.dims.size(); |
| 44 | + }(); |
| 45 | + switch (num_box_dims) { |
| 46 | + case 0: |
| 47 | + return MakeTileBoxN<1, long long>(index, sharding); |
| 48 | + case 1: |
| 49 | + return MakeTileBoxN<1, long long>(index, sharding); |
| 50 | + case 2: |
| 51 | + return MakeTileBoxN<2, long long>(index, sharding); |
| 52 | + case 3: |
| 53 | + return MakeTileBoxN<3, long long>(index, sharding); |
| 54 | + case 4: |
| 55 | + return MakeTileBoxN<4, long long>(index, sharding); |
| 56 | +#if REALM_MAX_DIM >= 5 |
| 57 | + case 5: |
| 58 | + return MakeTileBoxN<5, long long>(index, sharding); |
| 59 | +#endif |
| 60 | + default: |
| 61 | + throw std::runtime_error("unsupported no. dims"); |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +ScatterGatherConfig GetScatterGatherConfig(int64_t dim, |
| 66 | + std::optional<int64_t> src_shard, |
| 67 | + std::optional<int64_t> dst_shard, |
| 68 | + const Sharding& src, |
| 69 | + const Sharding& dst) { |
| 70 | + const int64_t src_sharding = src.dims[dim].sharding; |
| 71 | + const int64_t dst_sharding = dst.dims[dim].sharding; |
| 72 | + |
| 73 | + ScatterGatherConfig sg_config; |
| 74 | + |
| 75 | + assert(src_sharding != dst_sharding); |
| 76 | + if (src_sharding > dst_sharding) { |
| 77 | + const int64_t num_slices_per_dest = |
| 78 | + src.dims[dim].sharding / dst.dims[dim].sharding; |
| 79 | + const int64_t slice_size = src.dims[dim].size / src.dims[dim].sharding; |
| 80 | + if (src_shard.has_value()) { |
| 81 | + std::vector<int64_t> index = ComputeTileIndex(*src_shard, src); |
| 82 | + auto bounds = MakeTileBox(index, src); |
| 83 | + const int64_t offset = *src_shard % num_slices_per_dest; |
| 84 | + index[dim] /= num_slices_per_dest; |
| 85 | + const int64_t dst_shard = ComputeTileShard(index, dst); |
| 86 | + SliceConfig config{.src_shard = *src_shard, |
| 87 | + .dst_shard = dst_shard, |
| 88 | + .bounds = std::move(bounds)}; |
| 89 | + sg_config.from_source.push_back(std::move(config)); |
| 90 | + } |
| 91 | + |
| 92 | + if (dst_shard.has_value()) { |
| 93 | + // we are going to receive from multiple sources |
| 94 | + sg_config.to_target.reserve(num_slices_per_dest); |
| 95 | + std::vector<int64_t> index = ComputeTileIndex(*dst_shard, dst); |
| 96 | + const int64_t src_shard_start = index[dim] * num_slices_per_dest; |
| 97 | + for (int64_t slice = 0; slice < num_slices_per_dest; ++slice) { |
| 98 | + index[dim] = src_shard_start + slice; |
| 99 | + auto bounds = MakeTileBox(index, src); |
| 100 | + SliceConfig config{ |
| 101 | + .src_shard = ComputeTileShard(index, src), |
| 102 | + .dst_shard = *dst_shard, |
| 103 | + .bounds = std::move(bounds), |
| 104 | + }; |
| 105 | + sg_config.to_target.push_back(std::move(config)); |
| 106 | + } |
| 107 | + } |
| 108 | + } else { // dst_sharding > src_sharding |
| 109 | + const int64_t num_slices_per_source = |
| 110 | + dst.dims[dim].sharding / src.dims[dim].sharding; |
| 111 | + const int64_t slice_size = dst.dims[dim].size / dst.dims[dim].sharding; |
| 112 | + if (dst_shard.has_value()) { |
| 113 | + std::vector<int64_t> index = ComputeTileIndex(*dst_shard, dst); |
| 114 | + auto bounds = MakeTileBox(index, dst); |
| 115 | + index[dim] /= num_slices_per_source; |
| 116 | + const int64_t src_shard = ComputeTileShard(index, src); |
| 117 | + SliceConfig config{ |
| 118 | + .src_shard = src_shard, |
| 119 | + .dst_shard = *dst_shard, |
| 120 | + .bounds = std::move(bounds), |
| 121 | + }; |
| 122 | + sg_config.to_target.push_back(std::move(config)); |
| 123 | + } |
| 124 | + |
| 125 | + if (src_shard.has_value()) { |
| 126 | + sg_config.from_source.reserve(num_slices_per_source); |
| 127 | + |
| 128 | + std::vector<int64_t> index = ComputeTileIndex(*src_shard, src); |
| 129 | + const int64_t dest_idx_start = index[dim] * num_slices_per_source; |
| 130 | + for (int64_t slice = 0; slice < num_slices_per_source; ++slice) { |
| 131 | + index[dim] = dest_idx_start + slice; |
| 132 | + auto bounds = MakeTileBox(index, dst); |
| 133 | + SliceConfig config{ |
| 134 | + .src_shard = *src_shard, |
| 135 | + .dst_shard = ComputeTileShard(index, dst), |
| 136 | + .bounds = std::move(bounds), |
| 137 | + }; |
| 138 | + sg_config.from_source.push_back(std::move(config)); |
| 139 | + } |
| 140 | + } |
| 141 | + } |
| 142 | + return sg_config; |
| 143 | +} |
16 | 144 |
|
17 | 145 | std::ostream& operator<<(std::ostream& os, const SliceConfig& config) { |
18 | 146 | os << "SliceConfig { .src_shard = " << config.src_shard |
|
0 commit comments