Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi-Auto] Replace the pointers in reshape-like spmd rule with smart pointers. #59101

Merged
merged 4 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 96 additions & 95 deletions paddle/phi/infermeta/spmd_rules/dim_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ limitations under the License. */
namespace phi {
namespace distributed {

static std::vector<DimTrans*> all_dim_trans;

DimTrans::DimTrans(Type type) : type_(type) {}

DimTrans::~DimTrans() {}
Expand All @@ -35,14 +33,10 @@ void DimTrans::set_type(Type type) { type_ = type; }

std::string DimTrans::to_string() { return std::string(""); }

InputDim::InputDim() : DimTrans(DimTrans::Type::INPUTDIM) {
input_dim_ = -1;
all_dim_trans.emplace_back(this);
}
InputDim::InputDim() : DimTrans(DimTrans::Type::INPUTDIM) { input_dim_ = -1; }

InputDim::InputDim(int64_t dim) : DimTrans(DimTrans::Type::INPUTDIM) {
input_dim_ = dim;
all_dim_trans.emplace_back(this);
}

InputDim::~InputDim() {}
Expand All @@ -55,30 +49,26 @@ std::string InputDim::to_string() {
return ("InputDim(" + std::to_string(input_dim_) + ")");
}

Singleton::Singleton() : DimTrans(DimTrans::Type::SINGLETON) {
all_dim_trans.emplace_back(this);
}
Singleton::Singleton() : DimTrans(DimTrans::Type::SINGLETON) {}

std::string Singleton::to_string() { return "Singleton()"; }

Flatten::Flatten() : DimTrans(DimTrans::Type::FLATTEN) {
all_dim_trans.emplace_back(this);
}
Flatten::Flatten() : DimTrans(DimTrans::Type::FLATTEN) {}

Flatten::Flatten(const std::vector<DimTrans*>& dims)
Flatten::Flatten(const std::vector<std::shared_ptr<DimTrans>>& dims)
: DimTrans(DimTrans::Type::FLATTEN) {
input_dims_ = dims;
all_dim_trans.emplace_back(this);
}

Flatten::~Flatten() { // NOLINT
input_dims_.assign(input_dims_.size(), nullptr);
std::vector<DimTrans*>().swap(input_dims_);
input_dims_.clear();
}

const std::vector<DimTrans*>& Flatten::inputs() const { return input_dims_; }
const std::vector<std::shared_ptr<DimTrans>>& Flatten::inputs() const {
return input_dims_;
}

void Flatten::set_inputs(const std::vector<DimTrans*>& dims) {
void Flatten::set_inputs(const std::vector<std::shared_ptr<DimTrans>>& dims) {
input_dims_.assign(dims.begin(), dims.end());
}

Expand All @@ -93,27 +83,26 @@ std::string Flatten::to_string() {
return ret_str + ")";
}

Split::Split() : DimTrans(DimTrans::Type::SPLIT) {
input_dim_trans_ = nullptr;
all_dim_trans.emplace_back(this);
}
Split::Split() : DimTrans(DimTrans::Type::SPLIT) { input_dim_trans_ = nullptr; }

Split::Split(DimTrans* dim, const std::vector<int64_t>& shape, int64_t id)
Split::Split(const std::shared_ptr<DimTrans> dim,
const std::vector<int64_t>& shape,
int64_t id)
: DimTrans(DimTrans::Type::SPLIT) {
input_dim_trans_ = dim;
split_id_ = id;
splitted_shape_.assign(shape.begin(), shape.end());
all_dim_trans.emplace_back(this);
}

Split::~Split() {
input_dim_trans_ = nullptr;
std::vector<int64_t>().swap(splitted_shape_);
}
Split::~Split() { std::vector<int64_t>().swap(splitted_shape_); }

DimTrans* Split::input() const { return input_dim_trans_; }
const std::shared_ptr<DimTrans>& Split::input() const {
return input_dim_trans_;
}

void Split::set_input(DimTrans* dim) { input_dim_trans_ = dim; }
void Split::set_input(const std::shared_ptr<DimTrans> dim) {
input_dim_trans_ = dim;
}

int64_t Split::split_id() const { return split_id_; }

Expand All @@ -133,28 +122,40 @@ std::string Split::to_string() {
return ret_str + "), " + std::to_string(split_id_) + ")";
}

DimTrans* make_flatten(const std::vector<DimTrans*>& dims) {
DimTrans* ptr = nullptr;
std::shared_ptr<DimTrans> make_flatten(
const std::vector<std::shared_ptr<DimTrans>>& dims) {
std::shared_ptr<DimTrans> ptr;
if (dims.size() == 0) {
ptr = new Singleton();
ptr = std::make_shared<Singleton>();
} else if (dims.size() == 1) {
ptr = dims[0];
} else {
ptr = new Flatten(dims);
ptr = std::make_shared<Flatten>(dims);
}
return ptr;
}

DimTrans* make_split(DimTrans* dim,
const std::vector<int64_t>& shape,
int64_t id) {
assert(shape.size() > 0);
DimTrans* ptr = nullptr;
std::shared_ptr<DimTrans> make_split(const std::shared_ptr<DimTrans> dim,
const std::vector<int64_t>& shape,
int64_t id) {
PADDLE_ENFORCE_GT(shape.size(),
0,
phi::errors::InvalidArgument(
"The size of the `shape` vector in `make_split` "
"must be greater than 0, but received %d",
shape.size()));
std::shared_ptr<DimTrans> ptr;
if (shape.size() == 1) {
assert(id == 0);
PADDLE_ENFORCE_EQ(id,
0,
phi::errors::InvalidArgument(
"The `id` in `make_split` must be 0 when the "
"size of the `shape` vector is 1, but received %d",
id));
ptr = dim;
} else if (shape[id] == 1) {
ptr = new Singleton();
ptr = std::make_shared<Singleton>();
} else {
// new shape that remove 1
std::vector<int64_t> new_shape;
Expand All @@ -166,83 +167,77 @@ DimTrans* make_split(DimTrans* dim,
new_shape.emplace_back(shape[i]);
}
}
ptr = new Split(dim, new_shape, idx_map[id]);
ptr = std::make_shared<Split>(dim, new_shape, idx_map[id]);
}
return ptr;
}

void CleanUp() {
int n = static_cast<int>(all_dim_trans.size());
for (int i = 0; i < n; i++) {
if (all_dim_trans[i]) {
delete all_dim_trans[i];
all_dim_trans[i] = nullptr;
}
}
std::vector<DimTrans*>().swap(all_dim_trans);
}

// Given a `dim_trans` of an output axis, get the input axis
// whose dim mapping should be propogated to it.
// If the returned input axis is none, the output axis's
// dim mapping should be set to -1 (replicated). For an axis
// that is flattened from input axes, return the leftmost
// flattened input axis. For the split transformation,
// only the leftmost split axis in output will return its input.
DimTrans* GetDimTrans(DimTrans* dim_trans,
std::vector<std::vector<bool>>* shardable,
std::set<int64_t>* seen_dims,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& mesh_shape,
const std::vector<int64_t>& input_dims_mapping,
const std::set<int64_t>& sharded_input_dims) {
std::shared_ptr<DimTrans> GetDimTrans(
const std::shared_ptr<DimTrans> dim_trans,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& mesh_shape,
const std::vector<int64_t>& input_dims_mapping,
const std::set<int64_t>& sharded_input_dims,
std::vector<std::vector<bool>>* shardable,
std::set<int64_t>* seen_dims) {
DimTrans::Type type = dim_trans->type();
DimTrans* ret_dim_trans = nullptr;
std::shared_ptr<DimTrans> ret_dim_trans;

if (type == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim_trans);
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim_trans);
int64_t dim = inputdim->input_dim();
seen_dims->insert(dim);

if (sharded_input_dims.count(dim) > 0) {
ret_dim_trans = dim_trans;
}
} else if (type == DimTrans::Type::FLATTEN) {
Flatten* flatten = dynamic_cast<Flatten*>(dim_trans);
const std::vector<DimTrans*>& inputs = flatten->inputs();
std::shared_ptr<Flatten> flatten =
std::dynamic_pointer_cast<Flatten>(dim_trans);
const std::vector<std::shared_ptr<DimTrans>>& inputs = flatten->inputs();
int64_t nmesh = (*shardable)[0].size(); // NOLINT
for (int i = 1, n = static_cast<int>(inputs.size()); i < n; i++) {
DimTrans* input = inputs[i];
std::shared_ptr<DimTrans> input = inputs[i];
if (input->type() == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(input);
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(input);
(*shardable)[inputdim->input_dim()].assign(nmesh, false);
}

GetDimTrans(input,
shardable,
seen_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
sharded_input_dims,
shardable,
seen_dims);
}

DimTrans* dim0 = inputs[0];
std::shared_ptr<DimTrans> dim0 = inputs[0];
if (dim0->type() == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim0);
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim0);
if (sharded_input_dims.count(inputdim->input_dim()) > 0) {
ret_dim_trans = dim0;
}
}
} else if (type == DimTrans::Type::SPLIT) {
Split* split = dynamic_cast<Split*>(dim_trans);
DimTrans* dim = GetDimTrans(split->input(),
shardable,
seen_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans);
std::shared_ptr<DimTrans> dim = GetDimTrans(split->input(),
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims,
shardable,
seen_dims);
int64_t ret_size = split->local_splitted_shape_value();

if (split->split_id() == 0) {
Expand All @@ -251,7 +246,8 @@ DimTrans* GetDimTrans(DimTrans* dim_trans,
DimTrans::Type::INPUTDIM,
phi::errors::InvalidArgument(
"The returned dim_trans must be INPUTDIM."));
InputDim* inputdim = dynamic_cast<InputDim*>(dim);
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim);
int64_t nmesh = static_cast<int64_t>(mesh_shape.size());
int64_t input_axis = inputdim->input_dim();

Expand All @@ -270,25 +266,29 @@ DimTrans* GetDimTrans(DimTrans* dim_trans,
return ret_dim_trans;
}

void GetUsedInputDim(DimTrans* dim_trans, std::set<int64_t>* seen_dims) {
void GetUsedInputDim(const std::shared_ptr<DimTrans> dim_trans,
std::set<int64_t>* seen_dims) {
if (dim_trans->type() == DimTrans::Type::INPUTDIM) {
InputDim* input = dynamic_cast<InputDim*>(dim_trans);
std::shared_ptr<InputDim> input =
std::dynamic_pointer_cast<InputDim>(dim_trans);
seen_dims->insert(input->input_dim());
} else if (dim_trans->type() == DimTrans::Type::FLATTEN) {
Flatten* flatten = dynamic_cast<Flatten*>(dim_trans);
for (DimTrans* trans : flatten->inputs()) {
std::shared_ptr<Flatten> flatten =
std::dynamic_pointer_cast<Flatten>(dim_trans);
for (const std::shared_ptr<DimTrans>& trans : flatten->inputs()) {
GetUsedInputDim(trans, seen_dims);
}
} else if (dim_trans->type() == DimTrans::Type::SPLIT) {
Split* split = dynamic_cast<Split*>(dim_trans);
std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans);
GetUsedInputDim(split->input(), seen_dims);
} else {
return;
}
}

std::vector<std::vector<int64_t>> InferFromDimTrans(
const DistMetaTensor& input, const std::vector<DimTrans*>& dim_trans) {
const DistMetaTensor& input,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
std::vector<int64_t> input_shape = phi::vectorize(input.dims());
const std::vector<int64_t>& input_dims_mapping =
input.dist_attr().dims_mapping();
Expand All @@ -309,7 +309,7 @@ std::vector<std::vector<int64_t>> InferFromDimTrans(
std::vector<bool>(nmesh, true));

std::set<int64_t> seen_input_dims;
for (DimTrans* trans : dim_trans) {
for (const std::shared_ptr<DimTrans>& trans : dim_trans) {
GetUsedInputDim(trans, &seen_input_dims);
}

Expand All @@ -323,15 +323,16 @@ std::vector<std::vector<int64_t>> InferFromDimTrans(
// get the map from sharded input dimensions to output dimensions.
std::vector<int64_t> dim_map_src2tgt(ndim, -1);
for (int64_t i = 0, n = static_cast<int64_t>(dim_trans.size()); i < n; i++) {
DimTrans* dim = GetDimTrans(dim_trans[i],
&shardable,
&seen_input_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
std::shared_ptr<DimTrans> dim = GetDimTrans(dim_trans[i],
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims,
&shardable,
&seen_input_dims);
if (dim != nullptr && dim->type() == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim);
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim);
dim_map_src2tgt[inputdim->input_dim()] = i;
}
}
Expand Down
Loading