Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add split function #1237

Merged
merged 9 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nx/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ erl_crash.dump

# Ignore package tarball (built via "mix hex.build").
nx-*.tar

# ASDF files
.tool-versions
132 changes: 132 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3265,6 +3265,138 @@ defmodule Nx do
end)
end

@doc ~S"""
Split a tensor into train and test subsets.

`split` must be defined so that there are no empty result tensors.
This means that `split` must be:

* an integer such that `0 < split` and `split < axis_size`
* a float such that `0.0 < split` and `ceil(axis_size * split) < axis_size`

## Options

* `:axis` - The axis along which to split the tensor. Defaults to `0`.

## Examples

All examples will operate on the same tensor so that it's easier to compare different configurations.

iex> t = Nx.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
iex> {left, right} = Nx.split(t, 2, axis: 0)
iex> left
#Nx.Tensor<
s64[2][4]
[
[0, 1, 2, 3],
[4, 5, 6, 7]
]
>
iex> right
#Nx.Tensor<
s64[1][4]
[
[8, 9, 10, 11]
]
>
iex> {left, right} = Nx.split(t, 2, axis: 1)
iex> left
#Nx.Tensor<
s64[3][2]
[
[0, 1],
[4, 5],
[8, 9]
]
>
iex> right
#Nx.Tensor<
s64[3][2]
[
[2, 3],
[6, 7],
[10, 11]
]
>

iex> t = Nx.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
iex> {left, right} = Nx.split(t, 0.5, axis: 0)
iex> left
#Nx.Tensor<
s64[2][4]
[
[0, 1, 2, 3],
[4, 5, 6, 7]
]
>
iex> right
#Nx.Tensor<
s64[1][4]
[
[8, 9, 10, 11]
]
>
iex> {left, right} = Nx.split(t, 0.75, axis: 1)
iex> left
#Nx.Tensor<
s64[3][3]
[
[0, 1, 2],
[4, 5, 6],
[8, 9, 10]
]
>
iex> right
#Nx.Tensor<
s64[3][1]
[
[3],
[7],
[11]
]
>
"""
@doc type: :indexed
def split(tensor, split, opts \\ [])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tiagodavi @josevalim I've pushed a refactor to collapse both clauses into the same one, mostly to highlight that the slicing is the same, with the only change being how we calculate it.

I've also changed the examples to operate on a written-out iota instead of arbitrary values because it makes it easier to compare results.

We still need to discuss if we want to accept a negative integer split.

It's easy to do it by setting right after axis_size = : split = if is_integer(split) and split < 0, do: axis_size + split, else: split

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I think that negative split might be confusing in the sense that we don't know if the results will be reversed or swapped in the result tuple (either or both)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@polvalente we will base it on Enum.split (which means looking up from the end).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I am fine with postponing this for now. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so given this, my suggestion suffices:

iex(2)> Enum.split(1..10, 10 - 1)
{[1, 2, 3, 4, 5, 6, 7, 8, 9], '\n'}
iex(3)> Enum.split(1..10, -1)
{[1, 2, 3, 4, 5, 6, 7, 8, 9], '\n'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdyt?


def split(tensor, split, opts) do
tensor = to_tensor(tensor)
opts = keyword!(opts, axis: 0)
axis = Keyword.fetch!(opts, :axis)

axis = Nx.Shape.normalize_axis(tensor.shape, axis, tensor.names)
axis_size = axis_size(tensor, axis)

# only used in case the split is a float
float_split_index = Kernel.ceil(split * axis_size)

{split_index, remainder_length} =
cond do
is_integer(split) and split > 0 and split < axis_size ->
{split, axis_size - split}

is_integer(split) ->
raise ArgumentError,
"split must be an integer greater than zero and less than the length of the given axis"

is_float(split) and float_split_index > 0 and float_split_index < axis_size ->
{float_split_index, axis_size - float_split_index}

is_float(split) ->
raise ArgumentError,
"split must be a float such that 0 < split and ceil(split * axis_size) < 1"

true ->
raise ArgumentError,
"invalid split received, expected a float or an integer, got: #{inspect(split)}"
end

{
slice_along_axis(tensor, 0, split_index, axis: axis),
slice_along_axis(tensor, split_index, remainder_length, axis: axis)
}
end

@doc """
Broadcasts `tensor` to the given `broadcast_shape`.

Expand Down
101 changes: 101 additions & 0 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3022,4 +3022,105 @@ defmodule NxTest do
assert_all_close(zeros, Nx.imag(x_ifft), atol: 1.0e-8)
end
end

describe "split/2" do
test "split is less than zero" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"split must be an integer greater than zero and less than the length of the given axis",
fn ->
Nx.split(tensor, -1)
end
end

test "split is a float out of bounds" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"split must be a float such that 0 < split and ceil(split * axis_size) < 1",
fn ->
Nx.split(tensor, 1.0)
end
end

test "split is greater than tensor length" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"split must be an integer greater than zero and less than the length of the given axis",
fn ->
Nx.split(tensor, 3, axis: 1)
end
end

test "axis is out of tensor bounds" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"given axis (2) invalid for shape with rank 2",
fn ->
Nx.split(tensor, 2, axis: 2)
end
end

test "named axis is invalid" do
tensor = Nx.iota({10, 2}, names: [:x, :y])

assert_raise ArgumentError,
"name :z not found in tensor with names [:x, :y]",
fn ->
Nx.split(tensor, 2, axis: :z)
end
end

test "split into 50% for training and 50% for testing with floats on columns" do
tensor = Nx.iota({4, 4}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 0.5, axis: :columns)

assert {4, 2} == Nx.shape(train)
assert {4, 2} == Nx.shape(test)
end

test "split into 70% for training and 30% for testing along a named :axis" do
tensor = Nx.iota({100, 6}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 70, axis: :rows)

assert {70, 6} == Nx.shape(train)
assert {30, 6} == Nx.shape(test)
end

test "split into 90% for training and 10% for testing along a named :axis" do
tensor = Nx.iota({2, 100}, names: [:rows, :columns])
{train, test} = Nx.split(tensor, 90, axis: :columns)

assert {2, 90} == Nx.shape(train)
assert {2, 10} == Nx.shape(test)
end

test "split into 50% for training and 50% for testing along the :axis 1" do
tensor = Nx.iota({100, 10})
{train, test} = Nx.split(tensor, 5, axis: 1)

assert {100, 5} == Nx.shape(train)
assert {100, 5} == Nx.shape(test)
end

test "split into 61% for training and 39% for testing" do
tensor = Nx.iota({100, 10})
{train, test} = Nx.split(tensor, 61)

assert {61, 10} == Nx.shape(train)
assert {39, 10} == Nx.shape(test)
end

test "split into 60% for training and 40% for testing with unbalanced data" do
tensor = Nx.iota({99, 4})

{train, test} = Nx.split(tensor, 60)

assert {60, 4} == Nx.shape(train)
assert {39, 4} == Nx.shape(test)
end
end
end