Skip to content

Commit 16b5ee6

Browse files
authored
Merge pull request #386 from mj-will/support-glasflow-flows
Add experimental support for glasflow flows
2 parents b622fd0 + bf20c41 commit 16b5ee6

File tree

7 files changed

+227
-56
lines changed

7 files changed

+227
-56
lines changed

codecov.yml

+2
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ coverage:
44
default:
55
target: 90%
66
threshold: 1%
7+
ignore:
8+
- "nessai/experimental"

nessai/experimental/__init__.py

Whitespace-only changes.

nessai/experimental/flows/__init__.py

Whitespace-only changes.

nessai/experimental/flows/glasflow.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from functools import partial
2+
from glasflow.flows import CouplingNSF, RealNVP
3+
4+
from ...flows.base import BaseFlow
5+
6+
known_flows = {
7+
"nsf": CouplingNSF,
8+
"realnvp": RealNVP,
9+
}
10+
11+
12+
class GlasflowWrapper(BaseFlow):
13+
"""Wrapper for glasflow flow classes"""
14+
15+
def __init__(
16+
self,
17+
FlowClass,
18+
n_inputs,
19+
n_neurons,
20+
n_blocks,
21+
n_layers,
22+
**kwargs,
23+
) -> None:
24+
super().__init__()
25+
26+
n_conditional_inputs = kwargs.pop("context_features", None)
27+
self._flow = FlowClass(
28+
n_inputs=n_inputs,
29+
n_transforms=n_blocks,
30+
n_blocks_per_transform=n_layers,
31+
n_neurons=n_neurons,
32+
n_conditional_inputs=n_conditional_inputs,
33+
**kwargs,
34+
)
35+
36+
def forward(self, x, context=None):
37+
return self._flow.forward(x, conditional=context)
38+
39+
def inverse(self, z, context=None):
40+
return self._flow.inverse(z, conditional=context)
41+
42+
def log_prob(self, x, context=None):
43+
return self._flow.log_prob(x, conditional=context)
44+
45+
def sample(self, n, context=None):
46+
return self._flow.sample(n, conditional=context)
47+
48+
def forward_and_log_prob(self, x, context=None):
49+
return self._flow.forward_and_log_prob(x, conditional=context)
50+
51+
def sample_and_log_prob(self, n, context=None):
52+
return self._flow.sample_and_log_prob(n, conditional=context)
53+
54+
def sample_latent_distribution(self, n, context=None):
55+
if context is not None:
56+
raise ValueError
57+
return self._flow.distribution.sample(n)
58+
59+
def base_distribution_log_prob(self, z, context=None):
60+
if context is not None:
61+
raise ValueError("Context must be None")
62+
return self._flow.base_distribution_log_prob(z)
63+
64+
def freeze_transform(self):
65+
self._flow._transform.requires_grad_(False)
66+
67+
def unfreeze_transform(self):
68+
self._flow._transform.requires_grad_(True)
69+
70+
71+
def get_glasflow_class(name):
72+
"""Get the class for a glasflow flow.
73+
74+
Note: the name must start with :code:`glasflow.`
75+
"""
76+
name = name.lower()
77+
if "glasflow" not in name:
78+
raise ValueError("'glasflow' missing from name")
79+
short_name = name.replace("glasflow-", "")
80+
if short_name not in known_flows:
81+
raise ValueError(f"{name} is not a known glasflow flow")
82+
FlowClass = known_flows.get(short_name)
83+
return partial(GlasflowWrapper, FlowClass)

nessai/flows/utils.py

