@eqx.filter_jit
def solve(
D: Callable[ # noqa: N803
[float | jax.Array | np.ndarray[Any, Any]],
float | jax.Array | np.ndarray[Any, Any],
],
*,
b: float,
i: float,
itol: float = 1e-3,
maxiter: int = 100,
) -> Solution:
term = ode(D)
direction = jnp.sign(i - b)
@diffrax.Event
def event(t: float, y: jax.Array, args: object, **kwargs: object) -> jax.Array: # noqa: ARG001
return (direction * y[1] <= 0) | (direction * y[0] > direction * (i - itol))
def shoot(
d_dob: float | jax.Array,
args: None, # noqa: ARG001
) -> tuple[jax.Array | diffrax.Solution]:
sol = diffrax.diffeqsolve(
term,
solver=diffrax.Kvaerno5(),
t0=0,
t1=jnp.inf,
dt0=None,
y0=jnp.array([b, d_dob]),
event=event,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=diffrax.SaveAt(t0=True, t1=True, dense=True),
throw=False,
)
assert sol.ys is not None
residual = jax.lax.select(
sol.result == diffrax.RESULTS.event_occurred,
sol.ys[-1, 0] - i,
direction * jnp.inf,
)
return residual, sol # type: ignore[return-value]
sol: diffrax.Solution = optx.root_find(
shoot,
solver=optx.Bisection(rtol=jnp.inf, atol=itol, expand_if_necessary=True), # type: ignore[call-arg]
y0=0,
max_steps=maxiter,
has_aux=True,
options={"lower": 0, "upper": (i - b) / (2 * jnp.sqrt(D(b)))},
).aux
return Solution(sol, D)