Skip to content

Commit

Permalink
Some suggested refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfullard committed Aug 16, 2024
1 parent e36ccd5 commit 79df0e3
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 49 deletions.
74 changes: 50 additions & 24 deletions tardis/workflows/standard_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
class StandardSimulation(
SimpleSimulation, PlasmaStateStorerMixin, HDFWriterMixin
):
convergence_plots = None
export_convergence_plots = False

hdf_properties = [
"simulation_state",
"plasma_solver",
Expand All @@ -41,6 +44,7 @@ def __init__(
self.log_level = log_level
self.specific_log_level = specific_log_level
self.enable_virtual_packet_logging = enable_virtual_packet_logging
self.convergence_plots_kwargs = convergence_plots_kwargs

SimpleSimulation.__init__(self, configuration)

Expand All @@ -53,31 +57,53 @@ def __init__(

# Convergence plots
if show_convergence_plots:
if not is_notebook():
raise RuntimeError(
"Convergence Plots cannot be displayed in command-line. Set show_convergence_plots "
"to False."
)
(
self.convergence_plots,
self.export_convergence_plots,
) = self.initialize_convergence_plots()

def initialize_convergence_plots(self):
"""Initialize the convergence plot attributes
self.convergence_plots = ConvergencePlots(
iterations=self.total_iterations, **convergence_plots_kwargs
Returns
-------
ConvergencePlots
The convergence plot instance
bool
If convergence plots are to be exported
Raises
------
RuntimeError
Raised if run outside a notebook
TypeError
Raised if export_convergence_plots is not a bool
"""
if not is_notebook():
raise RuntimeError(
"Convergence Plots cannot be displayed in command-line. Set show_convergence_plots "
"to False."
)
else:
self.convergence_plots = None

if "export_convergence_plots" in convergence_plots_kwargs:
convergence_plots = ConvergencePlots(
iterations=self.total_iterations, **self.convergence_plots_kwargs
)

if "export_convergence_plots" in self.convergence_plots_kwargs:
if not isinstance(
convergence_plots_kwargs["export_convergence_plots"],
self.convergence_plots_kwargs["export_convergence_plots"],
bool,
):
raise TypeError(
"Expected bool in export_convergence_plots argument"
)
self.export_convergence_plots = convergence_plots_kwargs[
export_convergence_plots = self.convergence_plots_kwargs[
"export_convergence_plots"
]
else:
self.export_convergence_plots = False
export_convergence_plots = False

return convergence_plots, export_convergence_plots

def get_convergence_estimates(self, transport_state):
"""Compute convergence estimates from the transport state
Expand Down Expand Up @@ -124,6 +150,16 @@ def get_convergence_estimates(self, transport_state):
self.luminosity_nu_end,
)

luminosity_ratios = (
(emitted_luminosity / self.luminosity_requested).to(1).value
)

estimated_t_inner = (
self.simulation_state.t_inner
* luminosity_ratios
** self.convergence_strategy.t_inner_update_exponent
)

if self.convergence_plots is not None:
plot_data = {
"t_inner": [self.simulation_state.t_inner.value, "value"],
Expand All @@ -134,24 +170,14 @@ def get_convergence_estimates(self, transport_state):
"Absorbed": [absorbed_luminosity.value, "value"],
"Requested": [self.luminosity_requested.value, "value"],
}
self.update_convergence_plot_data(plot_data)
self.update_convergence_plot_data(plot_data)

logger.info(
f"\n\tLuminosity emitted = {emitted_luminosity:.3e}\n"
f"\tLuminosity absorbed = {absorbed_luminosity:.3e}\n"
f"\tLuminosity requested = {self.luminosity_requested:.3e}\n"
)

luminosity_ratios = (
(emitted_luminosity / self.luminosity_requested).to(1).value
)

estimated_t_inner = (
self.simulation_state.t_inner
* luminosity_ratios
** self.convergence_strategy.t_inner_update_exponent
)

self.log_plasma_state(
self.simulation_state.t_radiative,
self.simulation_state.dilution_factor,
Expand Down
66 changes: 41 additions & 25 deletions tardis/workflows/workflow_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,34 +58,50 @@ def log_plasma_state(
plasma_state_log["next_w"] = next_dilution_factor
plasma_state_log.columns.name = "Shell No."

logger.info("\n\tPlasma stratification:")

if is_notebook():
logger.info("\n\tPlasma stratification:")

# Displaying the DataFrame only when the logging level is NOTSET, DEBUG or INFO
if logger.level <= logging.INFO:
if not logger.filters:
display(
plasma_state_log.iloc[::log_sampling].style.format(
"{:.3g}"
)
)
elif logger.filters[0].log_level == 20:
display(
plasma_state_log.iloc[::log_sampling].style.format(
"{:.3g}"
)
)
self.log_dataframe_notebook(plasma_state_log, log_sampling)
else:
output_df = ""
plasma_output = plasma_state_log.iloc[::log_sampling].to_string(
float_format=lambda x: f"{x:.3g}",
justify="center",
)
for value in plasma_output.split("\n"):
output_df = output_df + f"\t{value}\n"
logger.info("\n\tPlasma stratification:")
logger.info(f"\n{output_df}")
self.log_dataframe_console(plasma_state_log, log_sampling)

logger.info(
f"\n\tCurrent t_inner = {t_inner:.3f}\n\tExpected t_inner for next iteration = {next_t_inner:.3f}\n"
)

def log_dataframe_notebook(self, dataframe, step):
"""Logs a dataframe in a notebook with a step sample
Parameters
----------
dataframe : pd.DataFrame
Dataframe to display
step : int
Step to use when sampling the dataframe
"""
# Displaying the DataFrame only when the logging level is NOTSET, DEBUG or INFO
if logger.level <= logging.INFO:
if not logger.filters:
display(dataframe.iloc[::step].style.format("{:.3g}"))
elif logger.filters[0].log_level == 20:
display(dataframe.iloc[::step].style.format("{:.3g}"))

def log_dataframe_console(self, dataframe, step):
"""Logs a dataframe to console with a step sample
Parameters
----------
dataframe : pd.DataFrame
Dataframe to display
step : int
Step to use when sampling the dataframe
"""
output_df = ""
output = dataframe.iloc[::step].to_string(
float_format=lambda x: f"{x:.3g}",
justify="center",
)
for value in output.split("\n"):
output_df = output_df + f"\t{value}\n"

logger.info(f"\n{output_df}")

0 comments on commit 79df0e3

Please sign in to comment.