From 6a209b97d16b34cf3acb9d40e59157371ca1e9be Mon Sep 17 00:00:00 2001 From: Winters Montagne <118546135+WintersMontagne10335@users.noreply.github.com> Date: Mon, 27 Nov 2023 16:09:08 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=205th=20No.52=E3=80=91=20?= =?UTF-8?q?=E4=B8=BA=20Paddle=20=E6=96=B0=E5=A2=9E=20squeeze=20=E5=92=8C?= =?UTF-8?q?=20unsqueeze=20=E7=9A=84=20spmd=20=E5=88=87=E5=88=86=E6=8E=A8?= =?UTF-8?q?=E5=AF=BC=E8=A7=84=E5=88=99=20(#57877)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add spmd segmentation and derivation rules for squeeze to Paddle * Add spmd segmentation derivation rule for unsqueeze to Paddle * fix bugs * fix bugs * fix bugs * fix bugs * Add unit test code * modify squeeze.cc and CMakeLists.txt * write separate rules * fix bugs * fix bugs * fix bugs * remove unsqueeze spmd rule * modified: test/auto_parallel/spmd_rules/test_squeeze_rule.py * re-run CI * fix bugs * modify pointer to smart pointer * fix bugs * fix bugs --- paddle/phi/infermeta/spmd_rules/rules.h | 5 + paddle/phi/infermeta/spmd_rules/squeeze.cc | 222 +++++++++++ paddle/phi/infermeta/spmd_rules/squeeze.h | 32 ++ test/auto_parallel/spmd_rules/CMakeLists.txt | 1 + .../spmd_rules/test_squeeze_rule.py | 353 ++++++++++++++++++ 5 files changed, 613 insertions(+) create mode 100644 paddle/phi/infermeta/spmd_rules/squeeze.cc create mode 100644 paddle/phi/infermeta/spmd_rules/squeeze.h create mode 100644 test/auto_parallel/spmd_rules/test_squeeze_rule.py diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index f8b8430a6dafe1..dd98f793ea8a53 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -33,6 +33,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/slice.h" #include "paddle/phi/infermeta/spmd_rules/softmax.h" #include "paddle/phi/infermeta/spmd_rules/split.h" +#include "paddle/phi/infermeta/spmd_rules/squeeze.h" #include "paddle/phi/infermeta/spmd_rules/stack.h" #include "paddle/phi/infermeta/spmd_rules/transpose.h" #include "paddle/phi/infermeta/spmd_rules/triu.h" @@ -520,6 +521,10 @@ PD_REGISTER_SPMD_RULE(reshape2, PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); +// squeeze rule +PD_REGISTER_SPMD_RULE(squeeze, + PD_INFER_SPMD(phi::distributed::SqueezeInferSpmd), + PD_INFER_SPMD(phi::distributed::SqueezeInferSpmdReverse)); // flatten rule PD_REGISTER_SPMD_RULE(flatten, PD_INFER_SPMD(phi::distributed::FlattenInferSpmd), diff --git a/paddle/phi/infermeta/spmd_rules/squeeze.cc b/paddle/phi/infermeta/spmd_rules/squeeze.cc new file mode 100644 index 00000000000000..046de2e0497605 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/squeeze.cc @@ -0,0 +1,222 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/squeeze.h" +#include +#include + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/dim_trans.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +void MakeSqueezeDimTransWithoutAxis( + const std::vector& x_shape, + std::vector* out_shape, + std::vector>* trans) { + for (int64_t i = 0, n = static_cast(x_shape.size()); i < n; i++) { + if (x_shape[i] != 1) { + trans->emplace_back(std::make_shared(i)); + out_shape->emplace_back(x_shape[i]); + } + } +} + +void MakeSqueezeDimTransWithAxis( + const std::vector& x_shape, + std::vector* out_shape, + const std::vector& axis, + std::vector>* trans) { + for (int64_t i = 0, n = static_cast(x_shape.size()); i < n; i++) { + if (x_shape[i] == 1) { + auto it = find(axis.begin(), axis.end(), i); + if (it == axis.end()) { + trans->emplace_back(std::make_shared()); + out_shape->emplace_back(1); + } + } else { + trans->emplace_back(std::make_shared(i)); + out_shape->emplace_back(x_shape[i]); + } + } +} + +void MakeSqueezeDimTransReverseWithoutAxis( + const std::vector& x_shape, + std::vector>* trans) { + for (int64_t i = 0, j = 0, n = static_cast(x_shape.size()); i < n; + i++) { + if (x_shape[i] != 1) { + trans->emplace_back(std::make_shared(j++)); + } else { + trans->emplace_back(std::make_shared()); + } + } +} + +void MakeSqueezeDimTransReverseWithAxis( + const std::vector& x_shape, + const std::vector& out_shape, + const std::vector& axis, + std::vector>* trans) { + for (int64_t i = 0, j = 0, n = static_cast(x_shape.size()); i < n; + i++) { + if (x_shape[i] == 1) { + trans->emplace_back(std::make_shared()); + + auto it = find(axis.begin(), axis.end(), i); + if (it == axis.end()) { + j++; + } + } else { + trans->emplace_back(std::make_shared(j++)); + } + } +} + +SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x, + const std::vector& axis) { + // Step0: Verify input args based on squeeze logic + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + auto x_dist_attr_src = x.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + + PADDLE_ENFORCE_EQ( + x_ndim, + x_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + x_ndim, + x_dims_mapping.size())); + + // Step1: Build the transformation from + // the original shape to the target shape + + std::vector> trans; + std::vector out_shape; + + if (static_cast(axis.size()) == 0) { + MakeSqueezeDimTransWithoutAxis(x_shape, &out_shape, &trans); + } else { + std::vector axis_copy(axis); + for (int64_t i = 0, n = static_cast(axis_copy.size()); i < n; + i++) { + if (axis_copy[i] < 0) { + axis_copy[i] += x_ndim; + } + } + MakeSqueezeDimTransWithAxis(x_shape, &out_shape, axis_copy, &trans); + } + + // Step2: Infer the dims mapping of input (if reshard is + // needed) and output from the dimension transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(x, trans); + + // Step3: Update the dist attributes of input + // and output with the inferred dims mapping. + TensorDistAttr x_dist_attr_dst(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr out_dist_attr(x_dist_attr_src); + out_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "SqueezeInferSpmd: X shape: [" << str_join(x_shape) + << "] Out shape: [" << str_join(out_shape) << "]"; + VLOG(4) << "Transformation from input to output:"; + for (int64_t i = 0, n = static_cast(trans.size()); i < n; i++) { + VLOG(4) << "\tOut axis[" << i << "]: " << trans[i]->to_string(); + } + VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping) + << "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) + << "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1]) + << "]\n\n"; + + return {{x_dist_attr_dst}, {out_dist_attr}}; +} + +SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + const std::vector& axis) { + // Step0: Verify input args based on squeeze logic + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + auto out_shape = phi::vectorize(out.dims()); + int out_ndim = out_shape.size(); + auto out_dist_attr_src = out.dist_attr(); + std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + out_ndim, + out_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor Out's rank [%d] and Out's " + "dims_mapping size [%d] are not matched.", + out_ndim, + out_dims_mapping.size())); + + // Step1: Build the transformation from the output shape + // to original shape. This function infers the dims mapping + // from output to input, we first get the transformation + // from output to input so that we can infer the dims mapping + // with the map from output axes to input axes. + + std::vector> trans; + + if (static_cast(axis.size()) == 0) { + MakeSqueezeDimTransReverseWithoutAxis(x_shape, &trans); + } else { + std::vector axis_copy(axis); + for (int64_t i = 0, n = static_cast(axis_copy.size()); i < n; + i++) { + if (axis_copy[i] < 0) { + axis_copy[i] += x_ndim; + } + } + MakeSqueezeDimTransReverseWithAxis(x_shape, out_shape, axis_copy, &trans); + } + + // Step2: Infer the dims mapping of input with + // output's dims_mapping and the transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(out, trans); + + // Step3: Update the dist attributes of input + // and output with the inferred dims mapping + TensorDistAttr out_dist_attr_dst(out_dist_attr_src); + out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr x_dist_attr(x.dist_attr()); + x_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "SqueezeInferSpmdReverse: Out shape: [" << str_join(out_shape) + << "] X shape: [" << str_join(x_shape) << "]"; + VLOG(4) << "Transformation from output to input:"; + for (int64_t i = 0, n = trans.size(); i < n; i++) { + VLOG(4) << "\tX axis[" << i << "]: " << trans[i]->to_string(); + } + VLOG(4) << "Out dims_mapping_src: [" << str_join(out_dims_mapping) << "] " + << "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]"; + VLOG(4) << "X dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; + + return {{x_dist_attr}, {out_dist_attr_dst}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/squeeze.h b/paddle/phi/infermeta/spmd_rules/squeeze.h new file mode 100644 index 00000000000000..b111c3272612fd --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/squeeze.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x, + const std::vector& axis); + +SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + const std::vector& axis); +} // namespace distributed +} // namespace phi diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index 80207b104dd5e7..d8f7c4abe4213a 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -18,6 +18,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_default_data_parallel_rule MODULES test_default_data_parallel_rule) py_test_modules(test_layer_norm_rule MODULES test_layer_norm_rule) + py_test_modules(test_squeeze_rule MODULES test_squeeze_rule) py_test_modules(test_slice_rule MODULES test_slice_rule) py_test_modules(test_flatten_rule MODULES test_flatten_rule) py_test_modules(test_unsqueeze_rule MODULES test_unsqueeze_rule) diff --git a/test/auto_parallel/spmd_rules/test_squeeze_rule.py b/test/auto_parallel/spmd_rules/test_squeeze_rule.py new file mode 100644 index 00000000000000..1aff4012836cb2 --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_squeeze_rule.py @@ -0,0 +1,353 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from collections import OrderedDict + +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto +from paddle.framework import core + + +class TestSqueezeSPMDRule(unittest.TestCase): + def setUp(self): + self.rule = core.get_phi_spmd_rule("squeeze") + + x_shape = [1, 8, 1, 16] + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + self.attrs = OrderedDict() + + def test_squeeze_infer_forward(self): + # # shape: [1, 8, 1, 16] --> [8, 16] + # # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, 1] + # self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + # self.attrs['axis'] = [] + # result_dist_attrs = self.rule.infer_forward( + # self.x_dist_tensor_spec, self.attrs['axis'] + # ) + # infered_input_dist_attrs = result_dist_attrs[0] + # infered_output_dist_attrs = result_dist_attrs[1] + + # self.assertEqual(len(infered_input_dist_attrs), 1) + # self.assertEqual(len(infered_output_dist_attrs), 1) + # self.assertEqual( + # infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + # ) + # self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + # shape: [1, 8, 1, 16] --> [8, 16] + # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + self.attrs['axis'] = [0, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + # shape: [1, 8, 1, 16] --> [1, 8, 16] + # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [-1, 0, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [1, 8, 1, 16] --> [8, 1, 16] + # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [0, -1, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + self.attrs['axis'] = [-4] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1]) + + # # shape: [1, 8, 1, 16] --> [8, 16] + # # dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, 0] + # self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0]) + # self.attrs['axis'] = [] + # result_dist_attrs = self.rule.infer_forward( + # self.x_dist_tensor_spec, self.attrs['axis'] + # ) + # infered_input_dist_attrs = result_dist_attrs[0] + # infered_output_dist_attrs = result_dist_attrs[1] + + # self.assertEqual( + # infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + # ) + # self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0]) + + # shape: [1, 8, 1, 16] --> [8, 16] + # dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0]) + self.attrs['axis'] = [0, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0]) + + # shape: [1, 8, 1, 16] --> [1, 8, 16] + # dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [-1, 1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + # shape: [1, 8, 1, 16] --> [8, 1, 16] + # dims_mapping: [-1, 1, -1, 0] --> [-1, 1, -1, 0] [1, -1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0]) + self.attrs['axis'] = [-4] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0]) + + # shape: [1, 8, 1, 16] --> [8, 1, 16] + # dims_mapping: [-1, 0, 1, -1] --> [-1, 0, -1, -1] [0, -1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1]) + self.attrs['axis'] = [0, 1] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) + + def test_squeeze_infer_backward(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + output_tensor_dist_attr = TensorDistAttr() + output_tensor_dist_attr.dims_mapping = [-1, -1] + output_tensor_dist_attr.process_mesh = process_mesh + self.output_dist_tensor_spec = DistTensorSpec( + [8, 16], output_tensor_dist_attr + ) + + # # shape: [1, 8, 1, 16] --> [8, 16] (input --> output) + # # dims_mapping: [0, 1] --> [-1, 0, -1, 1], [0, 1] (output --> input, output) + # self.output_dist_tensor_spec.shape = [8, 16] + # self.output_dist_tensor_spec.set_dims_mapping([0, 1]) + # self.attrs['axis'] = [] + # result_dist_attrs = self.rule.infer_backward( + # self.x_dist_tensor_spec, + # self.output_dist_tensor_spec, + # self.attrs['axis'], + # ) + # infered_input_dist_attrs = result_dist_attrs[0] + # infered_output_dist_attrs = result_dist_attrs[1] + + # self.assertEqual(len(infered_input_dist_attrs), 1) + # self.assertEqual(len(infered_output_dist_attrs), 1) + # self.assertEqual( + # infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + # ) + # self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + # shape: [1, 8, 1, 16] --> [8, 16] (input --> output) + # dims_mapping: [0, 1] --> [-1, 0, -1, 1], [0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [0, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + # shape: [1, 8, 1, 16] --> [1, 8, 16] (input --> output) + # dims_mapping: [-1, 0, 1] --> [-1, 0, -1, 1], [-1, 0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [1, 8, 1, 16] --> [8, 1, 16] (input --> output) + # dims_mapping: [0, -1, 1] --> [-1, 0, -1, 1], [0, -1, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 1, 16] + self.output_dist_tensor_spec.set_dims_mapping([0, -1, 1]) + self.attrs['axis'] = [-4] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1]) + + # # shape: [1, 8, 1, 16] --> [8, 16] (input --> output) + # # dims_mapping: [1, 0] --> [-1, 1, -1, 0], [1, 0] (output --> input, output) + # self.output_dist_tensor_spec.shape = [8, 16] + # self.output_dist_tensor_spec.set_dims_mapping([1, 0]) + # self.attrs['axis'] = [] + # result_dist_attrs = self.rule.infer_backward( + # self.x_dist_tensor_spec, + # self.output_dist_tensor_spec, + # self.attrs['axis'], + # ) + # infered_input_dist_attrs = result_dist_attrs[0] + # infered_output_dist_attrs = result_dist_attrs[1] + + # self.assertEqual( + # infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + # ) + # self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0]) + + # shape: [1, 8, 1, 16] --> [8, 16] (input --> output) + # dims_mapping: [1, 0] --> [-1, 1, -1, 0], [1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16] + self.output_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [0, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0]) + + # shape: [1, 8, 1, 16] --> [1, 8, 16] (input --> output) + # dims_mapping: [-1, 1, 0] --> [-1, 1, -1, 0], [-1, 1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 1, 0]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + # shape: [1, 8, 1, 16] --> [8, 1, 16] (input --> output) + # dims_mapping: [1, -1, 0] --> [-1, 1, -1, 0], [1, -1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 1, 16] + self.output_dist_tensor_spec.set_dims_mapping([1, -1, 0]) + self.attrs['axis'] = [-4] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0]) + + # shape: [1, 8, 1, 16] --> [8, 1, 16] (input --> output) + # dims_mapping: [1, 0, -1] --> [-1, 1, -1, -1], [1, -1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 1, 16] + self.output_dist_tensor_spec.set_dims_mapping([1, 0, -1]) + self.attrs['axis'] = [-4] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1]) + + +if __name__ == "__main__": + unittest.main()