Skip to content

Commit

Permalink
Avoid unnecessary context creation in HREX, using `BoundPotential::se…
Browse files Browse the repository at this point in the history
…t_params` (#1151)

* Default to 1 ps per hrex iteration

* Avoid context creation using set_params

* Allow initializing barostat with volume scale factor

Expose adaptive_scaling_enabled ctor argument in Python API

* Add getters for integrator, potentials, barostat to context

* Use fixed volume scale factor for HREX RBFE simulations

* Allow passing None for initial_volume_scale_factor

* Remove redundant constructor argument defaults

Prefer defining defaults in Python API

* Fix outdated/misleading docs
  • Loading branch information
mcwitt authored Sep 25, 2023
1 parent 5951327 commit 8492baf
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 201 deletions.
2 changes: 1 addition & 1 deletion tests/hrex/test_hrex_rbfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_hrex_rbfe_hif2a(hif2a_single_topology_leg):
lambda_interval=(0.0, 0.15),
n_windows=n_windows,
n_frames_bisection=100,
n_frames_per_iter=5,
n_frames_per_iter=1,
)

if DEBUG:
Expand Down
123 changes: 30 additions & 93 deletions tests/test_barostat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,7 @@ def test_barostat_validation():

# Invalid interval
with pytest.raises(RuntimeError, match="Barostat interval must be greater than 0"):
custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
[[0, 1]],
-1,
u_impls,
seed,
)
custom_ops.MonteCarloBarostat(coords.shape[0], pressure, temperature, [[0, 1]], -1, u_impls, seed, True, 0.0)

# Atom index over N
with pytest.raises(RuntimeError, match="Grouped indices must be between 0 and N"):
Expand All @@ -57,30 +49,20 @@ def test_barostat_validation():
barostat_interval,
u_impls,
seed,
True,
0.0,
)

# Atom index < 0
with pytest.raises(RuntimeError, match="Grouped indices must be between 0 and N"):
custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
[[-1, 0]],
barostat_interval,
u_impls,
seed,
coords.shape[0], pressure, temperature, [[-1, 0]], barostat_interval, u_impls, seed, True, 0.0
)

# Atom index in two groups
with pytest.raises(RuntimeError, match="All grouped indices must be unique"):
custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
[[0, 1], [1, 2]],
barostat_interval,
u_impls,
seed,
coords.shape[0], pressure, temperature, [[0, 1], [1, 2]], barostat_interval, u_impls, seed, True, 0.0
)


Expand Down Expand Up @@ -125,13 +107,7 @@ def test_barostat_with_clashes():
v_0 = sample_velocities(masses, temperature)

baro = custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
group_indices,
barostat_interval,
u_impls,
seed,
coords.shape[0], pressure, temperature, group_indices, barostat_interval, u_impls, seed, True, 0.0
)

# The clashes will result in overflows, so the box should never change as no move is accepted
Expand Down Expand Up @@ -170,23 +146,11 @@ def test_barostat_zero_interval():

with pytest.raises(RuntimeError):
custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
group_indices,
0,
u_impls,
seed,
coords.shape[0], pressure, temperature, group_indices, 0, u_impls, seed, True, 0.0
)
# Setting it to 1 should be valid.
baro = custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
group_indices,
1,
u_impls,
seed,
coords.shape[0], pressure, temperature, group_indices, 1, u_impls, seed, True, 0.0
)
# Setting back to 0 should raise another error
with pytest.raises(RuntimeError):
Expand Down Expand Up @@ -242,13 +206,7 @@ def test_barostat_partial_group_idxs():
v_0 = sample_velocities(masses, temperature)

baro = custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
group_indices,
barostat_interval,
u_impls,
seed,
coords.shape[0], pressure, temperature, group_indices, barostat_interval, u_impls, seed, True, 0.0
)

ctxt = custom_ops.Context(coords, v_0, complex_box, integrator_impl, u_impls, barostat=baro)
Expand Down Expand Up @@ -303,13 +261,7 @@ def test_barostat_is_deterministic():
v_0 = sample_velocities(masses, temperature)

baro = custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
group_indices,
barostat_interval,
u_impls,
seed,
coords.shape[0], pressure, temperature, group_indices, barostat_interval, u_impls, seed, True, 0.0
)

ctxt = custom_ops.Context(coords, v_0, host_box, integrator.impl(), u_impls, barostat=baro)
Expand All @@ -319,13 +271,7 @@ def test_barostat_is_deterministic():
assert compute_box_volume(atm_box) != compute_box_volume(host_box)

