Skip to content

Commit 0c5ead3

Browse files
committed
Some minor changes after first round of own review
1 parent bd50a5d commit 0c5ead3

File tree

4 files changed

+24
-20
lines changed

4 files changed

+24
-20
lines changed

src/optimagic/optimization/history.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _get_next_batch_id(self) -> int:
9393
return batch
9494

9595
# ==================================================================================
96-
# Properties to access the history
96+
# Properties and methods to access the history
9797
# ==================================================================================
9898

9999
# Function data, function value, and monotone function value
@@ -243,10 +243,11 @@ def _get_time_per_task(
243243
244244
"""
245245
dummy_task = np.array([1 if t == task else 0 for t in self.task])
246+
factor: float | NDArray[np.float64]
246247
if cost_factor is None:
247-
factor: float | NDArray[np.float64] = np.array(
248-
self.stop_time, dtype=np.float64
249-
) - np.array(self.start_time, dtype=np.float64)
248+
factor = np.array(self.stop_time, dtype=np.float64) - np.array(
249+
self.start_time, dtype=np.float64
250+
)
250251
else:
251252
factor = cost_factor
252253

@@ -342,16 +343,16 @@ def _calculate_monotone_sequence(
342343
sequence: list[float | None], direction: Direction
343344
) -> NDArray[np.float64]:
344345
sequence_arr = np.array(sequence, dtype=np.float64) # converts None to nan
345-
none_mask = np.isnan(sequence_arr)
346+
nan_mask = np.isnan(sequence_arr)
346347

347348
if direction == Direction.MINIMIZE:
348-
sequence_arr[none_mask] = np.inf
349+
sequence_arr[nan_mask] = np.inf
349350
out = np.minimum.accumulate(sequence_arr)
350351
elif direction == Direction.MAXIMIZE:
351-
sequence_arr[none_mask] = -np.inf
352+
sequence_arr[nan_mask] = -np.inf
352353
out = np.maximum.accumulate(sequence_arr)
353354

354-
out[none_mask] = np.nan
355+
out[nan_mask] = np.nan
355356
return out
356357

357358

@@ -404,7 +405,7 @@ def _apply_to_batch(
404405
405406
"""
406407
batch_starts = _get_batch_start(batch_ids)
407-
batch_stops = [*batch_starts, len(data)][1:]
408+
batch_stops = [*batch_starts[1:], len(data)]
408409

409410
batch_results = []
410411
for batch, (start, stop) in zip(

src/optimagic/optimization/process_results.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def process_multistart_result(
112112
direction=extra_fields.direction,
113113
fun=[opt.fun for opt in info.local_optima],
114114
params=[opt.params for opt in info.local_optima],
115-
start_time=[np.nan for _ in info.local_optima],
116-
stop_time=[np.nan for _ in info.local_optima],
115+
start_time=len(info.local_optima) * [np.nan],
116+
stop_time=len(info.local_optima) * [np.nan],
117117
batches=list(range(len(info.local_optima))),
118118
task=len(info.local_optima) * [EvalTask.FUN],
119119
)

src/optimagic/visualization/history_plots.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,11 @@ def _extract_plotting_data_from_database(res, stack_multistart, show_exploration
407407
fun=_history["fun"],
408408
params=_history["params"],
409409
start_time=_history["time"],
410-
# TODO: This needs to be updated
410+
# TODO (@janosg): Retrieve that information from `hist` once it is available.
411+
# https://github.com/optimagic-dev/optimagic/pull/553
411412
stop_time=len(_history["fun"]) * [None],
412-
batches=len(_history["fun"]) * [None],
413413
task=len(_history["fun"]) * [None],
414+
batches=list(range(len(_history["fun"]))),
414415
)
415416

416417
data = {
@@ -449,8 +450,10 @@ def _get_stacked_local_histories(local_histories, direction, history=None):
449450
fun=stacked["criterion"],
450451
params=stacked["params"],
451452
start_time=stacked["runtime"],
452-
# TODO: This needs to be fixed
453+
# TODO (@janosg): Retrieve that information from `hist` once it is available
454+
# for the IterationHistory.
455+
# https://github.com/optimagic-dev/optimagic/pull/553
453456
stop_time=len(stacked["criterion"]) * [None],
454457
task=len(stacked["criterion"]) * [None],
455-
batches=len(stacked["criterion"]) * [None],
458+
batches=list(range(len(stacked["criterion"]))),
456459
)

tests/optimagic/optimization/test_history.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def history(history_data):
169169

170170

171171
@pytest.fixture
172-
def history_with_batch_data(history_data):
172+
def history_parallel(history_data):
173173
data = history_data.copy()
174174
data["batches"] = [0, 0, 1, 1, 2, 2]
175175
return History(direction=Direction.MINIMIZE, **data)
@@ -223,8 +223,8 @@ def test_history_fun_data_with_fun_evaluations_cost_model_and_monotone(history):
223223
assert_frame_equal(got, exp, check_dtype=False, check_categorical=False)
224224

225225

226-
def test_history_fun_data_with_fun_batches_cost_model(history_with_batch_data):
227-
got = history_with_batch_data.fun_data(
226+
def test_history_fun_data_with_fun_batches_cost_model(history_parallel):
227+
got = history_parallel.fun_data(
228228
cost_model=om.timing.fun_batches,
229229
monotone=False,
230230
)
@@ -385,8 +385,8 @@ def test_get_time_fun_batches(history):
385385
assert_array_equal(got, exp)
386386

387387

388-
def test_get_time_fun_batches_with_batch_data(history_with_batch_data):
389-
got = history_with_batch_data._get_time(cost_model=om.timing.fun_batches)
388+
def test_get_time_fun_batches_parallel(history_parallel):
389+
got = history_parallel._get_time(cost_model=om.timing.fun_batches)
390390
exp = np.array([1, 1, 2, 2, 3, 3])
391391
assert_array_equal(got, exp)
392392

0 commit comments

Comments
 (0)