Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Expert parallelism support #1435

Open
2 tasks done
chongli-uw opened this issue Sep 16, 2024 · 7 comments
Open
2 tasks done

[Feature] Expert parallelism support #1435

chongli-uw opened this issue Sep 16, 2024 · 7 comments
Assignees
Labels
enhancement New feature or request

Comments

@chongli-uw
Copy link

chongli-uw commented Sep 16, 2024

Checklist

Motivation

Hi team,
First of all thanks so much for such a great project. I am wondering if there is plan to support Expert Parallelism for MoE models?

Related resources

https://nvidia.github.io/TensorRT-LLM/advanced/expert-parallelism.html

@Ying1123 Ying1123 added the enhancement New feature or request label Sep 16, 2024
@merrymercy
Copy link
Contributor

class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}."
)
# Split experts equally between ranks
self.expert_indicies = np.array_split(
range(self.num_total_experts), self.tp_size
)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList(
[
(
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
)
if idx in self.expert_indicies
else None
)
for idx in range(self.num_total_experts)
]
)
self.gate = ReplicatedLinear(
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = selected_experts == expert_idx
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states)

this is an early example

@liangzelang
Copy link

liangzelang commented Nov 14, 2024

class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}."
)
# Split experts equally between ranks
self.expert_indicies = np.array_split(
range(self.num_total_experts), self.tp_size
)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList(
[
(
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
)
if idx in self.expert_indicies
else None
)
for idx in range(self.num_total_experts)
]
)
self.gate = ReplicatedLinear(
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = selected_experts == expert_idx
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states)

this is an early example

@merrymercy Hi, any progress has been made on this issue? The example you provided previously didn't use FusedMOE but mlp. How can we enable Expert Parallel with the current Mixtral/DeepSeek-v2 after using FusedMOE? Do you have a modified example?

@merrymercy
Copy link
Contributor

related #1970

@liangzelang
Copy link

related #1970

@merrymercy I see that this issue is mainly related to TP and DP. I noticed that the SGLang Q4 roadmap #1487 mentioned supporting this feature.

@zhyncs
Copy link
Member

zhyncs commented Nov 17, 2024

@liangzelang DP has already been merged(only for DeepSeek right now) #1970 and EP will be supported soon cc @ispobock

@xiaobochen123
Copy link
Contributor

xiaobochen123 commented Nov 21, 2024

@liangzelang DP has already been merged(only for DeepSeek right now) #1970 and EP will be supported soon cc @ispobock

@zhyncs Does MoE-EP have any support? I have implemented MoE-EP.

@ispobock
Copy link
Collaborator

Does MoE-EP have any support? I have implemented MoE-EP.

@xiaobochen123 We are going to implement it with a DP + EP approach for throughput gains. Currently, DP attention is implemented. Before we start the EP, some updates to the MoE codebase should be done.

I am interested in what kind of MoE-EP did you implement and what codebase did you use? How much are the performance gains compared to TP?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

7 participants