Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
WintersMontagne10335 committed Oct 28, 2023
1 parent 1fbb2cf commit ea1e6fc
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 86 deletions.
39 changes: 19 additions & 20 deletions paddle/phi/infermeta/spmd_rules/squeeze.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ void MakeSqueezeDimTransWithAxis(const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& axis,
std::vector<DimTrans*>* trans) {
for (int64_t i = 0, n = static_cast<int64_t>(x_shape.size()); i < n; i++) {
trans->emplace_back(new InputDim(i));
out_shape->emplace_back(x_shape[i]);
}

for (int64_t i = 0, n = static_cast<int64_t>(axis.size()); i < n; i++) {
if (x_shape[axis[i]] == 1) {
trans->erase(trans->begin() + axis[i]);
out_shape->erase(out_shape->begin() + axis[i]);
if (x_shape[i] == 1) {
auto it = find(axis.begin(), axis.end(), i);
if (it == axis.end()) {
trans->emplace_back(new Singleton());
out_shape->emplace_back(1);
}
} else {
trans->emplace_back(new InputDim(i));
out_shape->emplace_back(x_shape[i]);
}
}
}
Expand All @@ -73,21 +74,21 @@ void MakeSqueezeDimTransReverseWithAxis(const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& axis,
std::vector<DimTrans*>* trans) {
for (int64_t i = 0, n = static_cast<int64_t>(out_shape.size()); i < n; i++) {
trans->emplace_back(new InputDim(i));
}
for (int64_t i = 0, j = 0, n = static_cast<int64_t>(x_shape.size()); i < n;
i++) {
if (x_shape[i] == 1) {
trans->emplace_back(new Singleton());

for (int64_t i = 0, n = static_cast<int64_t>(axis.size()); i < n; i++) {
if (x_shape[axis[i]] == 1) {
trans->emplace(trans->begin() + axis[i], new Singleton());
auto it = find(axis.begin(), axis.end(), i);
if (it == axis.end()) {
j++;
}
} else {
trans->emplace_back(new InputDim(j++));
}
}
}

bool SqueezeCompare(const int64_t& a, const int64_t& b) { return a > b; }

bool SqueezeReverseCompare(const int64_t& a, const int64_t& b) { return a < b; }

SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x,
const std::vector<int64_t>& axis) {
// Step0: Verify input args based on squeeze logic
Expand Down Expand Up @@ -120,7 +121,6 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x,
axis_copy[i] += x_ndim;
}
}
std::sort(axis_copy.begin(), axis_copy.end(), SqueezeCompare);
MakeSqueezeDimTransWithAxis(x_shape, &out_shape, axis_copy, &trans);
}

Expand Down Expand Up @@ -189,7 +189,6 @@ SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x,
axis_copy[i] += x_ndim;
}
}
std::sort(axis_copy.begin(), axis_copy.end(), SqueezeReverseCompare);
MakeSqueezeDimTransReverseWithAxis(x_shape, out_shape, axis_copy, &trans);
}

Expand Down
165 changes: 99 additions & 66 deletions test/auto_parallel/spmd_rules/test_squeeze_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,22 @@ def setUp(self):
self.attrs = OrderedDict()

def test_squeeze_infer_forward(self):
# # shape: [1, 8, 1, 16] --> [8, 16]
# # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, 1]
# self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1])
# self.attrs['axis'] = []
# result_dist_attrs = self.rule.infer_forward(
# self.x_dist_tensor_spec, self.attrs['axis']
# )
# infered_input_dist_attrs = result_dist_attrs[0]
# infered_output_dist_attrs = result_dist_attrs[1]

# self.assertEqual(len(infered_input_dist_attrs), 1)
# self.assertEqual(len(infered_output_dist_attrs), 1)
# self.assertEqual(
# infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
# )
# self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1])
# shape: [1, 8, 1, 16] --> [8, 16]
# dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, 1]
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1])
self.attrs['axis'] = []
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1])

# shape: [1, 8, 1, 16] --> [8, 16]
# dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, 1]
Expand Down Expand Up @@ -99,20 +99,20 @@ def test_squeeze_infer_forward(self):
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1])

# # shape: [1, 8, 1, 16] --> [8, 16]
# # dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, 0]
# self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0])
# self.attrs['axis'] = []
# result_dist_attrs = self.rule.infer_forward(
# self.x_dist_tensor_spec, self.attrs['axis']
# )
# infered_input_dist_attrs = result_dist_attrs[0]
# infered_output_dist_attrs = result_dist_attrs[1]

