Skip to content

Commit

Permalink
fix backward
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmovic committed Mar 20, 2024
1 parent 7aca619 commit 26fda63
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
15 changes: 15 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ void AngleGradInferMeta(const MetaTensor& x,
UnchangedInferMeta(x, x_grad);
}

void BatchFCGradInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& out_grad,
MetaTensor* input_grad,
MetaTensor* w_grad,
MetaTensor* bias_grad) {
input_grad->set_dims(input.dims());
input_grad->set_dtype(input.dtype());
w_grad->set_dims(w.dims());
w_grad->set_dtype(w.dtype());
bias_grad->set_dims(bias.dims());
bias_grad->set_dtype(bias.dtype());
}

void BilinearGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ void AngleGradInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
MetaTensor* x_grad);

void BatchFCGradInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& out_grad,
MetaTensor* intput_grad,
MetaTensor* w_grad,
MetaTensor* bias_grad);

void BilinearGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
Expand Down
15 changes: 0 additions & 15 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,21 +187,6 @@ void BatchFCInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}

void BatchFCGradInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& out_grad,
MetaTensor* input_grad,
MetaTensor* w_grad,
MetaTensor* bias_grad) {
input_grad->set_dims(input.dims());
input_grad->set_dtype(input.dtype());
w_grad->set_dims(w.dims());
w_grad->set_dtype(w.dtype());
bias_grad->set_dims(bias.dims());
bias_grad->set_dtype(bias.dtype());
}

void BoxCoderInferMeta(const MetaTensor& prior_box,
const MetaTensor& prior_box_var,
const MetaTensor& target_box,
Expand Down
8 changes: 0 additions & 8 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ void BatchFCInferMeta(const MetaTensor& input,
const MetaTensor& bias,
MetaTensor* out);

void BatchFCGradInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& out_grad,
MetaTensor* intput_grad,
MetaTensor* w_grad,
MetaTensor* bias_grad);

void BoxCoderInferMeta(const MetaTensor& prior_box,
const MetaTensor& prior_box_var,
const MetaTensor& target_box,
Expand Down

0 comments on commit 26fda63

Please sign in to comment.