baro = custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
group_indices,
barostat_interval,
u_impls,
seed,
coords.shape[0], pressure, temperature, group_indices, barostat_interval, u_impls, seed, True, 0.0
)
ctxt = custom_ops.Context(coords, v_0, host_box, integrator.impl(), u_impls, barostat=baro)
ctxt.multiple_steps(15)
Expand Down Expand Up @@ -371,13 +317,7 @@ def test_barostat_varying_pressure():
v_0 = sample_velocities(masses, temperature)

baro = custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
group_indices,
barostat_interval,
u_impls,
seed,
coords.shape[0], pressure, temperature, group_indices, barostat_interval, u_impls, seed, True, 0.0
)

ctxt = custom_ops.Context(coords, v_0, complex_box, integrator_impl, u_impls, barostat=baro)
Expand Down Expand Up @@ -436,13 +376,7 @@ def test_barostat_recentering_upon_acceptance():
v_0 = sample_velocities(masses, temperature)

baro = custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
group_indices,
barostat_interval,
u_impls,
seed,
coords.shape[0], pressure, temperature, group_indices, barostat_interval, u_impls, seed, True, 0.0
)
ctxt = custom_ops.Context(coords, v_0, complex_box, integrator_impl, u_impls, barostat=baro)
# mini equilibriate the system to get barostat proposals to be reasonable
Expand Down Expand Up @@ -552,13 +486,7 @@ def test_molecular_ideal_gas():
new_box = complex_box * length_scale

baro = custom_ops.MonteCarloBarostat(
new_coords.shape[0],
pressure,
temperature,
group_indices,
barostat_interval,
u_impls,
seed,
new_coords.shape[0], pressure, temperature, group_indices, barostat_interval, u_impls, seed, True, 0.0
)

ctxt = custom_ops.Context(new_coords, v_0, new_box, integrator_impl, u_impls, barostat=baro)
Expand Down Expand Up @@ -665,13 +593,7 @@ def test_barostat_scaling_behavior():
v_0 = sample_velocities(masses, temperature)

baro = custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
group_indices,
barostat_interval,
u_impls,
seed,
coords.shape[0], pressure, temperature, group_indices, barostat_interval, u_impls, seed, True, 0.0
)
# Initial volume scaling is 0
assert baro.get_volume_scale_factor() == 0.0
Expand Down Expand Up @@ -704,3 +626,18 @@ def test_barostat_scaling_behavior():
assert baro.get_adaptive_scaling()
ctxt.multiple_steps(100)
assert baro.get_volume_scale_factor() != 0.0

