NEOPAX

Public package API for NEOPAX.

The package root intentionally exposes a curated compatibility surface instead of re-exporting every internal helper via wildcard imports.

Submodules

Attributes

__version__

__version_tuple__

version

version_tuple

Classes

RunResult

Structured return object for direct API runs.

Monoenergetic

Monoenergetic database.

Solver_Parameters

Species

JAX-compatible species container for arbitrary number of species.

BoundaryConditionModel

Left/right radial BC for 1D arrays with optional species-wise values.

DirichletBC

NeumannBC

RobinBC

ModelCapabilities

ModelValidationContext

CombinedSourceModel

Base class for non-conservative source models.

TransportState

JAX-compatible transport state for arbitrary number of species.

ComposedEquationSystem

DensityEquation

Base class for transport equations. Subclasses must implement __call__.

ElectricFieldEquation

Base class for transport equations. Subclasses must implement __call__.

TemperatureEquation

Base class for transport equations. Subclasses must implement __call__.

AnalyticalTurbulentTransportModel

Abstract base class for transport flux models.

CombinedTransportFluxModel

Abstract base class for transport flux models.

FluxesRFileTransportModel

Abstract base class for transport flux models.

NTXExactLijRuntimeSupport

NTXExactLijRuntimeTransportModel

Abstract base class for transport flux models.

NTXRuntimeScanChannels

NTXRuntimeScanTransportModel

Abstract base class for transport flux models.

PowerAnalyticalTurbulentTransportModel

Abstract base class for transport flux models.

ZeroTransportModel

Abstract base class for transport flux models.

DiffraxSolver

Base class for transport solvers.

NewtonThetaMethodSolver

Base class for transport solvers.

RADAUSolver

Base class for transport solvers.

ThetaMethodSolver

Base class for transport solvers.

Functions

prepare_config(→ dict[str, Any])

Load and override a NEOPAX config without executing it.

run(→ RunResult)

Convenience entry point for direct NEOPAX execution.

get_Neoclassical_Fluxes(species, energy_grid, ...[, ...])

get_Neoclassical_Fluxes_Faces(species, energy_grid, ...)

get_Neoclassical_Fluxes_With_Momentum_Correction(...)

make_validation_context(→ ModelValidationContext)

Build a small default validation context for user model registration.

