Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Apr 6, 2024
1 parent 91e4455 commit 5c47556
Show file tree
Hide file tree
Showing 20 changed files with 839 additions and 441 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -880,10 +880,13 @@ bool Transpose_OpInferSymbolicShape(

bool SqueezeOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
IR_ENFORCE(op->num_operands() == 2,
"SqueezeOpInferSymbolicShape ONLY support num_operands() == 2 "
"now, but got %d operands",
op->num_operands());
PADDLE_ENFORCE_EQ(
op->num_operands(),
2,
phi::errors::InvalidArgument(
"SqueezeOpInferSymbolicShape ONLY support num_operands() == 2 "
"now, but got %d operands",
op->num_operands()));

auto x_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0));
Expand Down Expand Up @@ -993,10 +996,13 @@ bool UniqueConsecutiveOpInferSymbolicShape(

bool UnsqueezeOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
IR_ENFORCE(op->num_operands() == 2,
"UnsqueezeOp InferSymbolicShape ONLY support num_operands() == 2 "
"now, but got %d operands",
op->num_operands());
PADDLE_ENFORCE_EQ(
op->num_operands(),
2,
phi::errors::InvalidArgument(
"UnsqueezeOp InferSymbolicShape ONLY support num_operands() == 2 "
"now, but got %d operands",
op->num_operands()));

auto x_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0));
Expand Down
16 changes: 12 additions & 4 deletions paddle/fluid/pir/dialect/operator/ir/api_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ namespace dialect {
ApiBuilder::ApiBuilder()
: ctx_(pir::IrContext::Instance()),
builder_(std::make_shared<pir::Builder>(ctx_)) {
IR_ENFORCE(builder_ != nullptr, "api builder construct error!");
PADDLE_ENFORCE_NE(
builder_,
nullptr,
phi::errors::InvalidArgument("api builder construct error!"));
}

void ApiBuilder::SetProgram(pir::Program* program) {
IR_ENFORCE(program != nullptr, "argument of program is nullptr");
PADDLE_ENFORCE_NE(
program,
nullptr,
phi::errors::InvalidArgument("argument of program is nullptr"));
builder_->SetInsertionPointToBlockEnd(program->block());
}

Expand All @@ -50,8 +56,10 @@ void ApiBuilder::SetParameter(const std::string& name,
}

void ApiBuilder::LoadInsertionPoint() {
IR_ENFORCE(!insertion_point_stack_.empty(),
"insertion_point_stack_ is empty.");
PADDLE_ENFORCE_EQ(
!insertion_point_stack_.empty(),
true,
phi::errors::InvalidArgument("insertion_point_stack_ is empty."));
builder_->set_insertion_point(insertion_point_stack_.top());
insertion_point_stack_.pop();
}
Expand Down
Loading

0 comments on commit 5c47556

Please sign in to comment.