Skip to content

Commit 36734fb

Browse files
authored
Revert "Remove pybind dependencies from RaggedArc. (#842)"
This reverts commit daa98e7.
1 parent daa98e7 commit 36734fb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+1757
-2065
lines changed

.github/workflows/build-doc.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ jobs:
8888
- name: Build doc
8989
shell: bash
9090
run: |
91-
export PYTHONPATH=$PWD/k2/torch/python:$PWD/build/lib:$PYTHONPATH
91+
export PYTHONPATH=$PWD/k2/python:$PWD/build/lib:$PYTHONPATH
9292
echo "PYTHONPATH: $PYTHONPATH"
9393
cd docs
9494
python3 -m pip install -r ./requirements.txt

.github/workflows/run-tests-cpu.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,12 @@ jobs:
105105
- name: Display Build Information
106106
shell: bash
107107
run: |
108-
export PYTHONPATH=$PWD/k2/torch/python:$PWD/build/lib:$PYTHONPATH
108+
export PYTHONPATH=$PWD/k2/python:$PWD/build/lib:$PYTHONPATH
109109
python3 -m k2.version
110110
111111
- name: Run Tests
112112
shell: bash
113113
run: |
114-
export PYTHONPATH=$PWD/k2/torch/python:$PWD/build/lib:$PYTHONPATH
115114
cd build
116115
ctest --output-on-failure
117116
# default log level is INFO

.github/workflows/run-tests.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,12 @@ jobs:
118118
- name: Display Build Information
119119
shell: bash
120120
run: |
121-
export PYTHONPATH=$PWD/k2/torch/python:$PWD/build/lib:$PYTHONPATH
121+
export PYTHONPATH=$PWD/k2/python:$PWD/build/lib:$PYTHONPATH
122122
python3 -m k2.version
123123
124124
- name: Run Tests
125125
shell: bash
126126
run: |
127-
export PYTHONPATH=$PWD/k2/torch/python:$PWD/build/lib:$PYTHONPATH
128127
cd build
129128
ctest --output-on-failure
130129
# default log level is INFO

k2/CMakeLists.txt

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,2 @@
11
add_subdirectory(csrc)
2-
if(K2_USE_PYTORCH)
3-
add_subdirectory(torch)
4-
else()
5-
message(FATAL_ERROR "Please select a framework.")
6-
endif()
2+
add_subdirectory(python)

k2/torch/csrc/CMakeLists.txt k2/python/csrc/CMakeLists.txt

+22-15
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,35 @@
22
include(transform)
33

44
# please keep the list sorted
5-
set(torch_srcs
6-
fsa_class.cu
7-
ops.cu
8-
ragged_any.cu
9-
torch_utils.cu
5+
set(k2_srcs
6+
k2.cu
7+
torch.cu
8+
version.cu
109
)
1110

12-
if(NOT K2_WITH_CUDA)
13-
transform(OUTPUT_VARIABLE torch_srcs SRCS ${torch_srcs})
11+
if(K2_USE_PYTORCH)
12+
add_subdirectory(torch)
13+
set(k2_srcs ${k2_srcs} ${torch_srcs})
14+
else()
15+
message(FATAL_ERROR "Please select a framework.")
1416
endif()
1517

16-
add_library(k2_torch ${torch_srcs})
17-
target_link_libraries(k2_torch PUBLIC context)
18-
target_link_libraries(k2_torch PUBLIC fsa)
19-
target_include_directories(k2_torch PUBLIC ${CMAKE_SOURCE_DIR})
20-
target_include_directories(k2_torch PUBLIC ${CMAKE_BINARY_DIR})
18+
if(NOT K2_WITH_CUDA)
19+
transform(OUTPUT_VARIABLE k2_srcs SRCS ${k2_srcs})
20+
endif()
2121

22+
pybind11_add_module(_k2 ${k2_srcs} SHARED)
23+
target_link_libraries(_k2 PRIVATE context)
24+
target_link_libraries(_k2 PRIVATE fsa)
25+
target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR})
26+
target_include_directories(_k2 PRIVATE ${CMAKE_BINARY_DIR})
27+
set_property(TARGET _k2 PROPERTY CXX_VISIBILITY_PRESET "default")
2228

2329
#---------------------------- Test torch CUDA sources ----------------------------
2430

