Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jacfwd, jacrev, and grad for heterogeneous pytrees #7158

Merged
merged 1 commit into from
Oct 13, 2021

Conversation

NeilGirdhar
Copy link
Contributor

@NeilGirdhar NeilGirdhar commented Jul 1, 2021

Changed the behavior of jacfwd, jacrev, and grad when the input
pytree elements have heterogeneous dtypes, e.g., real and complex
elements:

  • Changed the dtypes of the pytree elements of the Jacobian produced by
    jacfwd to be those of the input tangent basis.

  • Changed the dtypes of the pytree elements of the Jacobian produced by
    jacrev to be those of the output tangent basis.

  • Changed the dtypes of the pytree elements of the primals and tangents
    produced by jacfwd and jacrev to be the same as the corresponding
    elements in the input.

Changed the behavior of the flags to jacfwd and jacrev:

  • Changed the allow_int flag to only allows integer and Boolean dtypes.
    Previously, this flag allowed all other types.

Fixes #7157
Fixes #7780

Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite a review yet; I'd like to better understand the problem and the PR. To what extent does the original issue remain unsolved if the caller were to tree-map a cast-to-complex function before calling jacfwd/jacrev? Would that also work as a solution here?

As I understand it, if we cast what's passed to jacfwd (or jacrev), then what it returns might not match the output tree (or input tree) of the original function. Doing that cast ourselves automatically might cause surprise. If callers cast explicitly, they'd arguably better know what to expect. What do you think?

jax/_src/api.py Outdated
@@ -1045,8 +1030,8 @@ def jacfun(*args, **kwargs):
jac = vmap(pullback)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac = tree_map(partial(_unravel_array_into_pytree, y, 0), jac)
return tree_transpose(tree_structure(example_args), tree_structure(y), jac)
jac_tree = tree_multimap(partial(_jacrev_unravel, y), example_args, jac)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi, tree_map is an alias for tree_multimap these days (cf. here).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice, I didn't know that! Should I switch to using tree_map exclusively in my code?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You certainly could. We haven't deprecated tree_multimap, so either is fine, but tree_map is shorter!

Copy link
Contributor Author

@NeilGirdhar NeilGirdhar Jul 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To what extent does the original issue remain unsolved if the caller were to tree-map a cast-to-complex function before calling jacfwd/jacrev?

It wouldn't make any difference because JAX already creates tangents that are complex-valued. The _std_basis function chooses as its type the tree_reduce using result_type of the input tree. So if there are any complex components, it chooses a basis of tangents of complex type.

In my code, this is a problem here, for example. I want to be able to calculate the Jacobian of the gradient. The function expects self.precision to be real-valued, and must produce a real-valued output. If we cast the inputs to complex, then I need to replace all uses of self.precision with self.precision.real to discard the complex components. If I don't do that, I'll get an exception in JAX because the output will be complex, which the gradient doesn't support.

I think it's surprising that a function that's called with real-valued self.precision should be called with complex values by JAX's Jacobian functions. Also, even if the XLA compiler optimizes it away, I think it's computationally wasteful to pass in synthetic complex components that are ultimately discarded.

And this isn't just a problem with complex types. Imagine if you had a function with one input component that's np.float64 and one that's np.float32. Wouldn't it be nice if JAX's Jacobian functions passed in the same types for its calculation? Currently, the narrower type is being silently widened, and therefore the Jacobian can contain components that are wider than expected.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying. If I understand correctly: we're arguably doing the surprising thing already, and either way, casting to complex can confuse a function that was expecting some strictly real-valued arguments.

@mattjj, do you recall whether we had a reason to unify the tangent types when forming the standard basis for jacfwd/jacrev?

By the way, should the tests in this PR cover the case where "the narrower type is silently widened" in a way that requires the function implementation to change, as in having to write x.real?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, should the tests in this PR cover the case where "the narrower type is silently widened" in a way that requires the function implementation to change, as in having to write x.real?

Good point. I will add that test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done adding the test. Please let me know if you need me to do anything else.

@NeilGirdhar NeilGirdhar force-pushed the fix_basis branch 2 times, most recently from f90480b to 38a5610 Compare July 15, 2021 03:51
"inputs, use vjp or set allow_int to True.")
if (dtypes.issubdtype(aval.dtype, np.integer) or
dtypes.issubdtype(aval.dtype, np.bool_)):
if not allow_int:
Copy link
Contributor Author

