File tree 2 files changed +6
-21
lines changed
2 files changed +6
-21
lines changed Original file line number Diff line number Diff line change @@ -14,6 +14,8 @@ defmodule Scholar.Impute.SimpleImputer do
14
14
default: :nan ,
15
15
doc: ~S"""
16
16
The placeholder for the missing values. All occurrences of `:missing_values` will be imputed.
17
+
18
+ The default value expects there are no NaNs in the input tensor.
17
19
"""
18
20
] ,
19
21
strategy: [
@@ -72,17 +74,10 @@ defmodule Scholar.Impute.SimpleImputer do
72
74
"""
73
75
deftransform fit ( x , opts \\ [ ] ) do
74
76
opts = NimbleOptions . validate! ( opts , @ opts_schema )
75
-
76
77
input_rank = Nx . rank ( x )
77
78
78
79
if input_rank != 2 do
79
- raise ArgumentError , "Wrong input rank. Expected: 2, got: #{ inspect ( input_rank ) } "
80
- end
81
-
82
- if opts [ :missing_values ] != :nan and
83
- Nx . any ( Nx . is_nan ( x ) ) == Nx . tensor ( 1 , type: :u8 ) do
84
- raise ArgumentError ,
85
- ":missing_values other than :nan possible only if there is no Nx.Constant.nan() in the array"
80
+ raise ArgumentError , "wrong input rank. Expected: 2, got: #{ inspect ( input_rank ) } "
86
81
end
87
82
88
83
{ type , _num_bits } = x_type = Nx . type ( x )
@@ -98,7 +93,7 @@ defmodule Scholar.Impute.SimpleImputer do
98
93
{ fill_value_type , _ } = Nx . type ( opts [ :fill_value ] )
99
94
100
95
raise ArgumentError ,
101
- "Wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: #{ inspect ( fill_value_type ) } "
96
+ "wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: #{ inspect ( fill_value_type ) } "
102
97
103
98
true ->
104
99
x
Original file line number Diff line number Diff line change @@ -122,27 +122,17 @@ defmodule SimpleImputerTest do
122
122
x = Nx . tensor ( [ 1 , 2 , 2 , 3 ] )
123
123
124
124
assert_raise ArgumentError ,
125
- "Wrong input rank. Expected: 2, got: 1" ,
125
+ "wrong input rank. Expected: 2, got: 1" ,
126
126
fn ->
127
127
SimpleImputer . fit ( x , missing_values: 1 , strategy: :mode )
128
128
end
129
129
end
130
130
131
- test "Collision of nan" do
132
- x = generate_data ( )
133
-
134
- assert_raise ArgumentError ,
135
- ":missing_values other than :nan possible only if there is no Nx.Constant.nan() in the array" ,
136
- fn ->
137
- SimpleImputer . fit ( x , missing_values: 1.0 , strategy: :mode )
138
- end
139
- end
140
-
141
131
test "Wrong :fill_value type" do
142
132
x = Nx . tensor ( [ [ 1.0 , 2.0 , 2.0 , 3.0 ] ] )
143
133
144
134
assert_raise ArgumentError ,
145
- "Wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: :s" ,
135
+ "wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: :s" ,
146
136
fn ->
147
137
SimpleImputer . fit ( x ,
148
138
missing_values: 1.0 ,
You can’t perform that action at this time.
0 commit comments