@@ -3268,7 +3268,7 @@ defmodule Nx do
3268
3268
@ doc ~S"""
3269
3269
Split a tensor into train and test subsets.
3270
3270
3271
- `split` split must be an integer greater than zero and less than the length of the tensor.
3271
+ `split` split must be either an integer greater than zero and less than the length of the tensor or a float number between `0.0` and `1.0` .
3272
3272
3273
3273
## Options
3274
3274
@@ -3315,67 +3315,33 @@ defmodule Nx do
3315
3315
]
3316
3316
>
3317
3317
3318
- iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5, 20 ], [26, 75, 3, 9 ], [23, 4, 1, 5]], names: [:rows, :columns] ), 2 , axis: :columns )
3318
+ iex> {train, test} = Nx.split(Nx.tensor([[3, 6, 5, 9 ], [26, 75, 3, 20 ], [23, 4, 1, 56], [40, 6, 78, 94]] ), 0.5 , axis: 0 )
3319
3319
iex> train
3320
3320
#Nx.Tensor<
3321
- s64[rows: 3][columns: 2 ]
3321
+ s64[2][4 ]
3322
3322
[
3323
- [3, 6],
3324
- [26, 75],
3325
- [23, 4]
3326
- ]
3327
- >
3328
- iex> test
3329
- #Nx.Tensor<
3330
- s64[rows: 3][columns: 2]
3331
- [
3332
- [5, 20],
3333
- [3, 9],
3334
- [1, 5]
3335
- ]
3336
- >
3337
-
3338
- iex>{train, test} = Nx.split(Nx.tensor([[3, 6, 5, 20], [26, 75, 3, 9], [23, 4, 1, 5]], names: [:rows, :columns]), 2, axis: :rows)
3339
- iex> train
3340
- #Nx.Tensor<
3341
- s64[rows: 2][columns: 4]
3342
- [
3343
- [3, 6, 5, 20],
3344
- [26, 75, 3, 9]
3323
+ [3, 6, 5, 9],
3324
+ [26, 75, 3, 20]
3345
3325
]
3346
3326
>
3347
3327
iex> test
3348
3328
#Nx.Tensor<
3349
- s64[rows: 1][columns: 4]
3329
+ s64[2][ 4]
3350
3330
[
3351
- [23, 4, 1, 5]
3331
+ [23, 4, 1, 56],
3332
+ [40, 6, 78, 94]
3352
3333
]
3353
3334
>
3354
3335
"""
3355
3336
@ doc type: :indexed
3356
3337
def split ( tensor , split , opts \\ [ ] )
3357
3338
3358
- def split ( % T { shape: shape } = tensor , split , opts ) do
3339
+ def split ( % T { shape: shape } = tensor , split , opts ) when is_integer ( split ) do
3359
3340
opts = keyword! ( opts , axis: 0 )
3360
3341
axis = Keyword . fetch! ( opts , :axis )
3361
3342
3362
- if is_integer ( split ) and split > 0 do
3363
- axis_values = axes ( tensor )
3364
- axis_names = names ( tensor )
3365
-
3366
- axis =
3367
- cond do
3368
- is_integer ( axis ) and axis in axis_values ->
3369
- axis
3370
-
3371
- is_atom ( axis ) and axis in axis_names ->
3372
- dimensions = Enum . zip ( axis_names , axis_values )
3373
- dimensions [ axis ]
3374
-
3375
- true ->
3376
- raise ":axis is out of tensor bounds."
3377
- end
3378
-
3343
+ if split > 0 do
3344
+ axis = find_axis ( tensor , axis )
3379
3345
values = elem ( shape , axis )
3380
3346
size = values - split
3381
3347
@@ -3388,6 +3354,35 @@ defmodule Nx do
3388
3354
end
3389
3355
end
3390
3356
3357
+ def split ( % T { shape: shape } = tensor , split , opts ) when is_float ( split ) do
3358
+ opts = keyword! ( opts , axis: 0 )
3359
+ axis = Keyword . fetch! ( opts , :axis )
3360
+
3361
+ if split > 0.0 and split < 1.0 do
3362
+ axis = find_axis ( tensor , axis )
3363
+
3364
+ values = elem ( shape , axis )
3365
+
3366
+ split_size = Kernel . ceil ( split * values )
3367
+
3368
+ split_size =
3369
+ cond do
3370
+ split_size < 1 -> 1
3371
+ split_size >= values -> 1
3372
+ true -> split_size
3373
+ end
3374
+
3375
+ remaining_size = values - split_size
3376
+
3377
+ {
3378
+ slice_along_axis ( tensor , 0 , split_size , axis: axis ) ,
3379
+ slice_along_axis ( tensor , split_size , remaining_size , axis: axis )
3380
+ }
3381
+ else
3382
+ raise "split must be a float number between 0.0 and 1.0."
3383
+ end
3384
+ end
3385
+
3391
3386
@ doc """
3392
3387
Broadcasts `tensor` to the given `broadcast_shape`.
3393
3388
@@ -16184,4 +16179,21 @@ defmodule Nx do
16184
16179
end )
16185
16180
end
16186
16181
end
16182
+
16183
+ defp find_axis ( tensor , axis ) do
16184
+ axis_values = axes ( tensor )
16185
+ axis_names = names ( tensor )
16186
+
16187
+ cond do
16188
+ is_integer ( axis ) and axis in axis_values ->
16189
+ axis
16190
+
16191
+ is_atom ( axis ) and axis in axis_names ->
16192
+ dimensions = Enum . zip ( axis_names , axis_values )
16193
+ dimensions [ axis ]
16194
+
16195
+ true ->
16196
+ raise ":axis is out of tensor bounds."
16197
+ end
16198
+ end
16187
16199
end
0 commit comments