Skip to content

Commit 24b77b4

Browse files
committed
improve again
1 parent 3404a25 commit 24b77b4

File tree

2 files changed

+144
-51
lines changed

2 files changed

+144
-51
lines changed

nx/lib/nx.ex

+81-10
Original file line numberDiff line numberDiff line change
@@ -3268,7 +3268,7 @@ defmodule Nx do
32683268
@doc ~S"""
32693269
Split a tensor into train and test subsets.
32703270
3271-
`split` must be a float number between `0.0` and `1.0`.
3271+
`split` split must be an integer greater than zero and less than the length of the tensor.
32723272
32733273
## Options
32743274
@@ -3278,7 +3278,7 @@ defmodule Nx do
32783278
32793279
Split a tensor into two separate tensors.
32803280
3281-
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), 0.8, axis: 0)
3281+
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), 2, axis: 0)
32823282
iex> train
32833283
#Nx.Tensor<
32843284
s64[2][3]
@@ -3294,6 +3294,63 @@ defmodule Nx do
32943294
[23, 4, 1]
32953295
]
32963296
>
3297+
3298+
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5], [26, 75, 3], [23, 4, 1]]), 2, axis: 1)
3299+
iex> train
3300+
#Nx.Tensor<
3301+
s64[3][2]
3302+
[
3303+
[3, 6],
3304+
[26, 75],
3305+
[23, 4]
3306+
]
3307+
>
3308+
iex> test
3309+
#Nx.Tensor<
3310+
s64[3][1]
3311+
[
3312+
[5],
3313+
[3],
3314+
[1]
3315+
]
3316+
>
3317+
3318+
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5, 20], [26, 75, 3, 9], [23, 4, 1, 5]], names: [:rows, :columns]), 2, axis: :columns)
3319+
iex> train
3320+
#Nx.Tensor<
3321+
s64[rows: 3][columns: 2]
3322+
[
3323+
[3, 6],
3324+
[26, 75],
3325+
[23, 4]
3326+
]
3327+
>
3328+
iex> test
3329+
#Nx.Tensor<
3330+
s64[rows: 3][columns: 2]
3331+
[
3332+
[5, 20],
3333+
[3, 9],
3334+
[1, 5]
3335+
]
3336+
>
3337+
3338+
iex>{train, test} = Nx.split(Nx.tensor([[3, 6, 5, 20], [26, 75, 3, 9], [23, 4, 1, 5]], names: [:rows, :columns]), 2, axis: :rows)
3339+
iex> train
3340+
#Nx.Tensor<
3341+
s64[rows: 2][columns: 4]
3342+
[
3343+
[3, 6, 5, 20],
3344+
[26, 75, 3, 9]
3345+
]
3346+
>
3347+
iex> test
3348+
#Nx.Tensor<
3349+
s64[rows: 1][columns: 4]
3350+
[
3351+
[23, 4, 1, 5]
3352+
]
3353+
>
32973354
"""
32983355
@doc type: :indexed
32993356
def split(tensor, split, opts \\ [])
@@ -3302,18 +3359,32 @@ defmodule Nx do
33023359
opts = keyword!(opts, axis: 0)
33033360
axis = Keyword.fetch!(opts, :axis)
33043361

3305-
if is_float(split) and split > 0.0 and split < 1.0 do
3306-
rows = elem(shape, 0)
3307-
split_size = Kernel.floor(split * rows)
3308-
split_size = if split_size < 1, do: 1, else: split_size
3309-
remaining_size = rows - split_size
3362+
if is_integer(split) and split > 0 do
3363+
axis_values = axes(tensor)
3364+
axis_names = names(tensor)
3365+
3366+
axis =
3367+
cond do
3368+
is_integer(axis) and axis in axis_values ->
3369+
axis
3370+
3371+
is_atom(axis) and axis in axis_names ->
3372+
dimensions = Enum.zip(axis_names, axis_values)
3373+
dimensions[axis]
3374+
3375+
true ->
3376+
raise ":axis is out of tensor bounds."
3377+
end
3378+
3379+
values = elem(shape, axis)
3380+
size = values - split
33103381

33113382
{
3312-
slice_along_axis(tensor, 0, split_size, axis: axis),
3313-
slice_along_axis(tensor, split_size, remaining_size, axis: axis)
3383+
slice_along_axis(tensor, 0, split, axis: axis),
3384+
slice_along_axis(tensor, split, size, axis: axis)
33143385
}
33153386
else
3316-
raise ":split must be a float number between 0.0 and 1.0"
3387+
raise "split must be an integer greater than zero and less than the length of the tensor."
33173388
end
33183389
end
33193390

nx/test/nx_test.exs

+63-41
Original file line numberDiff line numberDiff line change
@@ -3024,63 +3024,85 @@ defmodule NxTest do
30243024
end
30253025

