Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] cleanup of COO #69239

Merged
merged 1 commit into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,13 @@ struct ElementLT final {
const uint64_t rank;
};

/// The type of callback functions which receive an element.
template <typename V>
using ElementConsumer =
const std::function<void(const std::vector<uint64_t> &, V)> &;

/// A memory-resident sparse tensor in coordinate-scheme representation
/// (a collection of `Element`s). This data structure is used as
/// an intermediate representation; e.g., for reading sparse tensors
/// from external formats into memory, or for certain conversions between
/// different `SparseTensorStorage` formats.
/// (a collection of `Element`s). This data structure is used as an
/// intermediate representation, e.g., for reading sparse tensors from
/// external formats into memory.
template <typename V>
class SparseTensorCOO final {
public:
using const_iterator = typename std::vector<Element<V>>::const_iterator;

/// Constructs a new coordinate-scheme sparse tensor with the given
/// sizes and an optional initial storage capacity.
explicit SparseTensorCOO(const std::vector<uint64_t> &dimSizes,
Expand Down Expand Up @@ -106,7 +98,7 @@ class SparseTensorCOO final {
/// Returns the `operator<` closure object for the COO's element type.
ElementLT<V> getElementLT() const { return ElementLT<V>(getRank()); }

/// Adds an element to the tensor. This method invalidates all iterators.
/// Adds an element to the tensor.
void add(const std::vector<uint64_t> &dimCoords, V val) {
const uint64_t *base = coordinates.data();
const uint64_t size = coordinates.size();
Expand Down Expand Up @@ -135,12 +127,9 @@ class SparseTensorCOO final {
elements.push_back(addedElem);
}

const_iterator begin() const { return elements.cbegin(); }
const_iterator end() const { return elements.cend(); }

/// Sorts elements lexicographically by coordinates. If a coordinate
/// is mapped to multiple values, then the relative order of those
/// values is unspecified. This method invalidates all iterators.
/// values is unspecified.
void sort() {
if (isSorted)
return;
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
namespace mlir {
namespace sparse_tensor {

/// The type of callback functions which receive an element.
template <typename V>
using ElementConsumer =
const std::function<void(const std::vector<uint64_t> &, V)> &;

// Forward references.
template <typename V>
class SparseTensorEnumeratorBase;
Expand Down