+41-37
Original file line numberDiff line numberDiff line change
@@ -165,23 +165,45 @@ def get_n_neurons(
165165
return n
166166

167167

168-
def configure_model(config):
169-
"""
170-
Setup the flow form a configuration dictionary.
171-
"""
168+
def get_native_flow_class(name):
169+
"""Get a natively implemented flow class."""
170+
name = name.lower()
172171
from .realnvp import RealNVP
173172
from .maf import MaskedAutoregressiveFlow
174173
from .nsf import NeuralSplineFlow
175-
from ..flowmodel import config as fmconfig
176174

177-
kwargs = {}
178175
flows = {
179176
"realnvp": RealNVP,
180177
"maf": MaskedAutoregressiveFlow,
181178
"frealnvp": RealNVP,
182179
"spline": NeuralSplineFlow,
183180
"nsf": NeuralSplineFlow,
184181
}
182+
if name not in flows:
183+
raise ValueError(f"Unknown flow: {name}")
184+
return flows.get(name)
185+
186+
187+
def get_flow_class(name: str):
188+
"""Get the class to use for the normalizing flow from a string."""
189+
name = name.lower()
190+
if "glasflow" in name:
191+
from ..experimental.flows.glasflow import get_glasflow_class
192+
193+
logger.warning("Using experimental glasflow flow!")
194+
FlowClass = get_glasflow_class(name)
195+
else:
196+
FlowClass = get_native_flow_class(name)
197+
return FlowClass
198+
199+
200+
def configure_model(config):
201+
"""
202+
Setup the flow form a configuration dictionary.
203+
"""
204+
from ..flowmodel import config as fmconfig
205+
206+
kwargs = {}
185207
activations = {"relu": F.relu, "tanh": F.tanh, "swish": silu, "silu": silu}
186208

187209
config = config.copy()
@@ -218,39 +240,21 @@ def configure_model(config):
218240
if distribution:
219241
kwargs["distribution"] = distribution
220242

221-
fc = config.get("flow", None)
222-
ftype = config.get("ftype", None)
223-
if fc is not None:
224-
model = fc(
225-
config["n_inputs"],
226-
config["n_neurons"],
227-
config["n_blocks"],
228-
config["n_layers"],
229-
**kwargs,
230-
)
231-
elif ftype is not None:
232-
if ftype.lower() not in flows:
233-
raise RuntimeError(
234-
f"Unknown flow type: {ftype}. Choose from:" f"{flows.keys()}"
235-
)
236-
if (
237-
("mask" in kwargs and kwargs["mask"] is not None)
238-
or ("net" in kwargs and kwargs["net"] is not None)
239-
) and ftype.lower() not in ["realnvp", "frealnvp"]:
240-
raise RuntimeError(
241-
"Custom masks and networks are only " "supported for RealNVP"
242-
)
243-
244-
model = flows[ftype.lower()](
245-
config["n_inputs"],
246-
config["n_neurons"],
247-
config["n_blocks"],
248-
config["n_layers"],
249-
**kwargs,
250-
)
251-
else:
243+
FlowClass = config.get("flow")
244+
ftype = config.get("ftype")
245+
if FlowClass is None and ftype is None:
252246
raise RuntimeError("Must specify either 'flow' or 'ftype'.")
253247

248+
if FlowClass is None:
249+
FlowClass = get_flow_class(ftype)
250+
model = FlowClass(
251+
config["n_inputs"],
252+
config["n_neurons"],
253+
config["n_blocks"],
254+
config["n_layers"],
255+
**kwargs,
256+
)
257+
254258
device = torch.device(config.get("device_tag", "cpu"))
255259
if device != "cpu":
256260
try:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from nessai.flowmodel import FlowModel
2+
from nessai.experimental.flows.glasflow import (
3+
GlasflowWrapper,
4+
get_glasflow_class,
5+
known_flows,
6+
)
7+
import numpy as np
8+
import pytest
9+
10+
11+
@pytest.mark.parametrize("name", known_flows.keys())
12+
def test_get_glasflow_class(name):
13+
FlowClass = get_glasflow_class(f"glasflow-{name}")
14+
FlowClass(n_inputs=2, n_neurons=4, n_blocks=2, n_layers=1)
15+
16+
17+
def test_get_glasflow_class_missing_prefix():
18+
with pytest.raises(ValueError, match=r"'glasflow' missing from name"):
19+
get_glasflow_class("realnvp")
20+
21+
22+
def test_get_glasflow_class_invalid_flow():
23+
with pytest.raises(
24+
ValueError, match=r"invalid is not a known glasflow flow"
25+
):
26+
get_glasflow_class("glasflow.invalid")
27+
28+
29+
@pytest.mark.integration_test
30+
def test_glasflow_integration(tmp_path):
31+
32+
from glasflow.flows import RealNVP
33+
34+
config = dict(
35+
model_config=dict(
36+
ftype="glasflow-realnvp",
37+
n_inputs=2,
38+
kwargs=None,
39+
)
40+
)
41+
42+
flowmodel = FlowModel(config=config, output=tmp_path / "test")
43+
44+
flowmodel.initialise()
45+
46+
assert isinstance(flowmodel.model, GlasflowWrapper)
47+
assert isinstance(flowmodel.model._flow, RealNVP)
48+
49+
flowmodel.train(np.random.randn(100, 2))

tests/test_flows/test_flow_utils.py

+52-19
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,18 @@
1212
create_linear_transform,
1313
create_pre_transform,
1414
get_base_distribution,
15+
get_flow_class,
1516
get_n_neurons,
17+
get_native_flow_class,
1618
silu,
1719
reset_weights,
1820
reset_permutations,
1921
)
22+
from nessai.flows import (
23+
RealNVP,
24+
NeuralSplineFlow,
25+
MaskedAutoregressiveFlow,
26+
)
2027

