Skip to content

Commit

Permalink
Sharp bits: add note on subnormal flush-to-zero
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 21, 2025
1 parent 66037d1 commit b3cc38b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
13 changes: 13 additions & 0 deletions docs/notebooks/Common_Gotchas_in_JAX.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,19 @@
"\n",
" ```\n",
" This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n",
"- When operating on [Subnormal](https://en.wikipedia.org/wiki/Subnormal_number)\n",
" floating point numbers, JAX operations use flush-to-zero semantics on some\n",
" backends. For example:\n",
" ```python\n",
" >>> import jax.numpy as jnp\n",
" >>> subnormal = jnp.float32(1E-45)\n",
" >>> subnormal # subnormals are representable\n",
" Array(1.e-45, dtype=float32)\n",
" >>> subnormal + 0 # but are flushed to zero within operations\n",
" Array(0., dtype=float32)\n",
" ```\n",
" The detailed operation semantics for subnormal values will generally\n",
" vary depending on the backend.\n",
"\n",
"## 🔪 Sharp bits covered in tutorials\n",
"- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators.\n",
Expand Down
13 changes: 13 additions & 0 deletions docs/notebooks/Common_Gotchas_in_JAX.md
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,19 @@ Many such cases are discussed in detail in the sections above; here we list seve

```
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.
- When operating on [Subnormal](https://en.wikipedia.org/wiki/Subnormal_number)
floating point numbers, JAX operations use flush-to-zero semantics on some
backends. For example:
```python
>>> import jax.numpy as jnp
>>> subnormal = jnp.float32(1E-45)
>>> subnormal # subnormals are representable
Array(1.e-45, dtype=float32)
>>> subnormal + 0 # but are flushed to zero within operations
Array(0., dtype=float32)
```
The detailed operation semantics for subnormal values will generally
vary depending on the backend.

## 🔪 Sharp bits covered in tutorials
- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators.
Expand Down

0 comments on commit b3cc38b

Please sign in to comment.