-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Semi-Auto] Replace the pointers in reshape-like spmd rule with smart pointers. #59101
Merged
Merged
Changes from 2 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,8 +23,6 @@ limitations under the License. */ | |
namespace phi { | ||
namespace distributed { | ||
|
||
static std::vector<DimTrans*> all_dim_trans; | ||
|
||
DimTrans::DimTrans(Type type) : type_(type) {} | ||
|
||
DimTrans::~DimTrans() {} | ||
|
@@ -35,14 +33,10 @@ void DimTrans::set_type(Type type) { type_ = type; } | |
|
||
std::string DimTrans::to_string() { return std::string(""); } | ||
|
||
InputDim::InputDim() : DimTrans(DimTrans::Type::INPUTDIM) { | ||
input_dim_ = -1; | ||
all_dim_trans.emplace_back(this); | ||
} | ||
InputDim::InputDim() : DimTrans(DimTrans::Type::INPUTDIM) { input_dim_ = -1; } | ||
|
||
InputDim::InputDim(int64_t dim) : DimTrans(DimTrans::Type::INPUTDIM) { | ||
input_dim_ = dim; | ||
all_dim_trans.emplace_back(this); | ||
} | ||
|
||
InputDim::~InputDim() {} | ||
|
@@ -55,30 +49,26 @@ std::string InputDim::to_string() { | |
return ("InputDim(" + std::to_string(input_dim_) + ")"); | ||
} | ||
|
||
Singleton::Singleton() : DimTrans(DimTrans::Type::SINGLETON) { | ||
all_dim_trans.emplace_back(this); | ||
} | ||
Singleton::Singleton() : DimTrans(DimTrans::Type::SINGLETON) {} | ||
|
||
std::string Singleton::to_string() { return "Singleton()"; } | ||
|
||
Flatten::Flatten() : DimTrans(DimTrans::Type::FLATTEN) { | ||
all_dim_trans.emplace_back(this); | ||
} | ||
Flatten::Flatten() : DimTrans(DimTrans::Type::FLATTEN) {} | ||
|
||
Flatten::Flatten(const std::vector<DimTrans*>& dims) | ||
Flatten::Flatten(const std::vector<std::shared_ptr<DimTrans>>& dims) | ||
: DimTrans(DimTrans::Type::FLATTEN) { | ||
input_dims_ = dims; | ||
all_dim_trans.emplace_back(this); | ||
} | ||
|
||
Flatten::~Flatten() { // NOLINT | ||
input_dims_.assign(input_dims_.size(), nullptr); | ||
std::vector<DimTrans*>().swap(input_dims_); | ||
input_dims_.clear(); | ||
} | ||
|
||
const std::vector<DimTrans*>& Flatten::inputs() const { return input_dims_; } | ||
const std::vector<std::shared_ptr<DimTrans>>& Flatten::inputs() const { | ||
return input_dims_; | ||
} | ||
|
||
void Flatten::set_inputs(const std::vector<DimTrans*>& dims) { | ||
void Flatten::set_inputs(const std::vector<std::shared_ptr<DimTrans>>& dims) { | ||
input_dims_.assign(dims.begin(), dims.end()); | ||
} | ||
|
||
|
@@ -93,27 +83,26 @@ std::string Flatten::to_string() { | |
return ret_str + ")"; | ||
} | ||
|
||
Split::Split() : DimTrans(DimTrans::Type::SPLIT) { | ||
input_dim_trans_ = nullptr; | ||
all_dim_trans.emplace_back(this); | ||
} | ||
Split::Split() : DimTrans(DimTrans::Type::SPLIT) { input_dim_trans_ = nullptr; } | ||
|
||
Split::Split(DimTrans* dim, const std::vector<int64_t>& shape, int64_t id) | ||
Split::Split(const std::shared_ptr<DimTrans>& dim, | ||
const std::vector<int64_t>& shape, | ||
int64_t id) | ||
: DimTrans(DimTrans::Type::SPLIT) { | ||
input_dim_trans_ = dim; | ||
split_id_ = id; | ||
splitted_shape_.assign(shape.begin(), shape.end()); | ||
all_dim_trans.emplace_back(this); | ||
} | ||
|
||
Split::~Split() { | ||
input_dim_trans_ = nullptr; | ||
std::vector<int64_t>().swap(splitted_shape_); | ||
} | ||
Split::~Split() { std::vector<int64_t>().swap(splitted_shape_); } | ||
|
||
DimTrans* Split::input() const { return input_dim_trans_; } | ||
const std::shared_ptr<DimTrans>& Split::input() const { | ||
return input_dim_trans_; | ||
} | ||
|
||
void Split::set_input(DimTrans* dim) { input_dim_trans_ = dim; } | ||
void Split::set_input(const std::shared_ptr<DimTrans>& dim) { | ||
input_dim_trans_ = dim; | ||
} | ||
|
||
int64_t Split::split_id() const { return split_id_; } | ||
|
||
|
@@ -133,28 +122,40 @@ std::string Split::to_string() { | |
return ret_str + "), " + std::to_string(split_id_) + ")"; | ||
} | ||
|
||
DimTrans* make_flatten(const std::vector<DimTrans*>& dims) { | ||
DimTrans* ptr = nullptr; | ||
std::shared_ptr<DimTrans> make_flatten( | ||
const std::vector<std::shared_ptr<DimTrans>>& dims) { | ||
std::shared_ptr<DimTrans> ptr; | ||
if (dims.size() == 0) { | ||
ptr = new Singleton(); | ||
ptr = std::make_shared<Singleton>(); | ||
} else if (dims.size() == 1) { | ||
ptr = dims[0]; | ||
} else { | ||
ptr = new Flatten(dims); | ||
ptr = std::make_shared<Flatten>(dims); | ||
} | ||
return ptr; | ||
} | ||
|
||
DimTrans* make_split(DimTrans* dim, | ||
const std::vector<int64_t>& shape, | ||
int64_t id) { | ||
assert(shape.size() > 0); | ||
DimTrans* ptr = nullptr; | ||
std::shared_ptr<DimTrans> make_split(const std::shared_ptr<DimTrans>& dim, | ||
const std::vector<int64_t>& shape, | ||
int64_t id) { | ||
PADDLE_ENFORCE_GT(shape.size(), | ||
0, | ||
phi::errors::InvalidArgument( | ||
"The size of the `shape` vector in `make_split` " | ||
"must be greater than 0, but received %d", | ||
shape.size())); | ||
std::shared_ptr<DimTrans> ptr; | ||
if (shape.size() == 1) { | ||
assert(id == 0); | ||
PADDLE_ENFORCE_EQ(id, | ||
0, | ||
phi::errors::InvalidArgument( | ||
"The `id` in `make_split` must be 0 when the " | ||
"size of the `shape` vector is 1, but received %d", | ||
id)); | ||
ptr = dim; | ||
} else if (shape[id] == 1) { | ||
ptr = new Singleton(); | ||
ptr = std::make_shared<Singleton>(); | ||
} else { | ||
// new shape that remove 1 | ||
std::vector<int64_t> new_shape; | ||
|
@@ -166,55 +167,48 @@ DimTrans* make_split(DimTrans* dim, | |
new_shape.emplace_back(shape[i]); | ||
} | ||
} | ||
ptr = new Split(dim, new_shape, idx_map[id]); | ||
ptr = std::make_shared<Split>(dim, new_shape, idx_map[id]); | ||
} | ||
return ptr; | ||
} | ||
|
||
void CleanUp() { | ||
int n = static_cast<int>(all_dim_trans.size()); | ||
for (int i = 0; i < n; i++) { | ||
if (all_dim_trans[i]) { | ||
delete all_dim_trans[i]; | ||
all_dim_trans[i] = nullptr; | ||
} | ||
} | ||
std::vector<DimTrans*>().swap(all_dim_trans); | ||
} | ||
|
||
// Given a `dim_trans` of an output axis, get the input axis | ||
// whose dim mapping should be propogated to it. | ||
// If the returned input axis is none, the output axis's | ||
// dim mapping should be set to -1 (replicated). For an axis | ||
// that is flattened from input axes, return the leftmost | ||
// flattened input axis. For the split transformation, | ||
// only the leftmost split axis in output will return its input. | ||
DimTrans* GetDimTrans(DimTrans* dim_trans, | ||
std::vector<std::vector<bool>>* shardable, | ||
std::set<int64_t>* seen_dims, | ||
const std::vector<int64_t>& input_shape, | ||
const std::vector<int64_t>& mesh_shape, | ||
const std::vector<int64_t>& input_dims_mapping, | ||
const std::set<int64_t>& sharded_input_dims) { | ||
std::shared_ptr<DimTrans> GetDimTrans( | ||
const std::shared_ptr<DimTrans>& dim_trans, | ||
std::vector<std::vector<bool>>* shardable, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 参数顺序:这里是否考虑把输出参数放到输入参数的后面? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
std::set<int64_t>* seen_dims, | ||
const std::vector<int64_t>& input_shape, | ||
const std::vector<int64_t>& mesh_shape, | ||
const std::vector<int64_t>& input_dims_mapping, | ||
const std::set<int64_t>& sharded_input_dims) { | ||
DimTrans::Type type = dim_trans->type(); | ||
DimTrans* ret_dim_trans = nullptr; | ||
std::shared_ptr<DimTrans> ret_dim_trans; | ||
|
||
if (type == DimTrans::Type::INPUTDIM) { | ||
InputDim* inputdim = dynamic_cast<InputDim*>(dim_trans); | ||
std::shared_ptr<InputDim> inputdim = | ||
std::dynamic_pointer_cast<InputDim>(dim_trans); | ||
int64_t dim = inputdim->input_dim(); | ||
seen_dims->insert(dim); | ||
|
||
if (sharded_input_dims.count(dim) > 0) { | ||
ret_dim_trans = dim_trans; | ||
} | ||
} else if (type == DimTrans::Type::FLATTEN) { | ||
Flatten* flatten = dynamic_cast<Flatten*>(dim_trans); | ||
const std::vector<DimTrans*>& inputs = flatten->inputs(); | ||
std::shared_ptr<Flatten> flatten = | ||
std::dynamic_pointer_cast<Flatten>(dim_trans); | ||
const std::vector<std::shared_ptr<DimTrans>>& inputs = flatten->inputs(); | ||
int64_t nmesh = (*shardable)[0].size(); // NOLINT | ||
for (int i = 1, n = static_cast<int>(inputs.size()); i < n; i++) { | ||
DimTrans* input = inputs[i]; | ||
std::shared_ptr<DimTrans> input = inputs[i]; | ||
if (input->type() == DimTrans::Type::INPUTDIM) { | ||
InputDim* inputdim = dynamic_cast<InputDim*>(input); | ||
std::shared_ptr<InputDim> inputdim = | ||
std::dynamic_pointer_cast<InputDim>(input); | ||
(*shardable)[inputdim->input_dim()].assign(nmesh, false); | ||
} | ||
|
||
|
@@ -227,22 +221,23 @@ DimTrans* GetDimTrans(DimTrans* dim_trans, | |
sharded_input_dims); | ||
} | ||
|
||
DimTrans* dim0 = inputs[0]; | ||
std::shared_ptr<DimTrans> dim0 = inputs[0]; | ||
if (dim0->type() == DimTrans::Type::INPUTDIM) { | ||
InputDim* inputdim = dynamic_cast<InputDim*>(dim0); | ||
std::shared_ptr<InputDim> inputdim = | ||
std::dynamic_pointer_cast<InputDim>(dim0); | ||
if (sharded_input_dims.count(inputdim->input_dim()) > 0) { | ||
ret_dim_trans = dim0; | ||
} | ||
} | ||
} else if (type == DimTrans::Type::SPLIT) { | ||
Split* split = dynamic_cast<Split*>(dim_trans); | ||
DimTrans* dim = GetDimTrans(split->input(), | ||
shardable, | ||
seen_dims, | ||
input_shape, | ||
mesh_shape, | ||
input_dims_mapping, | ||
sharded_input_dims); | ||
std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans); | ||
std::shared_ptr<DimTrans> dim = GetDimTrans(split->input(), | ||
shardable, | ||
seen_dims, | ||
input_shape, | ||
mesh_shape, | ||
input_dims_mapping, | ||
sharded_input_dims); | ||
int64_t ret_size = split->local_splitted_shape_value(); | ||
|
||
if (split->split_id() == 0) { | ||
|
@@ -251,7 +246,8 @@ DimTrans* GetDimTrans(DimTrans* dim_trans, | |
DimTrans::Type::INPUTDIM, | ||
phi::errors::InvalidArgument( | ||
"The returned dim_trans must be INPUTDIM.")); | ||
InputDim* inputdim = dynamic_cast<InputDim*>(dim); | ||
std::shared_ptr<InputDim> inputdim = | ||
std::dynamic_pointer_cast<InputDim>(dim); | ||
int64_t nmesh = static_cast<int64_t>(mesh_shape.size()); | ||
int64_t input_axis = inputdim->input_dim(); | ||
|
||
|
@@ -270,25 +266,29 @@ DimTrans* GetDimTrans(DimTrans* dim_trans, | |
return ret_dim_trans; | ||
} | ||
|
||
void GetUsedInputDim(DimTrans* dim_trans, std::set<int64_t>* seen_dims) { | ||
void GetUsedInputDim(const std::shared_ptr<DimTrans>& dim_trans, | ||
std::set<int64_t>* seen_dims) { | ||
if (dim_trans->type() == DimTrans::Type::INPUTDIM) { | ||
InputDim* input = dynamic_cast<InputDim*>(dim_trans); | ||
std::shared_ptr<InputDim> input = | ||
std::dynamic_pointer_cast<InputDim>(dim_trans); | ||
seen_dims->insert(input->input_dim()); | ||
} else if (dim_trans->type() == DimTrans::Type::FLATTEN) { | ||
Flatten* flatten = dynamic_cast<Flatten*>(dim_trans); | ||
for (DimTrans* trans : flatten->inputs()) { | ||
std::shared_ptr<Flatten> flatten = | ||
std::dynamic_pointer_cast<Flatten>(dim_trans); | ||
for (const std::shared_ptr<DimTrans>& trans : flatten->inputs()) { | ||
GetUsedInputDim(trans, seen_dims); | ||
} | ||
} else if (dim_trans->type() == DimTrans::Type::SPLIT) { | ||
Split* split = dynamic_cast<Split*>(dim_trans); | ||
std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans); | ||
GetUsedInputDim(split->input(), seen_dims); | ||
} else { | ||
return; | ||
} | ||
} | ||
|
||
std::vector<std::vector<int64_t>> InferFromDimTrans( | ||
const DistMetaTensor& input, const std::vector<DimTrans*>& dim_trans) { | ||
const DistMetaTensor& input, | ||
const std::vector<std::shared_ptr<DimTrans>>& dim_trans) { | ||
std::vector<int64_t> input_shape = phi::vectorize(input.dims()); | ||
const std::vector<int64_t>& input_dims_mapping = | ||
input.dist_attr().dims_mapping(); | ||
|
@@ -309,7 +309,7 @@ std::vector<std::vector<int64_t>> InferFromDimTrans( | |
std::vector<bool>(nmesh, true)); | ||
|
||
std::set<int64_t> seen_input_dims; | ||
for (DimTrans* trans : dim_trans) { | ||
for (const std::shared_ptr<DimTrans>& trans : dim_trans) { | ||
GetUsedInputDim(trans, &seen_input_dims); | ||
} | ||
|
||
|
@@ -323,15 +323,16 @@ std::vector<std::vector<int64_t>> InferFromDimTrans( | |
// get the map from sharded input dimensions to output dimensions. | ||
std::vector<int64_t> dim_map_src2tgt(ndim, -1); | ||
for (int64_t i = 0, n = static_cast<int64_t>(dim_trans.size()); i < n; i++) { | ||
DimTrans* dim = GetDimTrans(dim_trans[i], | ||
&shardable, | ||
&seen_input_dims, | ||
input_shape, | ||
mesh_shape, | ||
input_dims_mapping, | ||
sharded_input_dims); | ||
std::shared_ptr<DimTrans> dim = GetDimTrans(dim_trans[i], | ||
&shardable, | ||
&seen_input_dims, | ||
input_shape, | ||
mesh_shape, | ||
input_dims_mapping, | ||
sharded_input_dims); | ||
if (dim != nullptr && dim->type() == DimTrans::Type::INPUTDIM) { | ||
InputDim* inputdim = dynamic_cast<InputDim*>(dim); | ||
std::shared_ptr<InputDim> inputdim = | ||
std::dynamic_pointer_cast<InputDim>(dim); | ||
dim_map_src2tgt[inputdim->input_dim()] = i; | ||
} | ||
} | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shared_ptr类型的参数,是否有必要使用引用,一般建议使用值传参即可。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done