Skip to content

Commit

Permalink
Convergence restructure (#2518)
Browse files Browse the repository at this point in the history
* Initial basic convergence class

* Remove convergence strategy type handling from simulation

Also fix formatting

* Docstrings

* Remove numba from function (for now?)

* Fix docstrings

* Adds tests to convergence solver

* black format

* Resolve comments

* Formatting fix

* Move to 1 shell for t_inner test
  • Loading branch information
andrewfullard authored Jun 7, 2024
1 parent 189b3a0 commit 41bcf24
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 57 deletions.
2 changes: 1 addition & 1 deletion tardis/io/configuration/config_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def parse_convergence_section(convergence_section_dict):
convergence_section_dict : dict
dictionary
"""
convergence_parameters = ["damping_constant", "threshold"]
convergence_parameters = ["damping_constant", "threshold", "type"]

for convergence_variable in ["t_inner", "t_rad", "w"]:
if convergence_variable not in convergence_section_dict:
Expand Down
86 changes: 30 additions & 56 deletions tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tardis.model.parse_input import initialize_packet_source
from tardis.transport.montecarlo.base import MonteCarloTransportSolver
from tardis.plasma.standard_plasmas import assemble_plasma
from tardis.simulation.convergence import ConvergenceSolver
from tardis.util.base import is_notebook
from tardis.visualization import ConvergencePlots

Expand Down Expand Up @@ -152,21 +153,21 @@ def __init__(
self.show_progress_bars = show_progress_bars
self.version = tardis.__version__

if convergence_strategy.type in ("damped"):
self.convergence_strategy = convergence_strategy
self.converged = False
self.consecutive_converges_count = 0
elif convergence_strategy.type in ("custom"):
raise NotImplementedError(
"Convergence strategy type is custom; "
"you need to implement your specific treatment!"
)
else:
raise ValueError(
f"Convergence strategy type is "
f"not damped or custom "
f"- input is {convergence_strategy.type}"
)
# Convergence
self.convergence_strategy = convergence_strategy
self.converged = False
self.consecutive_converges_count = 0

# Convergence solvers
self.t_rad_convergence_solver = ConvergenceSolver(
self.convergence_strategy.t_rad
)
self.w_convergence_solver = ConvergenceSolver(
self.convergence_strategy.w
)
self.t_inner_convergence_solver = ConvergenceSolver(
self.convergence_strategy.t_inner
)

if show_convergence_plots:
if not is_notebook():
Expand Down Expand Up @@ -215,48 +216,25 @@ def estimate_t_inner(

return input_t_inner * luminosity_ratios**t_inner_update_exponent

@staticmethod
def damped_converge(value, estimated_value, damping_factor):
# FIXME: Should convergence strategy have its own class containing this
# as a method
return value + damping_factor * (estimated_value - value)

def _get_convergence_status(
self, t_rad, w, t_inner, estimated_t_rad, estimated_w, estimated_t_inner
):
# FIXME: Move the convergence checking in its own class.
no_of_shells = self.simulation_state.no_of_shells

convergence_t_rad = (
abs(t_rad - estimated_t_rad) / estimated_t_rad
).value
convergence_w = abs(w - estimated_w) / estimated_w
convergence_t_inner = (
abs(t_inner - estimated_t_inner) / estimated_t_inner
).value

fraction_t_rad_converged = (
np.count_nonzero(
convergence_t_rad < self.convergence_strategy.t_rad.threshold
)
/ no_of_shells
)

t_rad_converged = (
fraction_t_rad_converged > self.convergence_strategy.fraction
t_rad_converged = self.t_rad_convergence_solver.get_convergence_status(
t_rad.value,
estimated_t_rad.value,
self.simulation_state.no_of_shells,
)

fraction_w_converged = (
np.count_nonzero(
convergence_w < self.convergence_strategy.w.threshold
)
/ no_of_shells
w_converged = self.w_convergence_solver.get_convergence_status(
w, estimated_w, self.simulation_state.no_of_shells
)

w_converged = fraction_w_converged > self.convergence_strategy.fraction

t_inner_converged = (
convergence_t_inner < self.convergence_strategy.t_inner.threshold
self.t_inner_convergence_solver.get_convergence_status(
t_inner.value,
estimated_t_inner.value,
1,
)
)

if np.all([t_rad_converged, w_converged, t_inner_converged]):
Expand Down Expand Up @@ -304,24 +282,20 @@ def advance_state(self):
)

# calculate_next_plasma_state equivalent
# FIXME: Should convergence strategy have its own class?
next_t_radiative = self.damped_converge(
next_t_radiative = self.t_rad_convergence_solver.converge(
self.simulation_state.t_radiative,
estimated_t_rad,
self.convergence_strategy.t_rad.damping_constant,
)
next_dilution_factor = self.damped_converge(
next_dilution_factor = self.w_convergence_solver.converge(
self.simulation_state.dilution_factor,
estimated_dilution_factor,
self.convergence_strategy.w.damping_constant,
)
if (
self.iterations_executed + 1
) % self.convergence_strategy.lock_t_inner_cycles == 0:
next_t_inner = self.damped_converge(
next_t_inner = self.t_inner_convergence_solver.converge(
self.simulation_state.t_inner,
estimated_t_inner,
self.convergence_strategy.t_inner.damping_constant,
)
else:
next_t_inner = self.simulation_state.t_inner
Expand Down
78 changes: 78 additions & 0 deletions tardis/simulation/convergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np


class ConvergenceSolver:
def __init__(self, strategy):
"""Convergence solver. Sets convergence strategy and assigns a method
to the converge property.
Parameters
----------
strategy : string
Convergence strategy for the physical property
Raises
------
NotImplementedError
Custom convergence type specified
ValueError
Unknown convergence type specified
"""
self.convergence_strategy = strategy
self.damping_factor = self.convergence_strategy.damping_constant
self.threshold = self.convergence_strategy.threshold

if self.convergence_strategy.type in ("damped"):
self.converge = self.damped_converge
elif self.convergence_strategy.type in ("custom"):
raise NotImplementedError(
"Convergence strategy type is custom; "
"you need to implement your specific treatment!"
)
else:
raise ValueError(
f"Convergence strategy type is "
f"not damped or custom "
f"- input is {self.convergence_strategy.type}"
)

def damped_converge(self, value, estimated_value):
"""Damped convergence solver
Parameters
----------
value : np.float64
The current value of the physical property
estimated_value : np.float64
The estimated value of the physical property
Returns
-------
np.float64
The converged value
"""
return value + self.damping_factor * (estimated_value - value)

def get_convergence_status(self, value, estimated_value, no_of_cells):
"""Get the status of convergence for the physical property
Parameters
----------
value : np.float64, Quantity
The current value of the physical property
estimated_value : np.float64, Quantity
The estimated value of the physical property
no_of_cells : np.int64
The number of cells to measure convergence over
Returns
-------
bool
True if convergence is reached
"""
convergence = abs(value - estimated_value) / estimated_value

fraction_converged = (
np.count_nonzero(convergence < self.threshold) / no_of_cells
)
return fraction_converged > self.threshold
65 changes: 65 additions & 0 deletions tardis/simulation/tests/test_convergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from pathlib import Path

import numpy as np
import numpy.testing as npt
import pytest

from tardis.io.configuration.config_reader import Configuration
from tardis.simulation.convergence import ConvergenceSolver


@pytest.fixture(scope="function")
def config(example_configuration_dir: Path):
return Configuration.from_yaml(
example_configuration_dir / "tardis_configv1_verysimple.yml"
)


@pytest.fixture(scope="function")
def strategy(config):
return config.montecarlo.convergence_strategy.t_rad


def test_convergence_solver_init_damped(strategy):
solver = ConvergenceSolver(strategy)
assert solver.damping_factor == 0.5
assert solver.threshold == 0.05
assert solver.converge == solver.damped_converge


def test_convergence_solver_init_custom(strategy):
strategy.type = "custom"
with pytest.raises(NotImplementedError):
ConvergenceSolver(strategy)


def test_convergence_solver_init_invalid(strategy):
strategy.type = "invalid"
with pytest.raises(ValueError):
ConvergenceSolver(strategy)


def test_damped_converge(strategy):
solver = ConvergenceSolver(strategy)
value = np.float64(10.0)
estimated_value = np.float64(20.0)
converged_value = solver.damped_converge(value, estimated_value)
npt.assert_almost_equal(converged_value, 15.0)


def test_get_convergence_status(strategy):
solver = ConvergenceSolver(strategy)
value = np.array([1.0, 2.0, 3.0], dtype=np.float64)
estimated_value = np.array([1.01, 2.02, 3.03], dtype=np.float64)
no_of_cells = np.int64(3)
is_converged = solver.get_convergence_status(
value, estimated_value, no_of_cells
)
assert is_converged

value = np.array([1.0, 2.0, 3.0], dtype=np.float64)
estimated_value = np.array([2.0, 3.0, 4.0], dtype=np.float64)
is_converged = solver.get_convergence_status(
value, estimated_value, no_of_cells
)
assert not is_converged

0 comments on commit 41bcf24

Please sign in to comment.