Skip to content

Solving problems

frontx.solve(D: Callable[[float | jax.Array | np.ndarray[Any, Any]], float | jax.Array | np.ndarray[Any, Any]], *, b: float, i: float, itol: float = 0.001, maxiter: int = 100) -> Solution

Source code in frontx/__init__.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@eqx.filter_jit
def solve(
    D: Callable[  # noqa: N803
        [float | jax.Array | np.ndarray[Any, Any]],
        float | jax.Array | np.ndarray[Any, Any],
    ],
    *,
    b: float,
    i: float,
    itol: float = 1e-3,
    maxiter: int = 100,
) -> Solution:
    term = ode(D)
    direction = jnp.sign(i - b)

    @diffrax.Event
    def event(t: float, y: jax.Array, args: object, **kwargs: object) -> jax.Array:  # noqa: ARG001
        return (direction * y[1] <= 0) | (direction * y[0] > direction * (i - itol))

    def shoot(
        d_dob: float | jax.Array,
        args: None,  # noqa: ARG001
    ) -> tuple[jax.Array | diffrax.Solution]:
        sol = diffrax.diffeqsolve(
            term,
            solver=diffrax.Kvaerno5(),
            t0=0,
            t1=jnp.inf,
            dt0=None,
            y0=jnp.array([b, d_dob]),
            event=event,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(t0=True, t1=True, dense=True),
            throw=False,
        )
        assert sol.ys is not None
        residual = jax.lax.select(
            sol.result == diffrax.RESULTS.event_occurred,
            sol.ys[-1, 0] - i,
            direction * jnp.inf,
        )
        return residual, sol  # type: ignore[return-value]

    sol: diffrax.Solution = optx.root_find(
        shoot,
        solver=optx.Bisection(rtol=jnp.inf, atol=itol, expand_if_necessary=True),  # type: ignore[call-arg]
        y0=0,
        max_steps=maxiter,
        has_aux=True,
        options={"lower": 0, "upper": (i - b) / (2 * jnp.sqrt(D(b)))},
    ).aux

    return Solution(sol, D)