Skip to content

Commit cca7a54

Browse files
pkufoolcsukuangfj
andauthored
Construct RaggedArc from unary function tensor (#30)
* Construct RaggedArc from unary function tensor * Move fsa_from_unary_ragged and fsa_from_binary_tensor to C++ * add unit test to from unary function; add more functions to fsa * Remove some rabbish code * Add more unit tests and docs * Remove the unused code * Fix review comments, propagate attributes in To() * Change the argument type from RaggedAny to Ragged<int32_t> in autograd function * Delete declaration for template function * Apply suggestions from code review Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * Fix documentation errors Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
1 parent cbff6a1 commit cca7a54

29 files changed

+1704
-433
lines changed

k2/csrc/fsa.h

+7-9
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,12 @@ struct DenseFsaVec {
187187
std::ostream &operator<<(std::ostream &os, const DenseFsaVec &dfsavec);
188188

189189
/*
190-
Create an FSA from a Tensor. The Tensor t is expected to be an N by 4 tensor of
191-
int32_t, where N is the number of arcs (the format is src_state, dest_state,
192-
symbol, cost). The cost is not really an int32_t, it is a float. This code
193-
will print an error message and output 'true' to 'error', and return an empty
194-
FSA (with no states or arcs) if t was not interpretable as a valid FSA.
195-
These requirements for a valid FSA are:
190+
Create an FSA from a Tensor. The Tensor t is expected to be an N by 4 tensor
191+
of int32_t, where N is the number of arcs (the format is src_state,
192+
dest_state, symbol, cost). The cost is not really an int32_t, it is a float.
193+
This code will print an error message and output 'true' to 'error', and return
194+
an empty FSA (with no states or arcs) if t was not interpretable as a valid
195+
FSA. These requirements for a valid FSA are:
196196
197197
- src_state values on the arcs must be non-decreasing
198198
- all arcs with -1 as the label must be to a single state (call this
@@ -333,9 +333,7 @@ FsaVec FsaVecFromTensor(Tensor &t, bool *error);
333333
refer to a part of the `values` array of
334334
the input `vec`.
335335
*/
336-
inline Fsa GetFsaVecElement(FsaVec &vec, int32_t i) {
337-
return vec.Index(0, i);
338-
}
336+
inline Fsa GetFsaVecElement(FsaVec &vec, int32_t i) { return vec.Index(0, i); }
339337

340338
/*
341339
Create an FsaVec from a list of Fsas. Caution: Fsa and FsaVec are really

k2/python/csrc/CMakeLists.txt

+41
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,44 @@ target_link_libraries(_k2 PRIVATE context)
2424
target_link_libraries(_k2 PRIVATE fsa)
2525
target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR})
2626
target_include_directories(_k2 PRIVATE ${CMAKE_BINARY_DIR})
27+
set_property(TARGET _k2 PROPERTY CXX_VISIBILITY_PRESET "default")
28+
29+
#---------------------------- Test torch CUDA sources ----------------------------
30+
31+
# please sort the source files alphabetically
32+
set(torch_cuda_test_srcs
33+
torch/v2/ragged_arc_test.cu
34+
)
35+
if(NOT K2_WITH_CUDA)
36+
transform(OUTPUT_VARIABLE torch_cuda_test_srcs SRCS ${torch_cuda_test_srcs})
37+
endif()
38+
39+
# utility function to add gtest
40+
function(torch_add_cuda_test source)
41+
get_filename_component(target_name ${source} NAME_WE)
42+
add_executable(${target_name} "${source}")
43+
set_target_properties(${target_name} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
44+
target_link_libraries(${target_name}
45+
PRIVATE
46+
_k2
47+
context
48+
fsa
49+
gtest
50+
)
51+
52+
# NOTE: We set the working directory here so that
53+
# it works also on windows. The reason is that
54+
# the required DLLs are inside ${TORCH_DIR}/lib
55+
# and they can be found by the exe if the current
56+
# working directory is ${TORCH_DIR}\lib
57+
add_test(NAME "Test.Cuda.${target_name}"
58+
COMMAND
59+
$<TARGET_FILE:${target_name}>
60+
WORKING_DIRECTORY ${TORCH_DIR}/lib
61+
)
62+
endfunction()
63+
64+
foreach(source IN LISTS torch_cuda_test_srcs)
65+
torch_add_cuda_test(${source})
66+
endforeach()
67+

k2/python/csrc/k2.cu

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
PYBIND11_MODULE(_k2, m) {
2828
m.doc() = "pybind11 binding of k2";
29+
// _k2 depends on torch, we should import torch before importing _k2.
30+
py::module_::import("torch");
2931
PybindVersion(m);
3032
PybindTorch(m);
3133
}

k2/python/csrc/torch.cu

-4
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
#include "k2/python/csrc/torch/discounted_cum_sum.h"
2929
#include "k2/python/csrc/torch/fsa.h"
3030
#include "k2/python/csrc/torch/fsa_algo.h"
31-
#include "k2/python/csrc/torch/index_add.h"
32-
#include "k2/python/csrc/torch/index_select.h"
3331
#include "k2/python/csrc/torch/nbest.h"
3432
#include "k2/python/csrc/torch/ragged.h"
3533
#include "k2/python/csrc/torch/ragged_ops.h"
@@ -40,8 +38,6 @@ void PybindTorch(py::module &m) {
4038
PybindDiscountedCumSum(m);
4139
PybindFsa(m);
4240
PybindFsaAlgo(m);
43-
PybindIndexAdd(m);
44-
PybindIndexSelect(m);
4541
PybindNbest(m);
4642
PybindRagged(m);
4743
PybindRaggedOps(m);

k2/python/csrc/torch.h

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
#define K2_PYTHON_CSRC_TORCH_H_
2525

2626
#include "k2/csrc/log.h"
27-
#include "k2/python/csrc/torch.h"
2827
#include "torch/extension.h"
2928

3029
namespace pybind11 {

k2/python/csrc/torch/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ set(torch_srcs
44
discounted_cum_sum.cu
55
fsa.cu
66
fsa_algo.cu
7-
index_add.cu
8-
index_select.cu
97
nbest.cu
108
ragged.cu
119
ragged_ops.cu
@@ -15,6 +13,8 @@ set(torch_srcs
1513
v2/doc/doc.cu
1614
v2/fsa.cu
1715
v2/k2.cu
16+
v2/k2_ops.cu
17+
v2/ops.cu
1818
v2/ragged_any.cu
1919
v2/ragged_arc.cu
2020
v2/ragged_shape.cu

k2/python/csrc/torch/index_add.cu

-71
This file was deleted.

k2/python/csrc/torch/index_add.h

-35
This file was deleted.

k2/python/csrc/torch/torch_util.cu

+2-19
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,6 @@
2424

2525
namespace k2 {
2626

27-
torch::DeviceType ToTorchDeviceType(DeviceType type) {
28-
switch (type) {
29-
case kCuda:
30-
return torch::kCUDA;
31-
case kCpu:
32-
return torch::kCPU;
33-
case kUnk: // fall-through
34-
default:
35-
K2_LOG(FATAL) << "kUnk is not supported!";
36-
return torch::kCPU; // unreachable code
37-
}
38-
}
39-
4027
DeviceType FromTorchDeviceType(const torch::DeviceType &type) {
4128
switch (type) {
4229
case torch::kCUDA:
@@ -86,9 +73,7 @@ torch::ScalarType ScalarTypeFromDtype(Dtype dtype) {
8673

8774
template <>
8875
torch::Tensor ToTorch(Array1<Arc> &array) {
89-
auto device_type = ToTorchDeviceType(array.Context()->GetDeviceType());
90-
int32_t device_id = array.Context()->GetDeviceId();
91-
auto device = torch::Device(device_type, device_id);
76+
auto device = GetDevice(array.Context());
9277
auto scalar_type = ToScalarType<int32_t>::value;
9378
// an Arc has 4 members
9479
K2_STATIC_ASSERT(sizeof(Arc) == 4 * sizeof(int32_t));
@@ -134,9 +119,7 @@ Tensor FromTorch(torch::Tensor tensor, TensorTag) {
134119
return Tensor(dtype, shape, region, 0);
135120
}
136121
torch::Tensor ToTorch(Tensor &tensor) {
137-
auto device_type = ToTorchDeviceType(tensor.Context()->GetDeviceType());
138-
int32_t device_id = tensor.Context()->GetDeviceId();
139-
auto device = torch::Device(device_type, device_id);
122+
auto device = GetDevice(tensor.Context());
140123
auto scalar_type = ScalarTypeFromDtype(tensor.GetDtype());
141124
auto options = torch::device(device).dtype(scalar_type);
142125

k2/python/csrc/torch/torch_util.h

+24-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,18 @@ namespace k2 {
3838
3939
@return torch::kCUDA or torch.kCPU.
4040
*/
41-
torch::DeviceType ToTorchDeviceType(DeviceType type);
41+
inline torch::DeviceType ToTorchDeviceType(DeviceType type) {
42+
switch (type) {
43+
case kCuda:
44+
return torch::kCUDA;
45+
case kCpu:
46+
return torch::kCPU;
47+
case kUnk: // fall-through
48+
default:
49+
K2_LOG(FATAL) << "kUnk is not supported!";
50+
return torch::kCPU; // unreachable code
51+
}
52+
}
4253

4354
/* Convert torch::DeviceType to k2::DeviceType.
4455
Abort on failure.
@@ -252,6 +263,18 @@ PyClass To(PyClass &pyclass, py::object device) {
252263
*/
253264
ContextPtr GetContext(torch::Device device);
254265

266+
/** Create a torch device from a k2 context.
267+
268+
@param [in] context It must be a CPU or a CUDA context.
269+
270+
@return Return a CPU or a GPU device depending on the given context.
271+
*/
272+
inline torch::Device GetDevice(ContextPtr context) {
273+
auto device_type = ToTorchDeviceType(context->GetDeviceType());
274+
int32_t device_id = context->GetDeviceId();
275+
return torch::Device(device_type, device_id);
276+
}
277+
255278
inline ContextPtr GetContext(torch::Tensor tensor) {
256279
return GetContext(tensor.device());
257280
}

0 commit comments

Comments
 (0)