The step size in ConstantStepSize is currently computed as: (t1-t0)/(step/n_steps (see #666). I understand the rationale and advantages of this, but IMO, the documentation is unclear on this, and the results can be somewhat unexpected for. Specifically:
- If
t1-t0 is not divisible by dt0, the step size will not be equal to dt0, unlike what is stated in the docs: use a constant step size, equal to the dt0 argument of diffrax.diffeqsolve
- When run using
dfx.SaveAt(t0=True, steps=True), the step from t0 to the first solver step is not necessarily the same as the rest of the steps. This might be unexpected for a class called ConstantStepSize. See example below.
- The documentation for the
dt0 argument of diffeqsolve currently states: dt0: The step size to use for the first step. If using fixed step sizes then this will also be the step size for all other steps. (Except the last one, which may be slightly smaller and clipped to t1.) If set as None then the initial step size will be determined automatically. This seems outdated, as with the current implementation of ConstantStepSize, step size is computed such that the last step occurs exactly at t1, without clipping.
Consider this simple example:
import diffrax as dfx
import jax.numpy as jnp
import jax.random as jr
t0, t1, dt0 = 0.0, 1.05, 0.1
key = jr.PRNGKey(0)
class TestModel:
@property
def initial(self):
return jnp.array(0.0)
def drift(self, t, x, args):
return jnp.array(0.0)
def diffusion(self, t, x, args):
return jnp.array(0.0)
def terms(self, key):
process_noise = dfx.UnsafeBrownianPath(
shape=self.initial.shape, key=key, levy_area=dfx.SpaceTimeLevyArea
)
return dfx.MultiTerm(
dfx.ODETerm(self.drift), dfx.ControlTerm(self.diffusion, process_noise)
)
model = TestModel()
terms = model.terms(jr.PRNGKey(0))
sol = dfx.diffeqsolve(
terms,
dfx.Euler(),
t0=t0,
t1=t1,
dt0=dt0,
y0=model.initial,
args={},
saveat= dfx.SaveAt(t0=True, steps=True),
adjoint=dfx.ForwardMode(),
stepsize_controller=dfx.ConstantStepSize(),
)
print("Timesteps from diffrax solution:", sol.ts[jnp.isfinite(sol.ys)])
# prints: Timesteps from diffrax solution: [0. 0.1 0.19090909 0.28636363 0.38181818 0.47727272 0.57272726 0.6681818 0.76363635 0.8590909 0.95454544 1.05 ]
As you can see, the steps are neither constant (if you include the t0) nor are they equal to dt0. Of course, for simulations with many steps, this differences will be minimal, so it is likely not a significant problem in practice, but it might be confusing to some Diffrax beginners such as myself. Curious to hear your thoughts on this!
The step size in ConstantStepSize is currently computed as:
(t1-t0)/(step/n_steps(see #666). I understand the rationale and advantages of this, but IMO, the documentation is unclear on this, and the results can be somewhat unexpected for. Specifically:t1-t0is not divisible bydt0, the step size will not be equal todt0, unlike what is stated in the docs: use a constant step size, equal to the dt0 argument of diffrax.diffeqsolvedfx.SaveAt(t0=True, steps=True), the step fromt0to the first solver step is not necessarily the same as the rest of the steps. This might be unexpected for a class calledConstantStepSize. See example below.dt0argument ofdiffeqsolvecurrently states:dt0: The step size to use for the first step. If using fixed step sizes then this will also be the step size for all other steps. (Except the last one, which may be slightly smaller and clipped to t1.) If set as None then the initial step size will be determined automatically. This seems outdated, as with the current implementation ofConstantStepSize, step size is computed such that the last step occurs exactly att1, without clipping.Consider this simple example:
As you can see, the steps are neither constant (if you include the
t0) nor are they equal todt0. Of course, for simulations with many steps, this differences will be minimal, so it is likely not a significant problem in practice, but it might be confusing to some Diffrax beginners such as myself. Curious to hear your thoughts on this!