Skip to content

Plating Fixes After New Test Patterns#240

Merged
mattlevine22 merged 10 commits into
mainfrom
dw-update-dists-under-plates
May 29, 2026
Merged

Plating Fixes After New Test Patterns#240
mattlevine22 merged 10 commits into
mainfrom
dw-update-dists-under-plates

Conversation

@DanWaxman
Copy link
Copy Markdown
Collaborator

#237 includes a few new tests with desired use patterns for plating. This PR identifies and fixes a few different problems to close the gap:

Initial condition handling

Initial conditions were a pain point that kept causing problems. This is because of various oddities in how numpyro distributions work; for one, the batch_shape is a static property of the pytree, so slicing the initial condition doesn't update the batch shape accordingly. Second was that MultivariateNormal often stores a leading singleton dimension for no apparent reason, which broadcasts shapes unintentionally.

To address this, we now branch according to if an initial condition is batched or not. If it is, we flatten the corresponding pytree, broadcast all vectors to the full plate-padded shape, and vmap afterwards.

Parameters under closure

An antipattern that was causing issues in tests was that a plated parameter appeared under a Python closure. For example, something like

with dsx.plate("M", 10):
    alpha = numpyro.sample(...)
    def drift(x, u, t): foo(alpha)

This uses the original 10-dimensional alpha in the closure of drift, and is therefore not sliced properly in vmaps later on.

The correct pattern is to define such things as equinox modules, with properly-defined parameters. This fixes the failing tests in tests/test_plate_vector_initial_mean_continuous_drift_diffusion.py.

Discretizer with bias

Another issue was that bias variables caused issues under Discretizer. This is because bias functions as a "whitelisted" variable under plating, such that it is properly understood to be state_dim and not a plating dimension. But under discretizer, the original bias term is stored in a corresponding cte object. These are now included in the whitelist so they're properly vmapped over.

Other Changes

I've also updated the tutorial notebook with the discretizer, and a "sharp bits" section explaining some choices as above. I also added some additional tests, and fixed the shape error in one of the xfailed tests.

DanWaxman added 7 commits May 28, 2026 13:29
There were previously issues with batched i.c.s due to NumPyro distribution shape sematics. This commit changes the corresponding mechanisms for i.c.s, os that distributions are sliced properly according to a tree unflatten, broadcast, and tree reassembly.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses plating issues uncovered by new hierarchical continuous-time model tests, especially batched initial conditions, closure-captured plated parameters, and discretized drift-bias handling.

Changes:

  • Refactors plate distribution slicing into shared inference utilities and uses it in filters/smoothers for batched initial conditions.
  • Updates vector-field whitelist handling for Discretizer’s cte wrapper.
  • Adds/removes regression tests and documents sharp edges around closure-captured plated parameters.

Reviewed changes

Copilot reviewed 9 out of 10 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
dynestyx/inference/plate_utils.py Centralizes plate-aware array/distribution slicing and treats distributions as opaque for vmap axes.
dynestyx/inference/filters.py Rebuilds batched initial conditions per plate member during batched filtering.
dynestyx/inference/smoothers.py Mirrors filter batched initial-condition handling for smoothing.
dynestyx/simulators.py Reuses shared plate slicing utilities instead of local implementations.
dynestyx/utils.py Allows whitelisted vector-field paths to match through Discretizer’s cte wrapper.
dynestyx/handlers.py Documents closure-capture pitfalls for plated parameters.
tests/test_plate_vector_initial_mean_continuous_filters.py Adds discretized drift-bias regression and un-xfails filter cases.
tests/test_plate_vector_initial_mean_continuous_drift_diffusion.py Converts closure drift to an eqx.Module field and un-xfails related cases.
tests/test_desirable_shared_initial_mean_hierarchical.py Fixes batched A construction and un-xfails hierarchical initial-mean cases.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread dynestyx/inference/smoothers.py
Copy link
Copy Markdown
Collaborator

@mattlevine22 mattlevine22 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks great to me, thanks Dan!

  • all xfailed tests now pass
  • I didn't notice any slowdowns / performance hits when running the hierarchical notebook
  • The improved hierarchical tutorial does a nice job of showing how (and why) to make an equinox class drift.

One request---in the sharp-edges section of the tutorial, I think (2) to_event(1) would warrant a more explicit DO and DON'T (I like that 1 does this). I think (3) covers it with "do MVN not Normal", but if you think there is a clear DO and DON'T, I suggest adding that as well.

I also asked copilot to review---maybe it will find a bug or something (though you probably already did some LLM checks).

@DanWaxman DanWaxman requested a review from mattlevine22 May 29, 2026 14:02
Copy link
Copy Markdown
Collaborator

@mattlevine22 mattlevine22 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!!

@mattlevine22 mattlevine22 merged commit 6c3b87e into main May 29, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants