-
-
Notifications
You must be signed in to change notification settings - Fork 426
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial basic convergence class * Remove convergence strategy type handling from simulation Also fix formatting * Docstrings * Remove numba from function (for now?) * Fix docstrings * Adds tests to convergence solver * black format * Resolve comments * Formatting fix * Move to 1 shell for t_inner test
- Loading branch information
1 parent
189b3a0
commit 41bcf24
Showing
4 changed files
with
174 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
|
||
|
||
class ConvergenceSolver: | ||
def __init__(self, strategy): | ||
"""Convergence solver. Sets convergence strategy and assigns a method | ||
to the converge property. | ||
Parameters | ||
---------- | ||
strategy : string | ||
Convergence strategy for the physical property | ||
Raises | ||
------ | ||
NotImplementedError | ||
Custom convergence type specified | ||
ValueError | ||
Unknown convergence type specified | ||
""" | ||
self.convergence_strategy = strategy | ||
self.damping_factor = self.convergence_strategy.damping_constant | ||
self.threshold = self.convergence_strategy.threshold | ||
|
||
if self.convergence_strategy.type in ("damped"): | ||
self.converge = self.damped_converge | ||
elif self.convergence_strategy.type in ("custom"): | ||
raise NotImplementedError( | ||
"Convergence strategy type is custom; " | ||
"you need to implement your specific treatment!" | ||
) | ||
else: | ||
raise ValueError( | ||
f"Convergence strategy type is " | ||
f"not damped or custom " | ||
f"- input is {self.convergence_strategy.type}" | ||
) | ||
|
||
def damped_converge(self, value, estimated_value): | ||
"""Damped convergence solver | ||
Parameters | ||
---------- | ||
value : np.float64 | ||
The current value of the physical property | ||
estimated_value : np.float64 | ||
The estimated value of the physical property | ||
Returns | ||
------- | ||
np.float64 | ||
The converged value | ||
""" | ||
return value + self.damping_factor * (estimated_value - value) | ||
|
||
def get_convergence_status(self, value, estimated_value, no_of_cells): | ||
"""Get the status of convergence for the physical property | ||
Parameters | ||
---------- | ||
value : np.float64, Quantity | ||
The current value of the physical property | ||
estimated_value : np.float64, Quantity | ||
The estimated value of the physical property | ||
no_of_cells : np.int64 | ||
The number of cells to measure convergence over | ||
Returns | ||
------- | ||
bool | ||
True if convergence is reached | ||
""" | ||
convergence = abs(value - estimated_value) / estimated_value | ||
|
||
fraction_converged = ( | ||
np.count_nonzero(convergence < self.threshold) / no_of_cells | ||
) | ||
return fraction_converged > self.threshold |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import numpy.testing as npt | ||
import pytest | ||
|
||
from tardis.io.configuration.config_reader import Configuration | ||
from tardis.simulation.convergence import ConvergenceSolver | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def config(example_configuration_dir: Path): | ||
return Configuration.from_yaml( | ||
example_configuration_dir / "tardis_configv1_verysimple.yml" | ||
) | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def strategy(config): | ||
return config.montecarlo.convergence_strategy.t_rad | ||
|
||
|
||
def test_convergence_solver_init_damped(strategy): | ||
solver = ConvergenceSolver(strategy) | ||
assert solver.damping_factor == 0.5 | ||
assert solver.threshold == 0.05 | ||
assert solver.converge == solver.damped_converge | ||
|
||
|
||
def test_convergence_solver_init_custom(strategy): | ||
strategy.type = "custom" | ||
with pytest.raises(NotImplementedError): | ||
ConvergenceSolver(strategy) | ||
|
||
|
||
def test_convergence_solver_init_invalid(strategy): | ||
strategy.type = "invalid" | ||
with pytest.raises(ValueError): | ||
ConvergenceSolver(strategy) | ||
|
||
|
||
def test_damped_converge(strategy): | ||
solver = ConvergenceSolver(strategy) | ||
value = np.float64(10.0) | ||
estimated_value = np.float64(20.0) | ||
converged_value = solver.damped_converge(value, estimated_value) | ||
npt.assert_almost_equal(converged_value, 15.0) | ||
|
||
|
||
def test_get_convergence_status(strategy): | ||
solver = ConvergenceSolver(strategy) | ||
value = np.array([1.0, 2.0, 3.0], dtype=np.float64) | ||
estimated_value = np.array([1.01, 2.02, 3.03], dtype=np.float64) | ||
no_of_cells = np.int64(3) | ||
is_converged = solver.get_convergence_status( | ||
value, estimated_value, no_of_cells | ||
) | ||
assert is_converged | ||
|
||
value = np.array([1.0, 2.0, 3.0], dtype=np.float64) | ||
estimated_value = np.array([2.0, 3.0, 4.0], dtype=np.float64) | ||
is_converged = solver.get_convergence_status( | ||
value, estimated_value, no_of_cells | ||
) | ||
assert not is_converged |