From 686114a447f1ce54ee8c9e67c981a4376c9bbe2e Mon Sep 17 00:00:00 2001 From: risemeup1 <515586620@qq.com> Date: Wed, 20 Mar 2024 03:00:10 +0000 Subject: [PATCH 1/7] support sparsecootensortype --- .../pir/dialect/operator/ir/op_dialect.cc | 2 + paddle/pir/include/core/builtin_type.h | 13 ++ paddle/pir/include/core/sparse_type.h | 83 +++++++++++ paddle/pir/include/core/sparse_type_storage.h | 139 ++++++++++++++++++ paddle/pir/src/core/sparse_type.cc | 64 ++++++++ test/cpp/pir/core/type_test.cc | 34 +++++ 6 files changed, 335 insertions(+) create mode 100644 paddle/pir/include/core/sparse_type.h create mode 100644 paddle/pir/include/core/sparse_type_storage.h create mode 100644 paddle/pir/src/core/sparse_type.cc diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 3d3ef1efb354b6..e9305acb66ff1e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -28,6 +28,7 @@ #include "paddle/pir/include/core/builtin_type_interfaces.h" #include "paddle/pir/include/core/interface_value.h" #include "paddle/pir/include/core/ir_printer.h" +#include "paddle/pir/include/core/sparse_type.h" #include "paddle/pir/include/core/utils.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" @@ -279,6 +280,7 @@ void PrintOperationImpl(pir::Operation* op, void OperatorDialect::initialize() { RegisterTypes(); RegisterAttributes +struct hash { + std::size_t operator()(const pir::DenseTensorType &obj) const { + // return + // pir::DenseTensorTypeStorage::HashValue(std::make_tuple(pir::Type(), + // pir::DDim(), pir::DataLayout(), pir::LoD(), size_t())); + return pir::DenseTensorTypeStorage::HashValue(std::tuple( + obj.dtype(), obj.dims(), obj.data_layout(), obj.lod(), obj.offset())); + } +}; +} // namespace std diff --git a/paddle/pir/include/core/sparse_type.h b/paddle/pir/include/core/sparse_type.h new file mode 100644 index 00000000000000..6cf1bb95d301f7 --- /dev/null +++ b/paddle/pir/include/core/sparse_type.h @@ -0,0 +1,83 @@ + +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/core/sparse_type_storage.h" +#include "paddle/pir/include/core/type.h" +namespace paddle { +namespace dialect { +/// +/// \brief Define built-in parameterless types. +/// +/// NOTE(zhangbo9674): If you need to directly +/// cache the object of this built-in type in IrContext, please overload the get +/// method, and construct and cache the object in IrContext. For the specific +/// implementation method, please refer to Float16Type. +/// +/// The built-in type object get method is as follows: +/// \code{cpp} +/// pir::IrContext *ctx = pir::IrContext::Instance(); +/// Type fp32 = Float32Type::get(ctx); +/// \endcode +/// + +// NOTE(dev): Currently Int8 are not considered as a cached member +// in IrContextImpl because it is not widely used. +class IR_API SparseCooTensorType + : public pir::Type:: + TypeBase { + public: + using Base::Base; + using Type = pir::Type; + using Dim = SparseCooTensorTypeStorage::Dim; + using DataLayout = pir::DataLayout; + using DenseTensorType = pir::DenseTensorType; + + Type dtype() const; + const Dim &dims() const; + DataLayout data_layout() const; + DenseTensorType get_indices() const; + DenseTensorType get_elements() const; + bool get_coalesced() const; + + /// + /// \brief Implementation of 'classof' that compares the type id of + /// the provided value with the concrete type id. + /// + static bool classof(Type type); + + static SparseCooTensorType dyn_cast_impl(Type type); + + static SparseCooTensorType get(pir::IrContext *ctx, + Type dtype, + const Dim &dims, + DataLayout layout, + DenseTensorType non_zero_indices, + DenseTensorType non_zero_elements, + bool coalesced = false) { + return Base::get(ctx, + dtype, + dims, + layout, + non_zero_indices, + non_zero_elements, + coalesced); + } +}; +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCooTensorType) diff --git a/paddle/pir/include/core/sparse_type_storage.h b/paddle/pir/include/core/sparse_type_storage.h new file mode 100644 index 00000000000000..01b523f3c02538 --- /dev/null +++ b/paddle/pir/include/core/sparse_type_storage.h @@ -0,0 +1,139 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/common/ddim.h" +#include "paddle/common/dim.h" +#include "paddle/common/hash_funcs.h" +#include "paddle/common/layout.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/core/type.h" +#include "paddle/pir/include/core/type_base.h" +#include "paddle/pir/include/core/utils.h" + +namespace paddle { +namespace dialect { +/// +/// \brief Define Parametric TypeStorage for SparseCooTensorType. +/// +/// NOTE(risemeup1): The derived TypeStorage class needs to implement the +/// following methods: (1)declare ParamKey, (2)define Construction method, +/// (3)define HashValue method, (4)overload operator==. +/// + +struct SparseCooTensorTypeStorage : public pir::TypeStorage { + /// + /// \brief Declare ParamKey according to parameter type. + /// + using Dim = pir::DDim; + using DataLayout = pir::DataLayout; + using DataType = pir::Type; + using DenseTensorType = pir::DenseTensorType; + using ParamKey = std:: + tuple; + SparseCooTensorTypeStorage(DataType dtype, + Dim dims, + DataLayout layout, + DenseTensorType non_zero_indices, + DenseTensorType non_zero_elements, + bool coalesced = false) + : dtype_(dtype), + dims_(dims), + layout_(layout), + non_zero_indices_(non_zero_indices), + non_zero_elements_(non_zero_elements), + coalesced_(coalesced) {} + + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static SparseCooTensorTypeStorage* Construct(const ParamKey& key) { + return new SparseCooTensorTypeStorage(std::get<0>(key), + std::get<1>(key), + std::get<2>(key), + std::get<3>(key), + std::get<4>(key), + std::get<5>(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { + std::size_t hash_value = 0; + // hash dtype + hash_value = pir::detail::hash_combine( + hash_value, std::hash()(std::get<0>(key))); + // hash dims + hash_value = pir::detail::hash_combine(hash_value, + std::hash()(std::get<1>(key))); + // hash layout + hash_value = pir::detail::hash_combine( + hash_value, + std::hash::type>()( + static_cast::type>( + std::get<2>(key)))); + // hash DenseTensorType + hash_value = pir::detail::hash_combine( + hash_value, std::hash()(std::get<3>(key))); + // hash DenseTensorType + hash_value = pir::detail::hash_combine( + hash_value, std::hash()(std::get<4>(key))); + + // hash coalesced + hash_value = pir::detail::hash_combine(hash_value, + std::hash()(std::get<5>(key))); + + return hash_value; + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return ParamKey(dtype_, + dims_, + layout_, + non_zero_indices_, + non_zero_elements_, + coalesced_) == key; + } + + ParamKey GetAsKey() const { + return ParamKey(dtype_, + dims_, + layout_, + non_zero_indices_, + non_zero_elements_, + coalesced_); + } + + /// + /// \brief SparseCooTensorTypeStorage include six parameters: dims, dtype, + /// layout, non_zero_indices_, non_zero_elements_,coalesced_. + /// + + DataType dtype_; + Dim dims_; + DataLayout layout_{DataLayout::NCHW}; + DenseTensorType non_zero_indices_; + DenseTensorType non_zero_elements_; + bool coalesced_ = false; +}; +} // namespace dialect +} // namespace paddle diff --git a/paddle/pir/src/core/sparse_type.cc b/paddle/pir/src/core/sparse_type.cc new file mode 100644 index 00000000000000..4e39a84ebf910b --- /dev/null +++ b/paddle/pir/src/core/sparse_type.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/include/core/sparse_type.h" + +namespace paddle { +namespace dialect { +pir::Type SparseCooTensorType::dtype() const { return storage()->dtype_; } + +const SparseCooTensorType::Dim& SparseCooTensorType::dims() const { + return storage()->dims_; +} + +DataLayout SparseCooTensorType::data_layout() const { + return storage()->layout_; +} + +pir::DenseTensorType SparseCooTensorType::get_indices() const { + return storage()->non_zero_indices_; +} + +pir::DenseTensorType SparseCooTensorType::get_elements() const { + return storage()->non_zero_elements_; +} + +bool SparseCooTensorType::get_coalesced() const { + return storage()->coalesced_; +} + +bool SparseCooTensorType::classof(Type type) { + if (type) { + if (type.type_id() == type_id()) return true; + if (auto wrap_type = type.dyn_cast()) { + return classof(wrap_type.prim_type()); + } + } + return false; +} + +SparseCooTensorType SparseCooTensorType::dyn_cast_impl(Type type) { + if (type) { + if (type.type_id() == type_id()) return SparseCooTensorType(type.storage()); + if (auto wrap_type = type.dyn_cast()) { + return dyn_cast_impl(wrap_type.prim_type()); + } + } + return nullptr; +} + +} // namespace dialect +} // namespace paddle + +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCooTensorType) diff --git a/test/cpp/pir/core/type_test.cc b/test/cpp/pir/core/type_test.cc index 9a7f70b779191a..93fd7d9e5d9be3 100644 --- a/test/cpp/pir/core/type_test.cc +++ b/test/cpp/pir/core/type_test.cc @@ -21,6 +21,7 @@ #include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/dialect.h" #include "paddle/pir/include/core/ir_context.h" +#include "paddle/pir/include/core/sparse_type.h" #include "paddle/pir/include/core/type.h" #include "paddle/pir/include/core/type_base.h" #include "paddle/pir/include/core/type_name.h" @@ -249,6 +250,39 @@ TEST(type_test, custom_type_dialect) { EXPECT_EQ(dialect_integer1, dialect_integer2); } +TEST(type_test, sparse_dialect) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + pir::Type fp32_dtype = pir::Float32Type::get(ctx); + common::DDim dims = {4, 4}; + common::DataLayout data_layout = common::DataLayout::NCHW; + pir::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + pir::DenseTensorType none_zero_indices = pir::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset); + pir::DenseTensorType none_zero_elements = pir::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset); + bool coalesced = false; + pir::Type pir_type = + paddle::dialect::SparseCooTensorType::get(ctx, + fp32_dtype, + dims, + data_layout, + none_zero_indices, + none_zero_elements, + coalesced); + + paddle::dialect::SparseCooTensorType sparse_coo_tensor_type = + paddle::dialect::SparseCooTensorType::dyn_cast_impl(pir_type); + + EXPECT_EQ(sparse_coo_tensor_type.isa(), + true); + EXPECT_EQ(sparse_coo_tensor_type.dims(), dims); + EXPECT_EQ(sparse_coo_tensor_type.data_layout(), data_layout); + EXPECT_EQ(sparse_coo_tensor_type.get_indices(), none_zero_indices); + EXPECT_EQ(sparse_coo_tensor_type.get_elements(), none_zero_elements); +} + TEST(type_test, pd_op_dialect) { pir::IrContext *ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); From b13414a53fdee507cd87e297ae6fa762889202f3 Mon Sep 17 00:00:00 2001 From: risemeup1 <515586620@qq.com> Date: Wed, 20 Mar 2024 07:59:37 +0000 Subject: [PATCH 2/7] support sparsecootensortype --- .../pir/dialect/operator/ir/op_dialect.cc | 1 - .../fluid/pir/dialect/operator/ir/op_type.cc | 47 ++++++ .../fluid/pir/dialect/operator/ir/op_type.h | 46 ++++++ .../pir/dialect/operator/ir/type_storage.h | 124 ++++++++++++++++ paddle/pir/include/core/builtin_type.h | 13 -- paddle/pir/include/core/sparse_type.h | 83 ----------- paddle/pir/include/core/sparse_type_storage.h | 139 ------------------ paddle/pir/src/core/sparse_type.cc | 64 -------- test/cpp/pir/core/type_test.cc | 10 +- 9 files changed, 223 insertions(+), 304 deletions(-) delete mode 100644 paddle/pir/include/core/sparse_type.h delete mode 100644 paddle/pir/include/core/sparse_type_storage.h delete mode 100644 paddle/pir/src/core/sparse_type.cc diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index e9305acb66ff1e..579a5f81e9c8e2 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -28,7 +28,6 @@ #include "paddle/pir/include/core/builtin_type_interfaces.h" #include "paddle/pir/include/core/interface_value.h" #include "paddle/pir/include/core/ir_printer.h" -#include "paddle/pir/include/core/sparse_type.h" #include "paddle/pir/include/core/utils.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" diff --git a/paddle/fluid/pir/dialect/operator/ir/op_type.cc b/paddle/fluid/pir/dialect/operator/ir/op_type.cc index 3e3902a86376e8..ed8afd48f41427 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_type.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.cc @@ -78,8 +78,55 @@ DenseTensorArrayType DenseTensorArrayType::dyn_cast_impl(Type type) { return nullptr; } +pir::Type SparseCooTensorType::dtype() const { return storage()->dtype_; } + +const SparseCooTensorType::Dim& SparseCooTensorType::dims() const { + return storage()->dims_; +} + +const SparseCooTensorType::Dim& SparseCooTensorType::meta_dims() const { + return storage()->meta_dims_; +} + +DataLayout SparseCooTensorType::data_layout() const { + return storage()->layout_; +} + +pir::DenseTensorType SparseCooTensorType::non_zero_indices() const { + return storage()->non_zero_indices_; +} + +pir::DenseTensorType SparseCooTensorType::non_zero_elements() const { + return storage()->non_zero_elements_; +} + +bool SparseCooTensorType::coalesced() const { return storage()->coalesced_; } + +bool SparseCooTensorType::classof(Type type) { + if (type) { + if (type.type_id() == type_id()) return true; + if (auto wrap_type = type.dyn_cast()) { + return classof(wrap_type.prim_type()); + } + } + return false; +} + +SparseCooTensorType SparseCooTensorType::dyn_cast_impl(Type type) { + if (type) { + if (type.type_id() == type_id()) { + return SparseCooTensorType(type.storage()); + } + if (auto wrap_type = type.dyn_cast()) { + return dyn_cast_impl(wrap_type.prim_type()); + } + } + return nullptr; +} + } // namespace dialect } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorArrayType) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCooTensorType) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_type.h b/paddle/fluid/pir/dialect/operator/ir/op_type.h index 4cc68b6d9fd7ae..9202f049b525c8 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_type.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.h @@ -74,8 +74,54 @@ class DenseTensorArrayType static DenseTensorArrayType dyn_cast_impl(Type type); }; +class IR_API SparseCooTensorType + : public pir::Type:: + TypeBase { + public: + using Base::Base; + using Type = pir::Type; + using Dim = SparseCooTensorTypeStorage::Dim; + using DataLayout = common::DataLayout; + using DenseTensorType = pir::DenseTensorType; + + Type dtype() const; + const Dim &dims() const; + const Dim &meta_dims() const; + DataLayout data_layout() const; + DenseTensorType non_zero_indices() const; + DenseTensorType non_zero_elements() const; + bool coalesced() const; + + /// + /// \brief Implementation of 'classof' that compares the type id of + /// the provided value with the concrete type id. + /// + static bool classof(Type type); + + static SparseCooTensorType dyn_cast_impl(Type type); + + static SparseCooTensorType get(pir::IrContext *ctx, + Type dtype, + const Dim &dims, + const Dim &meta_dims, + DataLayout layout, + DenseTensorType non_zero_indices, + DenseTensorType non_zero_elements, + bool coalesced = false) { + return Base::get(ctx, + dtype, + dims, + meta_dims, + layout, + non_zero_indices, + non_zero_elements, + coalesced); + } +}; + } // namespace dialect } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorArrayType) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCooTensorType) diff --git a/paddle/fluid/pir/dialect/operator/ir/type_storage.h b/paddle/fluid/pir/dialect/operator/ir/type_storage.h index 375bef9799d6c7..3cebaf01aadc32 100644 --- a/paddle/fluid/pir/dialect/operator/ir/type_storage.h +++ b/paddle/fluid/pir/dialect/operator/ir/type_storage.h @@ -17,6 +17,7 @@ #include #include "paddle/phi/core/tensor_meta.h" +#include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/builtin_type_storage.h" #include "paddle/pir/include/core/type.h" #include "paddle/pir/include/core/type_base.h" @@ -166,5 +167,128 @@ struct DenseTensorArrayTypeStorage : public pir::TypeStorage { phi::DataLayout layout_; }; +struct SparseCooTensorTypeStorage : public pir::TypeStorage { + /// + /// \brief Declare ParamKey according to parameter type. + /// + using Dim = common::DDim; + using DataLayout = common::DataLayout; + using DataType = pir::Type; + using DenseTensorType = pir::DenseTensorType; + using ParamKey = std::tuple; + SparseCooTensorTypeStorage(DataType dtype, + Dim dims, + Dim meta_dims, + DataLayout layout, + DenseTensorType non_zero_indices, + DenseTensorType non_zero_elements, + bool coalesced = false) + : dtype_(dtype), + dims_(dims), + meta_dims_(meta_dims), + layout_(layout), + non_zero_indices_(non_zero_indices), + non_zero_elements_(non_zero_elements), + coalesced_(coalesced) {} + + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static SparseCooTensorTypeStorage* Construct(const ParamKey& key) { + return new SparseCooTensorTypeStorage(std::get<0>(key), + std::get<1>(key), + std::get<2>(key), + std::get<3>(key), + std::get<4>(key), + std::get<5>(key), + std::get<6>(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { + std::size_t hash_value = 0; + // hash dtype + hash_value = pir::detail::hash_combine( + hash_value, std::hash()(std::get<0>(key))); + // hash dims + hash_value = pir::detail::hash_combine(hash_value, + std::hash()(std::get<1>(key))); + // hash meta_dims + hash_value = pir::detail::hash_combine(hash_value, + std::hash()(std::get<2>(key))); + // hash layout + hash_value = pir::detail::hash_combine( + hash_value, + std::hash::type>()( + static_cast::type>( + std::get<3>(key)))); + // hash DenseTensorType + auto tuple1 = std::make_tuple(std::get<4>(key).dtype(), + std::get<4>(key).dims(), + std::get<4>(key).data_layout(), + std::get<4>(key).lod(), + std::get<4>(key).offset()); + hash_value = pir::detail::hash_combine( + hash_value, DenseTensorTypeStorage::HashValue(tuple1)); + // hash DenseTensorType + auto tuple2 = std::make_tuple(std::get<5>(key).dtype(), + std::get<5>(key).dims(), + std::get<5>(key).data_layout(), + std::get<5>(key).lod(), + std::get<5>(key).offset()); + hash_value = pir::detail::hash_combine( + hash_value, DenseTensorTypeStorage::HashValue(tuple2)); + // hash coalesced + hash_value = pir::detail::hash_combine(hash_value, + std::hash()(std::get<6>(key))); + + return hash_value; + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return ParamKey(dtype_, + dims_, + meta_dims_, + layout_, + non_zero_indices_, + non_zero_elements_, + coalesced_) == key; + } + + ParamKey GetAsKey() const { + return ParamKey(dtype_, + dims_, + meta_dims_, + layout_, + non_zero_indices_, + non_zero_elements_, + coalesced_); + } + + /// + /// \brief SparseCooTensorTypeStorage include six parameters: dims, dtype, + /// layout, non_zero_indices_, non_zero_elements_,coalesced_. + /// + + DataType dtype_; + Dim dims_; + Dim meta_dims_; + DataLayout layout_{DataLayout::NCHW}; + DenseTensorType non_zero_indices_; + DenseTensorType non_zero_elements_; + bool coalesced_ = false; +}; } // namespace dialect } // namespace paddle diff --git a/paddle/pir/include/core/builtin_type.h b/paddle/pir/include/core/builtin_type.h index 455bbfdff0a5f5..144b62bb9753e4 100644 --- a/paddle/pir/include/core/builtin_type.h +++ b/paddle/pir/include/core/builtin_type.h @@ -129,16 +129,3 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::IndexType) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Complex64Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Complex128Type) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::DenseTensorType) - -namespace std { -template <> -struct hash { - std::size_t operator()(const pir::DenseTensorType &obj) const { - // return - // pir::DenseTensorTypeStorage::HashValue(std::make_tuple(pir::Type(), - // pir::DDim(), pir::DataLayout(), pir::LoD(), size_t())); - return pir::DenseTensorTypeStorage::HashValue(std::tuple( - obj.dtype(), obj.dims(), obj.data_layout(), obj.lod(), obj.offset())); - } -}; -} // namespace std diff --git a/paddle/pir/include/core/sparse_type.h b/paddle/pir/include/core/sparse_type.h deleted file mode 100644 index 6cf1bb95d301f7..00000000000000 --- a/paddle/pir/include/core/sparse_type.h +++ /dev/null @@ -1,83 +0,0 @@ - -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pir/include/core/sparse_type_storage.h" -#include "paddle/pir/include/core/type.h" -namespace paddle { -namespace dialect { -/// -/// \brief Define built-in parameterless types. -/// -/// NOTE(zhangbo9674): If you need to directly -/// cache the object of this built-in type in IrContext, please overload the get -/// method, and construct and cache the object in IrContext. For the specific -/// implementation method, please refer to Float16Type. -/// -/// The built-in type object get method is as follows: -/// \code{cpp} -/// pir::IrContext *ctx = pir::IrContext::Instance(); -/// Type fp32 = Float32Type::get(ctx); -/// \endcode -/// - -// NOTE(dev): Currently Int8 are not considered as a cached member -// in IrContextImpl because it is not widely used. -class IR_API SparseCooTensorType - : public pir::Type:: - TypeBase { - public: - using Base::Base; - using Type = pir::Type; - using Dim = SparseCooTensorTypeStorage::Dim; - using DataLayout = pir::DataLayout; - using DenseTensorType = pir::DenseTensorType; - - Type dtype() const; - const Dim &dims() const; - DataLayout data_layout() const; - DenseTensorType get_indices() const; - DenseTensorType get_elements() const; - bool get_coalesced() const; - - /// - /// \brief Implementation of 'classof' that compares the type id of - /// the provided value with the concrete type id. - /// - static bool classof(Type type); - - static SparseCooTensorType dyn_cast_impl(Type type); - - static SparseCooTensorType get(pir::IrContext *ctx, - Type dtype, - const Dim &dims, - DataLayout layout, - DenseTensorType non_zero_indices, - DenseTensorType non_zero_elements, - bool coalesced = false) { - return Base::get(ctx, - dtype, - dims, - layout, - non_zero_indices, - non_zero_elements, - coalesced); - } -}; -} // namespace dialect -} // namespace paddle - -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCooTensorType) diff --git a/paddle/pir/include/core/sparse_type_storage.h b/paddle/pir/include/core/sparse_type_storage.h deleted file mode 100644 index 01b523f3c02538..00000000000000 --- a/paddle/pir/include/core/sparse_type_storage.h +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/common/ddim.h" -#include "paddle/common/dim.h" -#include "paddle/common/hash_funcs.h" -#include "paddle/common/layout.h" -#include "paddle/pir/include/core/builtin_type.h" -#include "paddle/pir/include/core/type.h" -#include "paddle/pir/include/core/type_base.h" -#include "paddle/pir/include/core/utils.h" - -namespace paddle { -namespace dialect { -/// -/// \brief Define Parametric TypeStorage for SparseCooTensorType. -/// -/// NOTE(risemeup1): The derived TypeStorage class needs to implement the -/// following methods: (1)declare ParamKey, (2)define Construction method, -/// (3)define HashValue method, (4)overload operator==. -/// - -struct SparseCooTensorTypeStorage : public pir::TypeStorage { - /// - /// \brief Declare ParamKey according to parameter type. - /// - using Dim = pir::DDim; - using DataLayout = pir::DataLayout; - using DataType = pir::Type; - using DenseTensorType = pir::DenseTensorType; - using ParamKey = std:: - tuple; - SparseCooTensorTypeStorage(DataType dtype, - Dim dims, - DataLayout layout, - DenseTensorType non_zero_indices, - DenseTensorType non_zero_elements, - bool coalesced = false) - : dtype_(dtype), - dims_(dims), - layout_(layout), - non_zero_indices_(non_zero_indices), - non_zero_elements_(non_zero_elements), - coalesced_(coalesced) {} - - /// - /// \brief Each derived TypeStorage must define a Construct method, which - /// StorageManager uses to construct a derived TypeStorage. - /// - static SparseCooTensorTypeStorage* Construct(const ParamKey& key) { - return new SparseCooTensorTypeStorage(std::get<0>(key), - std::get<1>(key), - std::get<2>(key), - std::get<3>(key), - std::get<4>(key), - std::get<5>(key)); - } - - /// - /// \brief Each derived TypeStorage must provide a HashValue method. - /// - static std::size_t HashValue(const ParamKey& key) { - std::size_t hash_value = 0; - // hash dtype - hash_value = pir::detail::hash_combine( - hash_value, std::hash()(std::get<0>(key))); - // hash dims - hash_value = pir::detail::hash_combine(hash_value, - std::hash()(std::get<1>(key))); - // hash layout - hash_value = pir::detail::hash_combine( - hash_value, - std::hash::type>()( - static_cast::type>( - std::get<2>(key)))); - // hash DenseTensorType - hash_value = pir::detail::hash_combine( - hash_value, std::hash()(std::get<3>(key))); - // hash DenseTensorType - hash_value = pir::detail::hash_combine( - hash_value, std::hash()(std::get<4>(key))); - - // hash coalesced - hash_value = pir::detail::hash_combine(hash_value, - std::hash()(std::get<5>(key))); - - return hash_value; - } - - /// - /// \brief Each derived TypeStorage needs to overload operator==. - /// - bool operator==(const ParamKey& key) const { - return ParamKey(dtype_, - dims_, - layout_, - non_zero_indices_, - non_zero_elements_, - coalesced_) == key; - } - - ParamKey GetAsKey() const { - return ParamKey(dtype_, - dims_, - layout_, - non_zero_indices_, - non_zero_elements_, - coalesced_); - } - - /// - /// \brief SparseCooTensorTypeStorage include six parameters: dims, dtype, - /// layout, non_zero_indices_, non_zero_elements_,coalesced_. - /// - - DataType dtype_; - Dim dims_; - DataLayout layout_{DataLayout::NCHW}; - DenseTensorType non_zero_indices_; - DenseTensorType non_zero_elements_; - bool coalesced_ = false; -}; -} // namespace dialect -} // namespace paddle diff --git a/paddle/pir/src/core/sparse_type.cc b/paddle/pir/src/core/sparse_type.cc deleted file mode 100644 index 4e39a84ebf910b..00000000000000 --- a/paddle/pir/src/core/sparse_type.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/pir/include/core/sparse_type.h" - -namespace paddle { -namespace dialect { -pir::Type SparseCooTensorType::dtype() const { return storage()->dtype_; } - -const SparseCooTensorType::Dim& SparseCooTensorType::dims() const { - return storage()->dims_; -} - -DataLayout SparseCooTensorType::data_layout() const { - return storage()->layout_; -} - -pir::DenseTensorType SparseCooTensorType::get_indices() const { - return storage()->non_zero_indices_; -} - -pir::DenseTensorType SparseCooTensorType::get_elements() const { - return storage()->non_zero_elements_; -} - -bool SparseCooTensorType::get_coalesced() const { - return storage()->coalesced_; -} - -bool SparseCooTensorType::classof(Type type) { - if (type) { - if (type.type_id() == type_id()) return true; - if (auto wrap_type = type.dyn_cast()) { - return classof(wrap_type.prim_type()); - } - } - return false; -} - -SparseCooTensorType SparseCooTensorType::dyn_cast_impl(Type type) { - if (type) { - if (type.type_id() == type_id()) return SparseCooTensorType(type.storage()); - if (auto wrap_type = type.dyn_cast()) { - return dyn_cast_impl(wrap_type.prim_type()); - } - } - return nullptr; -} - -} // namespace dialect -} // namespace paddle - -IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCooTensorType) diff --git a/test/cpp/pir/core/type_test.cc b/test/cpp/pir/core/type_test.cc index 93fd7d9e5d9be3..c54533f8e7c18d 100644 --- a/test/cpp/pir/core/type_test.cc +++ b/test/cpp/pir/core/type_test.cc @@ -21,7 +21,6 @@ #include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/dialect.h" #include "paddle/pir/include/core/ir_context.h" -#include "paddle/pir/include/core/sparse_type.h" #include "paddle/pir/include/core/type.h" #include "paddle/pir/include/core/type_base.h" #include "paddle/pir/include/core/type_name.h" @@ -250,11 +249,12 @@ TEST(type_test, custom_type_dialect) { EXPECT_EQ(dialect_integer1, dialect_integer2); } -TEST(type_test, sparse_dialect) { +TEST(type_test, sparse_coo) { pir::IrContext *ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); pir::Type fp32_dtype = pir::Float32Type::get(ctx); common::DDim dims = {4, 4}; + common::DDim meta_ddims = {4, 1}; common::DataLayout data_layout = common::DataLayout::NCHW; pir::LoD lod = {{0, 1, 2}}; size_t offset = 0; @@ -267,6 +267,7 @@ TEST(type_test, sparse_dialect) { paddle::dialect::SparseCooTensorType::get(ctx, fp32_dtype, dims, + meta_ddims, data_layout, none_zero_indices, none_zero_elements, @@ -278,9 +279,10 @@ TEST(type_test, sparse_dialect) { EXPECT_EQ(sparse_coo_tensor_type.isa(), true); EXPECT_EQ(sparse_coo_tensor_type.dims(), dims); + EXPECT_EQ(sparse_coo_tensor_type.meta_dims(), meta_ddims); EXPECT_EQ(sparse_coo_tensor_type.data_layout(), data_layout); - EXPECT_EQ(sparse_coo_tensor_type.get_indices(), none_zero_indices); - EXPECT_EQ(sparse_coo_tensor_type.get_elements(), none_zero_elements); + EXPECT_EQ(sparse_coo_tensor_type.non_zero_indices(), none_zero_indices); + EXPECT_EQ(sparse_coo_tensor_type.non_zero_elements(), none_zero_elements); } TEST(type_test, pd_op_dialect) { From 1f2a2721912d496e8294c3951be70f5b7109ffde Mon Sep 17 00:00:00 2001 From: risemeup1 <515586620@qq.com> Date: Wed, 20 Mar 2024 08:03:01 +0000 Subject: [PATCH 3/7] support sparsecootensortype --- paddle/fluid/pir/dialect/operator/ir/op_type.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_type.cc b/paddle/fluid/pir/dialect/operator/ir/op_type.cc index ed8afd48f41427..b8aa6e5c161570 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_type.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.cc @@ -104,7 +104,9 @@ bool SparseCooTensorType::coalesced() const { return storage()->coalesced_; } bool SparseCooTensorType::classof(Type type) { if (type) { - if (type.type_id() == type_id()) return true; + if (type.type_id() == type_id()) { + return true; + } if (auto wrap_type = type.dyn_cast()) { return classof(wrap_type.prim_type()); } From 69dda1f75658e8cbd807b6e73e60a9867dd23bfb Mon Sep 17 00:00:00 2001 From: risemeup1 <515586620@qq.com> Date: Wed, 20 Mar 2024 08:38:51 +0000 Subject: [PATCH 4/7] support sparsecootensortype --- paddle/fluid/pir/dialect/operator/ir/op_type.h | 11 +++++------ paddle/fluid/pir/dialect/operator/ir/type_storage.h | 13 ++++++------- test/cpp/pir/core/type_test.cc | 5 ++--- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_type.h b/paddle/fluid/pir/dialect/operator/ir/op_type.h index 9202f049b525c8..b6a6f2f3d7f2b5 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_type.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.h @@ -80,16 +80,15 @@ class IR_API SparseCooTensorType public: using Base::Base; using Type = pir::Type; - using Dim = SparseCooTensorTypeStorage::Dim; + using Dim = common::Dim; using DataLayout = common::DataLayout; - using DenseTensorType = pir::DenseTensorType; Type dtype() const; const Dim &dims() const; const Dim &meta_dims() const; DataLayout data_layout() const; - DenseTensorType non_zero_indices() const; - DenseTensorType non_zero_elements() const; + pir::DenseTensorType non_zero_indices() const; + pir::DenseTensorType non_zero_elements() const; bool coalesced() const; /// @@ -105,8 +104,8 @@ class IR_API SparseCooTensorType const Dim &dims, const Dim &meta_dims, DataLayout layout, - DenseTensorType non_zero_indices, - DenseTensorType non_zero_elements, + pir::DenseTensorType non_zero_indices, + pir::DenseTensorType non_zero_elements, bool coalesced = false) { return Base::get(ctx, dtype, diff --git a/paddle/fluid/pir/dialect/operator/ir/type_storage.h b/paddle/fluid/pir/dialect/operator/ir/type_storage.h index 3cebaf01aadc32..3b59fec232f4b9 100644 --- a/paddle/fluid/pir/dialect/operator/ir/type_storage.h +++ b/paddle/fluid/pir/dialect/operator/ir/type_storage.h @@ -174,20 +174,19 @@ struct SparseCooTensorTypeStorage : public pir::TypeStorage { using Dim = common::DDim; using DataLayout = common::DataLayout; using DataType = pir::Type; - using DenseTensorType = pir::DenseTensorType; using ParamKey = std::tuple; SparseCooTensorTypeStorage(DataType dtype, Dim dims, Dim meta_dims, DataLayout layout, - DenseTensorType non_zero_indices, - DenseTensorType non_zero_elements, + pir::DenseTensorType non_zero_indices, + pir::DenseTensorType non_zero_elements, bool coalesced = false) : dtype_(dtype), dims_(dims), @@ -286,8 +285,8 @@ struct SparseCooTensorTypeStorage : public pir::TypeStorage { Dim dims_; Dim meta_dims_; DataLayout layout_{DataLayout::NCHW}; - DenseTensorType non_zero_indices_; - DenseTensorType non_zero_elements_; + pir::DenseTensorType non_zero_indices_; + pir::DenseTensorType non_zero_elements_; bool coalesced_ = false; }; } // namespace dialect diff --git a/test/cpp/pir/core/type_test.cc b/test/cpp/pir/core/type_test.cc index c54533f8e7c18d..eed279b09ecd46 100644 --- a/test/cpp/pir/core/type_test.cc +++ b/test/cpp/pir/core/type_test.cc @@ -274,10 +274,9 @@ TEST(type_test, sparse_coo) { coalesced); paddle::dialect::SparseCooTensorType sparse_coo_tensor_type = - paddle::dialect::SparseCooTensorType::dyn_cast_impl(pir_type); + paddle::dialect::SparseCooTensorType::dyn_cast(pir_type); + EXPECT_EQ(pir_type.isa(), true); - EXPECT_EQ(sparse_coo_tensor_type.isa(), - true); EXPECT_EQ(sparse_coo_tensor_type.dims(), dims); EXPECT_EQ(sparse_coo_tensor_type.meta_dims(), meta_ddims); EXPECT_EQ(sparse_coo_tensor_type.data_layout(), data_layout); From 0b5df15f71758b27d55722fbf1e2d658f5f30d04 Mon Sep 17 00:00:00 2001 From: risemeup1 <515586620@qq.com> Date: Wed, 20 Mar 2024 09:58:50 +0000 Subject: [PATCH 5/7] support sparsecootensortype --- .../fluid/pir/dialect/operator/ir/op_type.cc | 14 ++---- .../fluid/pir/dialect/operator/ir/op_type.h | 30 ++++++------- .../pir/dialect/operator/ir/type_storage.h | 43 +++++++++---------- test/cpp/pir/core/type_test.cc | 10 ++--- 4 files changed, 44 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_type.cc b/paddle/fluid/pir/dialect/operator/ir/op_type.cc index b8aa6e5c161570..7972941ea2985b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_type.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.cc @@ -80,15 +80,15 @@ DenseTensorArrayType DenseTensorArrayType::dyn_cast_impl(Type type) { pir::Type SparseCooTensorType::dtype() const { return storage()->dtype_; } -const SparseCooTensorType::Dim& SparseCooTensorType::dims() const { +const common::DDim& SparseCooTensorType::dims() const { return storage()->dims_; } -const SparseCooTensorType::Dim& SparseCooTensorType::meta_dims() const { - return storage()->meta_dims_; +const common::DDim& SparseCooTensorType::non_zero_dims() const { + return storage()->non_zero_dims_; } -DataLayout SparseCooTensorType::data_layout() const { +common::DataLayout SparseCooTensorType::data_layout() const { return storage()->layout_; } @@ -107,9 +107,6 @@ bool SparseCooTensorType::classof(Type type) { if (type.type_id() == type_id()) { return true; } - if (auto wrap_type = type.dyn_cast()) { - return classof(wrap_type.prim_type()); - } } return false; } @@ -119,9 +116,6 @@ SparseCooTensorType SparseCooTensorType::dyn_cast_impl(Type type) { if (type.type_id() == type_id()) { return SparseCooTensorType(type.storage()); } - if (auto wrap_type = type.dyn_cast()) { - return dyn_cast_impl(wrap_type.prim_type()); - } } return nullptr; } diff --git a/paddle/fluid/pir/dialect/operator/ir/op_type.h b/paddle/fluid/pir/dialect/operator/ir/op_type.h index b6a6f2f3d7f2b5..3f6103b7cd1eec 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_type.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.h @@ -79,14 +79,14 @@ class IR_API SparseCooTensorType TypeBase { public: using Base::Base; - using Type = pir::Type; - using Dim = common::Dim; - using DataLayout = common::DataLayout; - - Type dtype() const; - const Dim &dims() const; - const Dim &meta_dims() const; - DataLayout data_layout() const; + // using Type = pir::Type; + // using Dim = common::DDim; + // using DataLayout = common::DataLayout; + + pir::Type dtype() const; + const common::DDim &dims() const; + const common::DDim &non_zero_dims() const; + common::DataLayout data_layout() const; pir::DenseTensorType non_zero_indices() const; pir::DenseTensorType non_zero_elements() const; bool coalesced() const; @@ -95,22 +95,22 @@ class IR_API SparseCooTensorType /// \brief Implementation of 'classof' that compares the type id of /// the provided value with the concrete type id. /// - static bool classof(Type type); + static bool classof(pir::Type type); - static SparseCooTensorType dyn_cast_impl(Type type); + static SparseCooTensorType dyn_cast_impl(pir::Type type); static SparseCooTensorType get(pir::IrContext *ctx, - Type dtype, - const Dim &dims, - const Dim &meta_dims, - DataLayout layout, + pir::Type dtype, + const common::DDim &dims, + const common::DDim &non_zero_dims, + common::DataLayout layout, pir::DenseTensorType non_zero_indices, pir::DenseTensorType non_zero_elements, bool coalesced = false) { return Base::get(ctx, dtype, dims, - meta_dims, + non_zero_dims, layout, non_zero_indices, non_zero_elements, diff --git a/paddle/fluid/pir/dialect/operator/ir/type_storage.h b/paddle/fluid/pir/dialect/operator/ir/type_storage.h index 3b59fec232f4b9..686058ce3acf94 100644 --- a/paddle/fluid/pir/dialect/operator/ir/type_storage.h +++ b/paddle/fluid/pir/dialect/operator/ir/type_storage.h @@ -171,26 +171,23 @@ struct SparseCooTensorTypeStorage : public pir::TypeStorage { /// /// \brief Declare ParamKey according to parameter type. /// - using Dim = common::DDim; - using DataLayout = common::DataLayout; - using DataType = pir::Type; - using ParamKey = std::tuple; - SparseCooTensorTypeStorage(DataType dtype, - Dim dims, - Dim meta_dims, - DataLayout layout, + SparseCooTensorTypeStorage(pir::Type dtype, + common::DDim dims, + common::DDim non_zero_dims, + common::DataLayout layout, pir::DenseTensorType non_zero_indices, pir::DenseTensorType non_zero_elements, bool coalesced = false) : dtype_(dtype), dims_(dims), - meta_dims_(meta_dims), + non_zero_dims_(non_zero_dims), layout_(layout), non_zero_indices_(non_zero_indices), non_zero_elements_(non_zero_elements), @@ -219,11 +216,11 @@ struct SparseCooTensorTypeStorage : public pir::TypeStorage { hash_value = pir::detail::hash_combine( hash_value, std::hash()(std::get<0>(key))); // hash dims - hash_value = pir::detail::hash_combine(hash_value, - std::hash()(std::get<1>(key))); - // hash meta_dims - hash_value = pir::detail::hash_combine(hash_value, - std::hash()(std::get<2>(key))); + hash_value = pir::detail::hash_combine( + hash_value, std::hash()(std::get<1>(key))); + // hash non_zero_dims + hash_value = pir::detail::hash_combine( + hash_value, std::hash()(std::get<2>(key))); // hash layout hash_value = pir::detail::hash_combine( hash_value, @@ -259,7 +256,7 @@ struct SparseCooTensorTypeStorage : public pir::TypeStorage { bool operator==(const ParamKey& key) const { return ParamKey(dtype_, dims_, - meta_dims_, + non_zero_dims_, layout_, non_zero_indices_, non_zero_elements_, @@ -269,7 +266,7 @@ struct SparseCooTensorTypeStorage : public pir::TypeStorage { ParamKey GetAsKey() const { return ParamKey(dtype_, dims_, - meta_dims_, + non_zero_dims_, layout_, non_zero_indices_, non_zero_elements_, @@ -281,10 +278,10 @@ struct SparseCooTensorTypeStorage : public pir::TypeStorage { /// layout, non_zero_indices_, non_zero_elements_,coalesced_. /// - DataType dtype_; - Dim dims_; - Dim meta_dims_; - DataLayout layout_{DataLayout::NCHW}; + pir::Type dtype_; + common::DDim dims_; + common::DDim non_zero_dims_; + common::DataLayout layout_{DataLayout::NCHW}; pir::DenseTensorType non_zero_indices_; pir::DenseTensorType non_zero_elements_; bool coalesced_ = false; diff --git a/test/cpp/pir/core/type_test.cc b/test/cpp/pir/core/type_test.cc index eed279b09ecd46..d0c77719fd7b77 100644 --- a/test/cpp/pir/core/type_test.cc +++ b/test/cpp/pir/core/type_test.cc @@ -254,7 +254,7 @@ TEST(type_test, sparse_coo) { ctx->GetOrRegisterDialect(); pir::Type fp32_dtype = pir::Float32Type::get(ctx); common::DDim dims = {4, 4}; - common::DDim meta_ddims = {4, 1}; + common::DDim non_zero_dims = {4, 1}; common::DataLayout data_layout = common::DataLayout::NCHW; pir::LoD lod = {{0, 1, 2}}; size_t offset = 0; @@ -273,12 +273,12 @@ TEST(type_test, sparse_coo) { none_zero_elements, coalesced); - paddle::dialect::SparseCooTensorType sparse_coo_tensor_type = - paddle::dialect::SparseCooTensorType::dyn_cast(pir_type); EXPECT_EQ(pir_type.isa(), true); - + paddle::dialect::SparseCooTensorType sparse_coo_tensor_type = + pir_type.dyn_cast(); + EXPECT_EQ(sparse_coo_tensor_type, true); EXPECT_EQ(sparse_coo_tensor_type.dims(), dims); - EXPECT_EQ(sparse_coo_tensor_type.meta_dims(), meta_ddims); + EXPECT_EQ(sparse_coo_tensor_type.non_zero_dims(), meta_ddims); EXPECT_EQ(sparse_coo_tensor_type.data_layout(), data_layout); EXPECT_EQ(sparse_coo_tensor_type.non_zero_indices(), none_zero_indices); EXPECT_EQ(sparse_coo_tensor_type.non_zero_elements(), none_zero_elements); From 6cf5f663618ffeeadfb07b6666f5e7f6a1bc5151 Mon Sep 17 00:00:00 2001 From: risemeup1 <515586620@qq.com> Date: Wed, 20 Mar 2024 11:19:55 +0000 Subject: [PATCH 6/7] support sparsecootensortype --- test/cpp/pir/core/type_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/cpp/pir/core/type_test.cc b/test/cpp/pir/core/type_test.cc index d0c77719fd7b77..f8a52a3d162dc2 100644 --- a/test/cpp/pir/core/type_test.cc +++ b/test/cpp/pir/core/type_test.cc @@ -267,7 +267,7 @@ TEST(type_test, sparse_coo) { paddle::dialect::SparseCooTensorType::get(ctx, fp32_dtype, dims, - meta_ddims, + non_zero_dims, data_layout, none_zero_indices, none_zero_elements, @@ -276,12 +276,12 @@ TEST(type_test, sparse_coo) { EXPECT_EQ(pir_type.isa(), true); paddle::dialect::SparseCooTensorType sparse_coo_tensor_type = pir_type.dyn_cast(); - EXPECT_EQ(sparse_coo_tensor_type, true); EXPECT_EQ(sparse_coo_tensor_type.dims(), dims); - EXPECT_EQ(sparse_coo_tensor_type.non_zero_dims(), meta_ddims); + EXPECT_EQ(sparse_coo_tensor_type.non_zero_dims(), non_zero_dims); EXPECT_EQ(sparse_coo_tensor_type.data_layout(), data_layout); EXPECT_EQ(sparse_coo_tensor_type.non_zero_indices(), none_zero_indices); EXPECT_EQ(sparse_coo_tensor_type.non_zero_elements(), none_zero_elements); + EXPECT_EQ(sparse_coo_tensor_type.coalesced(), coalesced); } TEST(type_test, pd_op_dialect) { From b40d85e7be03246536174109f9510ffa8a9fad47 Mon Sep 17 00:00:00 2001 From: risemeup1 <515586620@qq.com> Date: Wed, 20 Mar 2024 11:24:31 +0000 Subject: [PATCH 7/7] support sparsecootensortype --- paddle/fluid/pir/dialect/operator/ir/op_type.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_type.h b/paddle/fluid/pir/dialect/operator/ir/op_type.h index 3f6103b7cd1eec..5f881067a25319 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_type.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.h @@ -79,9 +79,6 @@ class IR_API SparseCooTensorType TypeBase { public: using Base::Base; - // using Type = pir::Type; - // using Dim = common::DDim; - // using DataLayout = common::DataLayout; pir::Type dtype() const; const common::DDim &dims() const;