Skip to content

Commit 6350742

Browse files
ChaiBapchyaJose Luis Contreras
authored and
Jose Luis Contreras
committed
[MXNET-1173] Debug operators - isfinite, isinf and isnan (apache#12967)
* is_finite and is_inf implementation for front-end python api debug operator * updated unit-tests * updated test cases and incorporated is_nan function * solved index out of bounds issue and added comments * simplified abs function call and added isnan to contrib.py and all debug ops to doc * changed dimensions, added regular number, assert_equal instead of almost, removed ctx and added data.abs
1 parent 879206b commit 6350742

File tree

3 files changed

+128
-1
lines changed

3 files changed

+128
-1
lines changed

docs/api/python/ndarray/contrib.md

+3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
5555
foreach
5656
while_loop
5757
cond
58+
isinf
59+
isfinite
60+
isnan
5861
index_copy
5962
getnnz
6063
```

python/mxnet/ndarray/contrib.py

+83-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# pylint: disable=wildcard-import, unused-wildcard-import,redefined-outer-name
2020
"""Contrib NDArray API of MXNet."""
2121
import math
22+
import numpy as np
2223
from ..context import current_context
2324
from ..random import uniform
2425
from ..base import _as_list
@@ -28,7 +29,7 @@
2829
except ImportError:
2930
pass
3031

31-
__all__ = ["rand_zipfian", "foreach", "while_loop", "cond"]
32+
__all__ = ["rand_zipfian", "foreach", "while_loop", "cond", "isinf", "isfinite", "isnan"]
3233

3334
# pylint: disable=line-too-long
3435
def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
@@ -460,3 +461,84 @@ def _to_python_scalar(inputs, type_, name):
460461
return then_func()
461462
else:
462463
return else_func()
464+
465+
def isinf(data):
466+
"""Performs an element-wise check to determine if the NDArray contains an infinite element
467+
or not.
468+
469+
470+
Parameters
471+
----------
472+
input : NDArray
473+
An N-D NDArray.
474+
475+
Returns
476+
-------
477+
output: NDArray
478+
The output NDarray, with same shape as input, where 1 indicates the array element is
479+
equal to positive or negative infinity and 0 otherwise.
480+
481+
Examples
482+
--------
483+
>>> data = mx.nd.array([np.inf, -np.inf, np.NINF, -1])
484+
>>> output = mx.nd.contrib.isinf(data)
485+
>>> output
486+
[1. 1. 1. 0.]
487+
<NDArray 4 @cpu(0)>
488+
"""
489+
return data.abs() == np.inf
490+
491+
def isfinite(data):
492+
"""Performs an element-wise check to determine if the NDArray contains an infinite element
493+
or not.
494+
495+
496+
Parameters
497+
----------
498+
input : NDArray
499+
An N-D NDArray.
500+
501+
Returns
502+
-------
503+
output: NDArray
504+
The output NDarray, with same shape as input, where 1 indicates the array element is
505+
finite i.e. not equal to positive or negative infinity and 0 in places where it is
506+
positive or negative infinity.
507+
508+
Examples
509+
--------
510+
>>> data = mx.nd.array([np.inf, -np.inf, np.NINF, -1])
511+
>>> output = mx.nd.contrib.isfinite(data)
512+
>>> output
513+
[0. 0. 0. 1.]
514+
<NDArray 4 @cpu(0)>
515+
"""
516+
is_data_not_nan = data == data
517+
is_data_not_infinite = data.abs() != np.inf
518+
return ndarray.logical_and(is_data_not_infinite, is_data_not_nan)
519+
520+
def isnan(data):
521+
"""Performs an element-wise check to determine if the NDArray contains a NaN element
522+
or not.
523+
524+
525+
Parameters
526+
----------
527+
input : NDArray
528+
An N-D NDArray.
529+
530+
Returns
531+
-------
532+
output: NDArray
533+
The output NDarray, with same shape as input, where 1 indicates the array element is
534+
NaN i.e. Not a Number and 0 otherwise.
535+
536+
Examples
537+
--------
538+
>>> data = mx.nd.array([np.nan, -1])
539+
>>> output = mx.nd.contrib.isnan(data)
540+
>>> output
541+
[1. 0.]
542+
<NDArray 2 @cpu(0)>
543+
"""
544+
return data != data

tests/python/unittest/test_ndarray.py

+42
Original file line numberDiff line numberDiff line change
@@ -1506,6 +1506,48 @@ def test_dlpack():
15061506
mx.test_utils.assert_almost_equal(a_np, d_np)
15071507
mx.test_utils.assert_almost_equal(a_np, e_np)
15081508

1509+
@with_seed()
1510+
def test_ndarray_is_inf():
1511+
random_dimensions = np.random.randint(2, 5)
1512+
random_shape = [np.random.randint(2, 5) for i in range(random_dimensions)]
1513+
data = mxnet.test_utils.rand_ndarray(random_shape,'default')
1514+
data[0][0] = np.inf
1515+
data[0][1] = -np.inf
1516+
data[1][0] = np.nan
1517+
data[1][1] = 5
1518+
output = mx.nd.contrib.isinf(data)
1519+
expected_output = np.isinf(data.asnumpy())
1520+
np.testing.assert_equal(output.asnumpy(), expected_output.astype(int))
1521+
# astype since numpy functions default return type is boolean array instead of int
1522+
1523+
@with_seed()
1524+
def test_ndarray_is_finite():
1525+
random_dimensions = np.random.randint(2, 5)
1526+
random_shape = [np.random.randint(2, 5) for i in range(random_dimensions)]
1527+
data = mxnet.test_utils.rand_ndarray(random_shape,'default')
1528+
data[0][0] = np.inf
1529+
data[0][1] = -np.inf
1530+
data[1][0] = np.nan
1531+
data[1][1] = 5
1532+
output = mx.nd.contrib.isfinite(data)
1533+
expected_output = np.isfinite(data.asnumpy())
1534+
np.testing.assert_equal(output.asnumpy(), expected_output.astype(int))
1535+
# astype since numpy functions default return type is boolean array instead of int
1536+
1537+
@with_seed()
1538+
def test_ndarray_is_nan():
1539+
random_dimensions = np.random.randint(2, 5)
1540+
random_shape = [np.random.randint(2, 5) for i in range(random_dimensions)]
1541+
data = mxnet.test_utils.rand_ndarray(random_shape,'default')
1542+
data[0][0] = np.inf
1543+
data[0][1] = -np.inf
1544+
data[1][0] = np.nan
1545+
data[1][1] = 5
1546+
output = mx.nd.contrib.isnan(data)
1547+
expected_output = np.isnan(data.asnumpy())
1548+
np.testing.assert_equal(output.asnumpy(), expected_output.astype(int))
1549+
# astype since numpy functions default return type is boolean array instead of int
1550+
15091551
if __name__ == '__main__':
15101552
import nose
15111553
nose.runmodule()

0 commit comments

Comments
 (0)