NEOPAX._transport_solvers¶
Modular solver interface and backends for NEOPAX transport equations. Supports multiple solver backends: time integration, root-finding, and optimization. Inspired by torax (https://github.com/google-deepmind/torax).
Attributes¶
Classes¶
Base class for transport solvers. |
|
Base class for transport solvers. |
|
Base class for transport solvers. |
|
Base class for transport solvers. |
|
Base class for transport solvers. |
|
Base class for transport solvers. |
|
Base class for transport solvers. |
|
Base class for transport solvers. |
Functions¶
|
Return real-valued stage transform data for an odd-stage Radau tableau. |
|
Build fixed-stage Radau IIA coefficients and transformed solve data. |
|
|
|
Return True when vector_field signature starts with a time argument. |
|
Adapt a (t, y, ...) vector field to a pure state residual f(y, ...). |
|
Pick the active Diffrax backend while preserving legacy integrator-only configs. |
|
|
|
|
|
|
Hook for state cleanup before solver calls. |
|
|
|
|
|
|
|
|
Convert a TransportState-like object into an array-only pytree. |
Rebuild a TransportState-like object from an array-only pytree. |
|
|
Hook for restoring state metadata after solver calls. |
Apply quasi-neutrality to either a single state or a saved time-series of states. |
|
Project packed solver arrays without rebuilding a full TransportState. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Create a time solver backend from runtime parameters/config. |
Module Contents¶
- NEOPAX._transport_solvers.TIME_SOLVER_REGISTRY: dict[str, Callable[Ellipsis, TransportSolver]]¶
- NEOPAX._transport_solvers.ODE_SOLVER_BACKENDS¶
- NEOPAX._transport_solvers._build_real_block_transform(a: numpy.ndarray) tuple[numpy.ndarray, numpy.ndarray, float, numpy.ndarray]¶
Return real-valued stage transform data for an odd-stage Radau tableau.
- class NEOPAX._transport_solvers._RadauStageConfig¶
- num_stages: int¶
- order: int¶
- embedded_order: int¶
- c: numpy.ndarray¶
- a: numpy.ndarray¶
- b: numpy.ndarray¶
- b_error: numpy.ndarray¶
- embedded_f0_weight: float¶
- has_embedded_estimator: bool¶
- transform: numpy.ndarray¶
- inv_transform: numpy.ndarray¶
- real_eig: float¶
- complex_blocks: numpy.ndarray¶
- NEOPAX._transport_solvers._build_radau_iia_stage_config(num_stages: int) _RadauStageConfig¶
Build fixed-stage Radau IIA coefficients and transformed solve data.
- NEOPAX._transport_solvers._RADAU_STAGE_CONFIGS¶
- NEOPAX._transport_solvers.register_time_solver(name: str, builder: Callable[Ellipsis, TransportSolver]) None¶
- NEOPAX._transport_solvers._expects_time_argument(vector_field: Callable) bool¶
Return True when vector_field signature starts with a time argument.
- NEOPAX._transport_solvers._as_state_residual(vector_field: Callable) Callable¶
Adapt a (t, y, …) vector field to a pure state residual f(y, …).
- NEOPAX._transport_solvers._select_solver_family_and_backend(solver_parameters: Any) tuple[str, str]¶
Pick the active Diffrax backend while preserving legacy integrator-only configs.
- NEOPAX._transport_solvers._extract_species_from_args(args: tuple[Any, Ellipsis]) Any¶
- NEOPAX._transport_solvers._save_state_series(state: Any) Any¶
- NEOPAX._transport_solvers._save_scalar_series(value: Any) jax.Array¶
- NEOPAX._transport_solvers._strip_state_metadata_for_solver(state: Any) Any¶
Hook for state cleanup before solver calls.
- NEOPAX._transport_solvers._electron_density_index(species: Any) int | None¶
- NEOPAX._transport_solvers._extract_fixed_temperature_projection(vector_field: Callable) tuple[Any, Any]¶
- NEOPAX._transport_solvers._extract_state_regularization(vector_field: Callable) tuple[Any, Any]¶
- NEOPAX._transport_solvers._pack_transport_state_arrays(state: Any, species: Any = None) Any¶
Convert a TransportState-like object into an array-only pytree.
- NEOPAX._transport_solvers._unpack_transport_state_arrays(state_like: Any, template_state: Any, species: Any = None, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None) Any¶
Rebuild a TransportState-like object from an array-only pytree.
- NEOPAX._transport_solvers._restore_state_metadata(state_like: Any, template_state: Any) Any¶
Hook for restoring state metadata after solver calls.
- NEOPAX._transport_solvers._project_fixed_temperature_output(state_like: Any, reference_state: Any, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None) Any¶
- NEOPAX._transport_solvers._apply_quasi_neutrality_output(state_like: Any, species: Any, reference_state: Any, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None) Any¶
Apply quasi-neutrality to either a single state or a saved time-series of states.
- NEOPAX._transport_solvers._project_state_to_quasi_neutrality(state_like: Any, species: Any, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None) Any¶
- NEOPAX._transport_solvers._project_packed_transport_state_arrays(state_like: Any, template_state: Any, species: Any, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None) Any¶
Project packed solver arrays without rebuilding a full TransportState.
- NEOPAX._transport_solvers._project_flat_state_if_needed(flat_y: jax.Array, project_flat: Callable[[jax.Array], jax.Array] | None) jax.Array¶
- NEOPAX._transport_solvers._make_solver_state_transform(template_state: Any, species: Any, temperature_active_mask: Any = None, fixed_temperature_profile: Any = None, density_floor: Any = None, temperature_floor: Any = None)¶
- NEOPAX._transport_solvers._get_diffrax_integrator(name: str) Callable¶
- class NEOPAX._transport_solvers.TransportSolver¶
Base class for transport solvers. All solvers must implement the solve() method.
- abstractmethod solve(state, vector_field: Callable, *args, **kwargs) Any¶
- class NEOPAX._transport_solvers.DiffraxSolver(integrator, t0, t1, dt, save_n=None, **integrator_kwargs)¶
Bases:
TransportSolverBase class for transport solvers. All solvers must implement the solve() method.
- integrator: Callable¶
- t0: float¶
- t1: float¶
- dt: float¶
- integrator_kwargs: dict[str, Any]¶
- solve(state, vector_field: Callable, *args, **kwargs)¶
- NEOPAX._transport_solvers._finalize_custom_solver_output(ys_saved_flat, ts_saved, dts_saved, accepted_mask_saved, failed_mask_saved, fail_codes_saved, y_final_flat, t_final, done_f, failed_f, fail_code_f, n_steps_f, last_attempt_accepted, last_attempt_converged, last_attempt_err_norm, last_attempt_fail_code, last_attempt_diverged, last_attempt_nonfinite_stage_state, last_attempt_nonfinite_stage_residual, last_attempt_finite_f0, last_attempt_finite_z0, last_attempt_finite_initial_residual, last_attempt_newton_iter_count, last_attempt_final_residual_norm, last_attempt_final_delta_norm, last_attempt_theta_final, last_attempt_slow_contraction, last_attempt_residual_blowup, last_attempt_newton_nonfinite, unpack_flat, reference_state, species, temperature_active_mask=None, fixed_temperature_profile=None, density_floor=None, temperature_floor=None)¶
- NEOPAX._transport_solvers._fill_saved_slots(save_idx, save_times, t_value, flat_y, dt_value, accepted, failed, fail_code, ys, ts, dts, accs, fails, codes)¶
- NEOPAX._transport_solvers._flat_rhs_factory(unravel, vector_field, args, kwargs, project_flat=None)¶
- NEOPAX._transport_solvers._lagged_response_hooks(vector_field: Callable)¶
- NEOPAX._transport_solvers._flat_rhs_with_lagged_response_factory(unravel, vector_field, args, kwargs, project_flat=None)¶
- NEOPAX._transport_solvers._solver_error_norm(err_vec, flat_ref, flat_candidate, atol: float, rtol: float, scale_mode: str = 'max', rtol_eff=None, scale_override=None)¶
- NEOPAX._transport_solvers._lagged_response_global_reuse_metric(current_flat, reference_flat, atol: float, rtol: float)¶
- NEOPAX._transport_solvers._make_radau_initial_step_state(t0, flat_state0, base_dt, dtype, initial_rhs, num_stages, real_lu0, real_piv0, complex_lu0, complex_piv0, lagged_response_cache, lagged_response_valid, lagged_reference_y)¶
- NEOPAX._transport_solvers._custom_loop_active(step_state, t_final, step_idx, max_total_steps)¶
- NEOPAX._transport_solvers._accepted_step_limit_reached(step_state, stop_after_accepted_steps)¶
- NEOPAX._transport_solvers._run_saved_loop(*, step_state0, step_fn, save_n, t0, t_final, state_dim, dtype, max_total_steps, stop_after_accepted_steps=None)¶
- NEOPAX._transport_solvers._run_saved_loop_debug_walltime(*, step_state0, step_fn, save_n, t0, t_final, state_dim, dtype, max_total_steps, stop_after_accepted_steps=None, walltime_label='solver.attempt')¶
- NEOPAX._transport_solvers._make_radau_stage_predictor(f0, prev_stages, prev_dt, h_value, c, dtype, density_size=0, pressure_size=0, er_size=0, prev_theta_final=None, prev_newton_iter_count=None, predictor_mode='current')¶
- NEOPAX._transport_solvers._apply_radau_lean_timestep_controller(*, step_state, trial_dt, trial_y, err_norm, density_err_norm, pressure_err_norm, er_err_norm, converged, stage_history, jacobian_out, cache_valid_out, cache_dt_out, cache_age_out, real_lu_out, real_piv_out, complex_lu_out, complex_piv_out, newton_shrink, diverged_final, nonfinite_stage_state, nonfinite_stage_residual, finite_f0, finite_z0, finite_initial_residual, newton_iter_count, final_residual_norm, final_delta_norm, theta_final, slow_contraction, residual_blowup, newton_nonfinite, lagged_reused, jacobian_reused, fail_code, n_accepted, dtype, dt_min, dt_max, safety_factor, controller_alpha, min_step_factor, max_step_factor, controller_mode, use_transport_lagged_response, lagged_response_reuse_mode, lagged_response_reuse_rtol, lagged_response_reuse_atol, project_flat)¶
- class NEOPAX._transport_solvers._RadauSolverConfig(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, rtol: float = 1e-06, atol: float = 1e-08, max_step: float = 1.0, min_step: float = 1e-14, tol: float = 1e-08, maxiter: int = 20, error_estimator: str = 'embedded2', num_stages: int = 3, safety_factor: float = 0.9, min_step_factor: float = 0.1, max_step_factor: float = 5.0, jacobian_reuse_rtol: float = 0.1, max_jacobian_age: int = 8, rhs_mode: str = 'black_box', newton_divergence_mode: str = 'legacy', newton_residual_norm: str = 'raw', newton_tol_mode: str = 'residual', newton_fnewt_mode: str = 'tol', controller_mode: str = 'current', predictor_mode: str = 'current', lagged_response_reuse_mode: str = 'retry_only', lagged_response_reuse_rtol: float = 0.05, lagged_response_reuse_atol: float = 1e-08, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, debug_stage_markers: bool = False, debug_walltime_attempts: bool = False, save_n=None)¶
Bases:
TransportSolverBase class for transport solvers. All solvers must implement the solve() method.
- t0: float¶
- t1: float¶
- dt: float¶
- rtol: float = 1e-06¶
- atol: float = 1e-08¶
- max_step: float = 1.0¶
- min_step: float = 1e-14¶
- tol: float = 1e-08¶
- maxiter: int = 20¶
- error_estimator: str = 'embedded2'¶
- num_stages: int = 3¶
- safety_factor: float = 0.9¶
- min_step_factor: float = 0.1¶
- max_step_factor: float = 5.0¶
- jacobian_reuse_rtol: float = 0.1¶
- max_jacobian_age: int = 8¶
- rhs_mode: str = 'black_box'¶
- newton_divergence_mode: str = 'legacy'¶
- newton_residual_norm: str = 'raw'¶
- newton_tol_mode: str = 'residual'¶
- newton_fnewt_mode: str = 'tol'¶
- controller_mode: str = 'current'¶
- predictor_mode: str = 'current'¶
- lagged_response_reuse_mode: str = 'retry_only'¶
- lagged_response_reuse_rtol: float = 0.05¶
- lagged_response_reuse_atol: float = 1e-08¶
- max_steps: int = 20000¶
- stop_after_accepted_steps: int | None = None¶
- n_steps: int = 0¶
- debug_stage_markers: bool = False¶
- debug_walltime_attempts: bool = False¶
- class NEOPAX._transport_solvers._RadauStepState¶
- t: Any¶
- y: Any¶
- dt: Any¶
- status: Any¶
- prev_error: Any¶
- prev_stages: Any¶
- prev_dt: Any¶
- recent_reject_count: Any¶
- regrowth_cooldown: Any¶
- easy_growth_streak: Any¶
- lagged_response_cache: Any¶
- lagged_response_valid: Any¶
- lagged_reference_y: Any¶
- jacobian: Any¶
- cache_valid: Any¶
- cache_dt: Any¶
- cache_age: Any¶
- real_lu: Any¶
- real_piv: Any¶
- complex_lu: Any¶
- complex_piv: Any¶
- prev_theta_final: Any¶
- prev_newton_iter_count: Any¶
- class NEOPAX._transport_solvers._RadauStepInfo¶
- y: Any¶
- t: Any¶
- dt: Any¶
- next_dt: Any = None¶
- growth: Any = None¶
- lagged_reused: Any = None¶
- jacobian_reused: Any = None¶
- accepted: Any = None¶
- failed: Any = None¶
- fail_code: Any = None¶
- converged: Any = None¶
- err_norm: Any = None¶
- diverged: Any = None¶
- nonfinite_stage_state: Any = None¶
- nonfinite_stage_residual: Any = None¶
- finite_f0: Any = None¶
- finite_z0: Any = None¶
- finite_initial_residual: Any = None¶
- newton_iter_count: Any = None¶
- final_residual_norm: Any = None¶
- final_delta_norm: Any = None¶
- theta_final: Any = None¶
- slow_contraction: Any = None¶
- residual_blowup: Any = None¶
- newton_nonfinite: Any = None¶
- class NEOPAX._transport_solvers.RADAUSolver(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, rtol: float = 1e-06, atol: float = 1e-08, max_step: float = 1.0, min_step: float = 1e-14, tol: float = 1e-08, maxiter: int = 20, error_estimator: str = 'embedded2', num_stages: int = 3, safety_factor: float = 0.9, min_step_factor: float = 0.1, max_step_factor: float = 5.0, jacobian_reuse_rtol: float = 0.1, max_jacobian_age: int = 8, rhs_mode: str = 'black_box', newton_divergence_mode: str = 'legacy', newton_residual_norm: str = 'raw', newton_tol_mode: str = 'residual', newton_fnewt_mode: str = 'tol', controller_mode: str = 'current', predictor_mode: str = 'current', lagged_response_reuse_mode: str = 'retry_only', lagged_response_reuse_rtol: float = 0.05, lagged_response_reuse_atol: float = 1e-08, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, debug_stage_markers: bool = False, debug_walltime_attempts: bool = False, save_n=None)¶
Bases:
_RadauSolverConfigBase class for transport solvers. All solvers must implement the solve() method.
- solve(state, vector_field: Callable, *args, **kwargs)¶
- class NEOPAX._transport_solvers._ThetaSolverConfig(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, min_step: float = 1e-14, theta_implicit: float = 1.0, predictor_mode: str = 'linearized', rhs_mode: str = 'black_box', use_predictor_corrector: bool = False, n_corrector_steps: int = 1, tol: float = 1e-08, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, save_n=None)¶
Bases:
TransportSolverBase class for transport solvers. All solvers must implement the solve() method.
- t0: float¶
- t1: float¶
- dt: float¶
- min_step: float = 1e-14¶
- theta_implicit: float = 1.0¶
- predictor_mode: str = 'linearized'¶
- rhs_mode: str = 'black_box'¶
- use_predictor_corrector: bool = False¶
- n_corrector_steps: int = 1¶
- tol: float = 1e-08¶
- max_steps: int = 20000¶
- stop_after_accepted_steps: int | None = None¶
- n_steps: int = 0¶
- class NEOPAX._transport_solvers._ThetaNewtonSolverConfig(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, min_step: float = 1e-14, theta_implicit: float = 1.0, predictor_mode: str = 'linearized', rhs_mode: str = 'black_box', use_predictor_corrector: bool = False, n_corrector_steps: int = 1, tol: float = 1e-08, maxiter: int = 20, max_step: float | None = None, safety_factor: float = 0.9, min_step_factor: float = 0.5, max_step_factor: float = 2.0, target_nonlinear_iterations: int = 4, delta_reduction_factor: float = 0.5, tau_min: float = 0.01, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, save_n=None)¶
Bases:
_ThetaSolverConfigBase class for transport solvers. All solvers must implement the solve() method.
- maxiter: int = 20¶
- max_step: float = 1.0¶
- safety_factor: float = 0.9¶
- min_step_factor: float = 0.5¶
- max_step_factor: float = 2.0¶
- target_nonlinear_iterations: int = 4¶
- delta_reduction_factor: float = 0.5¶
- tau_min: float = 0.01¶
- class NEOPAX._transport_solvers._ThetaStepInfo¶
- y: Any¶
- t: Any¶
- dt: Any¶
- accepted: Any¶
- failed: Any¶
- fail_code: Any¶
- class NEOPAX._transport_solvers.ThetaMethodSolver(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, min_step: float = 1e-14, theta_implicit: float = 1.0, predictor_mode: str = 'linearized', rhs_mode: str = 'black_box', use_predictor_corrector: bool = False, n_corrector_steps: int = 1, tol: float = 1e-08, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, save_n=None)¶
Bases:
_ThetaSolverConfigBase class for transport solvers. All solvers must implement the solve() method.
- solve(state, vector_field: Callable, *args, **kwargs)¶
- class NEOPAX._transport_solvers.NewtonThetaMethodSolver(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, min_step: float = 1e-14, theta_implicit: float = 1.0, predictor_mode: str = 'linearized', rhs_mode: str = 'black_box', use_predictor_corrector: bool = False, n_corrector_steps: int = 1, tol: float = 1e-08, maxiter: int = 20, max_step: float | None = None, safety_factor: float = 0.9, min_step_factor: float = 0.5, max_step_factor: float = 2.0, target_nonlinear_iterations: int = 4, delta_reduction_factor: float = 0.5, tau_min: float = 0.01, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, save_n=None)¶
Bases:
_ThetaNewtonSolverConfigBase class for transport solvers. All solvers must implement the solve() method.
- solve(state, vector_field: Callable, *args, **kwargs)¶
- NEOPAX._transport_solvers.build_time_solver(solver_parameters: Any, solver_override: Any = None) TransportSolver¶
Create a time solver backend from runtime parameters/config.
- solver_override can be either:
an instance with .solve(…) (used directly), or
a diffrax solver instance (wrapped in DiffraxSolver).