diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index d6e4d4ba01368f..7b04678b05113a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -1298,12 +1298,82 @@ bool FusedMultiTransformerOpInferSymbolicShape( // return true; // } -// bool GraphKhopSamplerOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext -// *infer_context) { -// // pass -// return true; -// } +bool GraphKhopSamplerOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const symbol::ShapeOrDataDimExprs &row_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const symbol::ShapeOrDataDimExprs &col_ptr_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const symbol::ShapeOrDataDimExprs &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + const symbol::ShapeOrDataDimExprs &eids_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(3)); + + auto row_shape = row_shape_or_data.shape(); + auto col_ptr_shape = col_ptr_shape_or_data.shape(); + auto x_shape = x_shape_or_data.shape(); + auto eids_shape = eids_shape_or_data.shape(); + + auto GKSShapeCheck = [&](const std::vector &shape, + const std::string &tensor_name) { + if (shape.size() == 2) + infer_context->AddEqualCstr(shape[1], symbol::DimExpr(1)); + else + PADDLE_ENFORCE_EQ( + shape.size(), + 1, + common::errors::InvalidArgument( + "The %s should be 1D, when it is not 2D, but we get %d", + tensor_name, + shape.size())); + }; + + GKSShapeCheck(row_shape, "row"); + GKSShapeCheck(col_ptr_shape, "col_ptr"); + GKSShapeCheck(x_shape, "x"); + + std::vector sample_sizes = + paddle::dialect::details::GetVectorAttr(op, "sample_sizes"); + PADDLE_ENFORCE_EQ( + !sample_sizes.empty(), + true, + common::errors::InvalidArgument( + "The parameter 'sample_sizes' in GraphSampleOp must be set. " + "But received 'sample_sizes' is empty.")); + + bool return_eids = op->attribute("return_eids").data(); + if (return_eids) { + GKSShapeCheck(eids_shape, "eids"); + symbol::DimExpr out_unknown_4 = infer_context->GetNextSymName(); + infer_context->SetShapeOrDataForValue( + op->result(4), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs({out_unknown_4})}); + } else { + infer_context->SetSymbolForValueByStaticShape(op->result(4)); + } + + symbol::DimExpr out_unknown_0_1 = infer_context->GetNextSymName(); + symbol::DimExpr out_unknown_2 = infer_context->GetNextSymName(); + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs({out_unknown_0_1, 1})}); + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs({out_unknown_0_1, 1})}); + infer_context->SetShapeOrDataForValue( + op->result(2), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs({out_unknown_2})}); + infer_context->SetShapeOrDataForValue( + op->result(3), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs({x_shape[0]})}); + + return true; +} // bool GraphReindexOpInferSymbolicShape(pir::Operation *op, // pir::InferSymbolicShapeContext diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index 2336b0d0abbb9f..01df952f97ae2c 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -61,7 +61,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedMultiTransformer) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(GenerateProposals) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphKhopSampler) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphKhopSampler) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphReindex) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(GraphSampleNeighbors) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gru) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index fef1932c633290..175193bce9a620 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -2281,7 +2281,7 @@ func : graph_khop_sampler data_type : row optional : eids - # interfaces : paddle::dialect::InferSymbolicShapeInterface + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : graph_sample_neighbors args : (Tensor row, Tensor colptr, Tensor x, Tensor eids, Tensor perm_buffer, int sample_size, bool return_eids, bool flag_perm_buffer)