-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support SparseCooTensorType #62868
Support SparseCooTensorType #62868
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
std::size_t operator()(const pir::DenseTensorType &obj) const { | ||
// return | ||
// pir::DenseTensorTypeStorage::HashValue(std::make_tuple(pir::Type(), | ||
// pir::DDim(), pir::DataLayout(), pir::LoD(), size_t())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无用的注释可以删除一下,下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,收到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest delete
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest delete
好的
using ParamKey = std:: | ||
tuple<DataType, Dim, DataLayout, DenseTensorType, DenseTensorType, bool>; | ||
SparseCooTensorTypeStorage(DataType dtype, | ||
Dim dims, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和phi体系下数据结构对应,添加meta_dims
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
paddle/pir/src/core/sparse_type.cc
Outdated
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/pir/include/core/sparse_type.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接放在 op_type.h 中吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接放在 op_type.h 中吧
好的
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同理,直接放在 paddle/fluid/pir/dialect/operator/ir/type_storage.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同理,直接放在 paddle/fluid/pir/dialect/operator/ir/type_storage.h
好的
|
||
#pragma once | ||
|
||
#include "paddle/pir/include/core/sparse_type_storage.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上移到:paddle/fluid/pir/dialect/operator/ir/op_type.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上移到:paddle/fluid/pir/dialect/operator/ir/op_type.h
好的
std::get<2>(key)))); | ||
// hash DenseTensorType | ||
hash_value = pir::detail::hash_combine( | ||
hash_value, std::hash<DenseTensorType>()(std::get<3>(key))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DenseTensorTypeStorage:: HashValue(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DenseTensorTypeStorage:: HashValue(...)
好的
std::size_t operator()(const pir::DenseTensorType &obj) const { | ||
// return | ||
// pir::DenseTensorTypeStorage::HashValue(std::make_tuple(pir::Type(), | ||
// pir::DDim(), pir::DataLayout(), pir::LoD(), size_t())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest delete
Type dtype() const; | ||
const Dim &dims() const; | ||
DataLayout data_layout() const; | ||
DenseTensorType get_indices() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DenseTensorType get_indices() const; | |
DenseTensorType non_zero_indices() const; |
同下
/// | ||
/// \brief Declare ParamKey according to parameter type. | ||
/// | ||
using Dim = pir::DDim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using Dim = pir::DDim; | |
using Dim = common::DDim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
/// \brief Declare ParamKey according to parameter type. | ||
/// | ||
using Dim = pir::DDim; | ||
using DataLayout = pir::DataLayout; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
test/cpp/pir/core/type_test.cc
Outdated
@@ -249,6 +250,39 @@ TEST(type_test, custom_type_dialect) { | |||
EXPECT_EQ(dialect_integer1, dialect_integer2); | |||
} | |||
|
|||
TEST(type_test, sparse_dialect) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TEST(type_test, sparse_dialect) { | |
TEST(type_test, sparse_coo) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
paddle/pir/src/core/sparse_type.cc
Outdated
|
||
SparseCooTensorType SparseCooTensorType::dyn_cast_impl(Type type) { | ||
if (type) { | ||
if (type.type_id() == type_id()) return SparseCooTensorType(type.storage()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if内单行语句也建议用大括号包裹起来,其他地方同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if内单行语句也建议用大括号包裹起来,其他地方同
ok
using Type = pir::Type; | ||
using Dim = SparseCooTensorTypeStorage::Dim; | ||
using DataLayout = common::DataLayout; | ||
using DenseTensorType = pir::DenseTensorType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些是不是相当于在类SparseCooTensorType内取了这些type的别名并且还是public的啊,感觉不太好?比如你可以这样访问SparseCooTensorType::DenseTensorType,不建议这样做
using Dim = common::DDim; | ||
using DataLayout = common::DataLayout; | ||
using DataType = pir::Type; | ||
using DenseTensorType = pir::DenseTensorType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里同
public: | ||
using Base::Base; | ||
using Type = pir::Type; | ||
using Dim = SparseCooTensorTypeStorage::Dim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using Dim = SparseCooTensorTypeStorage::Dim; | |
using Dim = common::Dim; |
test/cpp/pir/core/type_test.cc
Outdated
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); | ||
pir::Type fp32_dtype = pir::Float32Type::get(ctx); | ||
common::DDim dims = {4, 4}; | ||
common::DDim meta_ddims = {4, 1}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meta_ddims -> meta_dims
test/cpp/pir/core/type_test.cc
Outdated
coalesced); | ||
|
||
paddle::dialect::SparseCooTensorType sparse_coo_tensor_type = | ||
paddle::dialect::SparseCooTensorType::dyn_cast_impl(pir_type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里一般验证方法是pir_type.dyn_cast<paddle::dialect::SparseCooTensorType>(),而不用dyn_cast_impl
test/cpp/pir/core/type_test.cc
Outdated
paddle::dialect::SparseCooTensorType sparse_coo_tensor_type = | ||
paddle::dialect::SparseCooTensorType::dyn_cast_impl(pir_type); | ||
|
||
EXPECT_EQ(sparse_coo_tensor_type.isa<paddle::dialect::SparseCooTensorType>(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一般先判断isa<> 再进行dyn_cast
test/cpp/pir/core/type_test.cc
Outdated
paddle::dialect::SparseCooTensorType sparse_coo_tensor_type = | ||
paddle::dialect::SparseCooTensorType::dyn_cast_impl(pir_type); | ||
|
||
EXPECT_EQ(sparse_coo_tensor_type.isa<paddle::dialect::SparseCooTensorType>(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EXPECT_EQ(sparse_coo_tensor_type.isa<paddle::dialect::SparseCooTensorType>(), | |
EXPECT_EQ(pir_type.isa<paddle::dialect::SparseCooTensorType>(), |
if (type.type_id() == type_id()) { | ||
return true; | ||
} | ||
if (auto wrap_type = type.dyn_cast<pir::WrapTypeInterface>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SparseCooTensorType 在定义的时候就没有WrapTypeInterface,不需要这段逻辑
if (type.type_id() == type_id()) { | ||
return SparseCooTensorType(type.storage()); | ||
} | ||
if (auto wrap_type = type.dyn_cast<pir::WrapTypeInterface>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
|
||
DataType dtype_; | ||
Dim dims_; | ||
Dim meta_dims_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meta_dims 与 dims 的关系是什么,在 SparseCooTensor 中,这两个变量的具体作用有重叠么?
using Type = pir::Type; | ||
using Dim = common::Dim; | ||
using DataLayout = common::DataLayout; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否直接用比较好呢?不需要在这里对他们取别名
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Description
支持 SparceCooTensorType
Other
Pcard-67164