build_source_models_from_config(→ dict[str, ...)

Build density/temperature source callables from TOML-style config.

get_source_model(→ SourceModelBase)

register_source_model(→ None)

source_model(name, **register_kwargs)

get_Turbulent_Fluxes_Analytical(species, grid, ...[, ...])

Analytical diffusive turbulent flux model.

get_Turbulent_Fluxes_PowerOverN(species, ...[, ...])

Analytical power-scaled turbulent transport with coefficients ~ P^0.75 / N_e.

build_equation_system(config, species, field, flux_model)

Build the list of equation instances to evolve using prebuilt runtime

build_equation_system_from_config(config, species)

Backward-compatible wrapper that builds the required runtime objects from

build_ntx_exact_lij_runtime_support(...)

build_ntx_exact_lij_runtime_transport_model(species, ...)

build_ntx_runtime_scan_channels(→ NTXRuntimeScanChannels)

build_fluxes_r_file_transport_model(species, geometry, *)

build_ntx_runtime_scan_transport_model(species, ...[, ...])

build_transport_flux_model(→ CombinedTransportFluxModel)

Build the composed transport model from explicit model instances.

get_transport_flux_model_capabilities(...)

get_transport_flux_model(→ Callable[Ellipsis, ...)

register_transport_flux_model(→ None)

transport_flux_model(name, **register_kwargs)

build_time_solver(→ TransportSolver)

Create a time solver backend from runtime parameters/config.

load_config(*args, **kwargs)

run_config(*args, **kwargs)

run_config_path(*args, **kwargs)

Package Contents

class NEOPAX.RunResult

Structured return object for direct API runs.

mode: str
config: dict[str, Any]
raw_result: Any
final_state: Any = None
saved_states: Any = None
time_grid: Any = None
saved_step_sizes: Any = None
accepted_mask: Any = None
failed_mask: Any = None
fail_codes: Any = None
n_steps: Any = None
done: Any = None
failed: Any = None
fail_code: Any = None
final_time: Any = None
rho: Any = None
output_dir: pathlib.Path | None = None
NEOPAX.prepare_config(config_or_path: dict[str, Any] | str | pathlib.Path, *, mode: str | None = None, device: str | None = None, vmec_file: str | None = None, boozer_file: str | None = None, n_radial: int | None = None, n_x: int | None = None, backend: str | None = None, dt: float | None = None, t_final: float | None = None, output_dir: str | None = None, set_values: list[str] | None = None) dict[str, Any]

Load and override a NEOPAX config without executing it.

NEOPAX.run(config_or_path: dict[str, Any] | str | pathlib.Path, *, mode: str | None = None, device: str | None = None, vmec_file: str | None = None, boozer_file: str | None = None, n_radial: int | None = None, n_x: int | None = None, backend: str | None = None, dt: float | None = None, t_final: float | None = None, output_dir: str | None = None, set_values: list[str] | None = None) RunResult

Convenience entry point for direct NEOPAX execution.

This keeps the Python API explicit and usable from scripts or larger JAX workflows, while sharing the same common override mapping as the CLI.

NEOPAX.__version__: str
NEOPAX.__version_tuple__: tuple[int | str, Ellipsis]
NEOPAX.version: str
NEOPAX.version_tuple: tuple[int | str, Ellipsis]
class NEOPAX.Monoenergetic(a_b: float, rho: jaxtyping.Float[jaxtyping.Array, ...], nu_log: jaxtyping.Float[jaxtyping.Array, ...], Er_list: jaxtyping.Float[jaxtyping.Array, ...], D11_log: jaxtyping.Float[jaxtyping.Array, ...], D13: jaxtyping.Float[jaxtyping.Array, ...], D33: jaxtyping.Float[jaxtyping.Array, ...], **kwargs)

Monoenergetic database.

a_b: float
D11_lower_limit: float
Er_lower_limit: float
Er_lower_limit_log: float
low_limit_r: float
r1_lim: float
rmn2_lim: float
r1: float
r2: float
r3: float
rnm3: float
rnm2: float
rnm1: float
rho: jaxtyping.Float[jaxtyping.Array, ...]
nu_log: jaxtyping.Float[jaxtyping.Array, ...]
Er_list: jaxtyping.Float[jaxtyping.Array, ...]
D11_log: jaxtyping.Float[jaxtyping.Array, ...]
D13: jaxtyping.Float[jaxtyping.Array, ...]
D33: jaxtyping.Float[jaxtyping.Array, ...]
classmethod read_ntx(a_b, ntx_file)

Construct Field from BOOZ_XFORM file.

Parameters:

ntx_file (path-like) – Path to vmec wout file.

NEOPAX.get_Neoclassical_Fluxes(species, energy_grid, geometry, database, Er, temperature, density, density_right_constraint=None, density_right_grad_constraint=None, temperature_right_constraint=None, temperature_right_grad_constraint=None, collisionality_model='default')
NEOPAX.get_Neoclassical_Fluxes_Faces(species, energy_grid, geometry, database, Er_faces, temperature_faces, density_faces, dndr_faces, dTdr_faces, collisionality_model='default')
NEOPAX.get_Neoclassical_Fluxes_With_Momentum_Correction(species, grid, field, database, Er, temperature, density, density_right_constraint=None, density_right_grad_constraint=None, temperature_right_constraint=None, temperature_right_grad_constraint=None)
class NEOPAX.Solver_Parameters(t0=None, t_final=None, dt=None, ts_list=None, rtol=None, atol=None, momentum_correction_flag=None, DEr=None, Er_relax=None, er_mode=None, use_ap_er_preconditioner=None, evolve_Er=None, evolve_density=None, evolve_temperature=None, density_floor=None, temperature_floor=None, integrator=None, transport_solver_family=None, transport_solver_backend=None, nonlinear_solver_tol=None, nonlinear_solver_maxiter=None, anderson_history=None, theta_implicit=None, use_predictor_corrector=None, n_corrector_steps=None, theta_ptc_enabled=None, theta_ptc_dt_min_factor=None, theta_ptc_dt_max_factor=None, theta_ptc_growth=None, theta_ptc_shrink=None, theta_line_search_enabled=None, theta_line_search_contraction=None, theta_line_search_min_alpha=None, theta_line_search_c=None, theta_max_step_retries=None, theta_linear_solver=None, theta_gmres_tol=None, theta_gmres_maxiter=None, theta_trust_region_enabled=None, theta_trust_radius=None, theta_homotopy_steps=None, theta_differentiable_mode=None, er_ambipolar_scan_min=None, er_ambipolar_scan_max=None, er_ambipolar_n_scan=None, er_ambipolar_tol=None, er_ambipolar_maxiter=None, er_ambipolar_n_coarse=None, er_ambipolar_n_fine=None, er_ambipolar_method=None, neoclassical_transport_model=None, turbulent_transport_model=None, chi_temperature=None, chi_density=None, on_OmegaC=None)
momentum_correction_flag: int
integrator: str
transport_solver_family: str
transport_solver_backend: str
nonlinear_solver_tol: float
nonlinear_solver_maxiter: int
anderson_history: int
theta_implicit: float
use_predictor_corrector: bool
n_corrector_steps: int
theta_ptc_enabled: bool
theta_ptc_dt_min_factor: float
theta_ptc_dt_max_factor: float
theta_ptc_growth: float
theta_ptc_shrink: float
theta_line_search_enabled: bool
theta_line_search_contraction: float
theta_line_search_min_alpha: float
theta_line_search_c: float
theta_max_step_retries: int
theta_linear_solver: str
theta_gmres_tol: float
theta_gmres_maxiter: int
theta_trust_region_enabled: bool
theta_trust_radius: float
theta_homotopy_steps: int
theta_differentiable_mode: bool
er_ambipolar_scan_min: float
er_ambipolar_scan_max: float
er_ambipolar_n_scan: int
er_ambipolar_tol: float
er_ambipolar_maxiter: int
er_ambipolar_n_coarse: int
er_ambipolar_n_fine: int
er_ambipolar_method: str
neoclassical_transport_model: str
turbulent_transport_model: str
t0: float
t_final: float
dt: float
ts_list: jaxtyping.Float[jaxtyping.Array, ...]
rtol: float
atol: float
DEr: float
Er_relax: float
er_mode: str
use_ap_er_preconditioner: bool
on_OmegaC: float
evolve_Er: bool
evolve_density: jaxtyping.Float[jaxtyping.Array, ...]
evolve_temperature: jaxtyping.Float[jaxtyping.Array, ...]
density_floor: float | jaxtyping.Array | None
temperature_floor: float | jaxtyping.Array | None
chi_temperature: jaxtyping.Float[jaxtyping.Array, ...]
chi_density: jaxtyping.Float[jaxtyping.Array, ...]
class NEOPAX.Species

JAX-compatible species container for arbitrary number of species. All fields are JAX arrays for differentiability and vmap support.

number_species: int
species_indices: jaxtyping.Array
mass_mp: jaxtyping.Float[jaxtyping.Array, ...]
charge_qp: jaxtyping.Float[jaxtyping.Array, ...]
names: tuple[str, Ellipsis] = ()
is_frozen: jaxtyping.Array = None
property charge
property mass
property species_idx: dict

Return a mapping from species name to index.

property ion_indices: tuple

Return a tuple of all species indices except the electron (‘e’).

class NEOPAX.BoundaryConditionModel

Left/right radial BC for 1D arrays with optional species-wise values.

dr: float
left_type: str = 'dirichlet'
right_type: str = 'dirichlet'
left_value: jax.numpy.ndarray | None = None
right_value: jax.numpy.ndarray | None = None
left_gradient: jax.numpy.ndarray | None = None
right_gradient: jax.numpy.ndarray | None = None
left_decay_length: jax.numpy.ndarray | None = None
right_decay_length: jax.numpy.ndarray | None = None
reference_profile: jax.numpy.ndarray | None = None
reference_profiles: jax.numpy.ndarray | None = None
static _as_jnp_or_none(value, species_names=None)
static _pick_for_row(value, row_index: int)
_infer_gradient(arr: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]
_infer_log_gradient_coeff(boundary_value: jax.numpy.ndarray, boundary_grad: jax.numpy.ndarray) jax.numpy.ndarray
_infer_decay_length(boundary_value: jax.numpy.ndarray, boundary_grad: jax.numpy.ndarray) jax.numpy.ndarray
_apply_ghost_row(arr: jax.numpy.ndarray, row_index: int, reference_row: jax.numpy.ndarray | None) jax.numpy.ndarray
apply_ghost(arr: jax.numpy.ndarray) jax.numpy.ndarray
apply_ghost_all(arr2d: jax.numpy.ndarray) jax.numpy.ndarray
class NEOPAX.DirichletBC(axis_value, edge_value)

Bases: BoundaryCondition

axis_value
edge_value
apply(flux, state=None, axis=True, edge=True, **kwargs)
class NEOPAX.NeumannBC(grad_axis, grad_edge, dr)

Bases: BoundaryCondition

grad_axis
grad_edge
dr
apply(flux, state=None, axis=True, edge=True, **kwargs)
class NEOPAX.RobinBC(alpha_axis, beta_axis, gamma_axis, alpha_edge, beta_edge, gamma_edge, dr)

Bases: BoundaryCondition

alpha_axis
beta_axis
gamma_axis
alpha_edge
beta_edge
gamma_edge
dr
apply(flux, state, axis=True, edge=True, **kwargs)
class NEOPAX.ModelCapabilities
jit_safe: bool = True
autodiff_safe: bool = False
vmap_safe: bool = False
local_evaluator: bool = False
face_fluxes: bool = False
class NEOPAX.ModelValidationContext
builder_kwargs: dict[str, Any]
state: Any
species: Any | None = None
geometry: Any | None = None
face_state: Any | None = None
NEOPAX.make_validation_context(*, builder_kwargs: dict[str, Any] | None = None, n_species: int = 2, n_radial: int = 8, density_value: float = 1.0, temperature_value: float = 2.0, er_value: float = 0.0, species: Any | None = None, geometry: Any | None = None, include_face_state: bool = True) ModelValidationContext

Build a small default validation context for user model registration.

This helper is intentionally lightweight and avoids depending on geometry builders or larger runtime objects. It is meant for registration-time smoke tests, not for physical validation.

class NEOPAX.CombinedSourceModel

Bases: SourceModelBase

Base class for non-conservative source models.

sources: tuple[SourceModelBase, Ellipsis] = ()
classmethod from_names(names: list[str] | tuple[str, Ellipsis], *, params_cfg: dict[str, Any] | None = None, species: Any | None = None, cfg: dict[str, Any] | None = None) CombinedSourceModel
with_added_sources(*sources: SourceModelBase) CombinedSourceModel
with_replaced_sources(sources: tuple[SourceModelBase, Ellipsis]) CombinedSourceModel
__call__(state: Any)
build_lagged_response(state: Any, **kwargs)
evaluate_with_lagged_response(state: Any, lagged_response: Any, **kwargs)
NEOPAX.build_source_models_from_config(cfg: dict[str, Any], species: Any | None = None) dict[str, SourceModelBase] | None

Build density/temperature source callables from TOML-style config.

Supported schema:

[sources] density = [“name1”, “name2”] temperature = [“name3”]

[sources.parameters] name3 = {some_kw = 1.0}

NEOPAX.get_source_model(name: str, **kwargs) SourceModelBase
NEOPAX.register_source_model(name: str, builder: Callable[Ellipsis, SourceModelBase], *, validate: bool = False, validation_context: NEOPAX._model_api.ModelValidationContext | None = None) None
NEOPAX.source_model(name: str, **register_kwargs)
class NEOPAX.TransportState

JAX-compatible transport state for arbitrary number of species. All fields are JAX arrays for differentiability and vmap support.

density: jaxtyping.Float[jaxtyping.Array, ...]
pressure: jaxtyping.Float[jaxtyping.Array, ...]
Er: jaxtyping.Float[jaxtyping.Array, ...]
property temperature
NEOPAX.get_Turbulent_Fluxes_Analytical(species, grid, chi_temperature, chi_density, temperature, density, field, density_right_constraint=None, density_right_grad_constraint=None, temperature_right_constraint=None, temperature_right_grad_constraint=None)

Analytical diffusive turbulent flux model.

Per species:

Gamma_a = -chi_density[a] * d n_a / dr Q_a = -n_a * chi_temperature[a] * d T_a / dr

NEOPAX.get_Turbulent_Fluxes_PowerOverN(species, chi_temperature, chi_density, total_power_mw, temperature, density, field, density_right_constraint=None, density_right_grad_constraint=None, temperature_right_constraint=None, temperature_right_grad_constraint=None)

Analytical power-scaled turbulent transport with coefficients ~ P^0.75 / N_e.

class NEOPAX.ComposedEquationSystem
equations: tuple
density_equation: object | None = None
temperature_equation: object | None = None
er_equation: object | None = None
species: object | None = None
shared_flux_model: object | None = None
density_floor: object
temperature_floor: object
temperature_active_mask: object | None = None
fixed_temperature_profile: object | None = None
er_bc_model: object | None = None
_prepare_working_state(state)
_resolve_equations()
build_lagged_response(state)
evaluate_with_lagged_response(t, state, runtime, lagged_response)
_evaluate_state(state, lagged_response=None)
__call__(t, state, runtime)

Call all equations with state, return a TransportState matching the state structure. Always output all three fields, setting missing ones to zero arrays of the correct shape. When electrons are present, evaluate the RHS on a quasi-neutral working state, but keep electron density out of the solved density subsystem. This matches the NTSS-style pattern: evolve independent ion/impurity density rows, reconstruct electron density algebraically for the working state and accepted/output states.

vector_field(t, y, args)

Torax-style vector field for JAX ODE solvers: (t, y, args) -> dy/dt y is the state, args[0] is the runtime dict.

class NEOPAX.DensityEquation

Bases: EquationBase

Base class for transport equations. Subclasses must implement __call__.

dr_cells: jax.Array
Vprime: jax.Array
Vprime_half: jax.Array
flux_model: callable
flux_faces_builder: callable
active_species_mask: jax.Array
independent_density_mask: jax.Array
face_flux_builder: callable = None
density_bc_model: object = None
particle_flux_reconstruction: str = 'closure_face_flux'
particle_face_closure_mode: str = 'reconstructed'
source_model: callable = None
species: object = None
name: str = 'density'
_mode_requests_face_fluxes(mode_value)
_use_model_face_particle_fluxes()
enforce_dirichlet_boundary_rhs(state, density_rhs)
debug_components(state, fluxes=None, source_outputs=None)
__call__(state, fluxes=None, source_outputs=None)
class NEOPAX.ElectricFieldEquation

Bases: EquationBase

Base class for transport equations. Subclasses must implement __call__.

dr_cells: jax.Array
Vprime: jax.Array
Vprime_half: jax.Array
flux_model: callable
species_mass: jax.Array
charge_qp: jax.Array
permitivity_prefactor: jax.Array
gamma_faces_builder: callable
er_diffusive_flux_builder: callable
er_bc_model: object = None
source_mode: str = 'ambipolar_local'
permitivity_mode: str = 'neopax_local'
Er_relax: float = 1.0
DEr: float = 1.0
boundary_mode: str = 'standard'
ntss_B0_mid: float = 0.0
ntss_psfactor_mid: float = 1.0
ntss_density_indices: jax.Array = None
name: str = 'Er'
_charge_flux_from_gamma(Gamma)
_er_diffusion(Er)
_charge_flux_and_ambi_term(state, Gamma, plasma_permitivity)
debug_components(state, fluxes=None)
__call__(state, fluxes=None)
enforce_dirichlet_boundary_rhs(state, er_rhs)
ap_linear_split(state)

Return diagonal linearization and explicit source for optional AP preconditioning. Uses only attributes set at construction and the current state.

class NEOPAX.TemperatureEquation

Bases: EquationBase

Base class for transport equations. Subclasses must implement __call__.

dr_cells: jax.Array
Vprime: jax.Array
Vprime_half: jax.Array
flux_model: callable
flux_faces_builder: callable
temperature_ghost_builder: callable
charge_qp: jax.Array
active_species_mask: jax.Array
face_flux_builder: callable = None
temperature_bc_model: object = None
convection_reconstruction: str = 'tvd_mc'
heat_flux_reconstruction: str = 'tvd_mc'
include_neo_convection: bool = True
include_turbulent_convection: bool = True
include_classical_convection: bool = True
include_work_term: bool = True
source_model: callable = None
species: object = None
name: str = 'temperature'
_mode_requests_face_fluxes(mode_value)
_use_model_face_heat_fluxes()
_use_model_face_particle_fluxes()
enforce_dirichlet_boundary_rhs(state, density_rhs, pressure_rhs)
debug_components(state, fluxes=None, source_outputs=None)
__call__(state, fluxes=None, source_outputs=None)
NEOPAX.build_equation_system(config, species, field, flux_model, source_models=None, solver_cfg=None, boundary_models=None)

Build the list of equation instances to evolve using prebuilt runtime objects. This avoids rebuilding geometry, databases, and flux models inside the equation builder and keeps compile closures smaller.

NEOPAX.build_equation_system_from_config(config, species)

Backward-compatible wrapper that builds the required runtime objects from config before delegating to build_equation_system.

class NEOPAX.AnalyticalTurbulentTransportModel

Bases: TransportFluxModelBase

Abstract base class for transport flux models. Output dict keys:

  • Gamma: particle flux

  • Q: heat flux

  • Upar: parallel flow

species: Any
grid: Any
chi_t: Any
chi_n: Any
field: Any
with_transport_coeffs(*, chi_t=None, chi_n=None) AnalyticalTurbulentTransportModel
__call__(state) dict
build_local_particle_flux_evaluator(state)
evaluate_face_fluxes(state, face_state, **kwargs)
build_lagged_response(state, **kwargs)
evaluate_with_lagged_response(state, lagged_response, **kwargs)
class NEOPAX.CombinedTransportFluxModel

Bases: TransportFluxModelBase

Abstract base class for transport flux models. Output dict keys:

  • Gamma: particle flux

  • Q: heat flux

  • Upar: parallel flow

neoclassical_model: TransportFluxModelBase
turbulent_model: TransportFluxModelBase
classical_model: TransportFluxModelBase
include_turbulent_particle_flux: bool = True
static _zero_like_flux(reference, fallback=0)
__call__(state, *args, **kwargs) dict
build_local_particle_flux_evaluator(state)
evaluate_face_fluxes(state, face_state, **kwargs)
build_lagged_response(state, **kwargs)
evaluate_with_lagged_response(state, lagged_response, **kwargs)
class NEOPAX.FluxesRFileTransportModel

Bases: TransportFluxModelBase

Abstract base class for transport flux models. Output dict keys:

  • Gamma: particle flux

  • Q: heat flux

  • Upar: parallel flow

species: Any
geometry: Any
r_data: Any
gamma_data: Any = None
q_data: Any = None
upar_data: Any = None
profile_location: str = 'cell_centered'
q_scale: float = 1.0
with_q_scale(q_scale: float) FluxesRFileTransportModel
_interp_species_profile(data, target_r)
_normalize_profile_location()
_data_on_cell_grid(data)
_data_on_face_grid(data)
__call__(state) dict
build_local_particle_flux_evaluator(state)
evaluate_face_fluxes(state, face_state, **kwargs)
class NEOPAX.NTXExactLijRuntimeSupport
center_channels: NTXRuntimeScanChannels
face_channels: NTXRuntimeScanChannels
center_prepared: Any
face_prepared: Any
grid: Any
class NEOPAX.NTXExactLijRuntimeTransportModel

Bases: TransportFluxModelBase

Abstract base class for transport flux models. Output dict keys:

  • Gamma: particle flux

  • Q: heat flux

  • Upar: parallel flow

species: Any
energy_grid: Any
geometry: Any
vmec_file: str | None
boozer_file: str | None
n_theta: int = 25
n_zeta: int = 25
n_xi: int = 64
surface_backend: str = 'vmec'
face_response_mode: str = 'face_local_response'
radial_batch_size: int | None = None
radial_batch_mode: str = 'simple'
scan_batch_size: int | None = None
response_anchor_count: int | None = None
use_remat: bool = False
er_v_floor: float | None = None
collisionality_model: str = 'default'
bc_density: Any = None
bc_temperature: Any = None
support: NTXExactLijRuntimeSupport | None = None
_rho_center_face()
_static_support() NTXExactLijRuntimeSupport
with_static_support() NTXExactLijRuntimeTransportModel
with_transport_resolution(*, n_theta=None, n_zeta=None, n_xi=None) NTXExactLijRuntimeTransportModel
with_face_response_mode(face_response_mode: str) NTXExactLijRuntimeTransportModel
with_radial_batch_size(radial_batch_size: int | None) NTXExactLijRuntimeTransportModel
with_radial_batch_mode(radial_batch_mode: str | None) NTXExactLijRuntimeTransportModel
with_scan_batch_size(scan_batch_size: int | None) NTXExactLijRuntimeTransportModel
with_response_anchor_count(response_anchor_count: int | None) NTXExactLijRuntimeTransportModel
with_use_remat(use_remat: bool) NTXExactLijRuntimeTransportModel
with_er_v_floor(er_v_floor: float | None) NTXExactLijRuntimeTransportModel
static _normalize_radial_batch_mode(radial_batch_mode: str | None) str
_map_radius_axis_hybrid(fn, radius_indices)
_map_radius_axis(fn, radius_indices)
_map_radius_axis_unbatched(fn, radius_indices)
_response_anchor_indices(n_radius: int) jax.Array
_interpolate_anchor_values(anchor_indices, anchor_values, target_rho)
_regularize_axis_radius0(values_by_radius, radius_coordinates)
_map_radius_axis_regularized_at_axis0(fn, radius_indices, radius_coordinates, *, unbatched: bool = False)
_regularize_center_fluxes_axis0(gamma, q, upar)
_log_nu_star_from_nu_hat(nu_hat_a)
_local_scan_inputs(*, drds_value, species_index: int, er_value, temperature_local, density_local, vthermal_local, collisionality_kind)
_lij_from_coefficient_scan(coeff_scan, *, drds_value, species_index: int, vth_a)
_transport_moments_from_coefficient_scan(coeff_scan, *, drds_value)
_transport_moments_from_inputs_impl(prepared, nu_hat_a, epsi_hat_a, *, drds_value)
_transport_moments_from_inputs(prepared, nu_hat_a, epsi_hat_a, *, drds_value)
_lij_from_transport_moments(transport_moments, *, species_index: int, vth_a)
_batched_lij_from_transport_moments(transport_moments, v_thermal)
_solve_coefficient_scan_prepared_impl(prepared, nu_hat_a, epsi_hat_a)
_solve_coefficient_scan_prepared(prepared, nu_hat_a, epsi_hat_a)
_coefficient_scan_from_inputs(prepared, nu_hat_a, epsi_hat_a)
_solve_lij_prepared_local_impl(prepared, *, drds_value, species_index: int, er_value, temperature_local, density_local, vthermal_local, collisionality_kind)
_solve_lij_prepared_local(prepared, *, drds_value, species_index: int, er_value, temperature_local, density_local, vthermal_local, collisionality_kind)
_build_coefficient_response_local(prepared, *, drds_value, species_index: int, er_value, temperature_local, density_local, vthermal_local, collisionality_kind)
_build_interpolated_moment_response_local(prepared, *, drds_value, species_index: int, er_value, temperature_local, density_local, vthermal_local, collisionality_kind)
_lij_center(Er, temperature, density)
_lij_faces(Er_faces, temperature_faces, density_faces)
_assemble_center_fluxes(Er, temperature, density, lij, n_right, n_right_grad, t_right, t_right_grad)
_cell_centered_flux_to_faces_centered(flux)
__call__(state) dict
build_lagged_response(state, **kwargs)
evaluate_with_lagged_response(state, lagged_response, **kwargs)
build_local_particle_flux_evaluator(state)
evaluate_face_fluxes(state, face_state, **kwargs)
class NEOPAX.NTXRuntimeScanChannels
rho: Any
a_b: float
psia: float
b00: Any
r00: Any
boozer_i: Any
boozer_g: Any
iota: Any
drds: Any
dr_tildedr: Any
dr_tildeds: Any
fac_reference_to_sfincs_11: Any
fac_reference_to_sfincs_31: Any
fac_reference_to_sfincs_33: Any
fac_sfincs_to_dkes_11: Any
fac_sfincs_to_dkes_31: Any
fac_sfincs_to_dkes_33: Any
fac_dkes_to_d11star: Any
fac_dkes_to_d31star: Any
fac_dkes_to_d33star: Any
classmethod from_mapping(rho, channels: dict[str, jax.Array | float]) NTXRuntimeScanChannels
as_mapping() dict[str, jax.Array | float]
class NEOPAX.NTXRuntimeScanTransportModel

Bases: TransportFluxModelBase

Abstract base class for transport flux models. Output dict keys:

  • Gamma: particle flux

  • Q: heat flux

  • Upar: parallel flow

species: Any
energy_grid: Any
geometry: Any
vmec_file: str | None
boozer_file: str | None
rho_scan: Any
nu_v_scan: Any
er_tilde_scan: Any
n_theta: int = 25
n_zeta: int = 25
n_xi: int = 64
surface_backend: str = 'vmec'
source_name: str = 'ntx_scan_runtime'
collisionality_model: str = 'default'
bc_density: Any = None
bc_temperature: Any = None
channels: NTXRuntimeScanChannels | None = None
database: Any = None
_scan_axes() tuple[jax.Array, jax.Array, jax.Array]
_static_channels() NTXRuntimeScanChannels
_surface_loader(ntx)
_build_runtime_database()
with_static_channels() NTXRuntimeScanTransportModel
with_scan_inputs(*, rho_scan=None, nu_v_scan=None, er_tilde_scan=None, clear_database: bool = True) NTXRuntimeScanTransportModel
_database_model() NTXDatabaseTransportModel
__call__(state) dict
build_local_particle_flux_evaluator(state)
evaluate_face_fluxes(state, face_state, **kwargs)
with_runtime_database() NTXRuntimeScanTransportModel
class NEOPAX.PowerAnalyticalTurbulentTransportModel

Bases: TransportFluxModelBase

Abstract base class for transport flux models. Output dict keys:

  • Gamma: particle flux

  • Q: heat flux

  • Upar: parallel flow

species: Any
field: Any
chi_t: Any
chi_n: Any
pressure_source_model: Any = None
total_power_mw: Any = None
with_transport_coeffs(*, chi_t=None, chi_n=None, pressure_source_model=None, total_power_mw=None) PowerAnalyticalTurbulentTransportModel
_effective_total_power_mw(state)
__call__(state) dict
build_local_particle_flux_evaluator(state)
evaluate_face_fluxes(state, face_state, **kwargs)
build_lagged_response(state, **kwargs)
evaluate_with_lagged_response(state, lagged_response, **kwargs)
NEOPAX.build_ntx_exact_lij_runtime_support(vmec_file, boozer_file, rho_center, rho_face, *, surface_backend='auto', n_theta=25, n_zeta=25, n_xi=64) NTXExactLijRuntimeSupport
NEOPAX.build_ntx_exact_lij_runtime_transport_model(species, energy_grid, geometry, *, vmec_file, boozer_file, ntx_exact_n_theta=25, ntx_exact_n_zeta=25, ntx_exact_n_xi=64, ntx_exact_surface_backend='vmec', ntx_exact_face_response_mode='face_local_response', ntx_exact_radial_batch_size=None, ntx_exact_radial_batch_mode='simple', ntx_exact_scan_batch_size=None, ntx_exact_response_anchor_count=None, ntx_exact_use_remat=False, ntx_exact_er_v_floor=None, ntx_exact_lij_support=None, preload_support=False, collisionality_model='default', bc_density=None, bc_temperature=None, **kwargs)
NEOPAX.build_ntx_runtime_scan_channels(vmec_file, boozer_file, rho_scan) NTXRuntimeScanChannels
class NEOPAX.ZeroTransportModel

Bases: TransportFluxModelBase

Abstract base class for transport flux models. Output dict keys:

  • Gamma: particle flux

  • Q: heat flux

  • Upar: parallel flow

shape: Any = None
__call__(state) dict
build_local_particle_flux_evaluator(state)
evaluate_face_fluxes(state, face_state, **kwargs)
build_lagged_response(state, **kwargs)
NEOPAX.build_fluxes_r_file_transport_model(species, geometry, *, fluxes_file=None, file=None, flux_file=None, neoclassical_file=None, turbulence_file=None, classical_file=None, grid_location='cell_centered', profile_location=None, **kwargs)
NEOPAX.build_ntx_runtime_scan_transport_model(species, energy_grid, geometry, *, vmec_file, boozer_file, ntx_scan_rho, ntx_scan_nu_v, ntx_scan_er_tilde, ntx_scan_n_theta=25, ntx_scan_n_zeta=25, ntx_scan_n_xi=64, ntx_scan_surface_backend='auto', ntx_scan_source_name='ntx_scan_runtime', collisionality_model='default', bc_density=None, bc_temperature=None, ntx_scan_channels=None, preload_channels=False, prebuild_database=True, **kwargs)
NEOPAX.build_transport_flux_model(neo_model: TransportFluxModelBase, turb_model: TransportFluxModelBase, classical_model: TransportFluxModelBase = None, *, include_turbulent_particle_flux: bool = True) CombinedTransportFluxModel

Build the composed transport model from explicit model instances. All models must be constructed up front by the orchestrator.

NEOPAX.get_transport_flux_model_capabilities(name: str) NEOPAX._model_api.ModelCapabilities
NEOPAX.get_transport_flux_model(name: str) Callable[Ellipsis, TransportFluxModelBase]
NEOPAX.register_transport_flux_model(name: str, builder: Callable[Ellipsis, TransportFluxModelBase], *, capabilities: NEOPAX._model_api.ModelCapabilities | None = None, validate: bool = False, validation_context: NEOPAX._model_api.ModelValidationContext | None = None) None
NEOPAX.transport_flux_model(name: str, **register_kwargs)
class NEOPAX.DiffraxSolver(integrator, t0, t1, dt, save_n=None, **integrator_kwargs)

Bases: TransportSolver

Base class for transport solvers. All solvers must implement the solve() method.

integrator: Callable
t0: float
t1: float
dt: float
integrator_kwargs: dict[str, Any]
solve(state, vector_field: Callable, *args, **kwargs)
class NEOPAX.NewtonThetaMethodSolver(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, min_step: float = 1e-14, theta_implicit: float = 1.0, predictor_mode: str = 'linearized', rhs_mode: str = 'black_box', use_predictor_corrector: bool = False, n_corrector_steps: int = 1, tol: float = 1e-08, maxiter: int = 20, max_step: float | None = None, safety_factor: float = 0.9, min_step_factor: float = 0.5, max_step_factor: float = 2.0, target_nonlinear_iterations: int = 4, delta_reduction_factor: float = 0.5, tau_min: float = 0.01, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, save_n=None)

Bases: _ThetaNewtonSolverConfig

Base class for transport solvers. All solvers must implement the solve() method.

solve(state, vector_field: Callable, *args, **kwargs)
class NEOPAX.RADAUSolver(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, rtol: float = 1e-06, atol: float = 1e-08, max_step: float = 1.0, min_step: float = 1e-14, tol: float = 1e-08, maxiter: int = 20, error_estimator: str = 'embedded2', num_stages: int = 3, safety_factor: float = 0.9, min_step_factor: float = 0.1, max_step_factor: float = 5.0, jacobian_reuse_rtol: float = 0.1, max_jacobian_age: int = 8, rhs_mode: str = 'black_box', newton_divergence_mode: str = 'legacy', newton_residual_norm: str = 'raw', newton_tol_mode: str = 'residual', newton_fnewt_mode: str = 'tol', controller_mode: str = 'current', predictor_mode: str = 'current', lagged_response_reuse_mode: str = 'retry_only', lagged_response_reuse_rtol: float = 0.05, lagged_response_reuse_atol: float = 1e-08, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, debug_stage_markers: bool = False, debug_walltime_attempts: bool = False, save_n=None)

Bases: _RadauSolverConfig

Base class for transport solvers. All solvers must implement the solve() method.

solve(state, vector_field: Callable, *args, **kwargs)
class NEOPAX.ThetaMethodSolver(t0: float = 0.0, t1: float = 1.0, dt: float = 0.01, min_step: float = 1e-14, theta_implicit: float = 1.0, predictor_mode: str = 'linearized', rhs_mode: str = 'black_box', use_predictor_corrector: bool = False, n_corrector_steps: int = 1, tol: float = 1e-08, max_steps: int = 20000, stop_after_accepted_steps: int | None = None, save_n=None)

Bases: _ThetaSolverConfig

Base class for transport solvers. All solvers must implement the solve() method.

solve(state, vector_field: Callable, *args, **kwargs)
NEOPAX.build_time_solver(solver_parameters: Any, solver_override: Any = None) TransportSolver

Create a time solver backend from runtime parameters/config.

solver_override can be either:
  • an instance with .solve(…) (used directly), or

  • a diffrax solver instance (wrapped in DiffraxSolver).

NEOPAX.load_config(*args, **kwargs)
NEOPAX.run_config(*args, **kwargs)
NEOPAX.run_config_path(*args, **kwargs)