Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ArrayOfRagged #927

Merged
merged 8 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions k2/csrc/array_of_ragged.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Daniel Povey)
* 2022 ASLP@NWPU (authors: Hang Lyu)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "k2/csrc/array_of_ragged.h"

namespace k2 {

ArrayOfRaggedShape::ArrayOfRaggedShape(RaggedShape *srcs, int32_t num_srcs) :
num_srcs_(num_srcs) {
K2_CHECK_GT(num_srcs, 0);
K2_CHECK(srcs);

// Initialize context and num_axes_.
c_ = srcs[0].Context();
num_axes_ = srcs[0].NumAxes();

// Check if they have same num-axes and compatible context.
for (int32_t i = 1; i < num_srcs_; ++i) {
K2_CHECK_EQ(num_axes_, srcs[i].NumAxes());
K2_CHECK(IsCompatible(c_, srcs[i].Context()));
}

// Initialize row_splits, row_ids_ and tot_sizes_.
row_splits_ = Array2<int32_t *>(c_, num_axes_ - 1, num_srcs_);
row_ids_ = Array2<int32_t *>(c_, num_axes_ - 1, num_srcs_);
tot_sizes_ = Array1<int32_t>(c_, num_axes_, 0);

Array2Accessor<int32_t *> row_splits_acc = row_splits_.Accessor(),
row_ids_acc = row_ids_.Accessor();
// Bear in mind, when axis == 0, the TotSize() is row_splits.Dim() - 1.
// When 0 < axis < NumAxes(), the TotSize() is row_splits.Back().
int32_t tot_sizes_data = tot_sizes_.Data();

for (int32_t i = 1; i < num_axes_; ++i) {
for (int32_t j = 0; j < num_srcs_; ++j) {
row_splits_acc(i - 1, j) = srcs[j].RowSplits(i).Data();
row_ids_acc(i - 1, j) = srcs[j].RowIds(i).Data();
tot_sizes_[i] += srcs[j].TotSize(i);
}
}
// Deal with the special axis == 0.
for (int32_t i = 0; i < num_srcs_; ++i) {
tot_sizes_[0] += srcs[i].TotSize(0);
}

// Initialize meat_row_splits_
// We populate this on CPU and transfer to GPU.
meta_row_splits_ = Array2<int32_t>(GetCpuContext(), num_axes_, num_srcs_ + 1);
offsets_ = Array2<int32_t>(GetCpuContext(), num_axes_ + 1, num_srcs_ + 1);

Array2Accessor<int32_t> meta_row_splits_acc = meta_row_splits_.Accessor(),
offsets_acc = offsets_.Accessor();
// Initialize the 1st row/col of offsets_ and meta_row_splits_
for (int32_t col = 0; col <= num_srcs_; ++col) {
offsets_acc(0, col) = col;
}
for (int32_t row = 1; row <= num_axes_; ++ row) {
offsets_acc(row, 0) = 0;
meta_row_splits_acc(row, 0) = 0;
}
// The meta_row_splits_ is the cumulative sum of the tot-sizes of the
// individual arrays.
for (int32_t i = 0; i < num_axes_; ++i) {
for (int32_t j = 1; j <= num_srcs_; ++j) {
meta_row_splits_acc(i, j) = meta_row_splits_acc(i, j - 1) +
srcs[j - 1].TotSize(i);
offsets_acc(i + 1, j) = meta_row_splits_acc(i, j);
}
}

meta_row_splits_ = meta_row_splits_.To(c_);
offsets_ = offsets_.To(c_);

// Initialize meta_row_ids_
// Elements are in [0, NumSrcs() - 1]
meta_row_ids_.resize(num_axes_);
for (int32_t axis = 0; axis < num_axes_; ++axis) {
// The length equals to TotSize(axis)
meta_row_ids_.at(axis) = Array1<int32_t>(
GetCpuContext(), meta_row_splits_acc(axis, num_srcs_));
int32_t meta_row_ids_data = meta_row_ids_[axis].Data();

int32_t cur_row_start = meta_row_splits_acc(axis, 0);
for (int32_t src = 0; src < num_srcs_; ++src) {
int32_t next_row_start = meta_row_splits_acc(axis, src + 1);
for (; cur_row_start < next_row_start; ++cur_row_start) {
meta_row_ids_data[cur_row_start] = src;
}
}
meta_row_ids_[axis].To(c_);
}
}

ArrayOfRagged::ArrayOfRagged(Ragged<T> *srcs, int32_t num_srcs) :
values(srcs->Context(), num_srcs, nullptr) {
K2_CHECK_GT(num_srcs, 0);
K2_CHECK(srcs);

T **values_data = values.Data();
std::array<RaggedShape, num_srcs> shapes;

for (int32_t i = 0; i < num_srcs; i++) {
// Initialize values
values_data[i] = srcs[i].values.Data();
shapes[i] = srcs[i].shape;
}

// Initialize shape
shape = ArrayOfRaggedShape(shapes.data(), num_srcs);
}

} // namespeace k2
229 changes: 229 additions & 0 deletions k2/csrc/array_of_ragged.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Daniel Povey)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef K2_CSRC_ARRAY_OF_RAGGED_H_
#define K2_CSRC_ARRAY_OF_RAGGED_H_