2128

2229
@pytest.fixture
@@ -196,6 +203,47 @@ def test_reset_permutation_lu():
196203
lu._initialize.assert_called_once_with(identity_init=True)
197204

198205

206+
@pytest.mark.parametrize(
207+
"name, expected_class",
208+
[
209+
("realnvp", RealNVP),
210+
("frealnvp", RealNVP),
211+
("spline", NeuralSplineFlow),
212+
("nsf", NeuralSplineFlow),
213+
("maf", MaskedAutoregressiveFlow),
214+
],
215+
)
216+
def test_get_native_flow_class(name, expected_class):
217+
assert get_native_flow_class(name) is expected_class
218+
219+
220+
def test_get_native_flow_class_error():
221+
with pytest.raises(ValueError, match=r"Unknown flow: invalid"):
222+
get_native_flow_class("invalid")
223+
224+
225+
def test_get_flow_class_glasflow():
226+
expected = object()
227+
with patch(
228+
"nessai.experimental.flows.glasflow.get_glasflow_class",
229+
return_value=expected,
230+
) as mock_get:
231+
out = get_flow_class("glasflow.realnvp")
232+
mock_get.assert_called_once_with("glasflow.realnvp")
233+
assert out is expected
234+
235+
236+
def test_get_flow_class_native():
237+
expected = object()
238+
with patch(
239+
"nessai.flows.utils.get_native_flow_class",
240+
return_value=expected,
241+
) as mock_get:
242+
out = get_flow_class("realnvp")
243+
mock_get.assert_called_once_with("realnvp")
244+
assert out is expected
245+
246+
199247
def test_configure_model_basic(config):
200248
"""Test configure model with the most basic config."""
201249
config["kwargs"] = dict(num_bins=2)
@@ -211,27 +259,12 @@ def test_configure_model_basic(config):
211259
)
212260

213261

214-
@pytest.mark.parametrize(
215-
"flow_inputs",
216-
[
217-
{"ftype": "realnvp", "expected": "realnvp.RealNVP"},
218-
{"ftype": "frealnvp", "expected": "realnvp.RealNVP"},
219-
{"ftype": "spline", "expected": "nsf.NeuralSplineFlow"},
220-
{"ftype": "nsf", "expected": "nsf.NeuralSplineFlow"},
221-
{"ftype": "maf", "expected": "maf.MaskedAutoregressiveFlow"},
222-
],
223-
)
224-
def test_configure_model_flows(config, flow_inputs):
262+
def test_configure_model_ftype(config):
225263
"""Test the different flows."""
226-
config["ftype"] = flow_inputs["ftype"]
227-
with patch(f"nessai.flows.{flow_inputs['expected']}") as mock_flow:
264+
config["ftype"] = "realnvp"
265+
with patch("nessai.flows.utils.get_native_flow_class") as mock_get:
228266
model, _ = configure_model(config)
229-
mock_flow.assert_called_with(
230-
config["n_inputs"],
231-
config["n_neurons"],
232-
config["n_blocks"],
233-
config["n_layers"],
234-
)
267+
mock_get.assert_called_with("realnvp")
235268

236269

237270
def test_configure_model_flow_class(config):

0 commit comments

Comments
 (0)