Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SparseCooTensorType #62868

Merged
merged 7 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ void PrintOperationImpl(pir::Operation* op,

void OperatorDialect::initialize() {
RegisterTypes<paddle::dialect::SelectedRowsType,
paddle::dialect::SparseCooTensorType,
paddle::dialect::DenseTensorArrayType>();

RegisterAttributes<paddle::dialect::IntArrayAttribute,
Expand Down
43 changes: 43 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/op_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,51 @@ DenseTensorArrayType DenseTensorArrayType::dyn_cast_impl(Type type) {
return nullptr;
}

pir::Type SparseCooTensorType::dtype() const { return storage()->dtype_; }

const common::DDim& SparseCooTensorType::dims() const {
return storage()->dims_;
}

const common::DDim& SparseCooTensorType::non_zero_dims() const {
return storage()->non_zero_dims_;
}

common::DataLayout SparseCooTensorType::data_layout() const {
return storage()->layout_;
}

pir::DenseTensorType SparseCooTensorType::non_zero_indices() const {
return storage()->non_zero_indices_;
}

pir::DenseTensorType SparseCooTensorType::non_zero_elements() const {
return storage()->non_zero_elements_;
}

bool SparseCooTensorType::coalesced() const { return storage()->coalesced_; }

bool SparseCooTensorType::classof(Type type) {
if (type) {
if (type.type_id() == type_id()) {
return true;
}
}
return false;
}

SparseCooTensorType SparseCooTensorType::dyn_cast_impl(Type type) {
if (type) {
if (type.type_id() == type_id()) {
return SparseCooTensorType(type.storage());
}
}
return nullptr;
}

} // namespace dialect
} // namespace paddle

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorArrayType)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCooTensorType)
42 changes: 42 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/op_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,50 @@ class DenseTensorArrayType
static DenseTensorArrayType dyn_cast_impl(Type type);
};

class IR_API SparseCooTensorType
: public pir::Type::
TypeBase<SparseCooTensorType, pir::Type, SparseCooTensorTypeStorage> {
public:
using Base::Base;

pir::Type dtype() const;
const common::DDim &dims() const;
const common::DDim &non_zero_dims() const;
common::DataLayout data_layout() const;
pir::DenseTensorType non_zero_indices() const;
pir::DenseTensorType non_zero_elements() const;
bool coalesced() const;

///
/// \brief Implementation of 'classof' that compares the type id of
/// the provided value with the concrete type id.
///
static bool classof(pir::Type type);

static SparseCooTensorType dyn_cast_impl(pir::Type type);

static SparseCooTensorType get(pir::IrContext *ctx,
pir::Type dtype,
const common::DDim &dims,
const common::DDim &non_zero_dims,
common::DataLayout layout,
pir::DenseTensorType non_zero_indices,
pir::DenseTensorType non_zero_elements,
bool coalesced = false) {
return Base::get(ctx,
dtype,
dims,
non_zero_dims,
layout,
non_zero_indices,
non_zero_elements,
coalesced);
}
};

} // namespace dialect
} // namespace paddle

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorArrayType)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCooTensorType)
120 changes: 120 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/type_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <type_traits>

#include "paddle/phi/core/tensor_meta.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/core/builtin_type_storage.h"
#include "paddle/pir/include/core/type.h"
#include "paddle/pir/include/core/type_base.h"
Expand Down Expand Up @@ -166,5 +167,124 @@ struct DenseTensorArrayTypeStorage : public pir::TypeStorage {
phi::DataLayout layout_;
};

struct SparseCooTensorTypeStorage : public pir::TypeStorage {
///
/// \brief Declare ParamKey according to parameter type.
///
using ParamKey = std::tuple<pir::Type,
common::DDim,
common::DDim,
common::DataLayout,
pir::DenseTensorType,
pir::DenseTensorType,
bool>;
SparseCooTensorTypeStorage(pir::Type dtype,
common::DDim dims,
common::DDim non_zero_dims,
common::DataLayout layout,
pir::DenseTensorType non_zero_indices,
pir::DenseTensorType non_zero_elements,
bool coalesced = false)
: dtype_(dtype),
dims_(dims),
non_zero_dims_(non_zero_dims),
layout_(layout),
non_zero_indices_(non_zero_indices),
non_zero_elements_(non_zero_elements),
coalesced_(coalesced) {}

///
/// \brief Each derived TypeStorage must define a Construct method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static SparseCooTensorTypeStorage* Construct(const ParamKey& key) {
return new SparseCooTensorTypeStorage(std::get<0>(key),
std::get<1>(key),
std::get<2>(key),
std::get<3>(key),
std::get<4>(key),
std::get<5>(key),
std::get<6>(key));
}

