From c21eb80c9dc369ce6185bb6f243bde6b59c9bacb Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Sat, 6 Sep 2025 16:29:38 +0800 Subject: [PATCH 1/8] [Feature] support rl_tp_degree --- fastdeploy/model_executor/layers/embeddings.py | 1 + fastdeploy/model_executor/layers/linear.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 377ff19bb5..5c6a6d1418 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -77,6 +77,7 @@ def __init__( ) if self.world_size > 1: set_weight_attrs(self.embeddings.weight, {"output_dim": False}) + self.embeddings.weight.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size else: # column cut embedding self.embeddings = nn.Embedding( diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 2c7f9aef34..1efc2afbc8 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -243,6 +243,9 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: return linear_out + def set_rl_tp_depree(self, tp_degree): + self.weight.rl_tp_degree = tp_degree + class ReplicatedLinear(LinearBase): """ @@ -361,6 +364,9 @@ def __init__( _set_var_distributed(self.bias, split_axis=1) set_weight_attrs(self.bias, {"output_dim": True}) + # set_rl_tp_depree + self.set_rl_tp_depree(fd_config.parallel_config.tensor_parallel_size) + class MergedColumnParallelLinear(ColumnParallelLinear): """ @@ -755,6 +761,9 @@ def __init__( self.reduce_results = reduce_results + # set_rl_tp_depree + self.set_rl_tp_depree(fd_config.parallel_config.tensor_parallel_size) + def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: if self.fd_config.quant_config: out = self.quant_method.apply(self, x) From df7398fe7aeb4366eddbe04bdb53cc08489fe351 Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Sat, 6 Sep 2025 21:40:29 +0800 Subject: [PATCH 2/8] add rl_tp_degree in lmhead --- fastdeploy/model_executor/layers/lm_head.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index a62e46d61d..771589a0ca 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -116,6 +116,7 @@ def __init__( if self.nranks > 1: set_weight_attrs(self.linear.weight, {"output_dim": False}) + self.linear.weight.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): """ From 6496ba787863fe70fd9a2932455867ae9db060f3 Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Sat, 6 Sep 2025 21:55:03 +0800 Subject: [PATCH 3/8] add rl_tp_degree in bias --- fastdeploy/model_executor/layers/linear.py | 12 +++++------- fastdeploy/model_executor/layers/lm_head.py | 1 + 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 1efc2afbc8..ba0a70b212 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -243,9 +243,6 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: return linear_out - def set_rl_tp_depree(self, tp_degree): - self.weight.rl_tp_degree = tp_degree - class ReplicatedLinear(LinearBase): """ @@ -364,8 +361,9 @@ def __init__( _set_var_distributed(self.bias, split_axis=1) set_weight_attrs(self.bias, {"output_dim": True}) - # set_rl_tp_depree - self.set_rl_tp_depree(fd_config.parallel_config.tensor_parallel_size) + # set_rl_tp_degree + self.weight.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size + self.bias.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size class MergedColumnParallelLinear(ColumnParallelLinear): @@ -761,8 +759,8 @@ def __init__( self.reduce_results = reduce_results - # set_rl_tp_depree - self.set_rl_tp_depree(fd_config.parallel_config.tensor_parallel_size) + # set_rl_tp_degree + self.weight.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: if self.fd_config.quant_config: diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index 771589a0ca..2e9748e8b7 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -94,6 +94,7 @@ def __init__( "model_format": self.fd_config.model_config.model_format, }, ) + self.linear.bias.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size if self.nranks > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) else: From 83b6b14b3bcb7ae288e4978b54f889a44e572de2 Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Sat, 6 Sep 2025 22:38:51 +0800 Subject: [PATCH 4/8] fix split_axis=0 in bias --- fastdeploy/model_executor/layers/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index ba0a70b212..93ec4f950f 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -358,7 +358,7 @@ def __init__( if self.nranks > 0: if self.with_bias: # col parallel - _set_var_distributed(self.bias, split_axis=1) + _set_var_distributed(self.bias, split_axis=0) set_weight_attrs(self.bias, {"output_dim": True}) # set_rl_tp_degree From e257b64cba4ddcafd7bbed19187102df9d63ef90 Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Sat, 6 Sep 2025 22:43:06 +0800 Subject: [PATCH 5/8] fix split_axis in weight --- fastdeploy/model_executor/layers/linear.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 93ec4f950f..c000c54146 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -356,6 +356,7 @@ def __init__( ) if self.nranks > 0: + _set_var_distributed(self.weight, split_axis=-1) if self.with_bias: # col parallel _set_var_distributed(self.bias, split_axis=0) @@ -747,6 +748,7 @@ def __init__( model_format=fd_config.model_config.model_format, ) if self.nranks > 0: + _set_var_distributed(self.weight, split_axis=0) if self.with_bias: # col parallel _set_var_distributed(self.bias, split_axis=0) From 1dfeeedbfba6e8981d5a25015fcb2c116c1315ea Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Sun, 7 Sep 2025 15:13:14 +0800 Subject: [PATCH 6/8] fix bias rl_tp_degree --- fastdeploy/model_executor/layers/linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index c000c54146..16e53633fc 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -364,7 +364,8 @@ def __init__( # set_rl_tp_degree self.weight.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size - self.bias.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size + if self.with_bias: + self.bias.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size class MergedColumnParallelLinear(ColumnParallelLinear): From 4663a22814ffc66927aa5c3006180656e9cb584b Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Sun, 7 Sep 2025 16:39:58 +0800 Subject: [PATCH 7/8] fix bias rl_tp_degree --- fastdeploy/model_executor/layers/lm_head.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index 2e9748e8b7..6402427ccf 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -94,7 +94,8 @@ def __init__( "model_format": self.fd_config.model_config.model_format, }, ) - self.linear.bias.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size + if self.bias_key is not None: + self.linear.bias.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size if self.nranks > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) else: From 51b34847efd2cc20fc1c05f517a240ce7978518e Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Sun, 7 Sep 2025 22:05:51 +0800 Subject: [PATCH 8/8] change attr to dict --- fastdeploy/model_executor/layers/embeddings.py | 6 +++++- fastdeploy/model_executor/layers/linear.py | 12 +++++++++--- fastdeploy/model_executor/layers/lm_head.py | 10 ++++++++-- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 5c6a6d1418..0ac2d0d70a 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -77,7 +77,11 @@ def __init__( ) if self.world_size > 1: set_weight_attrs(self.embeddings.weight, {"output_dim": False}) - self.embeddings.weight.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size + set_weight_attrs( + self.embeddings.weight, + {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}, + ) + else: # column cut embedding self.embeddings = nn.Embedding( diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 16e53633fc..0d079c90ce 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -363,9 +363,13 @@ def __init__( set_weight_attrs(self.bias, {"output_dim": True}) # set_rl_tp_degree - self.weight.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size + set_weight_attrs( + self.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}} + ) if self.with_bias: - self.bias.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size + set_weight_attrs( + self.bias, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}} + ) class MergedColumnParallelLinear(ColumnParallelLinear): @@ -763,7 +767,9 @@ def __init__( self.reduce_results = reduce_results # set_rl_tp_degree - self.weight.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size + set_weight_attrs( + self.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}} + ) def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: if self.fd_config.quant_config: diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index 6402427ccf..b9dc06ab01 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -95,7 +95,11 @@ def __init__( }, ) if self.bias_key is not None: - self.linear.bias.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size + set_weight_attrs( + self.linear.bias, + {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}, + ) + if self.nranks > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) else: @@ -118,7 +122,9 @@ def __init__( if self.nranks > 1: set_weight_attrs(self.linear.weight, {"output_dim": False}) - self.linear.weight.rl_tp_degree = fd_config.parallel_config.tensor_parallel_size + set_weight_attrs( + self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}} + ) def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): """