-
Notifications
You must be signed in to change notification settings - Fork 180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inconsistency in documented/expected shapes for estimators #1500
Comments
RE inconsistency in shape conventions between docstrings and what the code expects, I wonder if it would be useful to use jaxtyping to provide shapes in type hints. These then can be automatically rendered in the documentation and the docstrings themselves don't need to include shape information. Contrary to its name, jaxtyping has no jax dependency and supports also GPyTorch, for example, uses it in a few places, e.g.: While static type checkers can't check the shape conventions are followed, runtime type checkers like beartype can do so very quickly. |
If there's interest in the jaxtyping approach, I'd be happy to hack together a quick proof-of-concept. |
I made a quick attempt to prototype jaxtyping+beartype in SBI, and I don't think it's a good fit for the following reasons:
It might still be worth using jaxtyping for methods that don't require multiple variadic shapes. GPyTorch takes this approach of piecemeal using jaxtyping for a few operators. But it's not a silver bullet. |
🐛 Bug Description
ConditionalDensityEstimator
andRatioEstimator
onsbi
main use different names, orders, and shape expectations. Also,RatioEstimator
's documented shape expectations are incompatible with what it actually expects.Details
While
ConditionalDensityEstimator
(CDE
) andRatioEstimator
(RE
) do not share a common parent type, ideally they would still be as consistent as possible. I assume here thatx
inRE
andinput
in CDE are roughly the same andtheta
andinput
are roughly the same, respectively.Inconsistent attributes
CDE
uses attributes(input_shape, condition_shape)
:sbi/sbi/neural_nets/estimators/base.py
Line 135 in a1d7555
RE
uses attributes(theta_shape, x_shape)
:sbi/sbi/neural_nets/ratio_estimators.py
Lines 29 to 30 in a1d7555
Their order in the constructors are reversed.
Inconsistent shapes
CDE
documents that the shape of theinput
is(sample_dim, batch_dim, *input_shape)
and the shape ofcondition
is(batch_dim, *condition_shape)
. While it doesn't enforce this, at least some children do. e.g.MixedDensityEstimator
:sbi/sbi/neural_nets/estimators/mixed_density_estimator.py
Line 74 in a1d7555
sbi/sbi/neural_nets/estimators/mixed_density_estimator.py
Line 141 in a1d7555
RE
documents that the shape ofx
is(batch_dim, *x_shape)
and that the shape oftheta
is(sample_dim, batch_dim, *theta_shape)
. Note that the two classes differ in which of the arguments has asample_dim
. However,RE
actually enforces thatx
is(*batch_shape, *x_shape)
andtheta
is(*batch_shape, *theta_shape)
, i.e. the two arguments share the same prefix, which is incompatible with the documented shapes:sbi/sbi/neural_nets/ratio_estimators.py
Lines 126 to 128 in a1d7555
Inconsistent argument order in methods
While
CDE.log_prob
andRE.unnormalized_log_ratio
are not equivalent, one would expect their order of arguments to be similar. However, the former takes the order(input, condition)
:sbi/sbi/neural_nets/estimators/base.py
Line 155 in a1d7555
while the latter takes
(theta, x)
:sbi/sbi/neural_nets/ratio_estimators.py
Line 156 in a1d7555
📌 Additional Context
Torch distributions (and Pyro) implement both
sample
andlog_prob
, supporting arbitrarybatch_shape
andsample_shape
(not just a single dimension). While not necessary, it would be nice if these methods supported the same shape conventions. This would in particular simplify #1491.The text was updated successfully, but these errors were encountered: