@@ -1703,31 +1703,28 @@ SplitedResult SplitForwardBackward(
1703
1703
auto &backward_value_map = backward_mapper.GetMutableMap <pir::Value>();
1704
1704
int counter = forward_outputs.size ();
1705
1705
1706
- auto create_output_fn_forward = [&ctx,
1707
- &forward_value_map,
1708
- &counter,
1709
- &forward_program,
1710
- &forward_inputs,
1711
- &forward_params](const pir::Value &v) {
1712
- if (v.impl () == nullptr ) {
1713
- return ;
1714
- }
1715
- // Skip the value that already in forward_params.
1716
- if (std::find (forward_params.begin (), forward_params.end (), v) !=
1717
- forward_params.end ()) {
1718
- return ;
1719
- }
1720
- std::string shadow_output_name =
1721
- std::string (" output_" ) + std::to_string (counter);
1722
- auto op_info = ctx->GetRegisteredOpInfo (pir::ShadowOutputOp::name ());
1723
- pir::AttributeMap attribute_map = {
1724
- {" output_name" , pir::StrAttribute::get (ctx, shadow_output_name)},
1725
- };
1726
- pir::Operation *operation = pir::Operation::Create (
1727
- {forward_value_map[v]}, attribute_map, {}, op_info);
1728
- forward_program->block ()->push_back (operation);
1729
- counter += 1 ;
1730
- };
1706
+ auto create_output_fn_forward =
1707
+ [&ctx, &forward_value_map, &counter, &forward_program, &forward_params](
1708
+ const pir::Value &v) {
1709
+ if (v.impl () == nullptr ) {
1710
+ return ;
1711
+ }
1712
+ // Skip the value that already in forward_params.
1713
+ if (std::find (forward_params.begin (), forward_params.end (), v) !=
1714
+ forward_params.end ()) {
1715
+ return ;
1716
+ }
1717
+ std::string shadow_output_name =
1718
+ std::string (" output_" ) + std::to_string (counter);
1719
+ auto op_info = ctx->GetRegisteredOpInfo (pir::ShadowOutputOp::name ());
1720
+ pir::AttributeMap attribute_map = {
1721
+ {" output_name" , pir::StrAttribute::get (ctx, shadow_output_name)},
1722
+ };
1723
+ pir::Operation *operation = pir::Operation::Create (
1724
+ {forward_value_map[v]}, attribute_map, {}, op_info);
1725
+ forward_program->block ()->push_back (operation);
1726
+ counter += 1 ;
1727
+ };
1731
1728
1732
1729
auto create_output_fn_backward = [&ctx,
1733
1730
&backward_value_map,
0 commit comments