#include <string>
#include <utility>
#include <vector>

#include "k2/csrc/array.h"
#include "k2/csrc/context.h"
#include "k2/csrc/log.h"

namespace k2 {

/*
ArrayOfRagged<T> is a 1-dimensional array of Ragged<T>.
It is intended for situations where you want to do some operations on
arrays of ragged arrays, without explicitly concatenating them (e.g. to
save time). This is a fairly low-level interface, intended to
be used mostly by CUDA/C++ implementation code. It is a convenience
wrapper that saves you the trouble of creating arrays of pointers.
*/


/*
ArrayOfRaggedShape is a convenience function that gives you easy access
to pointers-of-pointers for an array of ragged shapes.
*/
class ArrayOfRaggedShape {
public:

// Default constructor.
ArrayOfRaggedShape() = default;

/*
Constructor.
Args:
srcs: pointers to the source shapes, a CPU pointer
num_srcs: the number of source shapes. All shapes must have the
same NumAxes() and must be on the same device.

TODO: we'll likely, later, add optional args which dictate which of
the MetaRowSplits() and MetaRowIds() are to be pre-populated; this should
enable us to save kernels by combining certain operations across the
axes.

*/
ArrayOfRaggedShape(RaggedShape *srcs,
int32_t num_srcs);


int32_t NumSrcs() const { return num_srcs_; }
int32_t NumAxes() const { return num_axes_; }

// Returns device-accessible array of row-splits for the individual shapes,
// indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this
// Array2 is [NumAxes() - 1][NumSrcs()].
Array2<int32_t*> *RowSplits() { return row_splits_; }

// Returns device-accessible vector of row-splits for a particular
// axis, indexed by 0 <= src < num_srcs.
int32_t **RowSplits(int32_t axis) {
K2_CHECK_LT(static_cast<uint32_t>(axis),
static_cast<uint32_t>num_axes_);
return row_splits_.Row(axis - 1).Data();
}

// Returns device-accessible array of row-ids for the individual shapes
// indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this
// Array2 is [NumAxes() - 1][NumSrcs()].
Array2<int32_t*> *RowIds() { return row_ids_; }


// Returns device-accessible vector of row-splits for a particular
// axis, indexed by 0 <= src < num_srcs.
int32_t **RowIds(int32_t axis) {
K2_CHECK_LT(static_cast<uint32_t>(axis),
static_cast<uint32_t>num_axes_);
return row_ids_.Row(axis - 1).Data();
}


/* Return the total size on this axis, which is the sum of the TotSize() of
the individual shapes. Requires 0 <= axis < NumAxes() and
for axis=0 the returned value is the same as Dim0().
*/
int32_t TotSize(int32_t axis) const {
K2_CHECK_LT(static_cast<uint32_t>(axis),
static_cast<uint32_t>num_axes_);
return tot_sizes_[axis];
}

// equivalent to TotSize(0).
int32_t Dim0() const { return TotSize(0); }


/* Return the device-accessible meta-row-splits, which is the cumulative sum,
along the src axis, of the tot-sizes of the individual arrays.
This Array2 is of shape [NumAxes()][NumSrcs() + 1], indexed [axis][src];
caution, the indexing is different from RowSplits(), there is no offset.
Also, the meta_row_splits0 is a thing, unlike with regular row-splits
which start from 1.

Caution: the lengths of the arrays pointed to by the elements of this
Array2 (which contains pointers!) are of course all different, and
these lengths are currently only available

Implementation note: we can probably just populate this on CPU and transfer
to GPU, this will be faster than invoking an extra kernel in normal cases
when the NumSrcs() is small. [Also: see GetRowInfoMulti()].
*/
Array2<int32_t> MetaRowSplits() { return meta_row_splits_; }

// could POSSIBLY add this so this code could be used in functions like Stack().
// would be like MetaRowSplits but with an extra 1st row containing 0,1,2,...
// We could perhaps create it with 1 extra initial row so this is always
// convenient to output.
Array2<int32_t> Offsets() { return offsets_; }

/*
Returns the meta-row-splits for a particular axis, with 0 <= axis < NumAxes();
this is the cumulative sum of the TotSize(axis) for all of the sources,
with MetaRowSplits(axis).Dim() == NumSrcs() + 1.

Note: in ragged_opts.cu we refer to this as composed_row_splits
*/
Array1<int32_t> MetaRowSplits(int32_t axis) {
K2_CHECK_LT(static_cast<uint32_t>(axis),
static_cast<uint32_t>num_axes_);
return meta_row_splits_.Row(axis);
}

/* Return the device-accessible meta-row-ids, which are the row-ids corresponding
to MetaRowSplits(); this tells us, for indexes into the appended/concatenated
array, which source array they belong to, i.e. elements are in [0,NumSrcs()-1].

This cannot be an Array2 because unlike the MetaRowSplits(), all the row-ids
arrays are of different lengths.

Note: in ragged_ops.cu we refer to this as composed_row_ids.
*/
Array1<int32_t*> MetaRowIds() {
Array1<int32_t*> ans(c_, num_axes_);
int32_t* *ans_data = ans.Data();
K2_EVAL(
c_, num_axes_, lambda_set_meta_row_ids, (int32_t axis)->void {
ans_data[axis] = meta_row_ids_[axis].Data();
});
return ans;
}

/*
Returns the meta-row-ids for a particular axis, with 0 <= axis < NumAxes();
this is the row-ids corresponding to MetaRowSplits(axis), and its elements
gives, for indexes into the concatentated shape (concatenated on axis 0),m
which source they come from. E.g. element 100 of MetaRowIds(2)
would tell us which source an idx012 with value 100 into axis 2 of
concatenated array would come from.
*/
Array1<int32_t> MetaRowIds(int32_t axis) {
K2_CHECK_LT(static_cast<uint32_t>(axis),
static_cast<uint32_t>num_axes_);
return meta_row_ids_[axis];
}

private:
ContextPtr c_;
int32_t num_srcs_;
int32_t num_axes_;

// Its shape is [num_axes_ - 1][num_srcs_]
Array2<int32_t*> row_splits_;
Array2<int32_t*> row_ids_;

// Its dimension is num_axes_
Array1<int32_t> tot_sizes_;

// Its shape is [num_axes_][num_srcs_ + 1]
Array2<int32_t> meta_row_splits_;
// Its shape is [num_axes_ + 1][num_srcs_ + 1]
Array2<int32_t> offsets_;
// The length of vector is num_axes_
std::vector<Array1<int32_t> > meta_row_ids_;
};



template <typename T>
struct ArrayOfRagged {

ArrayOfRaggedShape shape;

// Array of the individual values pointers of the source arrays, indexed by
// shape
Array1<T*> values;

int32_t NumSrcs() { return values.Dim(); }

// Default constructor will not leave this a valid ArrayOfRagged object,
// you shouldn't do anything with it. Both members will be initialized with
// default constructors.
ArrayOfRagged() = default;

ArrayOfRagged(Ragged<T> *srcs,
int32_t num_srcs);
};



} // namespace k2

#endif // K2_CSRC_ARRAY_OF_RAGGED_H_