Skip to content

Commit

Permalink
fix overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyuqin1998 committed Mar 9, 2025
1 parent 00b572e commit c38b4b4
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 16 deletions.
108 changes: 96 additions & 12 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,10 @@ def __init__(self, config: DeepseekV2Config):

def forward(self, hidden_states):
final_hidden_states, l_aux, l_zloss = super().forward(hidden_states)
final_hidden_states = self.auxilibaryloss_and_shared_expert_compute(hidden_states, final_hidden_states, l_aux)
return final_hidden_states

Check warning on line 830 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L829-L830

Added lines #L829 - L830 were not covered by tests

def auxilibaryloss_and_shared_expert_compute(self, hidden_states, final_hidden_states, l_aux):
if self.training and self.alpha > 0.0:
l_aux = l_aux * self.alpha
final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux)
Expand Down Expand Up @@ -1145,6 +1149,48 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
self.input_layernorm = DeepseekV2RMSNorm(config)
self.post_attention_layernorm = DeepseekV2RMSNorm(config)

def self_attn_and_gate_compute(
self,
hidden_states: paddle.Tensor,
position_ids: Optional[paddle.Tensor] = None,
attention_mask: Optional[paddle.Tensor] = None,
output_attentions: Optional[bool] = False,
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
use_cache: Optional[bool] = False,
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
**kwargs,
):
hidden_states, residual = self.self_attn_compute(

Check warning on line 1163 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1163

Added line #L1163 was not covered by tests
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
past_key_value=past_key_value,
use_cache=use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
**kwargs,
)
probs, routing_map, l_aux, l_zloss = self.mlp.gate_compute(hidden_states)
return probs, routing_map, l_aux, l_zloss

Check warning on line 1174 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1173-L1174

Added lines #L1173 - L1174 were not covered by tests

def auxilibaryloss_and_shared_expert_compute(self, residual, hidden_states, expert_output, l_aux):
hidden_states = self.mlp.auxilibaryloss_and_shared_expert_compute(hidden_states, expert_output, l_aux)
hidden_states = residual + hidden_states

Check warning on line 1178 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1177-L1178

Added lines #L1177 - L1178 were not covered by tests

def post_process_output(self, hidden_states, output_attentions, use_cache, self_attn_weights, present_key_value):
outputs = (hidden_states,)

Check warning on line 1181 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1181

Added line #L1181 was not covered by tests

if output_attentions:
outputs += (self_attn_weights,)

Check warning on line 1184 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1183-L1184

Added lines #L1183 - L1184 were not covered by tests

if use_cache:
outputs += (present_key_value,)

Check warning on line 1187 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1186-L1187

Added lines #L1186 - L1187 were not covered by tests

if type(outputs) is tuple and len(outputs) == 1:
outputs = outputs[0]

Check warning on line 1190 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1189-L1190

Added lines #L1189 - L1190 were not covered by tests

return outputs

Check warning on line 1192 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1192

Added line #L1192 was not covered by tests