# self.assertEqual(
# infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
# )
# self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0])
# shape: [1, 8, 1, 16] --> [8, 16]
# dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, 0]
self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0])
self.attrs['axis'] = []
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0])

# shape: [1, 8, 1, 16] --> [8, 16]
# dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, 0]
Expand Down Expand Up @@ -159,6 +159,21 @@ def test_squeeze_infer_forward(self):
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0])

# shape: [1, 8, 1, 16] --> [8, 1, 16]
# dims_mapping: [-1, 0, 1, -1] --> [-1, 0, -1, -1] [0, -1, -1]
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1])
self.attrs['axis'] = [0, 1]
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, -1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])

def test_squeeze_infer_backward(self):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])

Expand All @@ -169,25 +184,25 @@ def test_squeeze_infer_backward(self):
[8, 16], output_tensor_dist_attr
)

# # shape: [1, 8, 1, 16] --> [8, 16] (input --> output)
# # dims_mapping: [0, 1] --> [-1, 0, -1, 1], [0, 1] (output --> input, output)
# self.output_dist_tensor_spec.shape = [8, 16]
# self.output_dist_tensor_spec.set_dims_mapping([0, 1])
# self.attrs['axis'] = []
# result_dist_attrs = self.rule.infer_backward(
# self.x_dist_tensor_spec,
# self.output_dist_tensor_spec,
# self.attrs['axis'],
# )
# infered_input_dist_attrs = result_dist_attrs[0]
# infered_output_dist_attrs = result_dist_attrs[1]

# self.assertEqual(len(infered_input_dist_attrs), 1)
# self.assertEqual(len(infered_output_dist_attrs), 1)
# self.assertEqual(
# infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
# )
# self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1])
# shape: [1, 8, 1, 16] --> [8, 16] (input --> output)
# dims_mapping: [0, 1] --> [-1, 0, -1, 1], [0, 1] (output --> input, output)
self.output_dist_tensor_spec.shape = [8, 16]
self.output_dist_tensor_spec.set_dims_mapping([0, 1])
self.attrs['axis'] = []
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1])

# shape: [1, 8, 1, 16] --> [8, 16] (input --> output)
# dims_mapping: [0, 1] --> [-1, 0, -1, 1], [0, 1] (output --> input, output)
Expand Down Expand Up @@ -243,23 +258,23 @@ def test_squeeze_infer_backward(self):
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1])

# # shape: [1, 8, 1, 16] --> [8, 16] (input --> output)
# # dims_mapping: [1, 0] --> [-1, 1, -1, 0], [1, 0] (output --> input, output)
# self.output_dist_tensor_spec.shape = [8, 16]
# self.output_dist_tensor_spec.set_dims_mapping([1, 0])
# self.attrs['axis'] = []
# result_dist_attrs = self.rule.infer_backward(
# self.x_dist_tensor_spec,
# self.output_dist_tensor_spec,
# self.attrs['axis'],
# )
# infered_input_dist_attrs = result_dist_attrs[0]
# infered_output_dist_attrs = result_dist_attrs[1]

# self.assertEqual(
# infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
# )
# self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0])
# shape: [1, 8, 1, 16] --> [8, 16] (input --> output)
# dims_mapping: [1, 0] --> [-1, 1, -1, 0], [1, 0] (output --> input, output)
self.output_dist_tensor_spec.shape = [8, 16]
self.output_dist_tensor_spec.set_dims_mapping([1, 0])
self.attrs['axis'] = []
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0])

# shape: [1, 8, 1, 16] --> [8, 16] (input --> output)
# dims_mapping: [1, 0] --> [-1, 1, -1, 0], [1, 0] (output --> input, output)
Expand Down Expand Up @@ -315,6 +330,24 @@ def test_squeeze_infer_backward(self):
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0])

# shape: [1, 8, 1, 16] --> [8, 1, 16] (input --> output)
# dims_mapping: [1, 0, -1] --> [-1, 1, -1, -1], [1, -1, -1] (output --> input, output)
self.output_dist_tensor_spec.shape = [8, 1, 16]
self.output_dist_tensor_spec.set_dims_mapping([1, 0, -1])
self.attrs['axis'] = [-4]
result_dist_attrs = self.rule.infer_backward(
self.x_dist_tensor_spec,
self.output_dist_tensor_spec,
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, -1]
)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1])


if __name__ == "__main__":
unittest.main()

0 comments on commit ea1e6fc

Please sign in to comment.