Skip to content

Commit 27dcae0

Browse files
committed
Update trainer tests; don't duplicate trainer setup, and also test some more combos via pytest parametrize
1 parent e6b53e4 commit 27dcae0

File tree

1 file changed

+52
-56
lines changed

1 file changed

+52
-56
lines changed

tests/unit/trainer/test_trainer.py

+52-56
Original file line numberDiff line numberDiff line change
@@ -44,36 +44,16 @@ def dummy_builder():
4444
early_stopping_lower_bounds={"LR": 1e-10},
4545
model_builders=[dummy_builder],
4646
)
47-
N_TRAIN_PERCENT = "75%"
48-
N_VAL_PERCENT = "15%"
49-
N_TRAIN_PERCENT_100 = "70%"
50-
N_VAL_PERCENT_100 = "30%"
5147

5248

53-
@pytest.fixture(scope="function")
54-
def trainer(float_tolerance):
55-
"""
56-
Generate a class instance with minimal configurations
57-
"""
58-
conf = minimal_config.copy()
59-
conf["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :]
60-
model = model_from_config(conf)
61-
with tempfile.TemporaryDirectory(prefix="output") as path:
62-
conf["root"] = path
63-
c = Trainer(model=model, **conf)
64-
yield c
65-
66-
67-
@pytest.fixture(scope="function")
68-
def trainer_w_percent_n_train_n_val(float_tolerance):
49+
def create_trainer(float_tolerance, **kwargs):
6950
"""
7051
Generate a class instance with minimal configurations,
71-
where n_train and n_val are given as percentage of the
72-
dataset size.
52+
with the option to modify the configurations using
53+
kwargs.
7354
"""
7455
conf = minimal_config.copy()
75-
conf["n_train"] = N_TRAIN_PERCENT
76-
conf["n_val"] = N_VAL_PERCENT # note that summed percentages don't have to be 100%
56+
conf.update(kwargs)
7757
conf["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :]
7858
model = model_from_config(conf)
7959
with tempfile.TemporaryDirectory(prefix="output") as path:
@@ -83,24 +63,11 @@ def trainer_w_percent_n_train_n_val(float_tolerance):
8363

8464

8565
@pytest.fixture(scope="function")
86-
def trainer_w_percent_n_train_n_val_flooring(float_tolerance):
66+
def trainer(float_tolerance):
8767
"""
88-
Generate a class instance with minimal configurations,
89-
where n_train and n_val are given as percentage of the
90-
dataset size, summing to 100% but with a split that gives
91-
non-integer numbers of frames for n_train and n_val.
92-
(i.e. n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames,
93-
so final n_train is 6 and n_val is 2)
68+
Generate a class instance with minimal configurations.
9469
"""
95-
conf = minimal_config.copy()
96-
conf["n_train"] = N_TRAIN_PERCENT_100
97-
conf["n_val"] = N_VAL_PERCENT_100
98-
conf["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :]
99-
model = model_from_config(conf)
100-
with tempfile.TemporaryDirectory(prefix="output") as path:
101-
conf["root"] = path
102-
c = Trainer(model=model, **conf)
103-
yield c
70+
yield from create_trainer(float_tolerance)
10471

10572

10673
class TestTrainerSetUp:
@@ -203,11 +170,25 @@ def test_split(self, trainer, nequip_dataset, mode):
203170
assert n_samples == trainer.n_train
204171

205172
@pytest.mark.parametrize("mode", ["random", "sequential"])
173+
@pytest.mark.parametrize(
174+
"n_train_percent, n_val_percent", [("75%", "15%"), ("20%", "30%")]
175+
)
206176
def test_split_w_percent_n_train_n_val(
207-
self, trainer_w_percent_n_train_n_val, nequip_dataset, mode
177+
self, nequip_dataset, mode, float_tolerance, n_train_percent, n_val_percent
208178
):
179+
"""
180+
Test case where n_train and n_val are given as percentage of the
181+
dataset size, and here they don't sum to 100%.
182+
"""
209183
# nequip_dataset has 8 frames, so setting n_train to 75% and n_val to 15% should give 6 and 1
210-
# frames respectively
184+
# frames respectively. Note that summed percentages don't have to be 100%
185+
trainer_w_percent_n_train_n_val = next(
186+
create_trainer(
187+
float_tolerance=float_tolerance,
188+
n_train=n_train_percent,
189+
n_val=n_val_percent,
190+
)
191+
)
211192
trainer_w_percent_n_train_n_val.train_val_split = mode
212193
trainer_w_percent_n_train_n_val.set_dataset(nequip_dataset)
213194
for epoch_i in range(3):
@@ -222,29 +203,46 @@ def test_split_w_percent_n_train_n_val(
222203
assert (
223204
n_samples != trainer_w_percent_n_train_n_val.n_train
224205
) # n_train now a percentage
225-
assert trainer_w_percent_n_train_n_val.n_train == N_TRAIN_PERCENT # 75%
206+
assert trainer_w_percent_n_train_n_val.n_train == n_train_percent # 75%
226207
assert n_samples == int(
227-
(float(N_TRAIN_PERCENT.strip("%")) / 100) * len(nequip_dataset)
208+
(float(n_train_percent.strip("%")) / 100) * len(nequip_dataset)
228209
) # 6
229-
assert trainer_w_percent_n_train_n_val.n_val == N_VAL_PERCENT # 15%
210+
assert trainer_w_percent_n_train_n_val.n_val == n_val_percent # 15%
230211

231212
for i, batch in enumerate(trainer_w_percent_n_train_n_val.dl_val):
232213
n_val_samples += batch[AtomicDataDict.BATCH_PTR_KEY].shape[0] - 1
233214

234215
assert (
235216
n_val_samples != trainer_w_percent_n_train_n_val.n_val
236217
) # n_val now a percentage
237-
assert trainer_w_percent_n_train_n_val.n_val == N_VAL_PERCENT # 15%
218+
assert trainer_w_percent_n_train_n_val.n_val == n_val_percent # 15%
238219
assert n_val_samples == int(
239-
(float(N_VAL_PERCENT.strip("%")) / 100) * len(nequip_dataset)
220+
(float(n_val_percent.strip("%")) / 100) * len(nequip_dataset)
240221
) # 1 (floored)
241222

242223
@pytest.mark.parametrize("mode", ["random", "sequential"])
224+
@pytest.mark.parametrize(
225+
"n_train_percent, n_val_percent", [("70%", "30%"), ("55%", "45%")]
226+
)
243227
def test_split_w_percent_n_train_n_val_flooring(
244-
self, trainer_w_percent_n_train_n_val_flooring, nequip_dataset, mode
228+
self, nequip_dataset, mode, float_tolerance, n_train_percent, n_val_percent
245229
):
230+
"""
231+
Test case where n_train and n_val are given as percentage of the
232+
dataset size, summing to 100% but with a split that gives
233+
non-integer numbers of frames for n_train and n_val.
234+
(i.e. n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames,
235+
so final n_train is 6 and n_val is 2)
236+
"""
246237
# nequip_dataset has 8 frames, so n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames,
247238
# so final n_train is 6 and n_val is 2
239+
trainer_w_percent_n_train_n_val_flooring = next(
240+
create_trainer(
241+
float_tolerance=float_tolerance,
242+
n_train=n_train_percent,
243+
n_val=n_val_percent,
244+
)
245+
)
248246
trainer_w_percent_n_train_n_val_flooring.train_val_split = mode
249247
trainer_w_percent_n_train_n_val_flooring.set_dataset(nequip_dataset)
250248
for epoch_i in range(3):
@@ -267,23 +265,21 @@ def test_split_w_percent_n_train_n_val_flooring(
267265
n_samples != trainer_w_percent_n_train_n_val_flooring.n_train
268266
) # n_train now a percentage
269267
assert (
270-
trainer_w_percent_n_train_n_val_flooring.n_train
271-
== N_TRAIN_PERCENT_100
268+
trainer_w_percent_n_train_n_val_flooring.n_train == n_train_percent
272269
) # 70%
273270
# _not_ equal to the bare floored value now:
274271
assert n_samples != int(
275-
(float(N_TRAIN_PERCENT_100.strip("%")) / 100) * len(nequip_dataset)
272+
(float(n_train_percent.strip("%")) / 100) * len(nequip_dataset)
276273
) # 5
277274
assert (
278275
n_samples
279276
== int( # equal to floored value plus 1
280-
(float(N_TRAIN_PERCENT_100.strip("%")) / 100)
281-
* len(nequip_dataset)
277+
(float(n_train_percent.strip("%")) / 100) * len(nequip_dataset)
282278
)
283279
+ 1
284280
) # 6
285281
assert (
286-
trainer_w_percent_n_train_n_val_flooring.n_val == N_VAL_PERCENT_100
282+
trainer_w_percent_n_train_n_val_flooring.n_val == n_val_percent
287283
) # 30%
288284

289285
for i, batch in enumerate(trainer_w_percent_n_train_n_val_flooring.dl_val):
@@ -293,10 +289,10 @@ def test_split_w_percent_n_train_n_val_flooring(
293289
n_val_samples != trainer_w_percent_n_train_n_val_flooring.n_val
294290
) # n_val now a percentage
295291
assert (
296-
trainer_w_percent_n_train_n_val_flooring.n_val == N_VAL_PERCENT_100
292+
trainer_w_percent_n_train_n_val_flooring.n_val == n_val_percent
297293
) # 30%
298294
assert n_val_samples == int(
299-
(float(N_VAL_PERCENT_100.strip("%")) / 100) * len(nequip_dataset)
295+
(float(n_val_percent.strip("%")) / 100) * len(nequip_dataset)
300296
) # 2 (floored)
301297

302298
assert n_samples + n_val_samples == len(nequip_dataset) # 100% coverage

0 commit comments

Comments
 (0)