class InterpolatedSolution(eqx.Module, AbstractSolution):
_sol: PchipInterpolator
_do_dtheta: PchipInterpolator
_Iodtheta: PchipInterpolator
_c: float
_oi: float
def __init__(
self,
o: jax.Array | np.ndarray[Any, Any],
theta: jax.Array | np.ndarray[Any, Any],
/,
*,
b: float | None = None,
i: float | None = None,
) -> None:
self._sol = PchipInterpolator(x=o, y=theta, check=False)
if b is not None:
o = jnp.insert(o, 0, 0)
theta = jnp.insert(theta, 0, b)
if i is not None:
o = jnp.append(o, o[-1] + 1)
theta = jnp.append(theta, i)
else:
i = theta[-1] # type: ignore[assignment]
self._oi: float = o[-1] # type: ignore[assignment]
theta, indices = jnp.unique(theta, return_index=True)
o = o[indices]
inverse = PchipInterpolator(x=theta, y=o, extrapolate=False, check=False)
self._do_dtheta = inverse.derivative()
self._Iodtheta = inverse.antiderivative()
self._c = self._Iodtheta(i)
@boltzmannmethod
def __call__(
self,
o: float | jax.Array | np.ndarray[Any, Any],
) -> float | jax.Array | np.ndarray[Any, Any]:
return self._sol(o) # type: ignore[no-any-return]
def D( # noqa: N802
self,
theta: float | jax.Array | np.ndarray[Any, Any],
/,
) -> float | jax.Array | np.ndarray[Any, Any]:
Iodtheta = self._Iodtheta(theta) - self._c # noqa: N806
do_dtheta = self._do_dtheta(theta)
return jnp.squeeze(-(do_dtheta * Iodtheta) / 2)
def sorptivity(
self, o: float | jax.Array | np.ndarray[Any, Any] = 0
) -> float | jax.Array | np.ndarray[Any, Any]:
Ithetado = self._sol.antiderivative() # noqa: N806
return (Ithetado(self._oi) - Ithetado(o)) - self.i * (self._oi - o) # type: ignore[no-any-return]
@property
def oi(self) -> float:
return self._oi