# Check that the adaptive_scaling_enabled, initial_volume_scale_factor constructor arguments works as expected
baro = custom_ops.MonteCarloBarostat(
coords.shape[0],
pressure,
temperature,
group_indices,
barostat_interval,
u_impls,
seed,
False,
initial_volume_scale_factor=1.23,
)
assert not baro.get_adaptive_scaling()
assert baro.get_volume_scale_factor() == 1.23
3 changes: 1 addition & 2 deletions tests/test_benchmark_free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def test_benchmark_hif2a_single_topology(hif2a_single_topology_leg, enable_hrex)
run_sims_hrex,
initial_states,
md_params,
n_frames_per_iter=5,
temperature=temperature,
n_frames_per_iter=1,
print_diagnostics_interval=None,
)
else:
Expand Down
6 changes: 4 additions & 2 deletions timemachine/cpp/src/barostat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ MonteCarloBarostat<RealType>::MonteCarloBarostat(
const int interval,
const std::vector<std::shared_ptr<BoundPotential>> bps,
const int seed,
const bool adaptive_scaling_enabled)
const bool adaptive_scaling_enabled,
const double initial_volume_scale_factor)
: N_(N), adaptive_scaling_enabled_(adaptive_scaling_enabled), bps_(bps), pressure_(pressure),
temperature_(temperature), interval_(interval), seed_(seed), group_idxs_(group_idxs), step_(0),
num_grouped_atoms_(0), runner_() {
Expand Down Expand Up @@ -82,7 +83,8 @@ MonteCarloBarostat<RealType>::MonteCarloBarostat(
cudaSafeMalloc(&d_volume_scale_, 1 * sizeof(*d_volume_scale_));
cudaSafeMalloc(&d_volume_delta_, 1 * sizeof(*d_volume_delta_));

gpuErrchk(cudaMemset(d_volume_scale_, 0, 1 * sizeof(*d_volume_scale_)));
gpuErrchk(cudaMemcpy(
d_volume_scale_, &initial_volume_scale_factor, 1 * sizeof(*d_volume_scale_), cudaMemcpyHostToDevice));

cudaSafeMalloc(&d_centroids_, num_mols * 3 * sizeof(*d_centroids_));
cudaSafeMalloc(&d_atom_idxs_, num_grouped_atoms_ * sizeof(*d_atom_idxs_));
Expand Down
3 changes: 2 additions & 1 deletion timemachine/cpp/src/barostat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ template <typename RealType> class MonteCarloBarostat {
const int interval,
std::vector<std::shared_ptr<BoundPotential>> bps,
const int seed,
const bool adapt_volume_scale_factor);
const bool adapt_volume_scale_factor,
const double initial_volume_scale_factor);

~MonteCarloBarostat();

Expand Down
9 changes: 9 additions & 0 deletions timemachine/cpp/src/context.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#include "barostat.hpp"
#include "bound_potential.hpp"
#include "constants.hpp"
#include "context.hpp"

#include "fixed_point.hpp"
#include "flat_bottom_bond.hpp"
#include "gpu_utils.cuh"
#include "integrator.hpp"
#include "kernels/kernel_utils.cuh"
#include "langevin_integrator.hpp"
#include "local_md_potentials.hpp"
Expand Down Expand Up @@ -348,4 +351,10 @@ void Context::get_box(double *out_buffer) const {
gpuErrchk(cudaMemcpy(out_buffer, d_box_t_, 3 * 3 * sizeof(*out_buffer), cudaMemcpyDeviceToHost));
}

std::shared_ptr<Integrator> Context::get_integrator() const { return intg_; }

std::vector<std::shared_ptr<BoundPotential>> Context::get_potentials() const { return bps_; }

std::shared_ptr<MonteCarloBarostat<float>> Context::get_barostat() const { return barostat_; }

} // namespace timemachine
6 changes: 6 additions & 0 deletions timemachine/cpp/src/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ class Context {

void setup_local_md(double temperature, bool freeze_reference);

std::shared_ptr<Integrator> get_integrator() const;

std::vector<std::shared_ptr<BoundPotential>> get_potentials() const;

std::shared_ptr<MonteCarloBarostat<float>> get_barostat() const;

private:
int N_; // number of particles

Expand Down
34 changes: 25 additions & 9 deletions timemachine/cpp/src/wrap_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,17 @@ void declare_context(py::module &m) {
ctxt.get_v_t(buffer.mutable_data());
return buffer;
})
.def("get_box", [](timemachine::Context &ctxt) -> py::array_t<double, py::array::c_style> {
unsigned int D = 3;
py::array_t<double, py::array::c_style> buffer({D, D});
ctxt.get_box(buffer.mutable_data());
return buffer;
});
.def(
"get_box",
[](timemachine::Context &ctxt) -> py::array_t<double, py::array::c_style> {
unsigned int D = 3;
py::array_t<double, py::array::c_style> buffer({D, D});
ctxt.get_box(buffer.mutable_data());
return buffer;
})
.def("get_integrator", &timemachine::Context::get_integrator)
.def("get_potentials", &timemachine::Context::get_potentials)
.def("get_barostat", &timemachine::Context::get_barostat);
}

void declare_integrator(py::module &m) {
Expand Down Expand Up @@ -1209,8 +1214,18 @@ void declare_barostat(py::module &m) {
const int frequency,
std::vector<std::shared_ptr<timemachine::BoundPotential>> bps,
const int seed,
const bool adaptive_scaling_enabled) {
return new Class(N, pressure, temperature, group_idxs, frequency, bps, seed, adaptive_scaling_enabled);
const bool adaptive_scaling_enabled,
const double initial_volume_scale_factor) {
return new Class(
N,
pressure,
temperature,
group_idxs,
frequency,
bps,
seed,
adaptive_scaling_enabled,
initial_volume_scale_factor);
}),
py::arg("N"),
py::arg("pressure"),
Expand All @@ -1219,7 +1234,8 @@ void declare_barostat(py::module &m) {
py::arg("frequency"),
py::arg("bps"),
py::arg("seed"),
py::arg("adaptive_scaling_enabled") = true)
py::arg("adaptive_scaling_enabled"),
py::arg("initial_volume_scale_factor"))
.def("set_interval", &Class::set_interval, py::arg("interval"))
.def("get_interval", &Class::get_interval)
.def("set_volume_scale_factor", &Class::set_volume_scale_factor, py::arg("volume_scale_factor"))
Expand Down
Loading

0 comments on commit 8492baf

Please sign in to comment.