-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix jacfwd, jacrev, and grad for heterogeneous pytrees #7158
Conversation
3a75ab2
to
fde2586
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite a review yet; I'd like to better understand the problem and the PR. To what extent does the original issue remain unsolved if the caller were to tree-map a cast-to-complex function before calling jacfwd
/jacrev
? Would that also work as a solution here?
As I understand it, if we cast what's passed to jacfwd
(or jacrev
), then what it returns might not match the output tree (or input tree) of the original function. Doing that cast ourselves automatically might cause surprise. If callers cast explicitly, they'd arguably better know what to expect. What do you think?
jax/_src/api.py
Outdated
@@ -1045,8 +1030,8 @@ def jacfun(*args, **kwargs): | |||
jac = vmap(pullback)(_std_basis(y)) | |||
jac = jac[0] if isinstance(argnums, int) else jac | |||
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args | |||
jac = tree_map(partial(_unravel_array_into_pytree, y, 0), jac) | |||
return tree_transpose(tree_structure(example_args), tree_structure(y), jac) | |||
jac_tree = tree_multimap(partial(_jacrev_unravel, y), example_args, jac) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fyi, tree_map
is an alias for tree_multimap
these days (cf. here).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nice, I didn't know that! Should I switch to using tree_map
exclusively in my code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You certainly could. We haven't deprecated tree_multimap
, so either is fine, but tree_map
is shorter!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To what extent does the original issue remain unsolved if the caller were to tree-map a cast-to-complex function before calling
jacfwd
/jacrev
?
It wouldn't make any difference because JAX already creates tangents that are complex-valued. The _std_basis
function chooses as its type the tree_reduce
using result_type
of the input tree. So if there are any complex components, it chooses a basis of tangents of complex type.
In my code, this is a problem here, for example. I want to be able to calculate the Jacobian of the gradient. The function expects self.precision
to be real-valued, and must produce a real-valued output. If we cast the inputs to complex, then I need to replace all uses of self.precision
with self.precision.real
to discard the complex components. If I don't do that, I'll get an exception in JAX because the output will be complex, which the gradient doesn't support.
I think it's surprising that a function that's called with real-valued self.precision
should be called with complex values by JAX's Jacobian functions. Also, even if the XLA compiler optimizes it away, I think it's computationally wasteful to pass in synthetic complex components that are ultimately discarded.
And this isn't just a problem with complex types. Imagine if you had a function with one input component that's np.float64
and one that's np.float32
. Wouldn't it be nice if JAX's Jacobian functions passed in the same types for its calculation? Currently, the narrower type is being silently widened, and therefore the Jacobian can contain components that are wider than expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying. If I understand correctly: we're arguably doing the surprising thing already, and either way, casting to complex can confuse a function that was expecting some strictly real-valued arguments.
@mattjj, do you recall whether we had a reason to unify the tangent types when forming the standard basis for jacfwd
/jacrev
?
By the way, should the tests in this PR cover the case where "the narrower type is silently widened" in a way that requires the function implementation to change, as in having to write x.real
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, should the tests in this PR cover the case where "the narrower type is silently widened" in a way that requires the function implementation to change, as in having to write x.real?
Good point. I will add that test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done adding the test. Please let me know if you need me to do anything else.
f90480b
to
38a5610
Compare
"inputs, use vjp or set allow_int to True.") | ||
if (dtypes.issubdtype(aval.dtype, np.integer) or | ||
dtypes.issubdtype(aval.dtype, np.bool_)): | ||
if not allow_int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cleaned this up, but wasn't sure whether np.bool_
should be allowed or not. It was allowed in the previous version when allow_int
is true--in fact, any type was allowed when allow_int
is true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the interest of changing as little as possible beyond the necessary significant change, let's keep allowing bool.
701c562
to
0351172
Compare
@NeilGirdhar Thanks so much for this. Sorry for being slow; the slowness was largely my fault. @froystig explained things to me, and it makes sense! |
Considering the various changes in this PR separately...
What motivates this part? Is this required in order to address the original issue (#7157), or is it a proposed change along the way? One way to understand
It is consistent then for (What might have been a confusing choice of ours was exposing
This seems to tackle the original issue inarguably, and you've made a very clear case for it more generally too. Thanks!
Similar to the first item, is this necessary for tackling the original issue? If possible, we'd like to consider a function holomorphic only if it is a complex-valued function of complex inputs only. |
No problem at all! Thanks for taking a look.
I think I didn't do a good enough job of explaining this bullet point. Yes, this is part of the original fix. This bullet point ensures that what comes out of the Jacobian calculation has the right dtype. The way that both Jacobian functions were already implemented involved flattening the inputs and outputs and calculating a Jacobian array. This array must have the "result type" (widest type) of all the inputs and outputs together. Then, that Jacobian array is "unraveled" (see So, for example, if your input to What this change does is make
Sure. I always write Excellent comment by the way. I agree with you. (I did try to make the errors more specific in this pull request, but I left the documentation alone.)
Great!
Yes, sounds good. (Edited my previous reply.) |
efedcf8
to
88778e7
Compare
45adea9
to
6a6b8f1
Compare
fe74c0c
to
afe147f
Compare
41793f1
to
95fd985
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for keeping this rebased and for your patience here @NeilGirdhar!
Having paged this back in now, the only part I'm unclear on still is the introduced possible-downcasts, namely how this PR casts output dtypes back down to match input dtypes in jacfwd (and vice versa). Why do we need that for the original fix, and why is it correct?
Again, going back to first principles, jacfwd
is the push-forward (of basis vectors for the input tangent space). I understand we want to allow for heterogeneity—that's great and I understand that part of the change—but why should the output cotangent type correspond to anything other than the output type?
The derivative of a real number with respect to a real number should be real.
This might be related to the point of confusion? There is no notion of differentiating one number with respect to another number in the push-forward terminology. There is only the output cotangent type, which is the type of cotangent values resulting from pushing forward the input tangent values (whatever their type is).
If this is a continued misunderstanding of terms then thanks for bearing with me here!
"inputs, use vjp or set allow_int to True.") | ||
if (dtypes.issubdtype(aval.dtype, np.integer) or | ||
dtypes.issubdtype(aval.dtype, np.bool_)): | ||
if not allow_int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the interest of changing as little as possible beyond the necessary significant change, let's keep allowing bool.
I guess one way to answer this is to work backwards from the test. Do you agree with the desired output for
I think you're right that this is where we disagree. The Jacobian is by definition the matrix of partial derivatives of the output with respect to the input. What you're describing is a method to calculate the derivative. I think the definition of the Jacobian defines the semantics whereas the calculation method and its peculiarities shouldn't define the semantics. The reason that the current Jacbobian methods are producing output types of Instead of changing the use of As for why I think it should be that: since the Jacobian is by definition the derivative of the output with respect to the input, it could have equivalently been implemented using a loop for each pair of inputs. In that case, the types would be the result types of each input and output pair. It may be helpful to add some of this explanation as comments in the code if you think that would be helpful to future readers. |
Good way to put it by asking about the test. The test assertion isn't consistent with JAX's intended definition for Take the current
and this is also the type of its Jacobian map. Our intended meaning for Our intended meaning for Note that one ends up with matrices of different types in either case. We could have elicited that with a simpler example too: if you consider a function I understand that you'd like for the meaning of Meanwhile the heterogeneity support here is great. That's why I think we should make this PR just about that! If nothing else, we can always discuss whether to introduce some kind of
The
It's possible that we need to better explain the current |
8c4e121
to
056b04c
Compare
Changed the behavior of `jacfwd`, `jacrev`, and `grad` when the input pytree elements have heterogeneous dtypes, e.g., real and complex elements: * Changed the dtypes of the pytree elements of the Jacobian produced by jacfwd to be those of the input tangent basis. * Changed the dtypes of the pytree elements of the Jacobian produced by jacrev to be those of the output tangent basis. * Changed the dtypes of the pytree elements of the primals and tangents produced by jacfwd and jacrev to be the same as the corresponding elements in the input. Changed the behavior of the flags to `jacfwd` and `jacrev`: * Changed the allow_int flag to only allows integer and Boolean dtypes. Previously, this flag allowed all other types.
056b04c
to
832cf21
Compare
Using "unified types" is actually what Jax is unfortunately currently doing. It seems like we were both on the same page about changing it!
Okay, you've convinced me. Thanks for the patient explanation. I modified the pull request to incorporate your suggestion. Would you mind looking over the tests to make sure I've got it right? |
Thanks @NeilGirdhar!
Indeed, you're right about that. The need to change the current implementation is why we're here.
The change looks good! Thank you both for the discussion and for the contribution! Let me give it an official approval... |
PiperOrigin-RevId: 402832009
Changed the behavior of
jacfwd
,jacrev
, andgrad
when the inputpytree elements have heterogeneous dtypes, e.g., real and complex
elements:
Changed the dtypes of the pytree elements of the Jacobian produced by
jacfwd to be those of the input tangent basis.
Changed the dtypes of the pytree elements of the Jacobian produced by
jacrev to be those of the output tangent basis.
Changed the dtypes of the pytree elements of the primals and tangents
produced by jacfwd and jacrev to be the same as the corresponding
elements in the input.
Changed the behavior of the flags to
jacfwd
andjacrev
:Previously, this flag allowed all other types.
Fixes #7157
Fixes #7780