Skip to content

Commit b815c59

Browse files
authored
Add CountVectorizer (#315)
1 parent 37197dd commit b815c59

File tree

2 files changed

+159
-0
lines changed

2 files changed

+159
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
defmodule Scholar.FeatureExtraction.CountVectorizer do
2+
@moduledoc """
3+
A `CountVectorizer` converts already indexed collection of text documents to a matrix of token counts.
4+
"""
5+
import Nx.Defn
6+
7+
opts_schema = [
8+
max_token_id: [
9+
type: :pos_integer,
10+
required: true,
11+
doc: ~S"""
12+
Maximum token id in the input tensor.
13+
"""
14+
]
15+
]
16+
17+
@opts_schema NimbleOptions.new!(opts_schema)
18+
19+
@doc """
20+
Generates a count matrix where each row corresponds to a document in the input corpus,
21+
and each column corresponds to a unique token in the vocabulary of the corpus.
22+
23+
The input must be a 2D tensor where:
24+
25+
* Each row represents a document.
26+
* Each document has integer values representing tokens.
27+
28+
The same number represents the same token in the vocabulary. Tokens should start from 0
29+
and be consecutive. Negative values are ignored, making them suitable for padding.
30+
31+
## Options
32+
33+
#{NimbleOptions.docs(@opts_schema)}
34+
35+
## Examples
36+
37+
iex> t = Nx.tensor([[0, 1, 2], [1, 3, 4]])
38+
iex> Scholar.FeatureExtraction.CountVectorizer.fit_transform(t, max_token_id: Scholar.FeatureExtraction.CountVectorizer.max_token_id(t))
39+
Nx.tensor([
40+
[1, 1, 1, 0, 0],
41+
[0, 1, 0, 1, 1]
42+
])
43+
44+
With padding:
45+
46+
iex> t = Nx.tensor([[0, 1, -1], [1, 3, 4]])
47+
iex> Scholar.FeatureExtraction.CountVectorizer.fit_transform(t, max_token_id: Scholar.FeatureExtraction.CountVectorizer.max_token_id(t))
48+
Nx.tensor([
49+
[1, 1, 0, 0, 0],
50+
[0, 1, 0, 1, 1]
51+
])
52+
"""
53+
deftransform fit_transform(tensor, opts \\ []) do
54+
fit_transform_n(tensor, NimbleOptions.validate!(opts, @opts_schema))
55+
end
56+
57+
@doc """
58+
Computes the max_token_id option from given tensor.
59+
60+
This function cannot be called inside `defn` (and it will raise
61+
if you try to do so).
62+
63+
## Examples
64+
65+
iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
66+
iex> Scholar.FeatureExtraction.CountVectorizer.max_token_id(t)
67+
2
68+
"""
69+
def max_token_id(tensor) do
70+
tensor |> Nx.reduce_max() |> Nx.to_number()
71+
end
72+
73+
defnp fit_transform_n(tensor, opts) do
74+
check_for_rank(tensor)
75+
counts = Nx.broadcast(0, {Nx.axis_size(tensor, 0), opts[:max_token_id] + 1})
76+
77+
{_, counts} =
78+
while {{i = 0, tensor}, counts}, Nx.less(i, Nx.axis_size(tensor, 0)) do
79+
{_, counts} =
80+
while {{j = 0, i, tensor}, counts}, Nx.less(j, Nx.axis_size(tensor, 1)) do
81+
index = tensor[i][j]
82+
83+
counts =
84+
if Nx.any(Nx.less(index, 0)),
85+
do: counts,
86+
else: Nx.indexed_add(counts, Nx.stack([i, index]), 1)
87+
88+
{{j + 1, i, tensor}, counts}
89+
end
90+
91+
{{i + 1, tensor}, counts}
92+
end
93+
94+
counts
95+
end
96+
97+
defnp check_for_rank(tensor) do
98+
if Nx.rank(tensor) != 2 do
99+
raise ArgumentError,
100+
"""
101+
expected tensor to have shape {num_documents, num_tokens}, \
102+
got tensor with shape: #{inspect(Nx.shape(tensor))}\
103+
"""
104+
end
105+
end
106+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
defmodule Scholar.Preprocessing.CountVectorizer do
2+
use Scholar.Case, async: true
3+
alias Scholar.FeatureExtraction.CountVectorizer
4+
doctest CountVectorizer
5+
6+
describe "fit_transform" do
7+
test "fit_transform test" do
8+
tesnsor = Nx.tensor([[2, 3, 0], [1, 4, 4]])
9+
10+
counts =
11+
CountVectorizer.fit_transform(tesnsor,
12+
max_token_id: CountVectorizer.max_token_id(tesnsor)
13+
)
14+
15+
expected_counts = Nx.tensor([[1, 0, 1, 1, 0], [0, 1, 0, 0, 2]])
16+
17+
assert counts == expected_counts
18+
end
19+
20+
test "fit_transform test - tensor with padding" do
21+
tensor = Nx.tensor([[2, 3, 0], [1, 4, -1]])
22+
23+
counts =
24+
CountVectorizer.fit_transform(tensor, max_token_id: CountVectorizer.max_token_id(tensor))
25+
26+
expected_counts = Nx.tensor([[1, 0, 1, 1, 0], [0, 1, 0, 0, 1]])
27+
28+
assert counts == expected_counts
29+
end
30+
end
31+
32+
describe "max_token_id" do
33+
test "max_token_id test" do
34+
tensor = Nx.tensor([[2, 3, 0], [1, 4, 4]])
35+
assert CountVectorizer.max_token_id(tensor) == 4
36+
end
37+
38+
test "max_token_id tes - tensor with padding" do
39+
tensor = Nx.tensor([[2, 3, 0], [1, 4, -1]])
40+
assert CountVectorizer.max_token_id(tensor) == 4
41+
end
42+
end
43+
44+
describe "errors" do
45+
test "wrong input rank" do
46+
assert_raise ArgumentError,
47+
"expected tensor to have shape {num_documents, num_tokens}, got tensor with shape: {3}",
48+
fn ->
49+
CountVectorizer.fit_transform(Nx.tensor([1, 2, 3]), max_token_id: 3)
50+
end
51+
end
52+
end
53+
end

0 commit comments

Comments
 (0)