Skip to content

Commit c024c5b

Browse files
committed
Fixes to simple imputter
1 parent 6c02f4f commit c024c5b

File tree

2 files changed

+6
-21
lines changed

2 files changed

+6
-21
lines changed

lib/scholar/impute/simple_imputer.ex

+4-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ defmodule Scholar.Impute.SimpleImputer do
1414
default: :nan,
1515
doc: ~S"""
1616
The placeholder for the missing values. All occurrences of `:missing_values` will be imputed.
17+
18+
The default value expects there are no NaNs in the input tensor.
1719
"""
1820
],
1921
strategy: [
@@ -72,17 +74,10 @@ defmodule Scholar.Impute.SimpleImputer do
7274
"""
7375
deftransform fit(x, opts \\ []) do
7476
opts = NimbleOptions.validate!(opts, @opts_schema)
75-
7677
input_rank = Nx.rank(x)
7778

7879
if input_rank != 2 do
79-
raise ArgumentError, "Wrong input rank. Expected: 2, got: #{inspect(input_rank)}"
80-
end
81-
82-
if opts[:missing_values] != :nan and
83-
Nx.any(Nx.is_nan(x)) == Nx.tensor(1, type: :u8) do
84-
raise ArgumentError,
85-
":missing_values other than :nan possible only if there is no Nx.Constant.nan() in the array"
80+
raise ArgumentError, "wrong input rank. Expected: 2, got: #{inspect(input_rank)}"
8681
end
8782

8883
{type, _num_bits} = x_type = Nx.type(x)
@@ -98,7 +93,7 @@ defmodule Scholar.Impute.SimpleImputer do
9893
{fill_value_type, _} = Nx.type(opts[:fill_value])
9994

10095
raise ArgumentError,
101-
"Wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: #{inspect(fill_value_type)}"
96+
"wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: #{inspect(fill_value_type)}"
10297

10398
true ->
10499
x

test/scholar/impute/simple_imputer_test.exs

+2-12
Original file line numberDiff line numberDiff line change
@@ -122,27 +122,17 @@ defmodule SimpleImputerTest do
122122
x = Nx.tensor([1, 2, 2, 3])
123123

124124
assert_raise ArgumentError,
125-
"Wrong input rank. Expected: 2, got: 1",
125+
"wrong input rank. Expected: 2, got: 1",
126126
fn ->
127127
SimpleImputer.fit(x, missing_values: 1, strategy: :mode)
128128
end
129129
end
130130

131-
test "Collision of nan" do
132-
x = generate_data()
133-
134-
assert_raise ArgumentError,
135-
":missing_values other than :nan possible only if there is no Nx.Constant.nan() in the array",
136-
fn ->
137-
SimpleImputer.fit(x, missing_values: 1.0, strategy: :mode)
138-
end
139-
end
140-
141131
test "Wrong :fill_value type" do
142132
x = Nx.tensor([[1.0, 2.0, 2.0, 3.0]])
143133

144134
assert_raise ArgumentError,
145-
"Wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: :s",
135+
"wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: :s",
146136
fn ->
147137
SimpleImputer.fit(x,
148138
missing_values: 1.0,

0 commit comments

Comments
 (0)