Skip to content

Commit 2239c39

Browse files
authored
Fix Stack (#925)
* return the correct layer * unskip the test
1 parent 0feefc7 commit 2239c39

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

k2/csrc/ragged_test.cu

+1-4
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,7 @@ TEST(RaggedShapeOpsTest, UnstackRandom) {
265265
for (size_t i = 0; i < out.size(); ++i) {
266266
out_ptr.emplace_back(&(out[i]));
267267
}
268-
// There is a bug in `Stack` for stacking a shape itself,
269-
// not urgent, so skipping here.
270-
// TODO: Remove this line when the bug fixed.
271-
if (out.size() == 1) continue;
268+
272269
auto dest = Stack(axis, out.size(), out_ptr.data());
273270
dest = RemoveEmptyLists(dest, axis);
274271

k2/csrc/ragged_utils.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ RaggedShape IntersperseRaggedLayer(int32_t layer,
9292
if (merge_map)
9393
*(reinterpret_cast<Array1<int32_t>*>(merge_map)) =
9494
Range(src[0]->Context(), src[0]->TotSize(layer + 1), 0);
95-
return *src[0];
95+
std::vector<RaggedShapeLayer> layers;
96+
layers.emplace_back(src[0]->Layers()[layer]);
97+
return RaggedShape(layers);
9698
}
9799

98100
std::vector<int32_t*> row_splits_ptrs_vec(num_srcs);

0 commit comments

Comments
 (0)