NEOPAX._model_api

Classes

ModelCapabilities

ModelValidationContext

Functions

make_validation_context(→ ModelValidationContext)

Build a small default validation context for user model registration.

_as_jax_array(→ jax.Array)

_assert_jax_compatible_pytree(→ None)

_assert_broadcastable(→ None)

validate_transport_flux_output(→ None)

validate_source_output(→ None)

_jax_shape_smoke_test(→ None)

_jax_jit_smoke_test(→ None)

_jax_autodiff_smoke_test(→ None)

_jax_vmap_smoke_test(→ None)

validate_transport_flux_builder(→ None)

validate_source_model_builder(→ None)

transport_model(name, registry_fn, **register_kwargs)

source_model(name, registry_fn, **register_kwargs)

Module Contents

class NEOPAX._model_api.ModelCapabilities
jit_safe: bool = True
autodiff_safe: bool = False
vmap_safe: bool = False
local_evaluator: bool = False
face_fluxes: bool = False
class NEOPAX._model_api.ModelValidationContext
builder_kwargs: dict[str, Any]
state: Any
species: Any | None = None
geometry: Any | None = None
face_state: Any | None = None
NEOPAX._model_api.make_validation_context(*, builder_kwargs: dict[str, Any] | None = None, n_species: int = 2, n_radial: int = 8, density_value: float = 1.0, temperature_value: float = 2.0, er_value: float = 0.0, species: Any | None = None, geometry: Any | None = None, include_face_state: bool = True) ModelValidationContext

Build a small default validation context for user model registration.

This helper is intentionally lightweight and avoids depending on geometry builders or larger runtime objects. It is meant for registration-time smoke tests, not for physical validation.

NEOPAX._model_api._as_jax_array(value: Any) jax.Array
NEOPAX._model_api._assert_jax_compatible_pytree(value: Any, *, where: str) None
NEOPAX._model_api._assert_broadcastable(value: Any, target_shape: tuple[int, Ellipsis], *, where: str) None
NEOPAX._model_api.validate_transport_flux_output(output: Any, state: Any, *, where: str = 'transport model output') None
NEOPAX._model_api.validate_source_output(output: Any, state: Any, *, where: str = 'source model output') None
NEOPAX._model_api._jax_shape_smoke_test(callable_obj: Callable[Ellipsis, Any], *args: Any, where: str) None
NEOPAX._model_api._jax_jit_smoke_test(callable_obj: Callable[Ellipsis, Any], *args: Any, where: str) None
NEOPAX._model_api._jax_autodiff_smoke_test(callable_obj: Callable[Ellipsis, Any], state: Any, *, where: str) None
NEOPAX._model_api._jax_vmap_smoke_test(callable_obj: Callable[Ellipsis, Any], state: Any, *, where: str) None
NEOPAX._model_api.validate_transport_flux_builder(builder: Callable[Ellipsis, Any], context: ModelValidationContext, *, capabilities: ModelCapabilities | None = None, name: str = 'transport model') None
NEOPAX._model_api.validate_source_model_builder(builder: Callable[Ellipsis, Any], context: ModelValidationContext, *, name: str = 'source model') None
NEOPAX._model_api.transport_model(name: str, registry_fn: Callable[Ellipsis, None], **register_kwargs)
NEOPAX._model_api.source_model(name: str, registry_fn: Callable[Ellipsis, None], **register_kwargs)