2531
# please sort the source files alphabetically
2632
set(torch_cuda_test_srcs
27-
fsa_class_test.cu
33+
torch/v2/ragged_arc_test.cu
2834
)
2935
if(NOT K2_WITH_CUDA)
3036
transform(OUTPUT_VARIABLE torch_cuda_test_srcs SRCS ${torch_cuda_test_srcs})
@@ -37,9 +43,10 @@ function(torch_add_cuda_test source)
3743
set_target_properties(${target_name} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
3844
target_link_libraries(${target_name}
3945
PRIVATE
40-
k2_torch
46+
_k2
47+
context
48+
fsa
4149
gtest
42-
gtest_main
4350
)
4451

4552
# NOTE: We set the working directory here so that

k2/python/csrc/k2.cu

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/**
2+
* @brief python wrappers for k2.
3+
*
4+
* @copyright
5+
* Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
6+
*
7+
* @copyright
8+
* See LICENSE for clarification regarding multiple authors
9+
*
10+
* Licensed under the Apache License, Version 2.0 (the "License");
11+
* you may not use this file except in compliance with the License.
12+
* You may obtain a copy of the License at
13+
*
14+
* http://www.apache.org/licenses/LICENSE-2.0
15+
*
16+
* Unless required by applicable law or agreed to in writing, software
17+
* distributed under the License is distributed on an "AS IS" BASIS,
18+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19+
* See the License for the specific language governing permissions and
20+
* limitations under the License.
21+
*/
22+
23+
#include "k2/python/csrc/k2.h"
24+
#include "k2/python/csrc/torch.h"
25+
#include "k2/python/csrc/version.h"
26+
27+
PYBIND11_MODULE(_k2, m) {
28+
m.doc() = "pybind11 binding of k2";
29+
// _k2 depends on torch, we should import torch before importing _k2.
30+
py::module_::import("torch");
31+
PybindVersion(m);
32+
PybindTorch(m);
33+
}

k2/torch/python/csrc/k2.h k2/python/csrc/k2.h

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
/**
2-
* @brief python wrapper for k2 2.0
2+
* @brief python wrappers for k2.
33
*
44
* @copyright
5-
* Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
5+
* Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
66
*
77
* @copyright
88
* See LICENSE for clarification regarding multiple authors
@@ -20,14 +20,12 @@
2020
* limitations under the License.
2121
*/
2222

23-
#ifndef K2_TORCH_PYTHON_CSRC_K2_H_
24-
#define K2_TORCH_PYTHON_CSRC_K2_H_
23+
#ifndef K2_PYTHON_CSRC_K2_H_
24+
#define K2_PYTHON_CSRC_K2_H_
2525

26-
#include "k2/csrc/log.h"
2726
#include "pybind11/pybind11.h"
2827
#include "pybind11/stl.h"
29-
#include "torch/extension.h"
3028

3129
namespace py = pybind11;
3230

33-
#endif // K2_TORCH_PYTHON_CSRC_K2_H_
31+
#endif // K2_PYTHON_CSRC_K2_H_

k2/python/csrc/torch.cu

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/**
2+
* @brief Everything related to PyTorch for k2 Python wrappers.
3+
*
4+
* @copyright
5+
* Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
6+
*
7+
* @copyright
8+
* See LICENSE for clarification regarding multiple authors
9+
*
10+
* Licensed under the Apache License, Version 2.0 (the "License");
11+
* you may not use this file except in compliance with the License.
12+
* You may obtain a copy of the License at
13+
*
14+
* http://www.apache.org/licenses/LICENSE-2.0
15+
*
16+
* Unless required by applicable law or agreed to in writing, software
17+
* distributed under the License is distributed on an "AS IS" BASIS,
18+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19+
* See the License for the specific language governing permissions and
20+
* limitations under the License.
21+
*/
22+
23+
#include "k2/python/csrc/torch.h"
24+
25+
#if defined(K2_USE_PYTORCH)
26+
27+
#include "k2/python/csrc/torch/arc.h"
28+
#include "k2/python/csrc/torch/discounted_cum_sum.h"
29+
#include "k2/python/csrc/torch/fsa.h"
30+
#include "k2/python/csrc/torch/fsa_algo.h"
31+
#include "k2/python/csrc/torch/nbest.h"
32+
#include "k2/python/csrc/torch/ragged.h"
33+
#include "k2/python/csrc/torch/ragged_ops.h"
34+
#include "k2/python/csrc/torch/v2/k2.h"
35+
36+
void PybindTorch(py::module &m) {
37+
PybindArc(m);
38+
PybindDiscountedCumSum(m);
39+
PybindFsa(m);
40+
PybindFsaAlgo(m);
41+
PybindNbest(m);
42+
PybindRagged(m);
43+
PybindRaggedOps(m);
44+
45+
k2::PybindV2(m);
46+
}
47+
48+
#else
49+
50+
void PybindTorch(py::module &) {}
51+
52+
#endif

k2/torch/python/csrc/torch.h k2/python/csrc/torch.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
* limitations under the License.
2121
*/
2222

23-
#ifndef K2_TORCH_PYTHON_CSRC_TORCH_H_
24-
#define K2_TORCH_PYTHON_CSRC_TORCH_H_
23+
#ifndef K2_PYTHON_CSRC_TORCH_H_
24+
#define K2_PYTHON_CSRC_TORCH_H_
2525

2626
#include "k2/csrc/log.h"
2727
#include "torch/extension.h"
@@ -102,4 +102,4 @@ struct type_caster<torch::ScalarType> {
102102

103103
void PybindTorch(py::module &m);
104104

105-
#endif // K2_TORCH_PYTHON_CSRC_TORCH_H_
105+
#endif // K2_PYTHON_CSRC_TORCH_H_

k2/torch/python/csrc/arc.cu k2/python/csrc/torch/arc.cu

+4-3
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525
#include "k2/csrc/device_guard.h"
2626
#include "k2/csrc/fsa.h"
27-
#include "k2/torch/csrc/torch_utils.h"
28-
#include "k2/torch/python/csrc/arc.h"
27+
#include "k2/python/csrc/torch/arc.h"
28+
#include "k2/python/csrc/torch/torch_util.h"
2929

3030
namespace k2 {
3131

@@ -101,5 +101,6 @@ static void PybindArcImpl(py::module &m) {
101101
py::arg("tensor"));
102102
}
103103

104-
void PybindArc(py::module &m) { PybindArcImpl(m); }
105104
} // namespace k2
105+
106+
void PybindArc(py::module &m) { k2::PybindArcImpl(m); }

k2/torch/python/csrc/arc.h k2/python/csrc/torch/arc.h

+5-6
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
* limitations under the License.
2121
*/
2222

23-
#ifndef K2_TORCH_PYTHON_CSRC_ARC_H_
24-
#define K2_TORCH_PYTHON_CSRC_ARC_H_
23+
#ifndef K2_PYTHON_CSRC_TORCH_ARC_H_
24+
#define K2_PYTHON_CSRC_TORCH_ARC_H_
25+
26+
#include "k2/python/csrc/torch.h"
2527

26-
#include "k2/torch/python/csrc/torch.h"
27-
namespace k2 {
2828
void PybindArc(py::module &m);
29-
} // namespace k2
3029

31-
#endif // K2_TORCH_PYTHON_CSRC_TORCH_ARC_H_
30+
#endif // K2_PYTHON_CSRC_TORCH_ARC_H_

k2/torch/python/csrc/discounted_cum_sum.cu k2/python/csrc/torch/discounted_cum_sum.cu

+5-5
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
#include "k2/csrc/macros.h"
2525
#include "k2/csrc/nvtx.h"
2626
#include "k2/csrc/tensor_ops.h"
27-
#include "k2/torch/csrc/torch_utils.h"
28-
#include "k2/torch/python/csrc/discounted_cum_sum.h"
27+
#include "k2/python/csrc/torch/discounted_cum_sum.h"
28+
#include "k2/python/csrc/torch/torch_util.h"
2929

3030
namespace k2 {
3131

@@ -46,9 +46,11 @@ static void DiscountedCumSumWrapper(torch::Tensor x, torch::Tensor gamma,
4646
DiscountedCumSum(x_k2, gamma_k2, &y_k2);
4747
}
4848

49+
} // namespace k2
50+
4951
void PybindDiscountedCumSum(py::module &m) {
5052
// note it supports only 1-D and 2-D tensors.
51-
m.def("discounted_cum_sum", &DiscountedCumSumWrapper, py::arg("x"),
53+
m.def("discounted_cum_sum", &k2::DiscountedCumSumWrapper, py::arg("x"),
5254
py::arg("gamma"), py::arg("y"), py::arg("flip") = false,
5355
R"(
5456
Args:
@@ -68,5 +70,3 @@ void PybindDiscountedCumSum(py::module &m) {
6870
If true, the time sequence is reversed..
6971
)");
7072
}
71-
72-
} // namespace k2

k2/torch/python/csrc/discounted_cum_sum.h k2/python/csrc/torch/discounted_cum_sum.h

+5-6
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
* limitations under the License.
2121
*/
2222

23-
#ifndef K2_TORCH_PYTHON_CSRC_DISCOUNTED_CUM_SUM_H_
24-
#define K2_TORCH_PYTHON_CSRC_DISCOUNTED_CUM_SUM_H_
23+
#ifndef K2_PYTHON_CSRC_TORCH_DISCOUNTED_CUM_SUM_H_
24+
#define K2_PYTHON_CSRC_TORCH_DISCOUNTED_CUM_SUM_H_
25+
26+
#include "k2/python/csrc/torch.h"
2527

26-
#include "k2/torch/python/csrc/torch.h"
27-
namespace k2 {
2828
void PybindDiscountedCumSum(py::module &m);
29-
} // namespace k2
3029

31-
#endif // K2_TORCH_PYTHON_CSRC_DISCOUNTED_CUM_SUM_H_
30+
#endif // K2_PYTHON_CSRC_TORCH_DISCOUNTED_CUM_SUM_H_

k2/torch/python/csrc/nbest.cu k2/python/csrc/torch/nbest.cu

+13-12
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@
2828
#include "k2/csrc/nbest.h"
2929
#include "k2/csrc/nvtx.h"
3030
#include "k2/csrc/tensor_ops.h"
31-
#include "k2/torch/csrc/ragged_any.h"
32-
#include "k2/torch/csrc/torch_utils.h"
33-
#include "k2/torch/python/csrc/doc/nbest.h"
34-
#include "k2/torch/python/csrc/nbest.h"
31+
#include "k2/python/csrc/torch/nbest.h"
32+
#include "k2/python/csrc/torch/torch_util.h"
33+
#include "k2/python/csrc/torch/v2/ragged_any.h"
3534

3635
namespace k2 {
3736

@@ -48,16 +47,18 @@ static void PybindGetBestMatchingStats(py::module &m) {
4847
Array1<int32_t> counts_array = FromTorch<int32_t>(counts);
4948
Array1<float> mean, var;
5049
Array1<int32_t> counts_out, ngram_order;
51-
GetBestMatchingStats(tokens, scores_array, counts_array, eos, min_token,
52-
max_token, max_order, &mean, &var, &counts_out,
53-
&ngram_order);
54-
return std::make_tuple(ToTorch(mean), ToTorch(var), ToTorch(counts_out),
55-
ToTorch(ngram_order));
50+
GetBestMatchingStats(tokens, scores_array, counts_array,
51+
eos, min_token, max_token, max_order,
52+
&mean, &var, &counts_out, &ngram_order);
53+
return std::make_tuple(ToTorch(mean), ToTorch(var),
54+
ToTorch(counts_out), ToTorch(ngram_order));
5655
},
5756
py::arg("tokens"), py::arg("scores"), py::arg("counts"), py::arg("eos"),
58-
py::arg("min_token"), py::arg("max_token"), py::arg("max_order"),
59-
kNbestGetBestMatchingStatsDoc);
57+
py::arg("min_token"), py::arg("max_token"), py::arg("max_order"));
6058
}
6159

62-
void PybindNbest(py::module &m) { PybindGetBestMatchingStats(m); }
6360
} // namespace k2
61+
62+
void PybindNbest(py::module &m) {
63+
k2::PybindGetBestMatchingStats(m);
64+
}

k2/torch/python/csrc/nbest.h k2/python/csrc/torch/nbest.h

+5-6
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
* limitations under the License.
2121
*/
2222

23-
#ifndef K2_TORCH_PYTHON_CSRC_NBEST_H_
24-
#define K2_TORCH_PYTHON_CSRC_NBEST_H_
23+
#ifndef K2_PYTHON_CSRC_TORCH_NBEST_H_
24+
#define K2_PYTHON_CSRC_TORCH_NBEST_H_
25+
26+
#include "k2/python/csrc/torch.h"
2527

26-
#include "k2/torch/python/csrc/torch.h"
27-
namespace k2 {
2828
void PybindNbest(py::module &m);
29-
} // namespace k2
3029

31-
#endif // K2_TORCH_PYTHON_CSRC_NBEST_H_
30+
#endif // K2_PYTHON_CSRC_TORCH_NBEST_H_

0 commit comments

Comments
 (0)