Skip to content

Commit

Permalink
add vjp interface for reshard op.
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Apr 9, 2024
1 parent d9c5c8b commit bcebb45
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 17 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/distributed/ir/dist_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ pir::Value reshard(const pir::Value& x,
TensorDistAttribute tensor_dist_attr =
TensorDistAttribute::get(ctx, process_mesh, dims_mapping, partial_status);

auto reshard_op = ApiBuilder::Instance().GetBuilder()->Build<ReShardOp>(
auto reshard_op = ApiBuilder::Instance().GetBuilder()->Build<ReshardOp>(
x, tensor_dist_attr);
return reshard_op.result(0);
}

pir::Value reshard(const pir::Value& x,
const TensorDistAttribute& tensor_dist_attr) {
auto reshard_op = ApiBuilder::Instance().GetBuilder()->Build<ReShardOp>(
auto reshard_op = ApiBuilder::Instance().GetBuilder()->Build<ReshardOp>(
x, tensor_dist_attr);
return reshard_op.result(0);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void DistDialect::initialize() {
TensorDistAttribute,
OperationDistAttribute>();
RegisterTypes<DistDenseTensorType>();
RegisterOps<ShardTensorOp, ReShardOp>();
RegisterOps<ShardTensorOp, ReshardOp>();
}

void DistDialect::PrintType(pir::Type type, std::ostream &os) const {
Expand Down
53 changes: 47 additions & 6 deletions paddle/fluid/pir/dialect/distributed/ir/dist_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/fluid/pir/dialect/distributed/ir/dist_op.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
Expand All @@ -28,7 +29,7 @@ namespace paddle {
namespace dialect {

const char* ShardTensorOp::attributes_name[1] = {"op_dist_attr"};
const char* ReShardOp::attributes_name[1] = {"op_dist_attr"};
const char* ReshardOp::attributes_name[1] = {"op_dist_attr"};

void ShardTensorOp::VerifySig() {
VLOG(4)
Expand Down Expand Up @@ -158,9 +159,49 @@ void ShardTensorOp::Build(pir::Builder& builder,
argument.AddOutput(out_dist_tensor_type);
::pir::PassStopGradientsDefaultly(argument);
}
std::vector<std::vector<pir::Value>> ReshardOp::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs_,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
VLOG(6) << "Start call vjp for reshard op.";
PADDLE_ENFORCE_EQ(
inputs_.size(),
1,
common::errors::InvalidArgument("reshard op's inputs' size should be 1"));
PADDLE_ENFORCE_EQ(inputs_[0].size(),
1,
common::errors::InvalidArgument(
"reshard op's inputs[0]'s size should be 1"));
auto dist_type = inputs_[0][0].type().dyn_cast<DistTypeInterface>();

PADDLE_ENFORCE_NOT_NULL(
dist_type,
common::errors::InvalidArgument(
"Currently, reshard op's inputs type must be dist type."));

PADDLE_ENFORCE_EQ(out_grads.size(),
1,
common::errors::InvalidArgument(
"reshard op's outputs grad size should be 1"));

PADDLE_ENFORCE_EQ(out_grads[0].size(),
1,
common::errors::InvalidArgument(
"reshard op's outputs grad[0] size should be 1"));

void ReShardOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: ReShardOp.";
auto& builder = *ApiBuilder::Instance().GetBuilder();

auto grad_op =
builder.Build<ReshardOp>(out_grads[0][0], dist_type.tensor_dist_attr());

VLOG(6) << "End call vjp for reshard op.";

return {std::vector<pir::Value>{grad_op->result(0)}};
}
void ReshardOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: ReshardOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
Expand Down Expand Up @@ -224,11 +265,11 @@ void ReShardOp::VerifySig() {
VLOG(4) << "End Verifying for: ShardTensorOp.";
}

void ReShardOp::Build(pir::Builder& builder,
void ReshardOp::Build(pir::Builder& builder,
pir::OperationArgument& argument,
pir::Value input,
TensorDistAttribute tensor_dist_attr) {
VLOG(4) << "Start build ReShardOp";
VLOG(4) << "Start build ReshardOp";

paddle::dialect::DistDenseTensorType input_tensor_type;
if (input.type().isa<paddle::dialect::DistDenseTensorType>()) {
Expand Down Expand Up @@ -276,4 +317,4 @@ void ReShardOp::Build(pir::Builder& builder,
} // namespace paddle

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShardTensorOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ReShardOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ReshardOp)
13 changes: 11 additions & 2 deletions paddle/fluid/pir/dialect/distributed/ir/dist_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once
#include <vector>

#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
#include "paddle/pir/include/core/builder.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/core/op_base.h"
Expand All @@ -39,7 +40,7 @@ class ShardTensorOp : public pir::Op<ShardTensorOp> {
void VerifySig();
};

class ReShardOp : public pir::Op<ReShardOp> {
class ReshardOp : public pir::Op<ReshardOp, VjpInterface> {
public:
using Op::Op;
static const char* name() { return "dist_op.reshard"; }
Expand All @@ -49,10 +50,18 @@ class ReShardOp : public pir::Op<ReShardOp> {
pir::OperationArgument& argument, // NOLINT
pir::Value input,
TensorDistAttribute tensor_dist_attr);

static std::vector<std::vector<pir::Value>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs_,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients);

void VerifySig();
};
} // namespace dialect
} // namespace paddle

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShardTensorOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ReShardOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ReshardOp)
3 changes: 3 additions & 0 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
EagerParamBase,
Variable,
default_main_program,
in_pir_mode,
)
from paddle.distributed.auto_parallel import Engine, strategy as auto_strategy
from paddle.distributed.auto_parallel.interface import (
Expand Down Expand Up @@ -388,6 +389,8 @@ def reshard(dist_tensor, mesh, placements):
dist_attr._set_partial_dims(partial_dims)

return paddle.base.core.reshard(dist_tensor, dist_attr)
elif in_pir_mode():
return paddle._pir_ops.reshard(dist_tensor, mesh, [-1, -1], [0])
else:
assert isinstance(
dist_tensor, Variable
Expand Down
12 changes: 6 additions & 6 deletions test/cpp/pir/distributed/dist_dialect_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ TEST(shard_tensor_op_replicate_test, base) {
auto dst_mesh_attr = ProcessMeshAttribute::get(ctx, dst_process_mesh);
auto dst_tensor_dist_attr = TensorDistAttribute::get(
ctx, dst_mesh_attr, dst_dims_mapping, partial_status);
paddle::dialect::ReShardOp reshard_op =
builder.Build<paddle::dialect::ReShardOp>(shard_op.out(),
paddle::dialect::ReshardOp reshard_op =
builder.Build<paddle::dialect::ReshardOp>(shard_op.out(),
dst_tensor_dist_attr);

EXPECT_TRUE(reshard_op.result(0).type().isa<DistDenseTensorType>());
Expand Down Expand Up @@ -428,8 +428,8 @@ TEST(shard_tensor_op_shard_row_test, base) {
auto dst_mesh_attr = ProcessMeshAttribute::get(ctx, dst_process_mesh);
auto dst_tensor_dist_attr = TensorDistAttribute::get(
ctx, dst_mesh_attr, dims_mapping, partial_status);
paddle::dialect::ReShardOp reshard_op =
builder.Build<paddle::dialect::ReShardOp>(shard_op.out(),
paddle::dialect::ReshardOp reshard_op =
builder.Build<paddle::dialect::ReshardOp>(shard_op.out(),
dst_tensor_dist_attr);

EXPECT_TRUE(reshard_op.result(0).type().isa<DistDenseTensorType>());
Expand Down Expand Up @@ -511,8 +511,8 @@ TEST(shard_tensor_op_shard_col_test, base) {
auto dst_mesh_attr = ProcessMeshAttribute::get(ctx, dst_process_mesh);
auto dst_tensor_dist_attr = TensorDistAttribute::get(
ctx, dst_mesh_attr, dst_dims_mapping, partial_status);
paddle::dialect::ReShardOp reshard_op =
builder.Build<paddle::dialect::ReShardOp>(shard_op.out(),
paddle::dialect::ReshardOp reshard_op =
builder.Build<paddle::dialect::ReshardOp>(shard_op.out(),
dst_tensor_dist_attr);

EXPECT_TRUE(reshard_op.result(0).type().isa<DistDenseTensorType>());
Expand Down

0 comments on commit bcebb45

Please sign in to comment.