Skip to content

Commit

Permalink
[Semi-Auto] Replace the pointers in reshape-like spmd rule with smart…
Browse files Browse the repository at this point in the history
… pointers. (#59101)

* use smart pointer in reshape-like spmd rules

* remove ununsed code and comments

* replace the reference of shared_ptr in function arg list

* modify the arg order in GetDimTrans function
  • Loading branch information
pkuzyc authored Nov 23, 2023
1 parent 723f2d3 commit 4499bdd
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 162 deletions.
191 changes: 96 additions & 95 deletions paddle/phi/infermeta/spmd_rules/dim_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ limitations under the License. */
namespace phi {
namespace distributed {

static std::vector<DimTrans*> all_dim_trans;

DimTrans::DimTrans(Type type) : type_(type) {}

DimTrans::~DimTrans() {}
Expand All @@ -35,14 +33,10 @@ void DimTrans::set_type(Type type) { type_ = type; }

std::string DimTrans::to_string() { return std::string(""); }

InputDim::InputDim() : DimTrans(DimTrans::Type::INPUTDIM) {
input_dim_ = -1;
all_dim_trans.emplace_back(this);
}
InputDim::InputDim() : DimTrans(DimTrans::Type::INPUTDIM) { input_dim_ = -1; }

InputDim::InputDim(int64_t dim) : DimTrans(DimTrans::Type::INPUTDIM) {
input_dim_ = dim;
all_dim_trans.emplace_back(this);
}

InputDim::~InputDim() {}
Expand All @@ -55,30 +49,26 @@ std::string InputDim::to_string() {
return ("InputDim(" + std::to_string(input_dim_) + ")");
}

Singleton::Singleton() : DimTrans(DimTrans::Type::SINGLETON) {
all_dim_trans.emplace_back(this);
}
Singleton::Singleton() : DimTrans(DimTrans::Type::SINGLETON) {}

std::string Singleton::to_string() { return "Singleton()"; }

Flatten::Flatten() : DimTrans(DimTrans::Type::FLATTEN) {
all_dim_trans.emplace_back(this);
}
Flatten::Flatten() : DimTrans(DimTrans::Type::FLATTEN) {}

Flatten::Flatten(const std::vector<DimTrans*>& dims)
Flatten::Flatten(const std::vector<std::shared_ptr<DimTrans>>& dims)
: DimTrans(DimTrans::Type::FLATTEN) {
input_dims_ = dims;
all_dim_trans.emplace_back(this);
}

Flatten::~Flatten() { // NOLINT
input_dims_.assign(input_dims_.size(), nullptr);
std::vector<DimTrans*>().swap(input_dims_);
input_dims_.clear();
}

const std::vector<DimTrans*>& Flatten::inputs() const { return input_dims_; }
const std::vector<std::shared_ptr<DimTrans>>& Flatten::inputs() const {
return input_dims_;
}

void Flatten::set_inputs(const std::vector<DimTrans*>& dims) {
void Flatten::set_inputs(const std::vector<std::shared_ptr<DimTrans>>& dims) {
input_dims_.assign(dims.begin(), dims.end());
}

Expand All @@ -93,27 +83,26 @@ std::string Flatten::to_string() {
return ret_str + ")";
}

Split::Split() : DimTrans(DimTrans::Type::SPLIT) {
input_dim_trans_ = nullptr;
all_dim_trans.emplace_back(this);
}
Split::Split() : DimTrans(DimTrans::Type::SPLIT) { input_dim_trans_ = nullptr; }

Split::Split(DimTrans* dim, const std::vector<int64_t>& shape, int64_t id)
Split::Split(const std::shared_ptr<DimTrans> dim,
const std::vector<int64_t>& shape,
int64_t id)
: DimTrans(DimTrans::Type::SPLIT) {
input_dim_trans_ = dim;
split_id_ = id;
splitted_shape_.assign(shape.begin(), shape.end());
all_dim_trans.emplace_back(this);
}

Split::~Split() {
input_dim_trans_ = nullptr;
std::vector<int64_t>().swap(splitted_shape_);
}
Split::~Split() { std::vector<int64_t>().swap(splitted_shape_); }

DimTrans* Split::input() const { return input_dim_trans_; }
const std::shared_ptr<DimTrans>& Split::input() const {
return input_dim_trans_;
}

void Split::set_input(DimTrans* dim) { input_dim_trans_ = dim; }
void Split::set_input(const std::shared_ptr<DimTrans> dim) {
input_dim_trans_ = dim;
}

int64_t Split::split_id() const { return split_id_; }

Expand All @@ -133,28 +122,40 @@ std::string Split::to_string() {
return ret_str + "), " + std::to_string(split_id_) + ")";
}

DimTrans* make_flatten(const std::vector<DimTrans*>& dims) {
DimTrans* ptr = nullptr;
std::shared_ptr<DimTrans> make_flatten(
const std::vector<std::shared_ptr<DimTrans>>& dims) {
std::shared_ptr<DimTrans> ptr;
if (dims.size() == 0) {
ptr = new Singleton();
ptr = std::make_shared<Singleton>();
} else if (dims.size() == 1) {
ptr = dims[0];
} else {
ptr = new Flatten(dims);
ptr = std::make_shared<Flatten>(dims);
}
return ptr;
}

DimTrans* make_split(DimTrans* dim,
const std::vector<int64_t>& shape,
int64_t id) {
assert(shape.size() > 0);
DimTrans* ptr = nullptr;
std::shared_ptr<DimTrans> make_split(const std::shared_ptr<DimTrans> dim,
const std::vector<int64_t>& shape,
int64_t id) {
PADDLE_ENFORCE_GT(shape.size(),
0,
phi::errors::InvalidArgument(
"The size of the `shape` vector in `make_split` "
"must be greater than 0, but received %d",
shape.size()));
std::shared_ptr<DimTrans> ptr;
if (shape.size() == 1) {
assert(id == 0);
PADDLE_ENFORCE_EQ(id,
0,
phi::errors::InvalidArgument(
"The `id` in `make_split` must be 0 when the "
"size of the `shape` vector is 1, but received %d",
id));
ptr = dim;
} else if (shape[id] == 1) {
ptr = new Singleton();
ptr = std::make_shared<Singleton>();
} else {
// new shape that remove 1
std::vector<int64_t> new_shape;
Expand All @@ -166,83 +167,77 @@ DimTrans* make_split(DimTrans* dim,
new_shape.emplace_back(shape[i]);
}
}
ptr = new Split(dim, new_shape, idx_map[id]);
ptr = std::make_shared<Split>(dim, new_shape, idx_map[id]);
}
return ptr;
}

void CleanUp() {
int n = static_cast<int>(all_dim_trans.size());
for (int i = 0; i < n; i++) {
if (all_dim_trans[i]) {
delete all_dim_trans[i];
all_dim_trans[i] = nullptr;
}
}
std::vector<DimTrans*>().swap(all_dim_trans);
}

// Given a `dim_trans` of an output axis, get the input axis
// whose dim mapping should be propogated to it.
// If the returned input axis is none, the output axis's
// dim mapping should be set to -1 (replicated). For an axis
// that is flattened from input axes, return the leftmost
// flattened input axis. For the split transformation,
// only the leftmost split axis in output will return its input.
DimTrans* GetDimTrans(DimTrans* dim_trans,
std::vector<std::vector<bool>>* shardable,
std::set<int64_t>* seen_dims,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& mesh_shape,
const std::vector<int64_t>& input_dims_mapping,
const std::set<int64_t>& sharded_input_dims) {
std::shared_ptr<DimTrans> GetDimTrans(
const std::shared_ptr<DimTrans> dim_trans,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& mesh_shape,
const std::vector<int64_t>& input_dims_mapping,
const std::set<int64_t>& sharded_input_dims,
std::vector<std::vector<bool>>* shardable,
std::set<int64_t>* seen_dims) {
DimTrans::Type type = dim_trans->type();
DimTrans* ret_dim_trans = nullptr;
std::shared_ptr<DimTrans> ret_dim_trans;

if (type == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim_trans);
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim_trans);
int64_t dim = inputdim->input_dim();
seen_dims->insert(dim);

if (sharded_input_dims.count(dim) > 0) {
ret_dim_trans = dim_trans;
}
} else if (type == DimTrans::Type::FLATTEN) {
Flatten* flatten = dynamic_cast<Flatten*>(dim_trans);
const std::vector<DimTrans*>& inputs = flatten->inputs();
std::shared_ptr<Flatten> flatten =
std::dynamic_pointer_cast<Flatten>(dim_trans);
const std::vector<std::shared_ptr<DimTrans>>& inputs = flatten->inputs();
int64_t nmesh = (*shardable)[0].size(); // NOLINT
for (int i = 1, n = static_cast<int>(inputs.size()); i < n; i++) {
DimTrans* input = inputs[i];
std::shared_ptr<DimTrans> input = inputs[i];
if (input->type() == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(input);
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(input);
(*shardable)[inputdim->input_dim()].assign(nmesh, false);
}

GetDimTrans(input,
shardable,
seen_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
sharded_input_dims,
shardable,
seen_dims);
}

DimTrans* dim0 = inputs[0];
std::shared_ptr<DimTrans> dim0 = inputs[0];
if (dim0->type() == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim0);
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim0);
if (sharded_input_dims.count(inputdim->input_dim()) > 0) {
ret_dim_trans = dim0;
}
}
} else if (type == DimTrans::Type::SPLIT) {
Split* split = dynamic_cast<Split*>(dim_trans);
DimTrans* dim = GetDimTrans(split->input(),
shardable,
seen_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans);
std::shared_ptr<DimTrans> dim = GetDimTrans(split->input(),
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims,
shardable,
seen_dims);
int64_t ret_size = split->local_splitted_shape_value();

if (split->split_id() == 0) {
Expand All @@ -251,7 +246,8 @@ DimTrans* GetDimTrans(DimTrans* dim_trans,
DimTrans::Type::INPUTDIM,
phi::errors::InvalidArgument(
"The returned dim_trans must be INPUTDIM."));
InputDim* inputdim = dynamic_cast<InputDim*>(dim);
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim);
int64_t nmesh = static_cast<int64_t>(mesh_shape.size());
int64_t input_axis = inputdim->input_dim();

Expand All @@ -270,25 +266,29 @@ DimTrans* GetDimTrans(DimTrans* dim_trans,
return ret_dim_trans;
}

void GetUsedInputDim(DimTrans* dim_trans, std::set<int64_t>* seen_dims) {
void GetUsedInputDim(const std::shared_ptr<DimTrans> dim_trans,
std::set<int64_t>* seen_dims) {
if (dim_trans->type() == DimTrans::Type::INPUTDIM) {
InputDim* input = dynamic_cast<InputDim*>(dim_trans);
std::shared_ptr<InputDim> input =
std::dynamic_pointer_cast<InputDim>(dim_trans);
seen_dims->insert(input->input_dim());
} else if (dim_trans->type() == DimTrans::Type::FLATTEN) {
Flatten* flatten = dynamic_cast<Flatten*>(dim_trans);
for (DimTrans* trans : flatten->inputs()) {
std::shared_ptr<Flatten> flatten =
std::dynamic_pointer_cast<Flatten>(dim_trans);
for (const std::shared_ptr<DimTrans>& trans : flatten->inputs()) {
GetUsedInputDim(trans, seen_dims);
}
} else if (dim_trans->type() == DimTrans::Type::SPLIT) {
Split* split = dynamic_cast<Split*>(dim_trans);
std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans);
GetUsedInputDim(split->input(), seen_dims);
} else {
return;
}
}

std::vector<std::vector<int64_t>> InferFromDimTrans(
const DistMetaTensor& input, const std::vector<DimTrans*>& dim_trans) {
const DistMetaTensor& input,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
std::vector<int64_t> input_shape = phi::vectorize(input.dims());
const std::vector<int64_t>& input_dims_mapping =
input.dist_attr().dims_mapping();
Expand All @@ -309,7 +309,7 @@ std::vector<std::vector<int64_t>> InferFromDimTrans(
std::vector<bool>(nmesh, true));

std::set<int64_t> seen_input_dims;
for (DimTrans* trans : dim_trans) {
for (const std::shared_ptr<DimTrans>& trans : dim_trans) {
GetUsedInputDim(trans, &seen_input_dims);
}

Expand All @@ -323,15 +323,16 @@ std::vector<std::vector<int64_t>> InferFromDimTrans(
// get the map from sharded input dimensions to output dimensions.
std::vector<int64_t> dim_map_src2tgt(ndim, -1);
for (int64_t i = 0, n = static_cast<int64_t>(dim_trans.size()); i < n; i++) {
DimTrans* dim = GetDimTrans(dim_trans[i],
&shardable,
&seen_input_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
std::shared_ptr<DimTrans> dim = GetDimTrans(dim_trans[i],
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims,
&shardable,
&seen_input_dims);
if (dim != nullptr && dim->type() == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim);
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim);
dim_map_src2tgt[inputdim->input_dim()] = i;
}
}
Expand Down
Loading

0 comments on commit 4499bdd

Please sign in to comment.