30263026
describe "split/2" do
3027-
test "Split list into 50% for training and 50% for testing" do
3027+
test "split is less than zero" do
30283028
tensor = Nx.iota({10, 2}, names: [:x, :y])
3029-
{train, test} = Nx.split(tensor, 0.5)
30303029

3031-
assert Nx.tensor(
3032-
[
3033-
[0, 1],
3034-
[2, 3],
3035-
[4, 5],
3036-
[6, 7],
3037-
[8, 9]
3038-
],
3039-
names: [:x, :y]
3040-
) == train
3041-
3042-
assert Nx.tensor(
3043-
[
3044-
[10, 11],
3045-
[12, 13],
3046-
[14, 15],
3047-
[16, 17],
3048-
[18, 19]
3049-
],
3050-
names: [:x, :y]
3051-
) == test
3030+
assert_raise RuntimeError,
3031+
"split must be an integer greater than zero and less than the length of the tensor.",
3032+
fn ->
3033+
Nx.split(tensor, -1)
3034+
end
3035+
end
3036+
3037+
test "split is greater than tensor length" do
3038+
tensor = Nx.iota({10, 2}, names: [:x, :y])
3039+
3040+
assert_raise ArgumentError,
3041+
"length at axis 1 must be less than axis size of 2, got: 3",
3042+
fn ->
3043+
Nx.split(tensor, 3, axis: 1)
3044+
end
3045+
end
3046+
3047+
test "axis is out of tensor bounds" do
3048+
tensor = Nx.iota({10, 2}, names: [:x, :y])
3049+
3050+
assert_raise RuntimeError,
3051+
":axis is out of tensor bounds.",
3052+
fn ->
3053+
Nx.split(tensor, 2, axis: 2)
3054+
end
3055+
end
3056+
3057+
test "named axis is out of tensor bounds" do
3058+
tensor = Nx.iota({10, 2}, names: [:x, :y])
3059+
3060+
assert_raise RuntimeError,
3061+
":axis is out of tensor bounds.",
3062+
fn ->
3063+
Nx.split(tensor, 2, axis: :z)
3064+
end
30523065
end
30533066

3054-
test "Split into 70% for training and 30% for testing" do
3055-
tensor = Nx.iota({100, 6})
3056-
{train, test} = Nx.split(tensor, 0.7)
3067+
test "split into 70% for training and 30% for testing along a named :axis" do
3068+
tensor = Nx.iota({100, 6}, names: [:rows, :columns])
3069+
{train, test} = Nx.split(tensor, 70, axis: :rows)
30573070

3058-
assert length(Nx.to_list(train)) == 70
3059-
assert length(Nx.to_list(test)) == 30
3071+
assert {70, 6} == Nx.shape(train)
3072+
assert {30, 6} == Nx.shape(test)
30603073
end
30613074

3062-
test "Split into 75% for training and 25% for testing" do
3075+
test "split into 90% for training and 10% for testing along a named :axis" do
3076+
tensor = Nx.iota({2, 100}, names: [:rows, :columns])
3077+
{train, test} = Nx.split(tensor, 90, axis: :columns)
3078+
3079+
assert {2, 90} == Nx.shape(train)
3080+
assert {2, 10} == Nx.shape(test)
3081+
end
3082+
3083+
test "split into 50% for training and 50% for testing along the :axis 1" do
30633084
tensor = Nx.iota({100, 10})
3064-
{train, test} = Nx.split(tensor, 0.75)
3085+
{train, test} = Nx.split(tensor, 5, axis: 1)
30653086

3066-
assert length(Nx.to_list(train)) == 75
3067-
assert length(Nx.to_list(test)) == 25
3087+
assert {100, 5} == Nx.shape(train)
3088+
assert {100, 5} == Nx.shape(test)
30683089
end
30693090

3070-
test "Split into 61% for training and 39% for testing" do
3091+
test "split into 61% for training and 39% for testing" do
30713092
tensor = Nx.iota({100, 10})
3072-
{train, test} = Nx.split(tensor, 0.61)
3093+
{train, test} = Nx.split(tensor, 61)
30733094

3074-
assert length(Nx.to_list(train)) == 61
3075-
assert length(Nx.to_list(test)) == 39
3095+
assert {61, 10} == Nx.shape(train)
3096+
assert {39, 10} == Nx.shape(test)
30763097
end
30773098

3078-
test "Split into 60% for training and 40% for testing with unbalanced data" do
3079-
tensor = Nx.iota({73, 4})
3080-
{train, test} = Nx.split(tensor, 0.61)
3099+
test "split into 60% for training and 40% for testing with unbalanced data" do
3100+
tensor = Nx.iota({99, 4})
3101+
3102+
{train, test} = Nx.split(tensor, 60)
30813103

3082-
assert length(Nx.to_list(train)) == 44
3083-
assert length(Nx.to_list(test)) == 29
3104+
assert {60, 4} == Nx.shape(train)
3105+
assert {39, 4} == Nx.shape(test)
30843106
end
30853107
end
30863108
end

0 commit comments

Comments
 (0)