@@ -44,36 +44,16 @@ def dummy_builder():
44
44
early_stopping_lower_bounds = {"LR" : 1e-10 },
45
45
model_builders = [dummy_builder ],
46
46
)
47
- N_TRAIN_PERCENT = "75%"
48
- N_VAL_PERCENT = "15%"
49
- N_TRAIN_PERCENT_100 = "70%"
50
- N_VAL_PERCENT_100 = "30%"
51
47
52
48
53
- @pytest .fixture (scope = "function" )
54
- def trainer (float_tolerance ):
55
- """
56
- Generate a class instance with minimal configurations
57
- """
58
- conf = minimal_config .copy ()
59
- conf ["default_dtype" ] = str (torch .get_default_dtype ())[len ("torch." ) :]
60
- model = model_from_config (conf )
61
- with tempfile .TemporaryDirectory (prefix = "output" ) as path :
62
- conf ["root" ] = path
63
- c = Trainer (model = model , ** conf )
64
- yield c
65
-
66
-
67
- @pytest .fixture (scope = "function" )
68
- def trainer_w_percent_n_train_n_val (float_tolerance ):
49
+ def create_trainer (float_tolerance , ** kwargs ):
69
50
"""
70
51
Generate a class instance with minimal configurations,
71
- where n_train and n_val are given as percentage of the
72
- dataset size .
52
+ with the option to modify the configurations using
53
+ kwargs .
73
54
"""
74
55
conf = minimal_config .copy ()
75
- conf ["n_train" ] = N_TRAIN_PERCENT
76
- conf ["n_val" ] = N_VAL_PERCENT # note that summed percentages don't have to be 100%
56
+ conf .update (kwargs )
77
57
conf ["default_dtype" ] = str (torch .get_default_dtype ())[len ("torch." ) :]
78
58
model = model_from_config (conf )
79
59
with tempfile .TemporaryDirectory (prefix = "output" ) as path :
@@ -83,24 +63,11 @@ def trainer_w_percent_n_train_n_val(float_tolerance):
83
63
84
64
85
65
@pytest .fixture (scope = "function" )
86
- def trainer_w_percent_n_train_n_val_flooring (float_tolerance ):
66
+ def trainer (float_tolerance ):
87
67
"""
88
- Generate a class instance with minimal configurations,
89
- where n_train and n_val are given as percentage of the
90
- dataset size, summing to 100% but with a split that gives
91
- non-integer numbers of frames for n_train and n_val.
92
- (i.e. n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames,
93
- so final n_train is 6 and n_val is 2)
68
+ Generate a class instance with minimal configurations.
94
69
"""
95
- conf = minimal_config .copy ()
96
- conf ["n_train" ] = N_TRAIN_PERCENT_100
97
- conf ["n_val" ] = N_VAL_PERCENT_100
98
- conf ["default_dtype" ] = str (torch .get_default_dtype ())[len ("torch." ) :]
99
- model = model_from_config (conf )
100
- with tempfile .TemporaryDirectory (prefix = "output" ) as path :
101
- conf ["root" ] = path
102
- c = Trainer (model = model , ** conf )
103
- yield c
70
+ yield from create_trainer (float_tolerance )
104
71
105
72
106
73
class TestTrainerSetUp :
@@ -203,11 +170,25 @@ def test_split(self, trainer, nequip_dataset, mode):
203
170
assert n_samples == trainer .n_train
204
171
205
172
@pytest .mark .parametrize ("mode" , ["random" , "sequential" ])
173
+ @pytest .mark .parametrize (
174
+ "n_train_percent, n_val_percent" , [("75%" , "15%" ), ("20%" , "30%" )]
175
+ )
206
176
def test_split_w_percent_n_train_n_val (
207
- self , trainer_w_percent_n_train_n_val , nequip_dataset , mode
177
+ self , nequip_dataset , mode , float_tolerance , n_train_percent , n_val_percent
208
178
):
179
+ """
180
+ Test case where n_train and n_val are given as percentage of the
181
+ dataset size, and here they don't sum to 100%.
182
+ """
209
183
# nequip_dataset has 8 frames, so setting n_train to 75% and n_val to 15% should give 6 and 1
210
- # frames respectively
184
+ # frames respectively. Note that summed percentages don't have to be 100%
185
+ trainer_w_percent_n_train_n_val = next (
186
+ create_trainer (
187
+ float_tolerance = float_tolerance ,
188
+ n_train = n_train_percent ,
189
+ n_val = n_val_percent ,
190
+ )
191
+ )
211
192
trainer_w_percent_n_train_n_val .train_val_split = mode
212
193
trainer_w_percent_n_train_n_val .set_dataset (nequip_dataset )
213
194
for epoch_i in range (3 ):
@@ -222,29 +203,46 @@ def test_split_w_percent_n_train_n_val(
222
203
assert (
223
204
n_samples != trainer_w_percent_n_train_n_val .n_train
224
205
) # n_train now a percentage
225
- assert trainer_w_percent_n_train_n_val .n_train == N_TRAIN_PERCENT # 75%
206
+ assert trainer_w_percent_n_train_n_val .n_train == n_train_percent # 75%
226
207
assert n_samples == int (
227
- (float (N_TRAIN_PERCENT .strip ("%" )) / 100 ) * len (nequip_dataset )
208
+ (float (n_train_percent .strip ("%" )) / 100 ) * len (nequip_dataset )
228
209
) # 6
229
- assert trainer_w_percent_n_train_n_val .n_val == N_VAL_PERCENT # 15%
210
+ assert trainer_w_percent_n_train_n_val .n_val == n_val_percent # 15%
230
211
231
212
for i , batch in enumerate (trainer_w_percent_n_train_n_val .dl_val ):
232
213
n_val_samples += batch [AtomicDataDict .BATCH_PTR_KEY ].shape [0 ] - 1
233
214
234
215
assert (
235
216
n_val_samples != trainer_w_percent_n_train_n_val .n_val
236
217
) # n_val now a percentage
237
- assert trainer_w_percent_n_train_n_val .n_val == N_VAL_PERCENT # 15%
218
+ assert trainer_w_percent_n_train_n_val .n_val == n_val_percent # 15%
238
219
assert n_val_samples == int (
239
- (float (N_VAL_PERCENT .strip ("%" )) / 100 ) * len (nequip_dataset )
220
+ (float (n_val_percent .strip ("%" )) / 100 ) * len (nequip_dataset )
240
221
) # 1 (floored)
241
222
242
223
@pytest .mark .parametrize ("mode" , ["random" , "sequential" ])
224
+ @pytest .mark .parametrize (
225
+ "n_train_percent, n_val_percent" , [("70%" , "30%" ), ("55%" , "45%" )]
226
+ )
243
227
def test_split_w_percent_n_train_n_val_flooring (
244
- self , trainer_w_percent_n_train_n_val_flooring , nequip_dataset , mode
228
+ self , nequip_dataset , mode , float_tolerance , n_train_percent , n_val_percent
245
229
):
230
+ """
231
+ Test case where n_train and n_val are given as percentage of the
232
+ dataset size, summing to 100% but with a split that gives
233
+ non-integer numbers of frames for n_train and n_val.
234
+ (i.e. n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames,
235
+ so final n_train is 6 and n_val is 2)
236
+ """
246
237
# nequip_dataset has 8 frames, so n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames,
247
238
# so final n_train is 6 and n_val is 2
239
+ trainer_w_percent_n_train_n_val_flooring = next (
240
+ create_trainer (
241
+ float_tolerance = float_tolerance ,
242
+ n_train = n_train_percent ,
243
+ n_val = n_val_percent ,
244
+ )
245
+ )
248
246
trainer_w_percent_n_train_n_val_flooring .train_val_split = mode
249
247
trainer_w_percent_n_train_n_val_flooring .set_dataset (nequip_dataset )
250
248
for epoch_i in range (3 ):
@@ -267,23 +265,21 @@ def test_split_w_percent_n_train_n_val_flooring(
267
265
n_samples != trainer_w_percent_n_train_n_val_flooring .n_train
268
266
) # n_train now a percentage
269
267
assert (
270
- trainer_w_percent_n_train_n_val_flooring .n_train
271
- == N_TRAIN_PERCENT_100
268
+ trainer_w_percent_n_train_n_val_flooring .n_train == n_train_percent
272
269
) # 70%
273
270
# _not_ equal to the bare floored value now:
274
271
assert n_samples != int (
275
- (float (N_TRAIN_PERCENT_100 .strip ("%" )) / 100 ) * len (nequip_dataset )
272
+ (float (n_train_percent .strip ("%" )) / 100 ) * len (nequip_dataset )
276
273
) # 5
277
274
assert (
278
275
n_samples
279
276
== int ( # equal to floored value plus 1
280
- (float (N_TRAIN_PERCENT_100 .strip ("%" )) / 100 )
281
- * len (nequip_dataset )
277
+ (float (n_train_percent .strip ("%" )) / 100 ) * len (nequip_dataset )
282
278
)
283
279
+ 1
284
280
) # 6
285
281
assert (
286
- trainer_w_percent_n_train_n_val_flooring .n_val == N_VAL_PERCENT_100
282
+ trainer_w_percent_n_train_n_val_flooring .n_val == n_val_percent
287
283
) # 30%
288
284
289
285
for i , batch in enumerate (trainer_w_percent_n_train_n_val_flooring .dl_val ):
@@ -293,10 +289,10 @@ def test_split_w_percent_n_train_n_val_flooring(
293
289
n_val_samples != trainer_w_percent_n_train_n_val_flooring .n_val
294
290
) # n_val now a percentage
295
291
assert (
296
- trainer_w_percent_n_train_n_val_flooring .n_val == N_VAL_PERCENT_100
292
+ trainer_w_percent_n_train_n_val_flooring .n_val == n_val_percent
297
293
) # 30%
298
294
assert n_val_samples == int (
299
- (float (N_VAL_PERCENT_100 .strip ("%" )) / 100 ) * len (nequip_dataset )
295
+ (float (n_val_percent .strip ("%" )) / 100 ) * len (nequip_dataset )
300
296
) # 2 (floored)
301
297
302
298
assert n_samples + n_val_samples == len (nequip_dataset ) # 100% coverage
0 commit comments