"""esbmtk: A general purpose Earth Science box model toolkit.
Copyright (C), 2020 Ulrich G. Wortmann
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or (at
your option) any later version.
This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import logging
import sys
import tempfile
import time
import typing as tp
import warnings
from pathlib import Path
from time import process_time
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import psutil
from scipy.integrate import solve_ivp
from esbmtk.ode_backend_2 import (
build_eqs_matrix,
get_initial_conditions,
write_equations_3,
)
from .esbmtk_base import esbmtkBase
from .initialize_unit_registry import Q_, ureg
from .utility_functions import (
find_matching_strings,
get_delta_from_concentration,
get_delta_h,
plot_geometry,
)
# declare numpy types
NDArrayFloat = npt.NDArray[np.float64]
if tp.TYPE_CHECKING:
from .base_classes import Species
[docs]
def deprecated_keyword(model, message):
"""Issue a deprecation warning with the provided message."""
warnings.warn(message, DeprecationWarning, stacklevel=2)
model.now = model.now + 1
[docs]
class ModelError(Exception):
"""Custom Error Class for Model-related errors."""
def __init__(self, message):
"""Initialize Error Instance with formatted message."""
message = f"\n\n{message}\n"
super().__init__(message)
[docs]
class SolverError(Exception):
"""Custom Error Class for solver-related errors."""
def __init__(self, message):
"""Initialize Error Instance with formatted message."""
message = f"\n\n{message}\n"
super().__init__(message)
[docs]
class FluxNameError(Exception):
"""Custom Error Class for Flux lookup errors."""
def __init__(self, message):
"""Initialize Error Instance with formatted message."""
message = f"\n\n{message}\n"
super().__init__(message)
[docs]
class Model(esbmtkBase):
r"""Earth Science Box Model Toolkit (ESBMTK) Model class.
This class represents the main model framework for creating and running
Earth science box models. It handles initialization of model parameters,
management of reservoirs, fluxes, and species, and provides methods for
running simulations and visualizing results.
The user-facing methods of the model class are:
- Model_Name.info() - Display model information
- Model_Name.save_data() - Save model data to files
- Model_Name.plot([sb.DIC, sb.TA]) - Plot specified objects
- Model_Name.save\_state() - Save current model state
- Model_Name.read\_state() - Initialize with a previous model state
- Model_Name.run() - Run the model simulation
- Model_Name.list_species() - List all defined species
- Model_Name.flux_summary() - Display flux information
- Model_Name.connection_summary() - Display connection information
Parameters
----------
**kwargs : dict
A dictionary with key-value pairs for model configuration.
Examples
--------
>>> esbmtkModel(
... name="Test_Model", # required
... stop="10000 yrs", # end time
... max_timestep="1 yr", # maximum time step
... element=["Carbon", "Sulfur"]
... )
Important Parameters
-------------------
name : str
The model name, e.g., "M".
mass_unit : str
Base mass unit for the model, default is "mol".
volume_unit : str
Volume unit for the model, default is "liter".
element : list or str
One or more species names to include in the model.
max_timestep : str
Limit automatic step size increase (time resolution of the model).
Optional, defaults to model duration/100.
m_type : str
Controls isotope calculation for the entire model.
Options: "Not set" (default, isotopes calculated only for reservoirs
with isotope keyword), "mass_only", or "both" (overrides reservoir settings).
offset : str
Offset the time axis by the specified amount when plotting data.
For display purposes only, does not affect model calculations.
display_precision : float
Affects on-screen display of data and sets cutoff for graphical output.
opt_k_carbonic : int
See https://doi.org/10.5194/gmd-15-15-2022.
opt_pH_scale : int
pH scale setting: total=1, free=3.
debug: bool
output debug information
debug_equations_file: bool
write a debug version of the equations file.
"""
def __init__(self, **kwargs: dict[str, any]) -> None:
"""Initialize a model instance."""
from importlib.metadata import version
from esbmtk.sealevel import hypsometry
# Define default values for model parameters
self.defaults: dict[str, list[any, tuple]] = {
"start": ["0 yrs", (str, Q_)],
"stop": ["None", (str, Q_)],
"offset": ["0 yrs", (str, Q_)], # deprecated
"timestep": ["None", (str, Q_)], # deprecated
"max_timestep": ["None", (str, Q_)],
"min_timestep": ["1 second", (str, Q_)],
"element": ["None", (str, list)],
"mass_unit": ["mol", (str)],
"volume_unit": ["liter", (str)],
"area_unit": ["m**2", (str)],
"time_unit": ["year", (str)],
"concentration_unit": ["mol/liter", (str)],
"time_label": ["Years", (str)],
"display_precision": [0.01, (float)],
"plot_style": ["default", (str)],
"m_type": ["Not Set", (str)],
"step_limit": [1e9, (int, float, str)],
"register": ["local", (str)],
"save_flux_data": [False, (bool)],
"full_name": ["None", (str)],
"parent": ["None", (str)],
"isotopes": [False, (bool)],
"debug": [False, (bool)],
"ideal_water": [True, (bool)],
"use_ode": [True, (bool)],
"debug_equations_file": [False, (bool)],
"rtol": [1.0e-4, (float)],
"bio_pump_functions": [0, (int)], # custom/old
"opt_k_carbonic": [15, (int)],
"opt_pH_scale": [1, (int)], # 1: total scale
"opt_buffers_mode": [2, (int)],
"display_steps": [1000, (int)],
}
# Define required keywords
self.lrk: list[str] = [
"stop",
["timestep", "max_timestep"],
]
# Initialize keyword variables from provided arguments
self.__initialize_keyword_variables__(kwargs)
# Check for deprecated keywords
if self.timestep != "None":
self.max_timestep = self.timestep
raise DeprecationWarning(
"\ntimestep is deprecated, please replace with max_timestep\n"
)
self.now = self.now + 1
# Set default model name
self.name = "M"
# Initialize model component containers
self._initialize_model_containers()
# Configure logging
self._setup_logging()
# Register with parent
self.__register_with_parent__()
# Set up unit definitions
self._configure_units()
# Process time parameters
self._configure_time_parameters()
# Create time arrays
self._create_time_arrays()
# Handle step limit
self._handle_step_limit()
# Register elements and species with model
self._register_elements_and_species()
# Display warranty information
self._display_warranty(version)
# Initialize the hypsometry class
hypsometry(name="hyp", model=self, register=self)
def _initialize_model_containers(self):
"""Initialize all model component containers."""
# Model objects
self.lmo: list = [] # List of all model objects
self.lmo2: list = [] # Secondary list of model objects
self.dmo: dict = {} # Dict of all model objects (for name lookups)
# Reservoirs and connections
self.lor: list = [] # List of all reservoir type objects
self.lic: list = [] # List reservoirs with initial conditions
# self.lis: list = [] # List of sources with initial conditions
self.loc: set = set() # Set of connection objects
# Elements and species
self.lel: list = [] # List of all element references
self.lsp: list = [] # List of all species references
# External data and signals
self.led: list = [] # List of all external data objects
self.los: list = [] # List of signal objects
self.lvd: list = [] # List of vector data objects
# Fluxes and processes
self.lof: list = [] # List of flux objects
self.lop: list = [] # List of flux processes
self.lpc_f: list = [] # List of external functions affecting fluxes
self.lpc_i: list = [] # List of external functions needed in ode_backend
self.lpc_r: list = [] # List of external functions affecting virtual reservoirs
self.lvr: list = [] # List of virtual reservoirs
# Other model components
self.ldf: list = [] # List of datafield objects
self.lrg: list = [] # List of reservoir groups
self.lto: list = [] # List of objects requiring delayed initialization
self.olkk: list = [] # Optional keywords for use in connector class
self.axd: dict = {} # list of axes objects
# Global parameters and constants
self.gpt: tuple = () # Global parameter list
self.toc: tuple = () # Global constant tuple
self.doc: dict = {} # Global dict of index values for toc
self.gcc: int = 0 # Constants counter
self.vpc: int = 0 # Parameter counter
self.luf: dict = {} # User functions and source
self.now: float = 0 # number of warnings
def _setup_logging(self):
"""Configure model logging."""
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
log_filename: str = f"{self.name}.log"
if self.debug:
logging.basicConfig(
filename=log_filename, filemode="w", level=logging.DEBUG
)
else:
logging.basicConfig(filename=log_filename, filemode="w", level=logging.INFO)
# Redirect warnings to logging
logging.captureWarnings(True)
def _configure_units(self):
"""Set up model units."""
self.l_unit = ureg.meter # Length unit
self.t_unit = Q_(self.time_unit).units # Time unit
self.d_unit = Q_(self.stop).units # Display time units
self.m_unit = Q_(self.mass_unit).units # Mass unit
self.v_unit = Q_(self.volume_unit).units # Volume unit
self.a_unit = Q_(self.area_unit).units # Area unit
self.c_unit = Q_(self.concentration_unit).units # Concentration unit
self.f_unit = self.m_unit / self.t_unit # Flux unit (mass/time)
self.r_unit = self.v_unit / self.t_unit # Flux as volume/time
def _configure_time_parameters(self):
"""Process and configure time-related parameters."""
# Process start and stop times
self.start = self.ensure_q(self.start).to(self.t_unit).magnitude
self.stop = self.ensure_q(self.stop).to(self.t_unit).magnitude
# Handle deprecated timestep parameter
if self.timestep != "None":
self.max_timestep = self.ensure_q(self.timestep).to(self.t_unit).magnitude
deprecated_keyword(self, "timestep is deprecated. Please use max_timestep")
else:
self.max_timestep = (
self.ensure_q(self.max_timestep).to(self.t_unit).magnitude
)
# Process remaining time parameters
self.min_timestep = self.ensure_q(self.min_timestep).to(self.t_unit).magnitude
self.dt = self.max_timestep
self.offset = self.ensure_q(self.offset).to(self.t_unit).magnitude
self.start = self.start + self.offset
self.stop = self.stop + self.offset
# Legacy variable names
self.n = self.name
self.mo = self.name
self.model = self
self.plot_style: list = [self.plot_style]
# Configure time axis
self.xl = f"Time [{self.t_unit}]" # Time axis label
self.length = int(abs(self.stop - self.start))
self.steps = int(abs(round(self.length / self.dt)))
self.number_of_datapoints = self.steps
def _create_time_arrays(self):
"""Create time arrays for model simulation."""
self.time_ode = np.linspace(
self.start,
self.stop - self.start,
num=self.number_of_datapoints + 1,
)
self.display_time = np.linspace(
self.start,
self.stop - self.start,
num=self.display_steps + 1,
)
self.time = self.time_ode
self.timec = np.empty(0)
self.executionstate = 0
# Set default stride
self.stride = 1
def _handle_step_limit(self):
"""Handle step limit configuration."""
if self.step_limit == "None":
self.number_of_solving_iterations: int = 0
elif self.step_limit > self.steps:
self.number_of_solving_iterations: int = 0
self.step_limit = "None"
else:
self.step_limit = int(self.step_limit)
self.number_of_solving_iterations = int(round(self.steps / self.step_limit))
self.reset_stride = int(round(self.steps / self.number_of_datapoints))
self.steps = self.step_limit
self.time = (np.arange(self.steps) * self.dt) + self.start
def _register_elements_and_species(self):
"""Register elements and species with the model."""
from importlib import import_module
if "element" in self.kwargs:
if isinstance(self.kwargs["element"], list):
element_list = self.kwargs["element"]
else:
element_list = [self.kwargs["element"]]
# Process each element
for element_name in element_list:
# Get function handle from species_definitions
element_handler = getattr(
import_module("esbmtk.species_definitions"), element_name
)
element_handler(self) # Register element with model
# Get element handle and register its species
element_handle = getattr(self, element_name)
element_handle.__register_species_with_model__()
def _display_warranty(self, version_func):
"""Display warranty and citation information."""
import datetime
warranty_text = (
f"\n"
f"ESBMTK {version_func('esbmtk')} \n Copyright (C) 2020 - "
f"{datetime.date.today().year} Ulrich G.Wortmann\n"
f"This program comes with ABSOLUTELY NO WARRANTY\n"
f"This is free software, and you are welcome to redistribute it\n"
f"under certain conditions; See the LICENSE file for details.\n\n"
f"If you use ESBMTK for your research, please cite:\n\n"
f"Wortmann et al. 2025, https://doi.org/10.5194/gmd-18-1155-2025\n"
)
print(warranty_text)
[docs]
def info(self, **kwargs) -> None:
"""Display an overview of the model properties.
Prints information about the model instance including defined elements
and their associated species.
Parameters
----------
**kwargs : dict
Optional keyword arguments.
indent : int, default=0
Number of spaces to use for indentation in the output.
index : int, default=0
Index to use when showing data samples (if applicable).
Returns
-------
None
This method prints to stdout but doesn't return a value.
"""
# Handle indentation
indent = kwargs.get("indent", 0)
indentation = " " * indent
offset = " " # Standard offset for nested items
# Print basic model information
print(self)
# Display elements and their species
print("Currently defined elements and their species:")
for element in self.lel:
print(f"{indentation}{element}")
print(f"{offset} Defined SpeciesProperties:")
# Display species for this element
for species in element.lsp:
print(f"{offset}{offset}{indentation}{species.n}")
[docs]
def save_state(self, directory: str = "state", prefix: str = "state") -> None:
"""Save the current model state to files.
Saves only the last time step of each reservoir to files in the specified directory.
This is similar to save_data() but focuses on capturing the current state rather
than the full time series.
Parameters
----------
directory : str, default="state"
Directory where state files will be saved. Will be created if it doesn't exist
and deleted if it already exists.
prefix : str, default="state"
Prefix to add to all saved filenames.
Returns
-------
None
Raises
------
FileExistsError
If the directory exists and cannot be deleted.
"""
from pathlib import Path
from esbmtk.utility_functions import rmtree
# Prepare directory
target_path = Path.cwd() / directory
# ugly workaround because sphinx stumbles over the underscore when
# we set it in the function signature
prefix = f"{prefix}_"
# Check if directory exists and remove it if it does
if target_path.exists():
logging.info(f"Found previous state directory, deleting {target_path}")
rmtree(target_path)
# Verify directory was deleted
if target_path.exists():
raise FileExistsError(
f"Failed to delete existing directory: {target_path}"
)
# Define slice parameters for the last state only
start_idx = -2 # Second-to-last index (to avoid boundary effects)
stop_idx = None # No stop index means go to the end
stride_idx = 1 # Use every value
# Write data for each reservoir
for reservoir in self.lor:
reservoir.__write_data__(
prefix=prefix,
start=start_idx,
stop=stop_idx,
stride=stride_idx,
append=False,
directory=directory,
)
[docs]
def save_data(self, directory: str = "./data") -> None:
"""Save all model results to CSV files.
Creates a directory (or recreates if it exists) and saves the full time series
of all model components to separate CSV files. Each reservoir, signal, and vector
data object will have its own CSV file.
Parameters
----------
directory : str, default="./data"
Directory where data files will be saved. Will be created if it doesn't exist
and deleted if it already exists.
Returns
-------
None
Raises
------
FileExistsError
If the directory exists and cannot be deleted.
"""
from pathlib import Path
from esbmtk.utility_functions import rmtree
# Prepare directory
target_path = Path.cwd() / directory
# Check if directory exists and remove it if it does
if target_path.exists():
logging.info(f"Found previous data directory, deleting {target_path}")
rmtree(target_path)
# Verify directory was deleted
if target_path.exists():
raise FileExistsError(
f"Failed to delete existing directory: {target_path}"
)
# Define common parameters for data writing
prefix = ""
stride = self.stride
start_idx = 0
stop_idx = len(self.time)
append = False
# Save all regular reservoirs (excluding flux-only types)
for reservoir in self.lor:
if reservoir.rtype != "flux_only":
reservoir.__write_data__(
prefix=prefix,
start=start_idx,
stop=stop_idx,
stride=stride,
append=append,
directory=directory,
)
# Save all signal objects
for signal in self.los:
signal.__write_data__(
prefix=prefix,
start=start_idx,
stop=stop_idx,
stride=stride,
append=append,
directory=directory,
)
# Save all vector data objects
for vector_data in self.lvd:
vector_data.__write_data__(
prefix=prefix,
start=start_idx,
stop=stop_idx,
stride=stride,
append=append,
directory=directory,
)
[docs]
def read_data(self, directory: str = "./data") -> None:
"""Read model results from CSV files.
Loads previously saved model data from CSV files in the specified directory.
Updates the model's internal state with the loaded data.
Parameters
----------
directory : str, default="./data"
Directory containing the saved model data files.
Returns
-------
None
"""
from esbmtk import GasReservoir, Species
prefix = ""
logging.info(f"Reading data from {directory}")
# Process each reservoir
for reservoir in self.lor:
# Only process Species and GasReservoir objects
if isinstance(reservoir, Species | GasReservoir):
# Read the state data
reservoir.__read_state__(directory, prefix)
# Calculate delta values for reservoirs with isotopes
if reservoir.isotopes:
reservoir.d = get_delta_from_concentration(
reservoir.c, reservoir.l, reservoir.sp.r
)
[docs]
def read_state(self, directory="state"):
"""Initialize the model with the result of a previous.
For this to work, you will first need to issue a
`save_state` command at then end of a model run. This
will create the necessary data files to initialize a
subsequent model run.
"""
from pathlib import Path
from esbmtk import GasReservoir, Species # GasReservoir
path = Path(directory).resolve()
if not path.exists() or not path.is_dir():
raise FileNotFoundError(
f"The directory '{path}' does not exist or is not a directory."
)
for r in self.lor:
if isinstance(r, Species | GasReservoir):
r.__read_state__(directory)
# update swc object
for rg in self.lrg:
if hasattr(rg, "swc"):
rg.swc.update_parameters(pos=0)
[docs]
def plot(self, pl: list = None, **kwargs) -> tuple:
"""Plot model objects and save results to a file.
Creates a figure with subplots for each provided model object and renders
their data using the object's __plot__ method.
Parameters
----------
pl : list or object, default=None
A list of ESBMTK instances (e.g., reservoirs) to plot.
If a single object is provided, it will be converted to a list.
If None, an empty list will be used.
**kwargs : dict
Optional plotting parameters:
fn : str, default="{model_name}.pdf"
Filename to save the plot.
title : str, default=None
Title for the plot window.
no_show : bool, default=False
If True, don't display or save the figure; instead return the
plt, fig, and axes handles for manual customization.
reverse_time : bool, default=False
If True, reverse the time axis and adjust tick labels.
blocking : bool, default=True
If True, block execution until plot window is closed.
Returns
-------
tuple or None
If no_show=True, returns (plt, fig, axes), otherwise None.
Examples
--------
Basic usage:
>>> M.plot([sb.PO4, sb.DIC], fn='test.pdf')
Advanced usage with customization:
>>> from esbmtk import data_summaries
>>> species_names = [M.DIC, M.TA, M.pH, M.CO3, M.zcc, M.zsat, M.zsnow, M.PO4]
>>> box_names = [M.L_b, M.H_b, M.D_b]
>>> pl = data_summaries(M, species_names, box_names, M.L_b.DIC)
>>> pl += [M.CO2_At]
>>> plt, fig, axs = M.plot(
>>> pl,
>>> fn="steady_state.pdf",
>>> title="ESBMTK Preindustrial Steady State",
>>> no_show=True,
>>> )
"""
# Ensure pl is a list
if pl is None:
pl = []
if not isinstance(pl, list):
pl = [pl]
# Extract plot configuration from kwargs
filename = kwargs.get("fn", f"{self.n}.pdf")
blocking = kwargs.get("blocking", True)
plot_title = kwargs.get("title", "None")
reverse_time = kwargs.get("reverse_time", False)
no_show = kwargs.get("no_show", False)
# Determine layout based on number of plots
num_plots = len(pl)
size, geometry = plot_geometry(num_plots)
row_count, col_count = geometry
# Create figure and subplots
fig, ax = plt.subplots(row_count, col_count)
# Normalize axes structure based on subplot layout
axs = self._normalize_axes_structure(ax, row_count, col_count)
# Configure plot style and title
plt.style.use(self.plot_style)
window_title = plot_title if plot_title != "None" else f"{self.n} Species"
fig.canvas.manager.set_window_title(window_title)
fig.set_size_inches(size)
# Plot each object in the appropriate subplot
self._plot_objects_to_subplots(pl, axs, row_count, col_count, num_plots)
# Adjust figure layout
fig.subplots_adjust(top=0.88)
# Handle time axis reversal if requested
if reverse_time:
self._reverse_time_axis(fig)
# Return or display/save the figure
if no_show:
return plt, fig, fig.get_axes()
else:
fig.tight_layout()
plt.show(block=blocking)
fig.savefig(filename)
return None
def _normalize_axes_structure(self, ax, row_count: int, col_count: int) -> list:
"""Normalize the axes structure based on subplot layout.
Parameters
----------
ax : matplotlib.axes.Axes or array of Axes
The axes object(s) returned by plt.subplots()
row_count : int
Number of rows in the subplot grid
col_count : int
Number of columns in the subplot grid
Returns
-------
list
Normalized axes structure for consistent handling
"""
if row_count == 1 and col_count == 1:
# Single subplot
return ax
elif row_count > 1 and col_count == 1:
# Multiple rows, one column
return [ax[i] for i in range(row_count)]
elif row_count == 1 and col_count > 1:
# One row, multiple columns
return [ax[i] for i in range(col_count)]
else:
# Multiple rows and columns
return ax
def _plot_objects_to_subplots(
self, plot_objects: list, axes, row_count: int, col_count: int, num_plots: int
) -> None:
"""Plot objects to their respective subplots.
Parameters
----------
plot_objects : list
List of objects to plot
axes : matplotlib.axes.Axes or array of Axes
The normalized axes structure
row_count : int
Number of rows in the subplot grid
col_count : int
Number of columns in the subplot grid
num_plots : int
Total number of objects to plot
"""
plot_index = 0 # Index of current plot object
for row in range(row_count):
if col_count > 1:
# Multi-column grid
for col in range(col_count):
if plot_index < num_plots:
plot_objects[plot_index].__plot__(self, axes[row][col])
plot_index += 1
else:
# Remove unused subplots
axes[row][col].remove()
elif row_count > 1:
# Single column, multiple rows
if plot_index < num_plots:
plot_objects[plot_index].__plot__(self, axes[row])
plot_index += 1
else:
# Single subplot
if plot_index < num_plots:
plot_objects[plot_index].__plot__(self, axes)
plot_index += 1
def _reverse_time_axis(self, fig) -> None:
"""Reverse the time axis for all subplots.
Parameters
----------
fig : matplotlib.figure.Figure
The figure containing the axes to modify
"""
from matplotlib.ticker import FuncFormatter
from esbmtk import Q_
from .utility_functions import reverse_tick_labels_factory
t_max = Q_(f"{self.time[-1]} {self.t_unit}").to(self.d_unit).magnitude
axes = fig.get_axes()
for ax in axes:
try:
if self.axd[ax] == "reversible":
# ax.xaxis.set_inverted(True)
ax.invert_xaxis()
ax.xaxis.set_major_formatter(
FuncFormatter(reverse_tick_labels_factory(t_max))
)
except KeyError:
pass
[docs]
def run(self, **kwargs) -> None:
"""Run the model simulation.
Executes the model simulation by solving the system of ordinary
differential equations (ODEs) that describe the model dynamics.
Parameters
----------
**kwargs : dict
Optional keyword arguments to control the simulation:
solver : str, default="ode"
The solver type to use. Currently only "ode" is supported.
method : str, default="BDF"
The integration method for the ODE solver.
Options include "BDF" and "LSODA".
stype : str, default="solve_ivp"
The solver function to use. Currently only "solve_ivp" is supported.
Returns
-------
None
Results are stored in the model instance.
Raises
------
ModelError
If an unsupported solver type is specified.
SolverError
If the solver fails to find a solution.
Notes
-----
After running, performance metrics (CPU time, memory usage) are printed.
"""
# Track execution time and resource usage
import os
import sys
from datetime import datetime
script_path = sys.argv[0]
script_name = os.path.basename(script_path)
wall_clock_start = time.time()
logging.info(f"{80 * '='}")
logging.info(
f"Intergration started at {datetime.fromtimestamp(wall_clock_start)}"
)
cpu_start = process_time()
# Run solver
self._ode_solver(kwargs)
# Mark model as executed
self.executionstate = 1
# Calculate and display performance metrics
cpu_duration = process_time() - cpu_start
wall_clock_duration = time.time() - wall_clock_start
print(
f"\n Execution took {cpu_duration:.2f} CPU seconds, "
f"wall time = {wall_clock_duration:.2f} seconds\n"
)
# Get memory usage
logging.info(
f"\n\n{script_name} stopped at {datetime.fromtimestamp(time.time())}"
)
process = psutil.Process(os.getpid())
memory_gb = process.memory_info().rss / 1e9
logging.info(f"This run used {memory_gb:.2f} GB of memory\n")
logging.info(f"{80 * '='}")
if self.now > 0:
print(
f"{80 * '='}\n\n"
f"There were {self.now} warnings, check M.log\n\n"
f"{80 * '='}\n"
)
def _write_temp_equations(self, cwd, R, icl, cpl, ipl):
"""Write temporary equations file and return the equationsset.
Creates a temporary Python module containing the model equations,
imports it, and returns the equationsset function.
Parameters
----------
cwd : str or Path
Current working directory
R : ndarray
Initial conditions
icl : list
List of initial condition objects
cpl : list
List of constant parameters
ipl : list
List of initial parameters
Returns
-------
function
The equations function imported from the temporary module
"""
from pathlib import Path
# Set temporary directory to current working directory
tempfile.tempdir = cwd
# Create a temporary Python file
with tempfile.NamedTemporaryFile(suffix=".py") as tmp_file:
# Get path to temporary file
equations_file_path = Path(tmp_file.name)
# Generate equations module
equations_module_name = write_equations_3(
self, R, icl, cpl, ipl, equations_file_path
)
eqs = __import__(equations_module_name).eqs
return eqs
def _ode_solver(self, kwargs: dict):
"""Initialize and run the ODE solver.
Sets up the system of ODEs, generates the equationsfile, and solves
the system using scipy's solve_ivp.
Parameters
----------
kwargs : dict
Keyword arguments to control the solver behavior
Raises
------
SolverError
If the solver fails to find a solution
"""
# Get initial conditions and build equationsmatrices
self.R_names_dict, icl, cpl, ipl, atol = get_initial_conditions(self, self.rtol)
self.R_names = list(self.R_names_dict.keys())
# get initial concentrations for each reservoir
R = np.array(list(self.R_names_dict.values()))
# icl = dict[Species, list[int, int]] where reservoir
# indicates the reservoir handle, and the list contains the
# index into the reservoir data. list[0] = concentration
# list[1] concentration of the light isotope.
self.icl = icl
# cpl = list of reservoirs that use function to evaluate
# reservoir data
self.cpl = cpl
# ipl = list of static reservoirs that serve as input
self.ipl = ipl
# Build coefficient matrix
self.CM, self.F, self.F_names = build_eqs_matrix(self)
# Set up paths for equationsfiles
current_dir = Path.cwd()
sys.path.append(str(current_dir)) # Required on Windows
equations_filename = "equations.py"
coefficients_file = "eqs_coeff.npz"
coeff_file_path = Path(f"{current_dir}/{coefficients_file}")
equations_file_path = Path(f"{current_dir}/{equations_filename}")
if self.debug_equations_file:
np.savez(coeff_file_path, CM=self.CM, F=self.F)
elif coeff_file_path.exists():
coeff_file_path.unlink()
# Handle equations file based on debug settings
equations_set = self._handle_equations_file(
equations_file_path, R, icl, cpl, ipl, current_dir
)
# Get solver configuration from kwargs
method = kwargs.get("method", "LSODA")
# Initialize carbonate chemistry tables if not present
self._initialize_carbonate_tables()
# Run the ODE solver
self._run_solve_ivp(R, equations_set, method, atol)
# Process results
self._process_solver_results()
def _handle_equations_file(
self, equations_file_path, R, icl, cpl, ipl, current_dir
):
"""Handle equationsfile generation based on debug settings.
Parameters
----------
equations_file_path : Path
Path to the equationsfile
R, icl, cpl, ipl : various
Parameters for equationsgeneration
current_dir : Path
Current working directory
Returns
-------
function
The equations function
"""
if self.debug_equations_file: # If debugging equations is enabled
if equations_file_path.exists():
print(
"\n\n Warning re-using the equations file \n"
"\n type r to reuse old file or n to create a new one",
)
user_input = input("type r/n: ")
if user_input.lower() == "r": # Use existing file
equations_module_name = equations_file_path.stem
# Also load saved matrices if they exist
matrix_file = Path(
str(equations_file_path).replace(".py", "_matrices.npz")
)
if matrix_file.exists():
saved_data = np.load(matrix_file)
self.CM = saved_data["CM"]
self.F = saved_data["F"]
self.F_names = (
saved_data["F_names"].tolist()
if "F_names" in saved_data
else []
)
else:
print(
"Warning: Reusing equationsfile but matrix file not found. Results may be inconsistent."
)
else: # Create new file
equations_file_path.unlink() # Delete old file
equations_module_name = write_equations_3(
self, R, icl, cpl, ipl, equations_file_path
)
# Save matrices for future reuse
matrix_file = Path(
str(equations_file_path).replace(".py", "_matrices.npz")
)
self.matrix_file = matrix_file
np.savez(
matrix_file,
CM=self.CM,
F=self.F,
F_names=np.array(self.F_names),
)
else: # First run - create persistent file
equations_module_name = write_equations_3(
self, R, icl, cpl, ipl, equations_file_path
)
eqs = __import__(equations_module_name).eqs
else: # Use temporary file for equations
if equations_file_path.exists():
equations_file_path.unlink()
eqs = self._write_temp_equations(current_dir, R, icl, cpl, ipl)
return eqs # module reference
def _initialize_carbonate_tables(self):
"""Initialize carbonate chemistry tables with default values if not present."""
if not hasattr(self, "area_table"):
self.area_table = 0
self.area_dz_table = 0
self.Csat_table = 0
def _run_solve_ivp(self, R, equations_set, method, atol):
"""Run the solve_ivp ODE solver.
Parameters
----------
R : ndarray
Initial conditions
equations_set : function
The ODE function
method : str
Integration method
atol : float or ndarray
Absolute tolerance
"""
if self.debug:
logging.info(f"R: {R}")
logging.info(
f"self.gpt shape: {
np.shape(self.gpt) if hasattr(self.gpt, 'shape') else len(self.gpt)
}"
)
logging.info(
f"self.toc shape: {
np.shape(self.toc) if hasattr(self.toc, 'shape') else len(self.toc)
}"
)
logging.info(f"CM shape: {np.shape(self.CM)}")
logging.info(f"F shape: {np.shape(self.F)}")
logging.info(f"time_ode shape: {np.shape(self.time_ode)}")
# Add hash values for large arrays to verify content
logging.info(f"CM hash: {hash(str(self.CM))}")
logging.info(f"F hash: {hash(str(self.F))}")
logging.info(f"time_ode hash: {hash(str(self.time_ode))}")
self.results = solve_ivp(
equations_set,
(self.time[0], self.time[-1]),
R,
args=(
self,
self.gpt,
self.toc, # Tuple of constants
self.area_table,
self.area_dz_table,
self.Csat_table,
self.CM, # Coefficient matrix
self.F, # Flux vector
),
method=method,
atol=atol,
rtol=self.rtol,
# t_eval=self.time_ode,
t_eval=self.display_time,
first_step=self.min_timestep,
max_step=self.dt,
vectorized=False, # Flux equations would need to be adjusted
)
def _process_solver_results(self):
"""Process the solver results and handle errors.
Raises
------
SolverError
If the solver fails to find a solution
"""
if self.results.status == 0:
# Print solver statistics
logging.info(
f"Intergration finished:\n "
f"Number of evaluations of the right-hand side = {self.results.nfev}\n"
f"Number of evaluations of the Jacobian = {self.results.njev}\n"
f"Number of LU decompositions = {self.results.nlu}\n"
f"status={self.results.status}\n"
f"message={self.results.message}\n"
)
print(f"status={self.results.status}")
print(f"message={self.results.message}\n")
# Process data
self.post_process_data(self.results)
else:
# Raise error with helpful message for failed solutions
error_message = (
"---------------------- Warning ------------------------\n"
"No solution was obtained, check "
"https://esbmtk.readthedocs.io/en/latest/manual/manual-6.html\n"
"---------------------- Warning ------------------------\n"
)
raise SolverError(error_message)
[docs]
def get_delta_values(self) -> None:
"""Calculate reservoir masses and isotope delta values.
Updates the mass (m) and delta (d) values for all reservoirs
in the model that have isotopes enabled. For each reservoir,
the mass is calculated from concentration and volume, and
the delta value is calculated using the get_delta_h function.
Parameters
----------
None
Returns
-------
None
The method modifies reservoir objects in place.
"""
for reservoir in self.lor:
if reservoir.isotopes:
# Update mass based on concentration and volume
reservoir.m = reservoir.c * reservoir.volume
# Calculate isotope delta values
reservoir.d = get_delta_h(reservoir)
[docs]
def sub_sample_data(self) -> None:
"""Reduce data resolution by subsampling time series data.
If the number of time points exceeds the desired number of data points,
this method reduces the data resolution by taking every nth point
(where n is the stride). This affects the time array and all data
in reservoirs, virtual reservoirs, and fluxes.
The method is mainly used to reduce memory usage and file sizes
when saving model output.
Parameters
----------
None
Returns
-------
None
The method modifies model data in place.
Notes
-----
Subsampling only occurs if the stride is greater than 1.
The time series boundaries (first two and last two points) are excluded
from subsampling to avoid boundary effects.
"""
# Calculate stride based on current time array and desired number of points
stride = int(len(self.time) / self.number_of_datapoints)
# Only subsample if stride is greater than 1
if stride > 1:
# Subsample time array, excluding first two and last two points
self.time = self.time[2:-2:stride]
# Subsample all reservoir data
for reservoir in self.lor:
reservoir.__sub_sample_data__(stride)
# Subsample all virtual reservoir data
for virtual_reservoir in self.lvr:
virtual_reservoir.__sub_sample_data__(stride)
# Subsample all flux data
for flux in self.lof:
flux.__sub_sample_data__(stride)
[docs]
def post_process_data(self, results) -> None:
"""Process solver results and update model data structures.
Takes the raw numerical results from the ODE solver and maps them back
into the appropriate ESBMTK data structures (reservoirs, signals, fluxes).
Also performs post-processing operations like interpolating signals,
calculating derived quantities, and checking for pH stability.
Parameters
----------
results : scipy.integrate._ivp.ivp.OdeResult
The results object returned by the ODE solver, containing solution
time points (t) and state variables (y)
Returns
-------
None
The method updates model data structures in-place
Notes
-----
The processing order is important:
1. Interpolate signals and external data to match solver time points
2. Map state variables to reservoir concentrations and masses
3. Update time vector and flux data
4. Perform specialized checks (pH stability) and calculations (carbonate chemistry)
"""
# Step 1: Interpolate signals to match solver time domain
self._interpolate_signals_to_solver_timepoints(results)
# FIXME: Is this needed? It messes with external data
# Step 2: Interpolate external data to match solver time domain
# self._interpolate_external_data_to_solver_timepoints(results)
# Step 3: Map solver state variables to reservoir properties
self._map_state_variables_to_reservoirs(results)
# Step 4: Update model time vector to match solver time points
self.time = results.t
self.time_u = self.time * self.t_unit
# Step 5: Update flux data to match solver time steps
steps = len(results.t) # Get number of solver steps
self._update_flux_data(steps)
# Step 6: Perform specialized post-processing
self._perform_specialized_post_processing(results)
def _interpolate_signals_to_solver_timepoints(self, results) -> None:
"""Interpolate signal data to match solver time points.
Parameters
----------
results : scipy.integrate._ivp.ivp.OdeResult
The ODE solver results
"""
# raise NotImplementedError("This method is currently not used")
from esbmtk.utility_functions import get_delta_from_concentration
for signal in self.los:
# Interpolate mass data
signal.signal_data.m = np.interp(results.t, self.time, signal.signal_data.m)
# Interpolate isotope data if present
if signal.isotopes:
signal.signal_data.l = np.interp(
results.t, self.time, signal.signal_data.l
)
signal.signal_data.d = get_delta_from_concentration(
signal.signal_data.m, signal.signal_data.l, signal.species.r
)
def _interpolate_external_data_to_solver_timepoints(self, results) -> None:
"""Interpolate external data to match solver time points.
Parameters
----------
results : scipy.integrate._ivp.ivp.OdeResult
The ODE solver results
"""
# raise NotImplementedError("This method is currently not used")
for external_data in self.led:
external_data.y = np.interp(results.t, external_data.x, external_data.y)
def _map_state_variables_to_reservoirs(self, results) -> None:
"""Map solver state variables to reservoir properties.
Parameters
----------
results : scipy.integrate._ivp.ivp.OdeResult
The ODE solver results
"""
from esbmtk import GasReservoir
state_index = 0
for reservoir in self.icl:
# Update reservoir concentration
reservoir.c = results.y[state_index]
reservoir.c_u = reservoir.c * self.c_unit
# Update reservoir mass (assumes constant volume)
# Note: This would need modification for variable volumes
density = reservoir.swc.density.m / 1000 if hasattr(reservoir, "swc") else 1
if isinstance(reservoir, GasReservoir):
# FIXME: GasReservoirs do have a volume and the current code
# assumes a constant mass
reservoir.m_u = reservoir.m * self.m_unit
else:
reservoir.m = (
results.y[state_index]
* reservoir.volume.to(self.v_unit).m
* density
)
reservoir.m_u = reservoir.m * self.m_unit
# Move to next state variable
state_index += 1
# Process isotope data if present
if reservoir.isotopes:
reservoir.l = results.y[state_index]
state_index += 1
# Calculate delta values from concentrations
reservoir.d = get_delta_from_concentration(
reservoir.c, reservoir.l, reservoir.sp.r
)
def _update_flux_data(self, steps: int) -> None:
"""Update flux data to match solver time steps.
Parameters
----------
steps : int
Number of time steps in the solver results
"""
for flux in self.lof:
if flux.save_flux_data:
# Trim flux mass data to match time steps
flux.m = flux.m[0:steps]
# Process isotope data if present
if flux.isotopes:
flux.l = flux.l[0:steps]
flux.d = get_delta_h(flux)
def _perform_specialized_post_processing(self, results) -> None:
"""Perform specialized post-processing tasks.
Parameters
----------
results : scipy.integrate._ivp.ivp.OdeResult
The ODE solver results
"""
from esbmtk import carbonate_system_1_pp
for reservoir_group in self.lrg:
# Check for pH stability if hydrogen ions are present
if hasattr(reservoir_group, "Hplus"):
self.test_d_pH(reservoir_group, results.t)
# Calculate carbonate system parameters if needed
if reservoir_group.has_cs1:
carbonate_system_1_pp(reservoir_group)
[docs]
def test_d_pH(self, reservoir_group: Species, time_vector: NDArrayFloat) -> None:
"""Test for large changes in pH between time steps.
Checks if the pH change between consecutive time steps exceeds 0.01 units,
which could indicate numerical instability or unrealistic model behavior.
Warnings are issued for any time steps where the threshold is exceeded.
Parameters
----------
reservoir_group : Species
The reservoir group containing a Hplus species to be checked
time_vector : NDArrayFloat
Time vector as returned by the solver
Returns
-------
None
Issues warnings if large pH changes are detected
Notes
-----
This is a crude test since the solver interpolates between integration steps,
so it may not catch all problems. It only identifies pH changes that exceed
0.01 units between the specific time points in the solution.
The pH is calculated as -log10([H+]), where [H+] is the hydrogen ion concentration.
"""
# Access the hydrogen ion concentration data
hydrogen_ions = reservoir_group.Hplus
# Calculate pH from hydrogen ion concentration and get differences between steps
pH_values = -np.log10(hydrogen_ions.c)
pH_changes = np.diff(pH_values)
# Find time steps where pH change exceeds threshold
pH_threshold = 0.01
large_pH_changes = pH_changes > pH_threshold
# If any large changes were found, issue warnings
if np.any(large_pH_changes):
for i, is_large_change in enumerate(large_pH_changes):
if is_large_change:
self.now = self.now + 1
warnings.warn(
f"{reservoir_group.full_name} delta pH = {pH_changes[i]:.2f} "
f"at t = {time_vector[i]:.2f} {self.t_unit:~P}",
# stacklevel=2,
)
[docs]
def list_species(self) -> None:
"""Display all elements and species defined in the model.
Prints a hierarchical list of all elements in the model and their
associated species properties. This provides a quick overview of
the chemical species available in the model simulation.
Parameters
----------
None
Returns
-------
None
This method prints to stdout but doesn't return a value.
Examples
--------
>>> model.list_species()
Currently defined elements and their species:
Carbon
Defined SpeciesProperties:
DIC
CO2
HCO3
CO3
Sulfur
Defined SpeciesProperties:
SO4
H2S
"""
# Print header
print("\nCurrently defined elements and their species:")
# Iterate through each element
for element in self.lel:
# Display element name
print(f"{element}")
print(" Defined SpeciesProperties:")
# Display all species for this element
for species in element.lsp:
print(f" {species.n}")
[docs]
def flux_summary(self, **kwargs: dict) -> list | None:
"""Display or return a filtered summary of model fluxes.
Creates a report of fluxes in the model, filtered by name patterns.
Can either print the results to the console or return them as a list.
Parameters
----------
**kwargs : dict
Optional keyword arguments:
filter_by : str, default=""
Filter fluxes by name or partial name. Multiple words separated
by spaces act as additional conditions - all words must appear
in the flux name.
exclude : str, default=""
Exclude any fluxes whose names contain this string.
return_list : bool, default=False
If True, return a list of flux objects instead of printing to console.
Returns
-------
list or None
If return_list=True, returns a list of flux objects matching the filters.
If return_list=False, returns None (results are printed to console).
Raises
------
ModelError
If the deprecated "filter" parameter is used instead of "filter_by".
Examples
--------
# Display all fluxes containing "PO4" in their name
>>> model.flux_summary(filter_by="PO4")
# Get a list of fluxes containing both "POP" and "A_sb" in their names
>>> fluxes = model.flux_summary(filter_by="POP A_sb", return_list=True)
# Display fluxes containing "PO4" but not "H_sb"
>>> model.flux_summary(filter_by="PO4", exclude="H_sb")
"""
# FIXME: Thios needs proper keyword parsing!
# Get filter parameters from kwargs with proper defaults
filter_terms = (
kwargs.get("filter_by", "").split() if "filter_by" in kwargs else []
)
exclude_term = kwargs.get("exclude", "")
return_as_list = kwargs.get("return_list", False)
# Check for deprecated parameter
if "filter" in kwargs:
raise ModelError("use filter_by instead of filter")
# Print header if displaying results
if not return_as_list:
print(f"\n --- Flux Summary -- filtered by {filter_terms}\n")
# Initialize result list
matching_fluxes = []
# Find all fluxes that match the filter criteria
for flux in self.lof:
# Check if flux name matches all filter terms and doesn't contain exclude term
if find_matching_strings(flux.full_name, filter_terms) and (
not exclude_term or exclude_term not in flux.full_name
):
matching_fluxes.append(flux)
if return_as_list:
if len(matching_fluxes) == 0:
raise FluxNameError(f"No flux {filter_terms} found. Typo?")
else:
# Print flux name if not returning a list
print(f"{flux.full_name}")
# Return results based on the return_list parameter
return matching_fluxes if return_as_list else None
[docs]
def connection_summary(self, **kwargs) -> None:
"""Display a summary of model connections.
Prints information about all connections in the model or a filtered subset.
For each connection, shows source and target, plus additional attributes
if requested.
Parameters
----------
**kwargs : dict
Optional keyword arguments:
filter_by : str, default=None
If provided, only show connections containing this substring in their name.
list_all : bool, default=False
If True, print all connection attributes including internal ones.
Returns
-------
None
This method prints to stdout but doesn't return a value.
Examples
--------
>>> model.connection_summary() # Show all connections
>>> model.connection_summary(filter_by="DIC") # Show only DIC connections
>>> model.connection_summary(list_all=True) # Show all connection details
"""
# Extract configuration from kwargs
show_all_attributes = kwargs.get("list_all", False)
connection_name_filter = kwargs.get("filter_by")
# Get filtered list of connections
filtered_connections = self._filter_connections(connection_name_filter)
# Exit early if no matching connections are found
if not filtered_connections:
self._report_no_connections(connection_name_filter)
return
# Display each connection
print("")
for connection in filtered_connections:
self._display_connection_info(connection, show_all_attributes)
def _filter_connections(self, name_filter: str = None) -> list:
"""Filter connections by name.
Parameters
----------
name_filter : str, default=None
Substring to match in connection names
Returns
-------
list
Filtered list of connection objects
"""
filtered_list = []
for connection in self.loc:
# If name filter is provided, only include connections with matching names
if name_filter is not None:
if name_filter in connection.name: # Substring search
filtered_list.append(connection)
else:
# No filter - include all connections
filtered_list.append(connection)
return filtered_list
def _report_no_connections(self, name_filter: str = None) -> None:
"""Report when no connections are found.
Parameters
----------
name_filter : str, default=None
The filter string that was used (if any)
"""
if name_filter is not None:
print(f"No connections with name '{name_filter}' found")
else:
print("No connections found")
def _display_connection_info(self, connection, show_all_attributes: bool) -> None:
"""Display information about a single connection.
Parameters
----------
connection : Connection
The connection object to display
show_all_attributes : bool
Whether to show all attributes of the connection
"""
# Get basic source and target info
from esbmtk import Species2Species
source = connection.source_name
target = connection.sink_name
# Display connection header with appropriate format based on connection type
if isinstance(connection, Species2Species):
# For species-to-species connections, show the specific species
source_species = f"{connection.source.sp.n}"
target_species = f"{connection.sink.sp.n}"
print(
f"Connection: {connection.full_name}: {source}.{source_species} -> {
target
}.{target_species}"
)
else:
# For reservoir-to-reservoir connections
print(f"Connection: {connection.full_name}: {source} -> {target}")
# Display connection attributes
# self._display_connection_attributes(connection, show_all_attributes)
# Add empty line after each connection for readability
print("")
def _display_connection_attributes(
self, connection, show_all_attributes: bool
) -> None:
"""Display the attributes of a connection.
Parameters
----------
connection : Connection
The connection object
show_all_attributes : bool
Whether to show all attributes
"""
# If all attributes requested, show the entire __dict__
if show_all_attributes:
print(f" {connection.__dict__}")
return
# Otherwise, show only selected attributes
excluded_attributes = [
"source",
"target",
"flux",
"source_name",
"target_name",
*self.olkk, # Optional keywords to exclude
]
for attr_name, attr_value in connection.__dict__.items():
# Skip private attributes and excluded ones
if attr_name[0] != "_" and attr_name not in excluded_attributes:
print(f" {attr_name}: {attr_value}")
[docs]
def clear(self):
"""Delete all model objects."""
for o in self.lmo:
print(f"deleting {o}")
del __builtins__[o]
def __init_dimensionalities__(self, ureg):
"""No longer needed."""
raise NotImplementedError()
"""Test the dimensionality of input data."""
self.substance_per_volume_d = ureg("mol/liter").dimensionality
self.substance_per_mass_d = ureg("mol/kg").dimensionality
self.substance_d = ureg("mol").dimensionality
self.mass_d = ureg("kg").dimensionality
self.length_d = ureg("m").dimensionality
self.flux_d = ureg("mol/s").dimensionality
self.time_d = ureg("s").dimensionality