def forward(
self,
hidden_states: paddle.Tensor,
Expand All @@ -1170,10 +1216,6 @@ def forward(
(see `past_key_values`).
past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
"""
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -1216,18 +1258,60 @@ def forward(
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
return self.post_process_output(

Check warning on line 1261 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1261

Added line #L1261 was not covered by tests
hidden_states, output_attentions, use_cache, self_attn_weights, present_key_value
)

if output_attentions:
outputs += (self_attn_weights,)
def self_attn_compute(
self,
hidden_states: paddle.Tensor,
position_ids: Optional[paddle.Tensor] = None,
attention_mask: Optional[paddle.Tensor] = None,
output_attentions: Optional[bool] = False,
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
use_cache: Optional[bool] = False,
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
**kwargs
):
residual = hidden_states

Check warning on line 1276 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1276

Added line #L1276 was not covered by tests

if use_cache:
outputs += (present_key_value,)
hidden_states = self.input_layernorm(hidden_states)

Check warning on line 1278 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1278

Added line #L1278 was not covered by tests

if type(outputs) is tuple and len(outputs) == 1:
outputs = outputs[0]
# Self Attention
has_gradient = not hidden_states.stop_gradient
if (

Check warning on line 1282 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1281-L1282

Added lines #L1281 - L1282 were not covered by tests
self.enable_recompute
and self.layerwise_recompute
and has_gradient
and self.recompute_granularity == "full_attn"
):
hidden_states, self_attn_weights, present_key_value = recompute(

Check warning on line 1288 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1288

Added line #L1288 was not covered by tests
self.self_attn,
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
past_key_value=past_key_value,
use_cache=use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
**kwargs,
)
else:
hidden_states, self_attn_weights, present_key_value = self.self_attn(

Check warning on line 1300 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1300

Added line #L1300 was not covered by tests
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
past_key_value=past_key_value,
use_cache=use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
**kwargs,
)
hidden_states = residual + hidden_states

Check warning on line 1310 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1310

Added line #L1310 was not covered by tests

return outputs
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
return hidden_states, residual

Check warning on line 1314 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1312-L1314

Added lines #L1312 - L1314 were not covered by tests


class DeepseekV2MTPLayer(DeepseekV2DecoderLayer):
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/transformers/deepseek_v2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,10 @@ def overlapped_forward_backward(
output_grads1,
scaler,
):
outputs0 = module0(inputs0)
outputs0 = inputs0
for layer in module0:
outputs0 = layer(outputs0)

Check warning on line 527 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L525-L527

Added lines #L525 - L527 were not covered by tests

outputs0 = [outputs0] if isinstance(outputs0, paddle.Tensor) else outputs0

Check warning on line 529 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L529

Added line #L529 was not covered by tests

if labels0 is not None:
Expand Down
21 changes: 18 additions & 3 deletions paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,27 @@ def expert_forward(self, dispatched_input, tokens_per_expert):
return paddle.concat(outputs, axis=0)

def forward(self, hidden_states: paddle.Tensor):
probs, routing_map, l_aux, l_zloss = self.gate_compute(hidden_states)
dispatched_input, tokens_per_expert = self.dispatch_comm(hidden_states, probs, routing_map)
expert_output = self.mlp_compute(dispatched_input, tokens_per_expert)
output = self.combine_comm(expert_output)
return output, l_aux, l_zloss

Check warning on line 318 in paddlenlp/transformers/moe_layer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_layer.py#L314-L318

Added lines #L314 - L318 were not covered by tests

def gate_compute(self, hidden_states):
_, _, d_model = hidden_states.shape
# reshaped_input = hidden_states.reshape([-1, d_model])
probs, routing_map, l_aux, l_zloss = self.router(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
return probs, routing_map, l_aux, l_zloss

Check warning on line 324 in paddlenlp/transformers/moe_layer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_layer.py#L324

Added line #L324 was not covered by tests

def dispatch_comm(self, hidden_states, probs, routing_map):
dispatched_input, tokens_per_expert = self.token_dispatcher.token_permutation(

Check warning on line 327 in paddlenlp/transformers/moe_layer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_layer.py#L327

Added line #L327 was not covered by tests
hidden_states, probs, routing_map
)
expert_output = self.expert_forward(dispatched_input, tokens_per_expert)
return dispatched_input, tokens_per_expert

Check warning on line 330 in paddlenlp/transformers/moe_layer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_layer.py#L330

Added line #L330 was not covered by tests

def mlp_compute(self, dispatched_input, tokens_per_expert):
return self.expert_forward(dispatched_input, tokens_per_expert)

Check warning on line 333 in paddlenlp/transformers/moe_layer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_layer.py#L333

Added line #L333 was not covered by tests

def combine_comm(self, expert_output):
output, _ = self.token_dispatcher.token_unpermutation(expert_output, None)
return output, l_aux, l_zloss
return output

Check warning on line 337 in paddlenlp/transformers/moe_layer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_layer.py#L337

Added line #L337 was not covered by tests

0 comments on commit c38b4b4

Please sign in to comment.