NEOPAX._transport_solvers

Modular solver interface and backends for NEOPAX transport equations. Supports multiple solver backends: time integration, root-finding, and optimization. Inspired by torax (https://github.com/google-deepmind/torax).

Attributes

TIME_SOLVER_REGISTRY

ODE_SOLVER_BACKENDS

_RADAU_STAGE_CONFIGS

Classes

_RadauStageConfig

TransportSolver

Base class for transport solvers.

DiffraxSolver

Base class for transport solvers.

_RadauSolverConfig

Base class for transport solvers.

_RadauStepState

_RadauStepInfo

RADAUSolver

Base class for transport solvers.

_ThetaSolverConfig

Base class for transport solvers.

_ThetaNewtonSolverConfig

Base class for transport solvers.

_ThetaStepState

_ThetaStepInfo

ThetaMethodSolver

Base class for transport solvers.

NewtonThetaMethodSolver

Base class for transport solvers.

Functions

_build_real_block_transform(→ tuple[numpy.ndarray, ...)

Return real-valued stage transform data for an odd-stage Radau tableau.

_build_radau_iia_stage_config(→ _RadauStageConfig)

Build fixed-stage Radau IIA coefficients and transformed solve data.

register_time_solver(→ None)

_expects_time_argument(→ bool)

Return True when vector_field signature starts with a time argument.

_as_state_residual(→ Callable)

Adapt a (t, y, ...) vector field to a pure state residual f(y, ...).

_select_solver_family_and_backend(→ tuple[str, str])

Pick the active Diffrax backend while preserving legacy integrator-only configs.

_extract_species_from_args(→ Any)

_save_state_series(→ Any)

_save_scalar_series(→ jax.Array)

_strip_state_metadata_for_solver(→ Any)

Hook for state cleanup before solver calls.

_electron_density_index(→ int | None)

_extract_fixed_temperature_projection(→ tuple[Any, Any])

_extract_state_regularization(→ tuple[Any, Any])

_pack_transport_state_arrays(→ Any)

Convert a TransportState-like object into an array-only pytree.

_unpack_transport_state_arrays(→ Any)

Rebuild a TransportState-like object from an array-only pytree.

_restore_state_metadata(→ Any)

Hook for restoring state metadata after solver calls.

_project_fixed_temperature_output(→ Any)

_apply_quasi_neutrality_output(→ Any)

Apply quasi-neutrality to either a single state or a saved time-series of states.

_project_state_to_quasi_neutrality(→ Any)

_project_packed_transport_state_arrays(→ Any)

Project packed solver arrays without rebuilding a full TransportState.

_project_flat_state_if_needed(→ jax.Array)

_make_solver_state_transform(template_state, species)

_get_diffrax_integrator(→ Callable)

_finalize_custom_solver_output(ys_saved_flat, ...[, ...])

_fill_saved_slots(save_idx, save_times, t_value, ...)

_flat_rhs_factory(unravel, vector_field, args, kwargs)

_lagged_response_hooks(vector_field)

_flat_rhs_with_lagged_response_factory(unravel, ...[, ...])

_solver_error_norm(err_vec, flat_ref, flat_candidate, ...)

_lagged_response_global_reuse_metric(current_flat, ...)

_make_radau_initial_step_state(t0, flat_state0, ...)

_custom_loop_active(step_state, t_final, step_idx, ...)

_accepted_step_limit_reached(step_state, ...)

_run_saved_loop(*, step_state0, step_fn, save_n, t0, ...)

_run_saved_loop_debug_walltime(*, step_state0, ...[, ...])

_make_radau_stage_predictor(f0, prev_stages, prev_dt, ...)

_apply_radau_lean_timestep_controller(*, step_state, ...)

build_time_solver(→ TransportSolver)

Create a time solver backend from runtime parameters/config.

Module Contents

NEOPAX._transport_solvers.TIME_SOLVER_REGISTRY: dict[str, Callable[Ellipsis, TransportSolver]]
NEOPAX._transport_solvers.ODE_SOLVER_BACKENDS
NEOPAX._transport_solvers._build_real_block_transform(a: numpy.ndarray) tuple[numpy.ndarray, numpy.ndarray, float, numpy.ndarray]

Return real-valued stage transform data for an odd-stage Radau tableau.

class NEOPAX._transport_solvers._RadauStageConfig
num_stages: int
order: int
embedded_order: int
c: numpy.ndarray
a: numpy.ndarray
b: numpy.ndarray
b_error: numpy.ndarray
embedded_f0_weight: float
has_embedded_estimator: bool
transform: numpy.ndarray
inv_transform: numpy.ndarray
real_eig: float
complex_blocks: numpy.ndarray
NEOPAX._transport_solvers._build_radau_iia_stage_config(num_stages: int) _RadauStageConfig

Build fixed-stage Radau IIA coefficients and transformed solve data.

NEOPAX._transport_solvers._RADAU_STAGE_CONFIGS
NEOPAX._transport_solvers.register_time_solver(name: str, builder: Callable[Ellipsis, TransportSolver]) None
NEOPAX._transport_solvers._expects_time_argument(vector_field: Callable) bool

Return True when vector_field signature starts with a time argument.

NEOPAX._transport_solvers._as_state_residual(vector_field: Callable) Callable

Adapt a (t, y, …) vector field to a pure state residual f(y, …).

NEOPAX._transport_solvers._select_solver_family_and_backend(solver_parameters: Any) tuple[str, str]

Pick the active Diffrax backend while preserving legacy integrator-only configs.

NEOPAX._transport_solvers._extract_species_from_args(args: tuple[Any, Ellipsis]) Any
NEOPAX._transport_solvers._save_state_series(state: Any) Any
NEOPAX._transport_solvers._save_scalar_series(value: Any) jax.Array
NEOPAX._transport_solvers._strip_state_metadata_for_solver(state: Any) Any

Hook for state cleanup before solver calls.

NEOPAX._transport_solvers._electron_density_index(species: Any) int | None
NEOPAX._transport_solvers._extract_fixed_temperature_projection(vector_field: Callable) tuple[Any, Any]
NEOPAX._transport_solvers._extract_state_regularization(vector_field: Callable) tuple[Any, Any]
NEOPAX._transport_solvers._pack_transport_state_arrays(state: Any, species: Any = None) Any

Convert a TransportState-like object into an array-only pytree.

NEOPAX._transport_solvers._unpack_transport_state_arrays(state_like: Any, template_state: Any, species: Any = None, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None) Any

Rebuild a TransportState-like object from an array-only pytree.

NEOPAX._transport_solvers._restore_state_metadata(state_like: Any, template_state: Any) Any

Hook for restoring state metadata after solver calls.

NEOPAX._transport_solvers._project_fixed_temperature_output(state_like: Any, reference_state: Any, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None) Any
NEOPAX._transport_solvers._apply_quasi_neutrality_output(state_like: Any, species: Any, reference_state: Any, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None) Any

Apply quasi-neutrality to either a single state or a saved time-series of states.

NEOPAX._transport_solvers._project_state_to_quasi_neutrality(state_like: Any, species: Any, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None) Any
NEOPAX._transport_solvers._project_packed_transport_state_arrays(state_like: Any, template_state: Any, species: Any, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None) Any

Project packed solver arrays without rebuilding a full TransportState.

NEOPAX._transport_solvers._project_flat_state_if_needed(flat_y: jax.Array, project_flat: Callable[[jax.Array], jax.Array] | None) jax.Array
NEOPAX._transport_solvers._make_solver_state_transform(template_state: Any, species: Any, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None)
NEOPAX._transport_solvers._get_diffrax_integrator(name: str) Callable
class NEOPAX._transport_solvers.TransportSolver

Base class for transport solvers. All solvers must implement the solve() method.

abstractmethod solve(state, vector_field: Callable, *args, **kwargs) Any
class NEOPAX._transport_solvers.DiffraxSolver(integrator, t0, t1, dt, save_n=None, **integrator_kwargs)

Bases: TransportSolver

Base class for transport solvers. All solvers must implement the solve() method.

integrator: Callable
t0: float
t1: float
dt: float
integrator_kwargs: dict[str, Any]
solve(state, vector_field: Callable, *args, **kwargs)
NEOPAX._transport_solvers._finalize_custom_solver_output(ys_saved_flat, ts_saved, dts_saved, accepted_mask_saved, failed_mask_saved, fail_codes_saved, y_final_flat, t_final, done_f, failed_f, fail_code_f, n_steps_f, last_attempt_accepted, last_attempt_converged, last_attempt_err_norm, last_attempt_fail_code, last_attempt_diverged, last_attempt_nonfinite_stage_state, last_attempt_nonfinite_stage_residual, last_attempt_finite_f0, last_attempt_finite_z0, last_attempt_finite_initial_residual, last_attempt_newton_iter_count, last_attempt_final_residual_norm, last_attempt_final_delta_norm, last_attempt_theta_final, last_attempt_slow_contraction, last_attempt_residual_blowup, last_attempt_newton_nonfinite, unpack_flat, reference_state, species, temperature_active_mask=None, fixed_temperature_profile=None, density_floor=None, temperature_floor=None)
NEOPAX._transport_solvers._fill_saved_slots(save_idx, save_times, t_value, flat_y, dt_value, accepted, failed, fail_code, ys, ts, dts, accs, fails, codes)
NEOPAX._transport_solvers._flat_rhs_factory(unravel, vector_field, args, kwargs, project_flat=None)
NEOPAX._transport_solvers._lagged_response_hooks(vector_field: Callable)
NEOPAX._transport_solvers._flat_rhs_with_lagged_response_factory(unravel, vector_field, args, kwargs, project_flat=None)
NEOPAX._transport_solvers._solver_error_norm(err_vec, flat_ref, flat_candidate, atol: float, rtol: float, scale_mode: str = 'max', rtol_eff=None, scale_override=None)
NEOPAX._transport_solvers._lagged_response_global_reuse_metric(current_flat, reference_flat, atol: float, rtol: float)
NEOPAX._transport_solvers._make_radau_initial_step_state(t0, flat_state0, base_dt, dtype, initial_rhs, num_stages, real_lu0, real_piv0, complex_lu0, complex_piv0, lagged_response_cache, lagged_response_valid, lagged_reference_y)
NEOPAX._transport_solvers._custom_loop_active(step_state, t_final, step_idx, max_total_steps)
NEOPAX._transport_solvers._accepted_step_limit_reached(step_state, stop_after_accepted_steps)
NEOPAX._transport_solvers._run_saved_loop(*, step_state0, step_fn, save_n, t0, t_final, state_dim, dtype, max_total_steps, stop_after_accepted_steps=None)
NEOPAX._transport_solvers._run_saved_loop_debug_walltime(*, step_state0, step_fn, save_n, t0, t_final, state_dim, dtype, max_total_steps, stop_after_accepted_steps=None, walltime_label='solver.attempt')
NEOPAX._transport_solvers._make_radau_stage_predictor(f0, prev_stages, prev_dt, h_value, c, dtype, density_size=0, pressure_size=0, er_size=0, prev_theta_final=None, prev_newton_iter_count=None, predictor_mode='current')
NEOPAX._transport_solvers._apply_radau_lean_timestep_controller(*, step_state, trial_dt, trial_y, err_norm, density_err_norm, pressure_err_norm, er_err_norm, converged, stage_history, jacobian_out, cache_valid_out, cache_dt_out, cache_age_out, real_lu_out, real_piv_out, complex_lu_out, complex_piv_out, newton_shrink, diverged_final, nonfinite_stage_state, nonfinite_stage_residual, finite_f0, finite_z0, finite_initial_residual, newton_iter_count, final_residual_norm, final_delta_norm, theta_final, slow_contraction, residual_blowup, newton_nonfinite, lagged_reused, jacobian_reused, fail_code, n_accepted, dtype, dt_min, dt_max, safety_factor, controller_alpha, min_step_factor, max_step_factor, controller_mode, use_transport_lagged_response, lagged_response_reuse_mode, lagged_response_reuse_rtol, lagged_response_reuse_atol, project_flat)
class NEOPAX._transport_solvers._RadauSolverConfig(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, rtol: float = 1e-06, atol: float = 1e-08, max_step: float = 1.0, min_step: float = 1e-14, tol: float = 1e-08, maxiter: int = 20, error_estimator: str = 'embedded2', num_stages: int = 3, safety_factor: float = 0.9, min_step_factor: float = 0.1, max_step_factor: float = 5.0, jacobian_reuse_rtol: float = 0.1, max_jacobian_age: int = 8, rhs_mode: str = 'black_box', newton_divergence_mode: str = 'legacy', newton_residual_norm: str = 'raw', newton_tol_mode: str = 'residual', newton_fnewt_mode: str = 'tol', controller_mode: str = 'current', predictor_mode: str = 'current', lagged_response_reuse_mode: str = 'retry_only', lagged_response_reuse_rtol: float = 0.05, lagged_response_reuse_atol: float = 1e-08, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, debug_stage_markers: bool = False, debug_walltime_attempts: bool = False, save_n=None)

Bases: TransportSolver

Base class for transport solvers. All solvers must implement the solve() method.

t0: float
t1: float
dt: float
rtol: float = 1e-06
atol: float = 1e-08
max_step: float = 1.0
min_step: float = 1e-14
tol: float = 1e-08
maxiter: int = 20
error_estimator: str = 'embedded2'
num_stages: int = 3
safety_factor: float = 0.9
min_step_factor: float = 0.1
max_step_factor: float = 5.0
jacobian_reuse_rtol: float = 0.1
max_jacobian_age: int = 8
rhs_mode: str = 'black_box'
newton_divergence_mode: str = 'legacy'
newton_residual_norm: str = 'raw'
newton_tol_mode: str = 'residual'
newton_fnewt_mode: str = 'tol'
controller_mode: str = 'current'
predictor_mode: str = 'current'
lagged_response_reuse_mode: str = 'retry_only'
lagged_response_reuse_rtol: float = 0.05
lagged_response_reuse_atol: float = 1e-08
max_steps: int = 20000
stop_after_accepted_steps: int | None = None
n_steps: int = 0
debug_stage_markers: bool = False
debug_walltime_attempts: bool = False
class NEOPAX._transport_solvers._RadauStepState
t: Any
y: Any
dt: Any
status: Any
prev_error: Any
prev_stages: Any
prev_dt: Any
recent_reject_count: Any
regrowth_cooldown: Any
easy_growth_streak: Any
lagged_response_cache: Any
lagged_response_valid: Any
lagged_reference_y: Any
jacobian: Any
cache_valid: Any
cache_dt: Any
cache_age: Any
real_lu: Any
real_piv: Any
complex_lu: Any
complex_piv: Any
prev_theta_final: Any
prev_newton_iter_count: Any
class NEOPAX._transport_solvers._RadauStepInfo
y: Any
t: Any
dt: Any
next_dt: Any = None
growth: Any = None
lagged_reused: Any = None
jacobian_reused: Any = None
accepted: Any = None
failed: Any = None
fail_code: Any = None
converged: Any = None
err_norm: Any = None
diverged: Any = None
nonfinite_stage_state: Any = None
nonfinite_stage_residual: Any = None
finite_f0: Any = None
finite_z0: Any = None
finite_initial_residual: Any = None
newton_iter_count: Any = None
final_residual_norm: Any = None
final_delta_norm: Any = None
theta_final: Any = None
slow_contraction: Any = None
residual_blowup: Any = None
newton_nonfinite: Any = None
class NEOPAX._transport_solvers.RADAUSolver(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, rtol: float = 1e-06, atol: float = 1e-08, max_step: float = 1.0, min_step: float = 1e-14, tol: float = 1e-08, maxiter: int = 20, error_estimator: str = 'embedded2', num_stages: int = 3, safety_factor: float = 0.9, min_step_factor: float = 0.1, max_step_factor: float = 5.0, jacobian_reuse_rtol: float = 0.1, max_jacobian_age: int = 8, rhs_mode: str = 'black_box', newton_divergence_mode: str = 'legacy', newton_residual_norm: str = 'raw', newton_tol_mode: str = 'residual', newton_fnewt_mode: str = 'tol', controller_mode: str = 'current', predictor_mode: str = 'current', lagged_response_reuse_mode: str = 'retry_only', lagged_response_reuse_rtol: float = 0.05, lagged_response_reuse_atol: float = 1e-08, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, debug_stage_markers: bool = False, debug_walltime_attempts: bool = False, save_n=None)

Bases: _RadauSolverConfig

Base class for transport solvers. All solvers must implement the solve() method.

solve(state, vector_field: Callable, *args, **kwargs)
class NEOPAX._transport_solvers._ThetaSolverConfig(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, min_step: float = 1e-14, theta_implicit: float = 1.0, predictor_mode: str = 'linearized', rhs_mode: str = 'black_box', use_predictor_corrector: bool = False, n_corrector_steps: int = 1, tol: float = 1e-08, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, save_n=None)

Bases: TransportSolver

Base class for transport solvers. All solvers must implement the solve() method.

t0: float
t1: float
dt: float
min_step: float = 1e-14
theta_implicit: float = 1.0
predictor_mode: str = 'linearized'
rhs_mode: str = 'black_box'
use_predictor_corrector: bool = False
n_corrector_steps: int = 1
tol: float = 1e-08
max_steps: int = 20000
stop_after_accepted_steps: int | None = None
n_steps: int = 0
class NEOPAX._transport_solvers._ThetaNewtonSolverConfig(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, min_step: float = 1e-14, theta_implicit: float = 1.0, predictor_mode: str = 'linearized', rhs_mode: str = 'black_box', use_predictor_corrector: bool = False, n_corrector_steps: int = 1, tol: float = 1e-08, maxiter: int = 20, max_step: float | None = None, safety_factor: float = 0.9, min_step_factor: float = 0.5, max_step_factor: float = 2.0, target_nonlinear_iterations: int = 4, delta_reduction_factor: float = 0.5, tau_min: float = 0.01, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, save_n=None)

Bases: _ThetaSolverConfig

Base class for transport solvers. All solvers must implement the solve() method.

maxiter: int = 20
max_step: float = 1.0
safety_factor: float = 0.9
min_step_factor: float = 0.5
max_step_factor: float = 2.0
target_nonlinear_iterations: int = 4
delta_reduction_factor: float = 0.5
tau_min: float = 0.01
class NEOPAX._transport_solvers._ThetaStepState
t: Any
y: Any
dt: Any
status: Any
class NEOPAX._transport_solvers._ThetaStepInfo
y: Any
t: Any
dt: Any
accepted: Any
failed: Any
fail_code: Any
class NEOPAX._transport_solvers.ThetaMethodSolver(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, min_step: float = 1e-14, theta_implicit: float = 1.0, predictor_mode: str = 'linearized', rhs_mode: str = 'black_box', use_predictor_corrector: bool = False, n_corrector_steps: int = 1, tol: float = 1e-08, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, save_n=None)

Bases: _ThetaSolverConfig

Base class for transport solvers. All solvers must implement the solve() method.

solve(state, vector_field: Callable, *args, **kwargs)
class NEOPAX._transport_solvers.NewtonThetaMethodSolver(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, min_step: float = 1e-14, theta_implicit: float = 1.0, predictor_mode: str = 'linearized', rhs_mode: str = 'black_box', use_predictor_corrector: bool = False, n_corrector_steps: int = 1, tol: float = 1e-08, maxiter: int = 20, max_step: float | None = None, safety_factor: float = 0.9, min_step_factor: float = 0.5, max_step_factor: float = 2.0, target_nonlinear_iterations: int = 4, delta_reduction_factor: float = 0.5, tau_min: float = 0.01, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, save_n=None)

Bases: _ThetaNewtonSolverConfig

Base class for transport solvers. All solvers must implement the solve() method.

solve(state, vector_field: Callable, *args, **kwargs)
NEOPAX._transport_solvers.build_time_solver(solver_parameters: Any, solver_override: Any = None) TransportSolver

Create a time solver backend from runtime parameters/config.

solver_override can be either:
  • an instance with .solve(…) (used directly), or

  • a diffrax solver instance (wrapped in DiffraxSolver).