Plating Fixes After New Test Patterns#240
Conversation
There were previously issues with batched i.c.s due to NumPyro distribution shape sematics. This commit changes the corresponding mechanisms for i.c.s, os that distributions are sliced properly according to a tree unflatten, broadcast, and tree reassembly.
There was a problem hiding this comment.
Pull request overview
This PR addresses plating issues uncovered by new hierarchical continuous-time model tests, especially batched initial conditions, closure-captured plated parameters, and discretized drift-bias handling.
Changes:
- Refactors plate distribution slicing into shared inference utilities and uses it in filters/smoothers for batched initial conditions.
- Updates vector-field whitelist handling for
Discretizer’sctewrapper. - Adds/removes regression tests and documents sharp edges around closure-captured plated parameters.
Reviewed changes
Copilot reviewed 9 out of 10 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
dynestyx/inference/plate_utils.py |
Centralizes plate-aware array/distribution slicing and treats distributions as opaque for vmap axes. |
dynestyx/inference/filters.py |
Rebuilds batched initial conditions per plate member during batched filtering. |
dynestyx/inference/smoothers.py |
Mirrors filter batched initial-condition handling for smoothing. |
dynestyx/simulators.py |
Reuses shared plate slicing utilities instead of local implementations. |
dynestyx/utils.py |
Allows whitelisted vector-field paths to match through Discretizer’s cte wrapper. |
dynestyx/handlers.py |
Documents closure-capture pitfalls for plated parameters. |
tests/test_plate_vector_initial_mean_continuous_filters.py |
Adds discretized drift-bias regression and un-xfails filter cases. |
tests/test_plate_vector_initial_mean_continuous_drift_diffusion.py |
Converts closure drift to an eqx.Module field and un-xfails related cases. |
tests/test_desirable_shared_initial_mean_hierarchical.py |
Fixes batched A construction and un-xfails hierarchical initial-mean cases. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mattlevine22
left a comment
There was a problem hiding this comment.
this looks great to me, thanks Dan!
- all xfailed tests now pass
- I didn't notice any slowdowns / performance hits when running the hierarchical notebook
- The improved hierarchical tutorial does a nice job of showing how (and why) to make an equinox class drift.
One request---in the sharp-edges section of the tutorial, I think (2) to_event(1) would warrant a more explicit DO and DON'T (I like that 1 does this). I think (3) covers it with "do MVN not Normal", but if you think there is a clear DO and DON'T, I suggest adding that as well.
I also asked copilot to review---maybe it will find a bug or something (though you probably already did some LLM checks).
mattlevine22
left a comment
There was a problem hiding this comment.
Looks great, thanks!!
#237 includes a few new tests with desired use patterns for plating. This PR identifies and fixes a few different problems to close the gap:
Initial condition handling
Initial conditions were a pain point that kept causing problems. This is because of various oddities in how numpyro distributions work; for one, the
batch_shapeis a static property of the pytree, so slicing the initial condition doesn't update the batch shape accordingly. Second was thatMultivariateNormaloften stores a leading singleton dimension for no apparent reason, which broadcasts shapes unintentionally.To address this, we now branch according to if an initial condition is batched or not. If it is, we flatten the corresponding pytree, broadcast all vectors to the full plate-padded shape, and vmap afterwards.
Parameters under closure
An antipattern that was causing issues in tests was that a plated parameter appeared under a Python closure. For example, something like
This uses the original 10-dimensional alpha in the closure of
drift, and is therefore not sliced properly in vmaps later on.The correct pattern is to define such things as equinox modules, with properly-defined parameters. This fixes the failing tests in
tests/test_plate_vector_initial_mean_continuous_drift_diffusion.py.Discretizer with bias
Another issue was that bias variables caused issues under
Discretizer. This is becausebiasfunctions as a "whitelisted" variable under plating, such that it is properly understood to bestate_dimand not a plating dimension. But under discretizer, the originalbiasterm is stored in a correspondingcteobject. These are now included in the whitelist so they're properly vmapped over.Other Changes
I've also updated the tutorial notebook with the discretizer, and a "sharp bits" section explaining some choices as above. I also added some additional tests, and fixed the shape error in one of the xfailed tests.