Skip to content

Commit cb88b56

Browse files
committed
release-v0.2
1 parent 24b9e1f commit cb88b56

File tree

6 files changed

+189
-122
lines changed

6 files changed

+189
-122
lines changed

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ cmake_minimum_required(VERSION 3.22.1 FATAL_ERROR)
88
set(CMAKE_CXX_STANDARD 17)
99
set(CMAKE_CXX_STANDARD_REQUIRED ON)
1010

11+
set(CMAKE_C_FLAGS_ASAN "${CMAKE_C_FLAGS_RELWITHDEBINFO} -fsanitize=address -fno-omit-frame-pointer -fno-common" CACHE STRING "" FORCE)
12+
set(CMAKE_CXX_FLAGS_ASAN "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -fsanitize=address -fno-omit-frame-pointer -fno-common" CACHE STRING "" FORCE)
13+
set(CMAKE_EXE_LINKER_FLAGS_ASAN "${CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO} -fsanitize=address" CACHE STRING "" FORCE)
14+
set(CMAKE_SHARED_LINKER_FLAGS_ASAN "${CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO} -fsanitize=address" CACHE STRING "" FORCE)
15+
16+
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo" "ASAN")
17+
1118
# ########################################################################################
1219
# * Download and initialize RAPIDS CMake helpers -----------------------------
1320

src/zuku/reshard.cc

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -94,30 +94,8 @@ ReshardConfigVariant DetermineScatterGatherConfig(const Processor& p,
9494
return std::nullopt;
9595
}();
9696

97-
switch (src.dims.size()) {
98-
case 0:
99-
return GetScatterGatherConfig<1, long long>(
100-
last_dim_different, src_shard, dst_shard, src, dst);
101-
case 1:
102-
return GetScatterGatherConfig<1, long long>(
103-
last_dim_different, src_shard, dst_shard, src, dst);
104-
case 2:
105-
return GetScatterGatherConfig<2, long long>(
106-
last_dim_different, src_shard, dst_shard, src, dst);
107-
case 3:
108-
return GetScatterGatherConfig<3, long long>(
109-
last_dim_different, src_shard, dst_shard, src, dst);
110-
case 4:
111-
return GetScatterGatherConfig<4, long long>(
112-
last_dim_different, src_shard, dst_shard, src, dst);
113-
#if REALM_MAX_DIM >= 5
114-
case 5:
115-
return GetScatterGatherConfig<5, long long>(
116-
last_dim_different, src_shard, dst_shard, src, dst);
117-
#endif
118-
default:
119-
throw std::runtime_error("unsupported no. dims");
120-
}
97+
return GetScatterGatherConfig(last_dim_different, src_shard, dst_shard, src,
98+
dst);
12199
}
122100
return UnsupportedConfig{};
123101
}

src/zuku/reshard_scatter_gather.cc

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,134 @@
1313
#include "type_traits.h"
1414

