Skip to content

Commit da279ac

Browse files
committed
Handle batch case
1 parent c86421c commit da279ac

File tree

3 files changed

+161
-19
lines changed

3 files changed

+161
-19
lines changed

src/optimagic/optimization/history.py

+88-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22
from dataclasses import dataclass
33
from functools import partial
4-
from typing import Any, Literal
4+
from typing import Any, Callable, Iterable, Literal
55

66
import numpy as np
77
import pandas as pd
@@ -192,6 +192,16 @@ def flat_param_names(self) -> list[str]:
192192
def _get_time(
193193
self, cost_model: CostModel | Literal["wall_time"]
194194
) -> NDArray[np.float64]:
195+
"""Return the cumulative time measure.
196+
197+
Args:
198+
cost_model: The cost model that is used to calculate the time measure. If
199+
"wall_time", the wall time is returned.
200+
201+
Returns:
202+
np.ndarray: The time measure.
203+
204+
"""
195205
if not isinstance(cost_model, CostModel) and cost_model != "wall_time":
196206
raise ValueError("cost_model must be a CostModel or 'wall_time'.")
197207

@@ -207,11 +217,31 @@ def _get_time(
207217
fun_and_jac_time = self._get_time_per_task(
208218
task=EvalTask.FUN_AND_JAC, cost_factor=cost_model.fun_and_jac
209219
)
210-
return fun_time + jac_time + fun_and_jac_time
220+
221+
time = fun_time + jac_time + fun_and_jac_time
222+
batch_time = _batch_apply(
223+
data=time,
224+
batch_ids=self.batches,
225+
func=cost_model.aggregate_batch_time,
226+
)
227+
return np.cumsum(batch_time)
211228

212229
def _get_time_per_task(
213230
self, task: EvalTask, cost_factor: float | None
214231
) -> NDArray[np.float64]:
232+
"""Return the time measure per task.
233+
234+
Args:
235+
task: The task for which the time is calculated.
236+
cost_factor: The cost factor used to calculate the time. If None, the time
237+
is the difference between the start and stop time, otherwise the time
238+
is given by the cost factor.
239+
240+
Returns:
241+
np.ndarray: The time per task. For entries where the task is not the
242+
requested task, the time is 0.
243+
244+
"""
215245
dummy_task = np.array([1 if t == task else 0 for t in self.task])
216246
if cost_factor is None:
217247
factor: float | NDArray[np.float64] = np.array(
@@ -220,7 +250,7 @@ def _get_time_per_task(
220250
else:
221251
factor = cost_factor
222252

223-
return np.cumsum(factor * dummy_task)
253+
return factor * dummy_task
224254

225255
@property
226256
def start_time(self) -> list[float]:
@@ -351,3 +381,58 @@ def _task_as_categorical(task: list[EvalTask]) -> pd.Categorical:
351381
return pd.Categorical(
352382
[t.value for t in task], categories=[t.value for t in EvalTask]
353383
)
384+
385+
386+
def _batch_apply(
387+
data: NDArray[np.float64],
388+
batch_ids: list[int],
389+
func: Callable[[Iterable[float]], float],
390+
) -> NDArray[np.float64]:
391+
"""Apply a reduction operator on batches of data.
392+
393+
Args:
394+
data: 1d array with data.
395+
batch_ids: A list whose length is equal to the size of data. Values need to be
396+
sorted and can be repeated.
397+
func: A reduction function that takes an iterable of floats as input (e.g., a
398+
numpy array or a list) and returns a scalar.
399+
400+
Returns:
401+
The transformed data. Has the same length as data. For each batch, the result of
402+
the reduction operation is stored at the first index of that batch, and all
403+
other values of that batch are set to zero.
404+
405+
"""
406+
batch_start = _get_batch_start(batch_ids)
407+
batch_stop = [*batch_start, len(data)][1:]
408+
409+
batch_result = []
410+
for batch, (start, stop) in zip(
411+
batch_ids, zip(batch_start, batch_stop, strict=False), strict=False
412+
):
413+
try:
414+
batch_data = data[start:stop]
415+
reduced = func(batch_data)
416+
batch_result.append(reduced)
417+
except Exception as e:
418+
msg = (
419+
f"Calling function {func.__name__} on batch {batch} of the History "
420+
f"History raised an Exception. Please verify that {func.__name__} is "
421+
"properly defined."
422+
)
423+
raise ValueError(msg) from e
424+
425+
out = np.zeros_like(data)
426+
out[batch_start] = batch_result
427+
return out
428+
429+
430+
def _get_batch_start(batch_ids: list[int]) -> list[int]:
431+
"""Get start indices of batch.
432+
433+
This function assumes that batch_ids non-empty and sorted.
434+
435+
"""
436+
ids_arr = np.array(batch_ids, dtype=np.int64)
437+
indices = np.where(ids_arr[:-1] != ids_arr[1:])[0] + 1
438+
return np.insert(indices, 0, 0).tolist()

src/optimagic/timing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Callable
2+
from typing import Callable, Iterable
33

44

55
@dataclass(frozen=True)
@@ -8,7 +8,7 @@ class CostModel:
88
jac: float | None
99
fun_and_jac: float | None
1010
label: str
11-
aggregate_batch_time: Callable[[list[float]], float]
11+
aggregate_batch_time: Callable[[Iterable[float]], float]
1212

1313

1414
evaluation_time = CostModel(

tests/optimagic/optimization/test_history.py

+71-14
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from optimagic.optimization.history import (
1111
History,
1212
HistoryEntry,
13+
_batch_apply,
1314
_calculate_monotone_sequence,
15+
_get_batch_start,
1416
_get_flat_param_names,
1517
_get_flat_params,
1618
_is_1d_array,
@@ -143,8 +145,8 @@ def params():
143145

144146

145147
@pytest.fixture
146-
def history(params):
147-
data = {
148+
def history_data(params):
149+
return {
148150
"fun": [10, None, 9, None, 2, 5],
149151
"task": [
150152
EvalTask.FUN,
@@ -157,9 +159,19 @@ def history(params):
157159
"start_time": [0, 2, 5, 7, 10, 12],
158160
"stop_time": [1, 4, 6, 9, 11, 14],
159161
"params": params,
160-
"batches": [0, 0, 1, 1, 2, 2],
162+
"batches": [0, 1, 2, 3, 4, 5],
161163
}
162164

165+
166+
@pytest.fixture
167+
def history(history_data):
168+
return History(direction=Direction.MINIMIZE, **history_data)
169+
170+
171+
@pytest.fixture
172+
def history_with_batch_data(history_data):
173+
data = history_data.copy()
174+
data["batches"] = [0, 0, 1, 1, 2, 2]
163175
return History(direction=Direction.MINIMIZE, **data)
164176

165177

@@ -211,9 +223,8 @@ def test_history_fun_data_with_fun_evaluations_cost_model_and_monotone(history):
211223
assert_frame_equal(got, exp, check_dtype=False, check_categorical=False)
212224

213225

214-
@pytest.mark.xfail(reason="Must be fixed!")
215-
def test_history_fun_data_with_fun_batches_cost_model(history):
216-
got = history.fun_data(
226+
def test_history_fun_data_with_fun_batches_cost_model(history_with_batch_data):
227+
got = history_with_batch_data.fun_data(
217228
cost_model=om.timing.fun_batches,
218229
monotone=False,
219230
)
@@ -328,23 +339,23 @@ def test_flat_param_names(history):
328339

329340
def test_get_time_per_task_fun(history):
330341
got = history._get_time_per_task(EvalTask.FUN, cost_factor=1)
331-
exp = np.array([1, 1, 2, 2, 3, 3])
342+
exp = np.array([1, 0, 1, 0, 1, 0])
332343
assert_array_equal(got, exp)
333344

334345

335-
def test_get_time_per_task_jac(history):
336-
got = history._get_time_per_task(EvalTask.JAC, cost_factor=1)
337-
exp = np.array([0, 1, 1, 2, 2, 2])
346+
def test_get_time_per_task_jac_cost_factor_none(history):
347+
got = history._get_time_per_task(EvalTask.JAC, cost_factor=None)
348+
exp = np.array([0, 2, 0, 2, 0, 0])
338349
assert_array_equal(got, exp)
339350

340351

341352
def test_get_time_per_task_fun_and_jac(history):
342-
got = history._get_time_per_task(EvalTask.FUN_AND_JAC, cost_factor=1)
343-
exp = np.array([0, 0, 0, 0, 0, 1])
353+
got = history._get_time_per_task(EvalTask.FUN_AND_JAC, cost_factor=-0.5)
354+
exp = np.array([0, 0, 0, 0, 0, -0.5])
344355
assert_array_equal(got, exp)
345356

346357

347-
def test_get_time_cost_model(history):
358+
def test_get_time_custom_cost_model(history):
348359
cost_model = om.timing.CostModel(
349360
fun=0.5, jac=1, fun_and_jac=2, label="test", aggregate_batch_time=sum
350361
)
@@ -362,6 +373,30 @@ def test_get_time_cost_model(history):
362373
assert_array_equal(got, exp)
363374

364375

376+
def test_get_time_fun_evaluations(history):
377+
got = history._get_time(cost_model=om.timing.fun_evaluations)
378+
exp = np.array([1, 1, 2, 2, 3, 4])
379+
assert_array_equal(got, exp)
380+
381+
382+
def test_get_time_fun_batches(history):
383+
got = history._get_time(cost_model=om.timing.fun_batches)
384+
exp = np.array([1, 1, 2, 2, 3, 4])
385+
assert_array_equal(got, exp)
386+
387+
388+
def test_get_time_fun_batches_with_batch_data(history_with_batch_data):
389+
got = history_with_batch_data._get_time(cost_model=om.timing.fun_batches)
390+
exp = np.array([1, 1, 2, 2, 3, 3])
391+
assert_array_equal(got, exp)
392+
393+
394+
def test_get_time_evaluation_time(history):
395+
got = history._get_time(cost_model=om.timing.evaluation_time)
396+
exp = np.array([1, 3, 4, 6, 7, 9])
397+
assert_array_equal(got, exp)
398+
399+
365400
def test_get_time_wall_time(history):
366401
got = history._get_time(cost_model="wall_time")
367402
exp = np.array([1, 4, 6, 9, 11, 14])
@@ -381,7 +416,7 @@ def test_stop_time_property(history):
381416

382417

383418
def test_batches_property(history):
384-
assert history.batches == [0, 0, 1, 1, 2, 2]
419+
assert history.batches == [0, 1, 2, 3, 4, 5]
385420

386421

387422
# Tasks
@@ -466,3 +501,25 @@ def test_task_as_categorical():
466501
got = _task_as_categorical(task)
467502
assert got.tolist() == ["fun", "jac", "fun_and_jac"]
468503
assert isinstance(got.dtype, pd.CategoricalDtype)
504+
505+
506+
def test_get_batch_start():
507+
batches = [0, 0, 1, 1, 1, 2, 2, 3]
508+
got = _get_batch_start(batches)
509+
assert got == [0, 2, 5, 7]
510+
511+
512+
def test_batch_apply_sum():
513+
data = np.array([0, 1, 2, 3, 4])
514+
batch_ids = [0, 0, 1, 1, 2]
515+
exp = np.array([1, 0, 5, 0, 4])
516+
got = _batch_apply(data, batch_ids, sum)
517+
assert_array_equal(exp, got)
518+
519+
520+
def test_batch_apply_max():
521+
data = np.array([0, 1, 2, 3, 4])
522+
batch_ids = [0, 0, 1, 1, 2]
523+
exp = np.array([1, 0, 3, 0, 4])
524+
got = _batch_apply(data, batch_ids, max)
525+
assert_array_equal(exp, got)

0 commit comments

Comments
 (0)