///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey& key) {
std::size_t hash_value = 0;
// hash dtype
hash_value = pir::detail::hash_combine(
hash_value, std::hash<pir::Type>()(std::get<0>(key)));
// hash dims
hash_value = pir::detail::hash_combine(
hash_value, std::hash<common::DDim>()(std::get<1>(key)));
// hash non_zero_dims
hash_value = pir::detail::hash_combine(
hash_value, std::hash<common::DDim>()(std::get<2>(key)));
// hash layout
hash_value = pir::detail::hash_combine(
hash_value,
std::hash<std::underlying_type<DataLayout>::type>()(
static_cast<std::underlying_type<DataLayout>::type>(
std::get<3>(key))));
// hash DenseTensorType
auto tuple1 = std::make_tuple(std::get<4>(key).dtype(),
std::get<4>(key).dims(),
std::get<4>(key).data_layout(),
std::get<4>(key).lod(),
std::get<4>(key).offset());
hash_value = pir::detail::hash_combine(
hash_value, DenseTensorTypeStorage::HashValue(tuple1));
// hash DenseTensorType
auto tuple2 = std::make_tuple(std::get<5>(key).dtype(),
std::get<5>(key).dims(),
std::get<5>(key).data_layout(),
std::get<5>(key).lod(),
std::get<5>(key).offset());
hash_value = pir::detail::hash_combine(
hash_value, DenseTensorTypeStorage::HashValue(tuple2));
// hash coalesced
hash_value = pir::detail::hash_combine(hash_value,
std::hash<bool>()(std::get<6>(key)));

return hash_value;
}

///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey& key) const {
return ParamKey(dtype_,
dims_,
non_zero_dims_,
layout_,
non_zero_indices_,
non_zero_elements_,
coalesced_) == key;
}

ParamKey GetAsKey() const {
return ParamKey(dtype_,
dims_,
non_zero_dims_,
layout_,
non_zero_indices_,
non_zero_elements_,
coalesced_);
}

///
/// \brief SparseCooTensorTypeStorage include six parameters: dims, dtype,
/// layout, non_zero_indices_, non_zero_elements_,coalesced_.
///

pir::Type dtype_;
common::DDim dims_;
common::DDim non_zero_dims_;
common::DataLayout layout_{DataLayout::NCHW};
pir::DenseTensorType non_zero_indices_;
pir::DenseTensorType non_zero_elements_;
bool coalesced_ = false;
};
} // namespace dialect
} // namespace paddle
35 changes: 35 additions & 0 deletions test/cpp/pir/core/type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,41 @@ TEST(type_test, custom_type_dialect) {
EXPECT_EQ(dialect_integer1, dialect_integer2);
}

TEST(type_test, sparse_coo) {
pir::IrContext *ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
pir::Type fp32_dtype = pir::Float32Type::get(ctx);
common::DDim dims = {4, 4};
common::DDim non_zero_dims = {4, 1};
common::DataLayout data_layout = common::DataLayout::NCHW;
pir::LoD lod = {{0, 1, 2}};
size_t offset = 0;
pir::DenseTensorType none_zero_indices = pir::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset);
pir::DenseTensorType none_zero_elements = pir::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset);
bool coalesced = false;
pir::Type pir_type =
paddle::dialect::SparseCooTensorType::get(ctx,
fp32_dtype,
dims,
non_zero_dims,
data_layout,
none_zero_indices,
none_zero_elements,
coalesced);

EXPECT_EQ(pir_type.isa<paddle::dialect::SparseCooTensorType>(), true);
paddle::dialect::SparseCooTensorType sparse_coo_tensor_type =
pir_type.dyn_cast<paddle::dialect::SparseCooTensorType>();
EXPECT_EQ(sparse_coo_tensor_type.dims(), dims);
EXPECT_EQ(sparse_coo_tensor_type.non_zero_dims(), non_zero_dims);
EXPECT_EQ(sparse_coo_tensor_type.data_layout(), data_layout);
EXPECT_EQ(sparse_coo_tensor_type.non_zero_indices(), none_zero_indices);
EXPECT_EQ(sparse_coo_tensor_type.non_zero_elements(), none_zero_elements);
EXPECT_EQ(sparse_coo_tensor_type.coalesced(), coalesced);
}

TEST(type_test, pd_op_dialect) {
pir::IrContext *ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand Down