Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency in documented/expected shapes for estimators #1500

Open
sethaxen opened this issue Mar 20, 2025 · 3 comments
Open

Inconsistency in documented/expected shapes for estimators #1500

sethaxen opened this issue Mar 20, 2025 · 3 comments
Labels
bug Something isn't working

Comments

@sethaxen
Copy link
Contributor

🐛 Bug Description

ConditionalDensityEstimator and RatioEstimator on sbi main use different names, orders, and shape expectations. Also, RatioEstimator's documented shape expectations are incompatible with what it actually expects.

Details

While ConditionalDensityEstimator (CDE) and RatioEstimator (RE) do not share a common parent type, ideally they would still be as consistent as possible. I assume here that x in RE and input in CDE are roughly the same and theta and input are roughly the same, respectively.

Inconsistent attributes

CDE uses attributes (input_shape, condition_shape):

self, net: nn.Module, input_shape: torch.Size, condition_shape: torch.Size

RE uses attributes (theta_shape, x_shape):
theta_shape: torch.Size | tuple[int, ...],
x_shape: torch.Size | tuple[int, ...],

Their order in the constructors are reversed.

Inconsistent shapes

CDE documents that the shape of the input is (sample_dim, batch_dim, *input_shape) and the shape of condition is (batch_dim, *condition_shape). While it doesn't enforce this, at least some children do. e.g. MixedDensityEstimator:

batch_dim = condition.shape[0]

input_sample_dim, input_batch_dim = input.shape[:2]

RE documents that the shape of x is (batch_dim, *x_shape) and that the shape of theta is (sample_dim, batch_dim, *theta_shape). Note that the two classes differ in which of the arguments has a sample_dim. However, RE actually enforces that x is (*batch_shape, *x_shape) and theta is (*batch_shape, *theta_shape), i.e. the two arguments share the same prefix, which is incompatible with the documented shapes:

theta_prefix = theta.shape[: -len(self.theta_shape)]
x_prefix = x.shape[: -len(self.x_shape)]
if theta_prefix != x_prefix:

Inconsistent argument order in methods

While CDE.log_prob and RE.unnormalized_log_ratio are not equivalent, one would expect their order of arguments to be similar. However, the former takes the order (input, condition):

def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:

while the latter takes (theta, x):
def unnormalized_log_ratio(self, theta: Tensor, x: Tensor) -> Tensor:

📌 Additional Context

Torch distributions (and Pyro) implement both sample and log_prob, supporting arbitrary batch_shape and sample_shape (not just a single dimension). While not necessary, it would be nice if these methods supported the same shape conventions. This would in particular simplify #1491.

@sethaxen sethaxen added the bug Something isn't working label Mar 20, 2025
@sethaxen
Copy link
Contributor Author

RE inconsistency in shape conventions between docstrings and what the code expects, I wonder if it would be useful to use jaxtyping to provide shapes in type hints. These then can be automatically rendered in the documentation and the docstrings themselves don't need to include shape information. Contrary to its name, jaxtyping has no jax dependency and supports also torch.Tensor hints.

GPyTorch, for example, uses it in a few places, e.g.:
https://github.com/cornellius-gp/gpytorch/blob/b017b9c3fe4de526f7a2243ce12ce2305862c90b/gpytorch/variational/nearest_neighbor_variational_strategy.py#L177-L184

While static type checkers can't check the shape conventions are followed, runtime type checkers like beartype can do so very quickly. sbi can either depend on one of these type-checkers or it can simply run one in the test suite. e.g. for beartype, see:

@sethaxen
Copy link
Contributor Author

If there's interest in the jaxtyping approach, I'd be happy to hack together a quick proof-of-concept.

@sethaxen
Copy link
Contributor Author

I made a quick attempt to prototype jaxtyping+beartype in SBI, and I don't think it's a good fit for the following reasons:

  • jaxtyping currently doesn't support multiple variadic shapes for a single tensor (e.g. batch_dim *x_shape is fine, but *batch_shape *x_shape is not). Because string interpolation is allowed, one could e.g. for RatioEstimator interpolate self.x_shape into the string, but I think one would actually need f"*batch_shape {' '.join(map(str, self.x_shape))}", which would just be ugly in the documentation.
  • beartype successfully caught type mismatches, but the errors raised were not very descriptive. It might make sense to have the runtime checking in the test suite, but I wouldn't rely on it to be the main shape checking for user-provided inputs. Here's an example error:
E   beartype.roar.BeartypeCallHintReturnViolation: Method sbi.neural_nets.ratio_estimators.RatioEstimator.combine_theta_and_x() return "tensor([[ 1.1846,  0.7002,  1.0900,  0.6601],
E           [-1.4753, -1.7527,  1.0900,  0.660...]])" violates type hint <class 'jaxtyping.Float[Tensor, 'sample_dim batch_dim combined_event_dim']'>, as this array has 2 dimensions, not the 3 expected by the type hint.

It might still be worth using jaxtyping for methods that don't require multiple variadic shapes. GPyTorch takes this approach of piecemeal using jaxtyping for a few operators. But it's not a silver bullet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant