-
Notifications
You must be signed in to change notification settings - Fork 28
/
extract_local.py
58 lines (40 loc) · 1.8 KB
/
extract_local.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.distributed as dist
from yunchang.globals import PROCESS_GROUP
def stripe_extract_local(value, rank, world_size, rd, ud, *args, **kwargs):
# ud at the highest dim
input_dim = value.dim()
assert input_dim >= 2
batch_size, seqlen, *rest = value.shape
assert dist.get_world_size(group=PROCESS_GROUP.RING_PG) == rd
assert dist.get_world_size(group=PROCESS_GROUP.ULYSSES_PG) == ud
value = value.reshape(batch_size, seqlen // rd, rd, -1).contiguous()
value = value.transpose(1, 2).reshape(batch_size, seqlen, -1).contiguous()
value = value.chunk(world_size, dim=1)[rank]
new_shape = [batch_size, seqlen // world_size] + rest
return value.reshape(new_shape)
def basic_extract_local(value, rank, world_size, *args, **kwargs):
return value.chunk(world_size, dim=1)[rank].detach().clone()
def zigzag_extract_local(value, rank, world_size, rd, ud, dim=1, *args, **kwargs):
"""
value is a tensor of shape (bs, seqlen, ...)
"""
input_dim = value.dim()
assert input_dim >= 2
batch_size, seqlen, *rest = value.shape
value_chunks = value.chunk(2 * rd, dim=dim)
r_rank = dist.get_rank(group=PROCESS_GROUP.RING_PG)
u_rank = dist.get_rank(group=PROCESS_GROUP.ULYSSES_PG)
assert dist.get_world_size(group=PROCESS_GROUP.RING_PG) == rd
assert dist.get_world_size(group=PROCESS_GROUP.ULYSSES_PG) == ud
local_value = torch.cat(
[value_chunks[r_rank], value_chunks[2 * rd - r_rank - 1]], dim=dim
).chunk(ud, dim=dim)[u_rank]
new_shape = [batch_size, seqlen // world_size] + rest
return local_value.reshape(new_shape).contiguous()
EXTRACT_FUNC_DICT = {
"basic": basic_extract_local,
"strip": stripe_extract_local,
"zigzag": zigzag_extract_local,
"basic_pytorch": basic_extract_local,
}