Skip to content

Commit 0264f4d

Browse files
committed
add both integer and float implementations
1 parent 24b77b4 commit 0264f4d

File tree

2 files changed

+75
-45
lines changed

2 files changed

+75
-45
lines changed

nx/lib/nx.ex

+57-45
Original file line numberDiff line numberDiff line change
@@ -3268,7 +3268,7 @@ defmodule Nx do
32683268
@doc ~S"""
32693269
Split a tensor into train and test subsets.
32703270
3271-
`split` split must be an integer greater than zero and less than the length of the tensor.
3271+
`split` split must be either an integer greater than zero and less than the length of the tensor or a float number between `0.0` and `1.0`.
32723272
32733273
## Options
32743274
@@ -3315,67 +3315,33 @@ defmodule Nx do
33153315
]
33163316
>
33173317
3318-
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5, 20], [26, 75, 3, 9], [23, 4, 1, 5]], names: [:rows, :columns]), 2, axis: :columns)
3318+
iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5, 9], [26, 75, 3, 20], [23, 4, 1, 56], [40, 6, 78, 94]]), 0.5, axis: 0)
33193319
iex> train
33203320
#Nx.Tensor<
3321-
s64[rows: 3][columns: 2]
3321+
s64[2][4]
33223322
[
3323-
[3, 6],
3324-
[26, 75],
3325-
[23, 4]
3326-
]
3327-
>
3328-
iex> test
3329-
#Nx.Tensor<
3330-
s64[rows: 3][columns: 2]
3331-
[
3332-
[5, 20],
3333-
[3, 9],
3334-
[1, 5]
3335-
]
3336-
>
3337-
3338-
iex>{train, test} = Nx.split(Nx.tensor([[3, 6, 5, 20], [26, 75, 3, 9], [23, 4, 1, 5]], names: [:rows, :columns]), 2, axis: :rows)
3339-
iex> train
3340-
#Nx.Tensor<
3341-
s64[rows: 2][columns: 4]
3342-
[
3343-
[3, 6, 5, 20],
3344-
[26, 75, 3, 9]
3323+
[3, 6, 5, 9],
3324+
[26, 75, 3, 20]
33453325
]
33463326
>
33473327
iex> test
33483328
#Nx.Tensor<
3349-
s64[rows: 1][columns: 4]
3329+
s64[2][4]
33503330
[
3351-
[23, 4, 1, 5]
3331+
[23, 4, 1, 56],
3332+
[40, 6, 78, 94]
33523333
]
33533334
>
33543335
"""
33553336
@doc type: :indexed
33563337
def split(tensor, split, opts \\ [])
33573338

3358-
def split(%T{shape: shape} = tensor, split, opts) do
3339+
def split(%T{shape: shape} = tensor, split, opts) when is_integer(split) do
33593340
opts = keyword!(opts, axis: 0)
33603341
axis = Keyword.fetch!(opts, :axis)
33613342

3362-
if is_integer(split) and split > 0 do
3363-
axis_values = axes(tensor)
3364-
axis_names = names(tensor)
3365-
3366-
axis =
3367-
cond do
3368-
is_integer(axis) and axis in axis_values ->
3369-
axis
3370-
3371-
is_atom(axis) and axis in axis_names ->
3372-
dimensions = Enum.zip(axis_names, axis_values)
3373-
dimensions[axis]
3374-
3375-
true ->
3376-
raise ":axis is out of tensor bounds."
3377-
end
3378-
3343+
if split > 0 do
3344+
axis = find_axis(tensor, axis)
33793345
values = elem(shape, axis)
33803346
size = values - split
33813347

@@ -3388,6 +3354,35 @@ defmodule Nx do
33883354
end
33893355
end
33903356

3357+
def split(%T{shape: shape} = tensor, split, opts) when is_float(split) do
3358+
opts = keyword!(opts, axis: 0)
3359+
axis = Keyword.fetch!(opts, :axis)
3360+
3361+
if split > 0.0 and split < 1.0 do
3362+
axis = find_axis(tensor, axis)
3363+
3364+
values = elem(shape, axis)
3365+
3366+
split_size = Kernel.ceil(split * values)
3367+
3368+
split_size =
3369+
cond do
3370+
split_size < 1 -> 1
3371+
split_size >= values -> 1
3372+
true -> split_size
3373+
end
3374+
3375+
remaining_size = values - split_size
3376+
3377+
{
3378+
slice_along_axis(tensor, 0, split_size, axis: axis),
3379+
slice_along_axis(tensor, split_size, remaining_size, axis: axis)
3380+
}
3381+
else
3382+
raise "split must be a float number between 0.0 and 1.0."
3383+
end
3384+
end
3385+
33913386
@doc """
33923387
Broadcasts `tensor` to the given `broadcast_shape`.
33933388
@@ -16184,4 +16179,21 @@ defmodule Nx do
1618416179
end)
1618516180
end
1618616181
end
16182+
16183+
defp find_axis(tensor, axis) do
16184+
axis_values = axes(tensor)
16185+
axis_names = names(tensor)
16186+
16187+
cond do
16188+
is_integer(axis) and axis in axis_values ->
16189+
axis
16190+
16191+
is_atom(axis) and axis in axis_names ->
16192+
dimensions = Enum.zip(axis_names, axis_values)
16193+
dimensions[axis]
16194+
16195+
true ->
16196+
raise ":axis is out of tensor bounds."
16197+
end
16198+
end
1618716199
end

nx/test/nx_test.exs

+18
Original file line numberDiff line numberDiff line change
@@ -3034,6 +3034,16 @@ defmodule NxTest do
30343034
end
30353035
end
30363036

3037+
test "split is a float out of bounds" do
3038+
tensor = Nx.iota({10, 2}, names: [:x, :y])
3039+
3040+
assert_raise RuntimeError,
3041+
"split must be a float number between 0.0 and 1.0.",
3042+
fn ->
3043+
Nx.split(tensor, 1.0)
3044+
end
3045+
end
3046+
30373047
test "split is greater than tensor length" do
30383048
tensor = Nx.iota({10, 2}, names: [:x, :y])
30393049

@@ -3064,6 +3074,14 @@ defmodule NxTest do
30643074
end
30653075
end
30663076

3077+
test "split into 50% for training and 50% for testing with floats on columns" do
3078+
tensor = Nx.iota({4, 4}, names: [:rows, :columns])
3079+
{train, test} = Nx.split(tensor, 0.5, axis: :columns)
3080+
3081+
assert {4, 2} == Nx.shape(train)
3082+
assert {4, 2} == Nx.shape(test)
3083+
end
3084+
30673085
test "split into 70% for training and 30% for testing along a named :axis" do
30683086
tensor = Nx.iota({100, 6}, names: [:rows, :columns])
30693087
{train, test} = Nx.split(tensor, 70, axis: :rows)

0 commit comments

Comments
 (0)