@NeilGirdhar NeilGirdhar Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cleaned this up, but wasn't sure whether np.bool_ should be allowed or not. It was allowed in the previous version when allow_int is true--in fact, any type was allowed when allow_int is true.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the interest of changing as little as possible beyond the necessary significant change, let's keep allowing bool.

@NeilGirdhar NeilGirdhar force-pushed the fix_basis branch 3 times, most recently from 701c562 to 0351172 Compare July 27, 2021 06:09
@mattjj
Copy link
Collaborator

mattjj commented Jul 27, 2021

@NeilGirdhar Thanks so much for this. Sorry for being slow; the slowness was largely my fault.

@froystig explained things to me, and it makes sense!

@froystig
Copy link
Member

Considering the various changes in this PR separately...

  • Changed the dtypes of the pytree elements of the Jacobian produced
    by jacfwd and jacrev to be the result type of the corresponding input
    component and output component.

What motivates this part? Is this required in order to address the original issue (#7157), or is it a proposed change along the way?

One way to understand jacfwd and jacrev is:

  • jacfwd means "form the standard basis of the input (tangent) domain and push it forward"
  • jacrev means "form the standard basis of the output (cotangent) domain and pull it back"

It is consistent then for jacfwd to match the output dtype and for jacrev to match the input dtype. If I understand correctly, this change would reverse that convention.

(What might have been a confusing choice of ours was exposing jacobian as an alias of jacrev in particular. See for instance #6638 (comment).)

  • Changed the dtypes of the pytree elements of the primals and tangents
    produced by jacfwd and jacrev to be the same as the corresponding
    elements in the input. This was not the case when the input pytree's
    elements had heterogeneous dtypes.

This seems to tackle the original issue inarguably, and you've made a very clear case for it more generally too. Thanks!

  • Changed the holomorphic flag of jacfwd and jacrev to no longer force all
    components of the input and output pytrees to be complex (some or all
    components can be real, for example).

Similar to the first item, is this necessary for tackling the original issue? If possible, we'd like to consider a function holomorphic only if it is a complex-valued function of complex inputs only.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Jul 28, 2021

Thanks so much for this. Sorry for being slow; the slowness was largely my fault.

No problem at all! Thanks for taking a look.

What motivates this part? Is this required in order to address the original issue (#7157), or is it a proposed change along the way?

I think I didn't do a good enough job of explaining this bullet point. Yes, this is part of the original fix. This bullet point ensures that what comes out of the Jacobian calculation has the right dtype.

The way that both Jacobian functions were already implemented involved flattening the inputs and outputs and calculating a Jacobian array. This array must have the "result type" (widest type) of all the inputs and outputs together. Then, that Jacobian array is "unraveled" (see _jacfwd_unravel and _jacrev_unravel) into a pytree of pytrees.

So, for example, if your input to f is a dataclass C with a real-valued field a and a complex-valued field b, then the Jacobian of f produces a structure like: C(a=C(a=aa, b=ab), b=C(a=ba, b=bb)) where all of the fields are now complex.

What this change does is make aa real-valued since it is the derivative of the component a of the output of f with respect to the a component of the input. The derivative of a real number with respect to a real number should be real.

(What might have been a confusing choice of ours was exposing jacobian as an alias of jacrev in particular. See for instance #6638 (comment).)

Sure. I always write jacfwd or jacrev in my own code.

Excellent comment by the way. I agree with you. (I did try to make the errors more specific in this pull request, but I left the documentation alone.)

This seems to tackle the original issue inarguably, and you've made a very clear case for it more generally too. Thanks!

Great!

Similar to the first item, is this necessary for tackling the original issue? If possible, we'd like to consider a function holomorphic only if it is a complex-valued function of complex inputs only.

Yes, sounds good. (Edited my previous reply.)

@NeilGirdhar NeilGirdhar force-pushed the fix_basis branch 2 times, most recently from efedcf8 to 88778e7 Compare August 3, 2021 03:23
@NeilGirdhar NeilGirdhar force-pushed the fix_basis branch 2 times, most recently from 45adea9 to 6a6b8f1 Compare August 3, 2021 14:58
@NeilGirdhar NeilGirdhar changed the title Fix holomorphic jacfwd for pytrees having both real and complex elements Fix jacfwd, jacrev, and grad for heterogeneous pytrees Aug 3, 2021
@NeilGirdhar NeilGirdhar force-pushed the fix_basis branch 2 times, most recently from fe74c0c to afe147f Compare August 7, 2021 17:48
@NeilGirdhar NeilGirdhar force-pushed the fix_basis branch 2 times, most recently from 41793f1 to 95fd985 Compare September 2, 2021 08:15
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for keeping this rebased and for your patience here @NeilGirdhar!

Having paged this back in now, the only part I'm unclear on still is the introduced possible-downcasts, namely how this PR casts output dtypes back down to match input dtypes in jacfwd (and vice versa). Why do we need that for the original fix, and why is it correct?

Again, going back to first principles, jacfwd is the push-forward (of basis vectors for the input tangent space). I understand we want to allow for heterogeneity—that's great and I understand that part of the change—but why should the output cotangent type correspond to anything other than the output type?

The derivative of a real number with respect to a real number should be real.

This might be related to the point of confusion? There is no notion of differentiating one number with respect to another number in the push-forward terminology. There is only the output cotangent type, which is the type of cotangent values resulting from pushing forward the input tangent values (whatever their type is).

If this is a continued misunderstanding of terms then thanks for bearing with me here!

"inputs, use vjp or set allow_int to True.")
if (dtypes.issubdtype(aval.dtype, np.integer) or
dtypes.issubdtype(aval.dtype, np.bool_)):
if not allow_int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the interest of changing as little as possible beyond the necessary significant change, let's keep allowing bool.

@froystig froystig linked an issue Sep 16, 2021 that may be closed by this pull request
@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Sep 17, 2021

Having paged this back in now, the only part I'm unclear on still is the introduced possible-downcasts, namely how this PR casts output dtypes back down to match input dtypes in jacfwd (and vice versa). Why do we need that for the original fix, and why is it correct?

I guess one way to answer this is to work backwards from the test. Do you agree with the desired output for test_heterogeneous_jacfwd and test_heterogeneous_jacrev? In particular that the first component should have type float16? I discuss this further below…

Again, going back to first principles, jacfwd is the push-forward (of basis vectors for the input tangent space). I understand we want to allow for heterogeneity—that's great and I understand that part of the change—but why should the output cotangent type correspond to anything other than the output type?

The derivative of a real number with respect to a real number should be real.

This might be related to the point of confusion? There is no notion of differentiating one number with respect to another number in the push-forward terminology. There is only the output cotangent type, which is the type of cotangent values resulting from pushing forward the input tangent values (whatever their type is).

I think you're right that this is where we disagree. The Jacobian is by definition the matrix of partial derivatives of the output with respect to the input. What you're describing is a method to calculate the derivative. I think the definition of the Jacobian defines the semantics whereas the calculation method and its peculiarities shouldn't define the semantics.

The reason that the current Jacbobian methods are producing output types of float32 in the mentioned tests is because of a peculiarity of the current implementation. What is happening is that_std_basis uses an ingenious call to np.eye (which explicitly chooses the result type of all the inputs) and then unravels that resulting matrix into the pytree of pytrees that compose the input primals/tangents or output cotangents. This has the peculiarity that all of the leaves of that pytree of pytrees has the same result type.

Instead of changing the use of np.eye, which I thought was quite clever, I thought it would be easier to make the unraveling downcast the types back to the result type of each input pair (rather than the result type of all the inputs).

As for why I think it should be that: since the Jacobian is by definition the derivative of the output with respect to the input, it could have equivalently been implemented using a loop for each pair of inputs. In that case, the types would be the result types of each input and output pair.

It may be helpful to add some of this explanation as comments in the code if you think that would be helpful to future readers.

@froystig
Copy link
Member

froystig commented Oct 2, 2021

I guess one way to answer this is to work backwards from the test. Do you agree with the desired output for test_heterogeneous_jacfwd and test_heterogeneous_jacrev? In particular that the first component should have type float16?

Good way to put it by asking about the test. The test assertion isn't consistent with JAX's intended definition for jacfwd and jacrev. This isn't because of the heterogeneity changes, but because of the unification with the opposite input/output type.

Take the current test_heterogeneous_jac{fwd,rev}. The function f specialized to its argument has type:

(f16, f32) -> (f16, f32, f32)

and this is also the type of its Jacobian map.

Our intended meaning for jacfwd is the image of the input tangent basis through this map (cf. pushforward). The input tangent space has two standard basis vectors, so the output of jacfwd should be two columns of type (f16, f32, f32), i.e. the output tangent type.

Our intended meaning for jacrev is the pullback of the output tangent basis, meaning the transpose of its image through the transposed Jacobian map. The output tangent space has three standard basis vectors, so the output of jacrev should be three rows of type (f16, f32), i.e. the input tangent type.

Note that one ends up with matrices of different types in either case. We could have elicited that with a simpler example too: if you consider a function g :: f16 -> f32, then the result of jacfwd is f16 and the result of jacrev is f32.

I understand that you'd like for the meaning of jacfwd and jacrev to change to one where input/output types are unified. That would be a significant change and it's not clearly one that we want to make. We like our definitions of jacfwd and jacrev for several reasons. A relevant one here is that they're well-defined even when input/output types are arbitrarily different and may not have a sensible unification.

Meanwhile the heterogeneity support here is great. That's why I think we should make this PR just about that! If nothing else, we can always discuss whether to introduce some kind of unified_jacfwd and unified_jacrev separately. Am I correct that these changes are orthogonal?

The Jacobian is by definition the matrix of partial derivatives of the output with respect to the input. What you're describing is a method to calculate the derivative. I think the definition of the Jacobian defines the semantics whereas the calculation method and its peculiarities shouldn't define the semantics.

The jacfwd and jacrev functions are not methods that calculate the same thing (they don't!). For a function f :: a -> b and a linearization point x :: a, the Jacobian map of f at x – call it jac(f, x) – is a linear map from tangents of a to tangents of b. Being a map from one space to another, it is not an element (or subset) of either space. The functions jacfwd and jacrev are what produce (different sets of) elements from either space. The semantics I'm describing determine what they compute (not how they must compute it), and imply that those objects can differ.

It may be helpful to add some of this explanation as comments in the code if you think that would be helpful to future readers.

It's possible that we need to better explain the current jacrev and jacfwd, since their names alone are clearly not enough. That goes along with #6638 remaining open for documentation.

@NeilGirdhar NeilGirdhar force-pushed the fix_basis branch 4 times, most recently from 8c4e121 to 056b04c Compare October 2, 2021 22:54
Changed the behavior of `jacfwd`, `jacrev`, and `grad` when the input
pytree elements have heterogeneous dtypes, e.g., real and complex
elements:

* Changed the dtypes of the pytree elements of the Jacobian produced by
  jacfwd to be those of the input tangent basis.

* Changed the dtypes of the pytree elements of the Jacobian produced by
  jacrev to be those of the output tangent basis.

* Changed the dtypes of the pytree elements of the primals and tangents
  produced by jacfwd and jacrev to be the same as the corresponding
  elements in the input.

Changed the behavior of the flags to `jacfwd` and `jacrev`:

* Changed the allow_int flag to only allows integer and Boolean dtypes.
  Previously, this flag allowed all other types.
@NeilGirdhar
Copy link
Contributor Author

I understand that you'd like for the meaning of jacfwd and jacrev to change to one where input/output types are unified.

Using "unified types" is actually what Jax is unfortunately currently doing. It seems like we were both on the same page about changing it!

Our intended meaning for jacfwd is the image of the input tangent basis through this map (cf. pushforward).
Our intended meaning for jacrev is the pullback of the output tangent basis, meaning the transpose of its image through the transposed Jacobian map.

Okay, you've convinced me. Thanks for the patient explanation. I modified the pull request to incorporate your suggestion. Would you mind looking over the tests to make sure I've got it right?

@froystig
Copy link
Member

Thanks @NeilGirdhar!

Using "unified types" is actually what Jax is unfortunately currently doing. It seems like we were both on the same page about changing it!

Indeed, you're right about that. The need to change the current implementation is why we're here.

Okay, you've convinced me. Thanks for the patient explanation. I modified the pull request to incorporate your suggestion. Would you mind looking over the tests to make sure I've got it right?

The change looks good! Thank you both for the discussion and for the contribution! Let me give it an official approval...

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Oct 13, 2021
@copybara-service copybara-service bot merged commit 4267bed into jax-ml:main Oct 13, 2021
@NeilGirdhar NeilGirdhar deleted the fix_basis branch October 13, 2021 10:34
romanngg added a commit to google/neural-tangents that referenced this pull request Oct 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
4 participants