Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SparseCooTensorType #62868

Merged
merged 7 commits into from
Mar 20, 2024

Conversation

risemeup1
Copy link
Contributor

PR types

New features

PR changes

Others

Description

支持 SparceCooTensorType

Other
Pcard-67164

Copy link

paddle-bot bot commented Mar 20, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@risemeup1 risemeup1 changed the title support sparsecootensortype Support SparseCooTensorType Mar 20, 2024
std::size_t operator()(const pir::DenseTensorType &obj) const {
// return
// pir::DenseTensorTypeStorage::HashValue(std::make_tuple(pir::Type(),
// pir::DDim(), pir::DataLayout(), pir::LoD(), size_t()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无用的注释可以删除一下,下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,收到

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest delete

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest delete

好的

using ParamKey = std::
tuple<DataType, Dim, DataLayout, DenseTensorType, DenseTensorType, bool>;
SparseCooTensorTypeStorage(DataType dtype,
Dim dims,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和phi体系下数据结构对应,添加meta_dims

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/include/core/sparse_type.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接放在 op_type.h 中吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接放在 op_type.h 中吧

好的

// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理,直接放在 paddle/fluid/pir/dialect/operator/ir/type_storage.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理,直接放在 paddle/fluid/pir/dialect/operator/ir/type_storage.h

好的


#pragma once

#include "paddle/pir/include/core/sparse_type_storage.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上移到:paddle/fluid/pir/dialect/operator/ir/op_type.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上移到:paddle/fluid/pir/dialect/operator/ir/op_type.h

好的

std::get<2>(key))));
// hash DenseTensorType
hash_value = pir::detail::hash_combine(
hash_value, std::hash<DenseTensorType>()(std::get<3>(key)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DenseTensorTypeStorage:: HashValue(...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DenseTensorTypeStorage:: HashValue(...)

好的

std::size_t operator()(const pir::DenseTensorType &obj) const {
// return
// pir::DenseTensorTypeStorage::HashValue(std::make_tuple(pir::Type(),
// pir::DDim(), pir::DataLayout(), pir::LoD(), size_t()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest delete

Type dtype() const;
const Dim &dims() const;
DataLayout data_layout() const;
DenseTensorType get_indices() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DenseTensorType get_indices() const;
DenseTensorType non_zero_indices() const;

同下

///
/// \brief Declare ParamKey according to parameter type.
///
using Dim = pir::DDim;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using Dim = pir::DDim;
using Dim = common::DDim;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

/// \brief Declare ParamKey according to parameter type.
///
using Dim = pir::DDim;
using DataLayout = pir::DataLayout;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -249,6 +250,39 @@ TEST(type_test, custom_type_dialect) {
EXPECT_EQ(dialect_integer1, dialect_integer2);
}

TEST(type_test, sparse_dialect) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TEST(type_test, sparse_dialect) {
TEST(type_test, sparse_coo) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的


SparseCooTensorType SparseCooTensorType::dyn_cast_impl(Type type) {
if (type) {
if (type.type_id() == type_id()) return SparseCooTensorType(type.storage());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if内单行语句也建议用大括号包裹起来,其他地方同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if内单行语句也建议用大括号包裹起来,其他地方同

ok

Comment on lines 82 to 85
using Type = pir::Type;
using Dim = SparseCooTensorTypeStorage::Dim;
using DataLayout = common::DataLayout;
using DenseTensorType = pir::DenseTensorType;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些是不是相当于在类SparseCooTensorType内取了这些type的别名并且还是public的啊,感觉不太好?比如你可以这样访问SparseCooTensorType::DenseTensorType,不建议这样做

Comment on lines 174 to 177
using Dim = common::DDim;
using DataLayout = common::DataLayout;
using DataType = pir::Type;
using DenseTensorType = pir::DenseTensorType;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里同

public:
using Base::Base;
using Type = pir::Type;
using Dim = SparseCooTensorTypeStorage::Dim;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using Dim = SparseCooTensorTypeStorage::Dim;
using Dim = common::Dim;

ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
pir::Type fp32_dtype = pir::Float32Type::get(ctx);
common::DDim dims = {4, 4};
common::DDim meta_ddims = {4, 1};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meta_ddims -> meta_dims

coalesced);

paddle::dialect::SparseCooTensorType sparse_coo_tensor_type =
paddle::dialect::SparseCooTensorType::dyn_cast_impl(pir_type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里一般验证方法是pir_type.dyn_cast<paddle::dialect::SparseCooTensorType>(),而不用dyn_cast_impl

paddle::dialect::SparseCooTensorType sparse_coo_tensor_type =
paddle::dialect::SparseCooTensorType::dyn_cast_impl(pir_type);

EXPECT_EQ(sparse_coo_tensor_type.isa<paddle::dialect::SparseCooTensorType>(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一般先判断isa<> 再进行dyn_cast

paddle::dialect::SparseCooTensorType sparse_coo_tensor_type =
paddle::dialect::SparseCooTensorType::dyn_cast_impl(pir_type);

EXPECT_EQ(sparse_coo_tensor_type.isa<paddle::dialect::SparseCooTensorType>(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
EXPECT_EQ(sparse_coo_tensor_type.isa<paddle::dialect::SparseCooTensorType>(),
EXPECT_EQ(pir_type.isa<paddle::dialect::SparseCooTensorType>(),

if (type.type_id() == type_id()) {
return true;
}
if (auto wrap_type = type.dyn_cast<pir::WrapTypeInterface>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SparseCooTensorType 在定义的时候就没有WrapTypeInterface,不需要这段逻辑

if (type.type_id() == type_id()) {
return SparseCooTensorType(type.storage());
}
if (auto wrap_type = type.dyn_cast<pir::WrapTypeInterface>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上


DataType dtype_;
Dim dims_;
Dim meta_dims_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meta_dims 与 dims 的关系是什么,在 SparseCooTensor 中,这两个变量的具体作用有重叠么?

Comment on lines 82 to 84
using Type = pir::Type;
using Dim = common::Dim;
using DataLayout = common::DataLayout;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否直接用比较好呢?不需要在这里对他们取别名

Copy link
Contributor

@chen2016013 chen2016013 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@risemeup1 risemeup1 merged commit cc53f1c into PaddlePaddle:develop Mar 20, 2024
29 of 30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants