Skip to content

Commit

Permalink
Montecarlo globals (limited scope) (tardis-sn#2705)
Browse files Browse the repository at this point in the history
* Initial restructure of configs

* Continuum processes to global

* Fixes mistakes

* RPacket tracking global

* Split tests for rpacket tracking, move codecov detection for now
  • Loading branch information
andrewfullard authored and sarthak-dv committed Jul 16, 2024
1 parent da17151 commit 75cddc2
Show file tree
Hide file tree
Showing 33 changed files with 83 additions and 137 deletions.
11 changes: 8 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,22 @@ jobs:
echo "TARDIS_PIP_PATH=$directory_path" >> $GITHUB_ENV
- name: Run tests
run: pytest tardis ${{ env.PYTEST_FLAGS }} -m "not continuum"
run: pytest tardis ${{ env.PYTEST_FLAGS }} -m "not (continuum or rpacket_tracking)"
working-directory: ${{ env.TARDIS_PIP_PATH }}
if: always()

- name: Upload to Codecov
run: bash <(curl -s https://codecov.io/bash)

- name: Run continuum tests
run: pytest tardis ${{ env.PYTEST_FLAGS }} -m continuum
working-directory: ${{ env.TARDIS_PIP_PATH }}
if: always()

- name: Upload to Codecov
run: bash <(curl -s https://codecov.io/bash)
- name: Run rpacket tracking tests
run: pytest tardis ${{ env.PYTEST_FLAGS }} -m rpacket_tracking
working-directory: ${{ env.TARDIS_PIP_PATH }}
if: always()

- name: Refdata Generation tests
run: pytest tardis ${{ env.PYTEST_FLAGS }} --generate-reference
Expand Down
32 changes: 12 additions & 20 deletions benchmarks/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from tardis.simulation import Simulation
from tardis.tests.fixtures.atom_data import DEFAULT_ATOM_DATA_UUID
from tardis.tests.fixtures.regression_data import RegressionData
from tardis.transport.montecarlo import RPacket, montecarlo_configuration
from tardis.transport.montecarlo import RPacket
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.configuration.base import (
MonteCarloConfiguration,
)
from tardis.transport.montecarlo.estimators import radfield_mc_estimators
from tardis.transport.montecarlo.numba_interface import opacity_state_initialize
from tardis.transport.montecarlo.packet_collections import (
Expand Down Expand Up @@ -235,9 +239,7 @@ def packet(self):

@property
def verysimple_packet_collection(self):
return (
self.nb_simulation_verysimple.transport.transport_state.packet_collection
)
return self.nb_simulation_verysimple.transport.transport_state.packet_collection

@property
def nb_simulation_verysimple(self):
Expand All @@ -259,7 +261,6 @@ def verysimple_opacity_state(self):
self.nb_simulation_verysimple.plasma,
line_interaction_type="macroatom",
disable_line_scattering=self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING,
continuum_processes_enabled=self.nb_simulation_verysimple.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
)

@property
Expand All @@ -268,27 +269,19 @@ def verysimple_enable_full_relativity(self):

@property
def verysimple_disable_line_scattering(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING
)
return self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING

@property
def verysimple_continuum_processes_enabled(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED
)
return montecarlo_globals.CONTINUUM_PROCESSES_ENABLED

@property
def verysimple_tau_russian(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN
)
return self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN

@property
def verysimple_survival_probability(self):
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY
)
return self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY

@property
def static_packet(self):
Expand Down Expand Up @@ -359,10 +352,10 @@ def verysimple_radfield_mc_estimators(self):

@property
def montecarlo_configuration(self):
return montecarlo_configuration.MonteCarloConfiguration()
return MonteCarloConfiguration()

@property
def rpacket_tracker(self):
def rpacket_tracker(self):
return RPacketTracker(0)

@property
Expand Down Expand Up @@ -396,7 +389,6 @@ def geometry(self):
v_outer=np.array([-1, -1], dtype=np.float64),
)


@property
def estimators(self):
return radfield_mc_estimators.RadiationFieldMCEstimators(
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/transport_montecarlo_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from asv_runner.benchmarks.mark import parameterize



class BenchmarkMontecarloMontecarloNumbaInteraction(BenchmarkBase):
"""
Class to benchmark the numba interaction function.
Expand Down Expand Up @@ -52,7 +51,6 @@ def time_line_scatter(self, line_interaction_type):
line_interaction_type,
self.verysimple_opacity_state,
self.verysimple_enable_full_relativity,
self.verysimple_continuum_processes_enabled,
)

@parameterize(
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/transport_montecarlo_numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,4 @@ def time_opacity_state_initialize(self, input_params):
plasma,
line_interaction_type,
self.verysimple_disable_line_scattering,
self.verysimple_continuum_processes_enabled,
)
)
27 changes: 4 additions & 23 deletions benchmarks/transport_montecarlo_vpacket.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def v_packet(self):
next_line_id=0,
index=0,
)

@property
def r_packet(self):
return RPacket(
Expand Down Expand Up @@ -62,9 +62,6 @@ def time_trace_vpacket_within_shell(self):
verysimple_time_explosion = self.verysimple_time_explosion
verysimple_opacity_state = self.verysimple_opacity_state
enable_full_relativity = self.verysimple_enable_full_relativity
continuum_processes_enabled = (
self.verysimple_continuum_processes_enabled
)

# Give the vpacket a reasonable line ID
self.v_packet_initialize_line_id(
Expand All @@ -80,7 +77,6 @@ def time_trace_vpacket_within_shell(self):
verysimple_time_explosion,
verysimple_opacity_state,
enable_full_relativity,
continuum_processes_enabled,
)

def time_trace_vpacket(self):
Expand All @@ -91,9 +87,6 @@ def time_trace_vpacket(self):
verysimple_time_explosion = self.verysimple_time_explosion
verysimple_opacity_state = self.verysimple_opacity_state
enable_full_relativity = self.verysimple_enable_full_relativity
continuum_processes_enabled = (
self.verysimple_continuum_processes_enabled
)
tau_russian = self.verysimple_tau_russian
survival_probability = self.verysimple_survival_probability

Expand All @@ -116,7 +109,6 @@ def time_trace_vpacket(self):
tau_russian,
survival_probability,
enable_full_relativity,
continuum_processes_enabled,
)

@property
Expand All @@ -139,9 +131,6 @@ def time_trace_bad_vpacket(self):
enable_full_relativity = self.verysimple_enable_full_relativity
verysimple_time_explosion = self.verysimple_time_explosion
verysimple_opacity_state = self.verysimple_opacity_state
continuum_processes_enabled = (
self.verysimple_continuum_processes_enabled
)
tau_russian = self.verysimple_tau_russian
survival_probability = self.verysimple_survival_probability

Expand All @@ -153,20 +142,13 @@ def time_trace_bad_vpacket(self):
tau_russian,
survival_probability,
enable_full_relativity,
continuum_processes_enabled,
)

@parameterize(
{
"Paramters": [
{
"tau_russian": 10.0,
"survival_possibility": 0.0
},
{
"tau_russian": 15.0,
"survival_possibility": 0.1
},
{"tau_russian": 10.0, "survival_possibility": 0.0},
{"tau_russian": 15.0, "survival_possibility": 0.1},
]
}
)
Expand All @@ -180,6 +162,5 @@ def time_trace_vpacket_volley(self, parameters):
False,
parameters["tau_russian"],
parameters["survival_possibility"],
False
False,
)

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ text_file_format = "rst"
markers = [
# continuum tests
"continuum",
# rpacket tracking tests
"rpacket_tracking"
]

[tool.tardis]
Expand Down
2 changes: 1 addition & 1 deletion tardis/opacities/opacities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tardis.transport.montecarlo import (
njit_dict_no_parallel,
)
from tardis.transport.montecarlo.numba_config import (
from tardis.transport.montecarlo.configuration.constants import (
SIGMA_THOMSON,
)

Expand Down
6 changes: 3 additions & 3 deletions tardis/opacities/opacity_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numba.experimental import jitclass

from tardis.opacities.tau_sobolev import calculate_sobolev_line_opacity
from tardis.transport.montecarlo.configuration import montecarlo_globals

opacity_state_spec = [
("electron_density", float64[:]),
Expand Down Expand Up @@ -110,7 +111,6 @@ def opacity_state_initialize(
plasma,
line_interaction_type,
disable_line_scattering,
continuum_processes_enabled,
):
"""
Initialize the OpacityState object and copy over the data over from TARDIS Plasma
Expand Down Expand Up @@ -156,7 +156,7 @@ def opacity_state_initialize(
)
# TODO: Fix setting of block references for non-continuum mode

if continuum_processes_enabled:
if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED:
macro_block_references = plasma.macro_block_references
else:
macro_block_references = plasma.atomic_data.macro_atom_references[
Expand All @@ -169,7 +169,7 @@ def opacity_state_initialize(
"destination_level_idx"
].values
transition_line_id = plasma.macro_atom_data["lines_idx"].values
if continuum_processes_enabled:
if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED:
bf_threshold_list_nu = plasma.nu_i.loc[
plasma.level2continuum_idx.index
].values
Expand Down
3 changes: 2 additions & 1 deletion tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tardis.plasma.standard_plasmas import assemble_plasma
from tardis.simulation.convergence import ConvergenceSolver
from tardis.transport.montecarlo.base import MonteCarloTransportSolver
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.util.base import is_notebook
from tardis.visualization import ConvergencePlots

Expand Down Expand Up @@ -199,7 +200,7 @@ def __init__(
self._callbacks = OrderedDict()
self._cb_next_id = 0

self.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED = (
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED = (
not self.plasma.continuum_interaction_species.empty
)

Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/frame_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
njit_dict_no_parallel,
)

from tardis.transport.montecarlo.numba_config import C_SPEED_OF_LIGHT
from tardis.transport.montecarlo.configuration.constants import C_SPEED_OF_LIGHT


@njit(**njit_dict_no_parallel)
Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/geometry/calculate_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
njit_dict_no_parallel,
)

from tardis.transport.montecarlo.numba_config import (
from tardis.transport.montecarlo.configuration.constants import (
C_SPEED_OF_LIGHT,
MISS_DISTANCE,
SIGMA_THOMSON,
Expand Down
18 changes: 9 additions & 9 deletions tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
from astropy import units as u
from numba import cuda, set_num_threads

import tardis.transport.montecarlo.configuration.constants as constants
from tardis import constants as const
from tardis.io.logger import montecarlo_tracking as mc_tracker
from tardis.io.util import HDFWriterMixin
from tardis.transport.montecarlo import (
montecarlo_main_loop,
numba_config,
)
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.configuration.base import (
MonteCarloConfiguration,
configuration_initialize,
)
from tardis.transport.montecarlo.estimators.radfield_mc_estimators import (
initialize_estimator_statistics,
)
from tardis.transport.montecarlo.formal_integral import FormalIntegrator
from tardis.transport.montecarlo.montecarlo_configuration import (
MonteCarloConfiguration,
configuration_initialize,
)
from tardis.transport.montecarlo.montecarlo_transport_state import (
MonteCarloTransportState,
)
Expand Down Expand Up @@ -116,7 +117,6 @@ def initialize_transport_state(
plasma,
self.line_interaction_type,
self.montecarlo_configuration.DISABLE_LINE_SCATTERING,
self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
)
transport_state = MonteCarloTransportState(
packet_collection,
Expand Down Expand Up @@ -211,7 +211,7 @@ def run(
update_iterations_pbar(1)
refresh_packet_pbar()
# Condition for Checking if RPacket Tracking is enabled
if self.montecarlo_configuration.ENABLE_RPACKET_TRACKING:
if self.enable_rpacket_tracking:
transport_state.rpacket_tracker = rpacket_trackers

if self.transport_state.rpacket_tracker is not None:
Expand Down Expand Up @@ -246,10 +246,10 @@ def from_config(
"Likely bug in formal integral - "
"will not give same results."
)
numba_config.SIGMA_THOMSON = 1e-200
constants.SIGMA_THOMSON = 1e-200
else:
logger.debug("Electron scattering switched on")
numba_config.SIGMA_THOMSON = const.sigma_T.to("cm^2").value
constants.SIGMA_THOMSON = const.sigma_T.to("cm^2").value

spectrum_frequency = quantity_linspace(
config.spectrum.stop.to("Hz", u.spectral()),
Expand Down
Empty file.
Loading

0 comments on commit 75cddc2

Please sign in to comment.