1515
namespace zuku {
16+
namespace {
17+
18+
template <int N, typename T>
19+
RealmShape MakeTileBoxN(const std::vector<int64_t>& index,
20+
const Sharding& sharding) {
21+
Realm::Rect<N, T> bounds;
22+
int64_t box_dim = 0;
23+
for (int64_t s = 0; s < sharding.dims.size(); ++s) {
24+
const int64_t tile_size = sharding.dims[s].size / sharding.dims[s].sharding;
25+
// skip replicated dimension
26+
if (tile_size > 0) {
27+
bounds.lo[box_dim] = index[box_dim] * tile_size;
28+
bounds.hi[box_dim] = bounds.lo[box_dim] + tile_size - 1;
29+
++box_dim;
30+
}
31+
}
32+
return bounds;
33+
}
34+
35+
} // namespace
36+
37+
RealmShape MakeTileBox(const std::vector<int64_t>& index,
38+
const Sharding& sharding) {
39+
const int64_t num_box_dims = [&] {
40+
if (sharding.IsPartiallyReplicated()) {
41+
return sharding.dims.size() - 1;
42+
}
43+
return sharding.dims.size();
44+
}();
45+
switch (num_box_dims) {
46+
case 0:
47+
return MakeTileBoxN<1, long long>(index, sharding);
48+
case 1:
49+
return MakeTileBoxN<1, long long>(index, sharding);
50+
case 2:
51+
return MakeTileBoxN<2, long long>(index, sharding);
52+
case 3:
53+
return MakeTileBoxN<3, long long>(index, sharding);
54+
case 4:
55+
return MakeTileBoxN<4, long long>(index, sharding);
56+
#if REALM_MAX_DIM >= 5
57+
case 5:
58+
return MakeTileBoxN<5, long long>(index, sharding);
59+
#endif
60+
default:
61+
throw std::runtime_error("unsupported no. dims");
62+
}
63+
}
64+
65+
ScatterGatherConfig GetScatterGatherConfig(int64_t dim,
66+
std::optional<int64_t> src_shard,
67+
std::optional<int64_t> dst_shard,
68+
const Sharding& src,
69+
const Sharding& dst) {
70+
const int64_t src_sharding = src.dims[dim].sharding;
71+
const int64_t dst_sharding = dst.dims[dim].sharding;
72+
73+
ScatterGatherConfig sg_config;
74+
75+
assert(src_sharding != dst_sharding);
76+
if (src_sharding > dst_sharding) {
77+
const int64_t num_slices_per_dest =
78+
src.dims[dim].sharding / dst.dims[dim].sharding;
79+
const int64_t slice_size = src.dims[dim].size / src.dims[dim].sharding;
80+
if (src_shard.has_value()) {
81+
std::vector<int64_t> index = ComputeTileIndex(*src_shard, src);
82+
auto bounds = MakeTileBox(index, src);
83+
const int64_t offset = *src_shard % num_slices_per_dest;
84+
index[dim] /= num_slices_per_dest;
85+
const int64_t dst_shard = ComputeTileShard(index, dst);
86+
SliceConfig config{.src_shard = *src_shard,
87+
.dst_shard = dst_shard,
88+
.bounds = std::move(bounds)};
89+
sg_config.from_source.push_back(std::move(config));
90+
}
91+
92+
if (dst_shard.has_value()) {
93+
// we are going to receive from multiple sources
94+
sg_config.to_target.reserve(num_slices_per_dest);
95+
std::vector<int64_t> index = ComputeTileIndex(*dst_shard, dst);
96+
const int64_t src_shard_start = index[dim] * num_slices_per_dest;
97+
for (int64_t slice = 0; slice < num_slices_per_dest; ++slice) {
98+
index[dim] = src_shard_start + slice;
99+
auto bounds = MakeTileBox(index, src);
100+
SliceConfig config{
101+
.src_shard = ComputeTileShard(index, src),
102+
.dst_shard = *dst_shard,
103+
.bounds = std::move(bounds),
104+
};
105+
sg_config.to_target.push_back(std::move(config));
106+
}
107+
}
108+
} else { // dst_sharding > src_sharding
109+
const int64_t num_slices_per_source =
110+
dst.dims[dim].sharding / src.dims[dim].sharding;
111+
const int64_t slice_size = dst.dims[dim].size / dst.dims[dim].sharding;
112+
if (dst_shard.has_value()) {
113+
std::vector<int64_t> index = ComputeTileIndex(*dst_shard, dst);
114+
auto bounds = MakeTileBox(index, dst);
115+
index[dim] /= num_slices_per_source;
116+
const int64_t src_shard = ComputeTileShard(index, src);
117+
SliceConfig config{
118+
.src_shard = src_shard,
119+
.dst_shard = *dst_shard,
120+
.bounds = std::move(bounds),
121+
};
122+
sg_config.to_target.push_back(std::move(config));
123+
}
124+
125+
if (src_shard.has_value()) {
126+
sg_config.from_source.reserve(num_slices_per_source);
127+
128+
std::vector<int64_t> index = ComputeTileIndex(*src_shard, src);
129+
const int64_t dest_idx_start = index[dim] * num_slices_per_source;
130+
for (int64_t slice = 0; slice < num_slices_per_source; ++slice) {
131+
index[dim] = dest_idx_start + slice;
132+
auto bounds = MakeTileBox(index, dst);
133+
SliceConfig config{
134+
.src_shard = *src_shard,
135+
.dst_shard = ComputeTileShard(index, dst),
136+
.bounds = std::move(bounds),
137+
};
138+
sg_config.from_source.push_back(std::move(config));
139+
}
140+
}
141+
}
142+
return sg_config;
143+
}
16144

17145
std::ostream& operator<<(std::ostream& os, const SliceConfig& config) {
18146
os << "SliceConfig { .src_shard = " << config.src_shard

src/zuku/reshard_scatter_gather.h

Lines changed: 4 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -33,107 +33,14 @@ struct ScatterGatherConfig {
3333

3434
std::ostream& operator<<(std::ostream& os, const ScatterGatherConfig& config);
3535

36-
template <int N, typename T>
36+
RealmShape MakeTileBox(const std::vector<int64_t>& index,
37+
const Sharding& sharding);
38+
3739
ScatterGatherConfig GetScatterGatherConfig(int64_t dim,
3840
std::optional<int64_t> src_shard,
3941
std::optional<int64_t> dst_shard,
4042
const Sharding& src,
41-
const Sharding& dst) {
42-
const int64_t src_sharding = src.dims[dim].sharding;
43-
const int64_t dst_sharding = dst.dims[dim].sharding;
44-
45-
ScatterGatherConfig sg_config;
46-
47-
assert(src_sharding != dst_sharding);
48-
if (src_sharding > dst_sharding) {
49-
const int64_t num_slices_per_dest =
50-
src.dims[dim].sharding / dst.dims[dim].sharding;
51-
const int64_t slice_size = src.dims[dim].size / src.dims[dim].sharding;
52-
if (src_shard.has_value()) {
53-
std::vector<int64_t> index = ComputeTileIndex(*src_shard, src);
54-
Realm::Rect<N, T> bounds;
55-
for (int64_t i = 0; i < N; ++i) {
56-
const int64_t tile_size = src.dims[i].size / src.dims[i].sharding;
57-
bounds.lo[i] = index[i] * tile_size;
58-
bounds.hi[i] = bounds.lo[i] + tile_size - 1;
59-
}
60-
61-
const int64_t offset = *src_shard % num_slices_per_dest;
62-
index[dim] /= num_slices_per_dest;
63-
const int64_t dst_shard = ComputeTileShard(index, dst);
64-
SliceConfig config{.src_shard = *src_shard,
65-
.dst_shard = dst_shard,
66-
.bounds = std::move(bounds)};
67-
sg_config.from_source.push_back(std::move(config));
68-
}
69-
70-
if (dst_shard.has_value()) {
71-
// we are going to receive from multiple sources
72-
sg_config.to_target.reserve(num_slices_per_dest);
73-
std::vector<int64_t> index = ComputeTileIndex(*dst_shard, dst);
74-
const int64_t src_shard_start = index[dim] * num_slices_per_dest;
75-
for (int64_t slice = 0; slice < num_slices_per_dest; ++slice) {
76-
index[dim] = src_shard_start + slice;
77-
Realm::Rect<N, T> bounds;
78-
for (int64_t i = 0; i < N; ++i) {
79-
const int64_t tile_size = src.dims[i].size / src.dims[i].sharding;
80-
bounds.lo[i] = index[i] * tile_size;
81-
bounds.hi[i] = bounds.lo[i] + tile_size - 1;
82-
}
83-
SliceConfig config{
84-
.src_shard = ComputeTileShard(index, src),
85-
.dst_shard = *dst_shard,
86-
.bounds = std::move(bounds),
87-
};
88-
sg_config.to_target.push_back(std::move(config));
89-
}
90-
}
91-
} else { // dst_sharding > src_sharding
92-
const int64_t num_slices_per_source =
93-
dst.dims[dim].sharding / src.dims[dim].sharding;
94-
const int64_t slice_size = dst.dims[dim].size / dst.dims[dim].sharding;
95-
if (dst_shard.has_value()) {
96-
std::vector<int64_t> index = ComputeTileIndex(*dst_shard, dst);
97-
Realm::Rect<N, T> bounds;
98-
for (int64_t i = 0; i < N; ++i) {
99-
const int64_t tile_size = dst.dims[i].size / dst.dims[i].sharding;
100-
bounds.lo[i] = index[i] * tile_size;
101-
bounds.hi[i] = bounds.lo[i] + tile_size - 1;
102-
}
103-
index[dim] /= num_slices_per_source;
104-
const int64_t src_shard = ComputeTileShard(index, src);
105-
SliceConfig config{
106-
.src_shard = src_shard,
107-
.dst_shard = *dst_shard,
108-
.bounds = std::move(bounds),
109-
};
110-
sg_config.to_target.push_back(std::move(config));
111-
}
112-
113-
if (src_shard.has_value()) {
114-
sg_config.from_source.reserve(num_slices_per_source);
115-
116-
std::vector<int64_t> index = ComputeTileIndex(*src_shard, src);
117-
const int64_t dest_idx_start = index[dim] * num_slices_per_source;
118-
for (int64_t slice = 0; slice < num_slices_per_source; ++slice) {
119-
index[dim] = dest_idx_start + slice;
120-
Realm::Rect<N, T> bounds;
121-
for (int64_t i = 0; i < N; ++i) {
122-
const int64_t tile_size = dst.dims[i].size / dst.dims[i].sharding;
123-
bounds.lo[i] = index[i] * tile_size;
124-
bounds.hi[i] = bounds.lo[i] + tile_size - 1;
125-
}
126-
SliceConfig config{
127-
.src_shard = *src_shard,
128-
.dst_shard = ComputeTileShard(index, dst),
129-
.bounds = std::move(bounds),
130-
};
131-
sg_config.from_source.push_back(std::move(config));
132-
}
133-
}
134-
}
135-
return sg_config;
136-
}
43+
const Sharding& dst);
13744

13845
void ReshardScatterGather(const Processor& p, const ScatterGatherConfig& config,
13946
View<zuku::ShardedArray> src,

tests/unit/reshard_config_test.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,53 @@ TEST(ReshardConfigTest, Contract3DSharding) {
439439
}
440440
}
441441

442+
TEST(ReshardConfigTest, Bad3DReshardingWithPartialReplication) {
443+
Sharding src{
444+
.dims = {{
445+
.size = 8,
446+
.sharding = 4,
447+
.permutation = 0,
448+
},
449+
{.size = 128, .sharding = 1, .permutation = 1},
450+
{.size = 1, .sharding = 1, .permutation = 2},
451+
{.size = 1, .sharding = 2, .permutation = 3}},
452+
.devices = {{.start = 0, .num_devices = 8}},
453+
};
454+
455+
Sharding dst{
456+
.dims = {{
457+
.size = 8,
458+
.sharding = 2,
459+
.permutation = 0,
460+
},
461+
{.size = 128, .sharding = 1, .permutation = 1},
462+
{.size = 1, .sharding = 1, .permutation = 2},
463+
{.size = 1, .sharding = 2, .permutation = 3}},
464+
.devices = {{.start = 0, .num_devices = 4}},
465+
};
466+
467+
for (int i = 0; i < 8; ++i) {
468+
Processor p =
469+
Processor::Create({.local = i, .global = i}, Processor::Type::TEST);
470+
auto config = DetermineShardingConfig(p, src, dst);
471+
ASSERT_TRUE(std::holds_alternative<ScatterGatherConfig>(config));
472+
const ScatterGatherConfig& sg_config =
473+
std::get<ScatterGatherConfig>(config);
474+
475+
if (i < 4) {
476+
auto tile_bounds = ComputeTileBounds(i, dst);
477+
int64_t tile_size = ComputeRealmShapeSize(tile_bounds);
478+
std::vector<int> covered(tile_size, 0);
479+
// make sure the dst is covered
480+
for (const auto& slice : sg_config.to_target) {
481+
zuku::Iterate(covered.data(), slice.bounds, tile_bounds,
482+
[](int& data, auto... indices) { data = 1; });
483+
}
484+
EXPECT_THAT(covered, Each(Eq(1)));
485+
}
486+
}
487+
}
488+
442489
} // namespace
443490
} // namespace zuku
444491

tests/unit/reshard_scatter_gather_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ TEST(ReshardScatterGather, GatherToSingleGPU) {
3535
ShardingDim{.size = 2048, .sharding = 1, .permutation = 1}},
3636
.devices = src_mesh};
3737

38-
auto sg_config = GetScatterGatherConfig<2, int64_t>(
38+
auto sg_config = GetScatterGatherConfig(
3939
/*dim=*/0, /*src_shard=*/1, /*dst_shard=*/1, source_sharding,
4040
dest_sharding);
4141

0 commit comments

Comments
 (0)