Skip to content

Commit c1f0d2c

Browse files
committed
Move results processing into solve_internal_problem.
1 parent 8596bed commit c1f0d2c

File tree

5 files changed

+111
-116
lines changed

5 files changed

+111
-116
lines changed

src/optimagic/optimization/algorithm.py

+66-6
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010

1111
from optimagic.exceptions import InvalidAlgoInfoError, InvalidAlgoOptionError
1212
from optimagic.logging.types import StepStatus
13+
from optimagic.optimization.convergence_report import get_convergence_report
1314
from optimagic.optimization.history import History
1415
from optimagic.optimization.internal_optimization_problem import (
1516
InternalOptimizationProblem,
1617
)
18+
from optimagic.optimization.optimize_result import OptimizeResult
19+
from optimagic.parameters.conversion import Converter
1720
from optimagic.type_conversion import TYPE_CONVERTERS
18-
from optimagic.typing import AggregationLevel
21+
from optimagic.typing import AggregationLevel, Direction, ExtraResultFields
22+
from optimagic.utilities import isscalar
1923

2024

2125
@dataclass(frozen=True)
@@ -142,6 +146,56 @@ def __post_init__(self) -> None:
142146
)
143147
raise TypeError(msg)
144148

149+
def create_optimize_result(
150+
self,
151+
converter: Converter,
152+
solver_type: AggregationLevel,
153+
extra_fields: ExtraResultFields,
154+
) -> OptimizeResult:
155+
"""Process an internal optimizer result."""
156+
params = converter.params_from_internal(self.x)
157+
if isscalar(self.fun):
158+
fun = float(self.fun)
159+
elif solver_type == AggregationLevel.LIKELIHOOD:
160+
fun = float(np.sum(self.fun))
161+
elif solver_type == AggregationLevel.LEAST_SQUARES:
162+
fun = np.dot(self.fun, self.fun)
163+
164+
if extra_fields.direction == Direction.MAXIMIZE:
165+
fun = -fun
166+
167+
if self.history is not None:
168+
conv_report = get_convergence_report(
169+
history=self.history, direction=extra_fields.direction
170+
)
171+
else:
172+
conv_report = None
173+
174+
out = OptimizeResult(
175+
params=params,
176+
fun=fun,
177+
start_fun=extra_fields.start_fun,
178+
start_params=extra_fields.start_params,
179+
algorithm=extra_fields.algorithm,
180+
direction=extra_fields.direction.value,
181+
n_free=extra_fields.n_free,
182+
message=self.message,
183+
success=self.success,
184+
n_fun_evals=self.n_fun_evals,
185+
n_jac_evals=self.n_jac_evals,
186+
n_hess_evals=self.n_hess_evals,
187+
n_iterations=self.n_iterations,
188+
status=self.status,
189+
jac=self.jac,
190+
hess=self.hess,
191+
hess_inv=self.hess_inv,
192+
max_constraint_violation=self.max_constraint_violation,
193+
history=self.history,
194+
algorithm_output=self.info,
195+
convergence_report=conv_report,
196+
)
197+
return out
198+
145199

146200
class AlgorithmMeta(ABCMeta):
147201
"""Metaclass to get repr, algo_info and name for classes, not just instances."""
@@ -234,25 +288,31 @@ def solve_internal_problem(
234288
problem: InternalOptimizationProblem,
235289
x0: NDArray[np.float64],
236290
step_id: int,
237-
) -> InternalOptimizeResult:
291+
) -> OptimizeResult:
238292
problem = problem.with_new_history().with_step_id(step_id)
239293

240294
if problem.logger:
241295
problem.logger.step_store.update(
242296
step_id, {"status": str(StepStatus.RUNNING.value)}
243297
)
244298

245-
result = self._solve_internal_problem(problem, x0)
299+
raw_res = self._solve_internal_problem(problem, x0)
246300

247-
if (not self.algo_info.disable_history) and (result.history is None):
248-
result = replace(result, history=problem.history)
301+
if (not self.algo_info.disable_history) and (raw_res.history is None):
302+
raw_res = replace(raw_res, history=problem.history)
249303

250304
if problem.logger:
251305
problem.logger.step_store.update(
252306
step_id, {"status": str(StepStatus.COMPLETE.value)}
253307
)
254308

255-
return result
309+
res = raw_res.create_optimize_result(
310+
converter=problem.converter,
311+
solver_type=self.algo_info.solver_type,
312+
extra_fields=problem.static_result_fields,
313+
)
314+
315+
return res
256316

257317
def with_option_if_applicable(self, **kwargs: Any) -> Self:
258318
"""Call with_option only with applicable keyword arguments."""

src/optimagic/optimization/multistart.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
"""
1313

1414
import warnings
15-
from dataclasses import replace
1615
from typing import Literal
1716

1817
import numpy as np
@@ -21,7 +20,7 @@
2120

2221
from optimagic.logging.logger import LogStore
2322
from optimagic.logging.types import StepStatus
24-
from optimagic.optimization.algorithm import Algorithm, InternalOptimizeResult
23+
from optimagic.optimization.algorithm import Algorithm
2524
from optimagic.optimization.internal_optimization_problem import (
2625
InternalBounds,
2726
InternalOptimizationProblem,
@@ -30,6 +29,8 @@
3029
from optimagic.optimization.optimization_logging import (
3130
log_scheduled_steps_and_get_ids,
3231
)
32+
from optimagic.optimization.optimize_result import OptimizeResult
33+
from optimagic.optimization.process_results import process_multistart_result
3334
from optimagic.typing import AggregationLevel, ErrorHandling
3435
from optimagic.utilities import get_rng
3536

@@ -42,7 +43,7 @@ def run_multistart_optimization(
4243
options: InternalMultistartOptions,
4344
logger: LogStore | None,
4445
error_handling: ErrorHandling,
45-
) -> InternalOptimizeResult:
46+
) -> OptimizeResult:
4647
steps = determine_steps(options.n_samples, stopping_maxopt=options.stopping_maxopt)
4748

4849
scheduled_steps = log_scheduled_steps_and_get_ids(
@@ -159,6 +160,7 @@ def single_optimization(x0, step_id):
159160
results=batch_results,
160161
convergence_criteria=convergence_criteria,
161162
solver_type=local_algorithm.algo_info.solver_type,
163+
converter=internal_problem.converter,
162164
)
163165
opt_counter += len(batch)
164166
if is_converged:
@@ -176,7 +178,12 @@ def single_optimization(x0, step_id):
176178
}
177179

178180
raw_res = state["best_res"]
179-
res = replace(raw_res, multistart_info=multistart_info)
181+
res = process_multistart_result(
182+
raw_res=raw_res,
183+
converter=internal_problem.converter,
184+
extra_fields=internal_problem.static_result_fields,
185+
multistart_info=multistart_info,
186+
)
180187

181188
return res
182189

@@ -371,7 +378,12 @@ def get_batched_optimization_sample(sorted_sample, stopping_maxopt, batch_size):
371378

372379

373380
def update_convergence_state(
374-
current_state, starts, results, convergence_criteria, solver_type
381+
current_state,
382+
starts,
383+
results,
384+
convergence_criteria,
385+
solver_type,
386+
converter,
375387
):
376388
"""Update the state of all quantities related to convergence.
377389
@@ -389,6 +401,7 @@ def update_convergence_state(
389401
convergence_criteria (dict): Dict with the entries "xtol" and "max_discoveries"
390402
solver_type: The aggregation level of the local optimizer. Needed to
391403
interpret the output of the internal criterion function.
404+
converter: The converter to map between internal and external parameter spaces.
392405
393406
394407
Returns:
@@ -422,7 +435,7 @@ def update_convergence_state(
422435
# ==================================================================================
423436
valid_results = [results[i] for i in valid_indices]
424437
valid_starts = [starts[i] for i in valid_indices]
425-
valid_new_x = [res.x for res in valid_results]
438+
valid_new_x = [converter.params_to_internal(res.params) for res in valid_results]
426439
valid_new_y = []
427440

428441
# make the criterion output scalar if a least squares optimizer returns an

src/optimagic/optimization/optimize.py

+2-21
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@
4848
)
4949
from optimagic.optimization.optimization_logging import log_scheduled_steps_and_get_ids
5050
from optimagic.optimization.optimize_result import OptimizeResult
51-
from optimagic.optimization.process_results import (
52-
process_multistart_result,
53-
process_single_result,
54-
)
5551
from optimagic.parameters.bounds import Bounds
5652
from optimagic.parameters.conversion import (
5753
get_converter,
@@ -644,7 +640,7 @@ def _optimize(problem: OptimizationProblem) -> OptimizeResult:
644640
logger=logger,
645641
)[0]
646642

647-
raw_res = problem.algorithm.solve_internal_problem(internal_problem, x, step_id)
643+
res = problem.algorithm.solve_internal_problem(internal_problem, x, step_id)
648644

649645
else:
650646
multistart_options = get_internal_multistart_options_from_public(
@@ -658,7 +654,7 @@ def _optimize(problem: OptimizationProblem) -> OptimizeResult:
658654
upper=internal_params.soft_upper_bounds,
659655
)
660656

661-
raw_res = run_multistart_optimization(
657+
res = run_multistart_optimization(
662658
local_algorithm=problem.algorithm,
663659
internal_problem=internal_problem,
664660
x=x,
@@ -672,21 +668,6 @@ def _optimize(problem: OptimizationProblem) -> OptimizeResult:
672668
# Process the result
673669
# ==================================================================================
674670

675-
if problem.multistart is None:
676-
res = process_single_result(
677-
raw_res=raw_res,
678-
converter=converter,
679-
solver_type=problem.algorithm.algo_info.solver_type,
680-
extra_fields=extra_fields,
681-
)
682-
else:
683-
res = process_multistart_result(
684-
raw_res=raw_res,
685-
converter=converter,
686-
solver_type=problem.algorithm.algo_info.solver_type,
687-
extra_fields=extra_fields,
688-
)
689-
690671
log_reader: LogReader[Any] | None
691672
if logger is not None:
692673
assert problem.logging is not None

src/optimagic/optimization/process_results.py

+9-80
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,29 @@
1-
from dataclasses import replace
1+
import copy
22
from typing import Any
33

44
import numpy as np
55

6-
from optimagic.optimization.algorithm import InternalOptimizeResult
76
from optimagic.optimization.convergence_report import get_convergence_report
87
from optimagic.optimization.optimize_result import MultistartInfo, OptimizeResult
98
from optimagic.parameters.conversion import Converter
10-
from optimagic.typing import AggregationLevel, Direction, ExtraResultFields
11-
from optimagic.utilities import isscalar
12-
13-
14-
def process_single_result(
15-
raw_res: InternalOptimizeResult,
16-
converter: Converter,
17-
solver_type: AggregationLevel,
18-
extra_fields: ExtraResultFields,
19-
) -> OptimizeResult:
20-
"""Process an internal optimizer result."""
21-
params = converter.params_from_internal(raw_res.x)
22-
if isscalar(raw_res.fun):
23-
fun = float(raw_res.fun)
24-
elif solver_type == AggregationLevel.LIKELIHOOD:
25-
fun = float(np.sum(raw_res.fun))
26-
elif solver_type == AggregationLevel.LEAST_SQUARES:
27-
fun = np.dot(raw_res.fun, raw_res.fun)
28-
29-
if extra_fields.direction == Direction.MAXIMIZE:
30-
fun = -fun
31-
32-
if raw_res.history is not None:
33-
conv_report = get_convergence_report(
34-
history=raw_res.history, direction=extra_fields.direction
35-
)
36-
else:
37-
conv_report = None
38-
39-
out = OptimizeResult(
40-
params=params,
41-
fun=fun,
42-
start_fun=extra_fields.start_fun,
43-
start_params=extra_fields.start_params,
44-
algorithm=extra_fields.algorithm,
45-
direction=extra_fields.direction.value,
46-
n_free=extra_fields.n_free,
47-
message=raw_res.message,
48-
success=raw_res.success,
49-
n_fun_evals=raw_res.n_fun_evals,
50-
n_jac_evals=raw_res.n_jac_evals,
51-
n_hess_evals=raw_res.n_hess_evals,
52-
n_iterations=raw_res.n_iterations,
53-
status=raw_res.status,
54-
jac=raw_res.jac,
55-
hess=raw_res.hess,
56-
hess_inv=raw_res.hess_inv,
57-
max_constraint_violation=raw_res.max_constraint_violation,
58-
history=raw_res.history,
59-
algorithm_output=raw_res.info,
60-
convergence_report=conv_report,
61-
)
62-
return out
9+
from optimagic.typing import Direction, ExtraResultFields
6310

6411

6512
def process_multistart_result(
66-
raw_res: InternalOptimizeResult,
13+
raw_res: OptimizeResult,
6714
converter: Converter,
68-
solver_type: AggregationLevel,
6915
extra_fields: ExtraResultFields,
16+
multistart_info: dict[str, Any],
7017
) -> OptimizeResult:
7118
"""Process results of internal optimizers."""
72-
if raw_res.multistart_info is None:
73-
raise ValueError("Multistart info is missing.")
7419

7520
if isinstance(raw_res, str):
7621
res = _dummy_result_from_traceback(raw_res, extra_fields)
7722
else:
78-
res = process_single_result(
79-
raw_res=raw_res,
80-
converter=converter,
81-
solver_type=solver_type,
82-
extra_fields=extra_fields,
83-
)
84-
23+
res = raw_res
8524
info = _process_multistart_info(
86-
raw_res.multistart_info,
25+
multistart_info,
8726
converter=converter,
88-
solver_type=solver_type,
8927
extra_fields=extra_fields,
9028
)
9129

@@ -118,24 +56,15 @@ def process_multistart_result(
11856
def _process_multistart_info(
11957
info: dict[str, Any],
12058
converter: Converter,
121-
solver_type: AggregationLevel,
12259
extra_fields: ExtraResultFields,
12360
) -> MultistartInfo:
12461
starts = [converter.params_from_internal(x) for x in info["start_parameters"]]
12562

12663
optima = []
12764
for res, start in zip(info["local_optima"], starts, strict=False):
128-
replacements = {
129-
"start_params": start,
130-
"start_fun": None,
131-
}
132-
133-
processed = process_single_result(
134-
res,
135-
converter=converter,
136-
solver_type=solver_type,
137-
extra_fields=replace(extra_fields, **replacements),
138-
)
65+
processed = copy.copy(res)
66+
processed.start_params = start
67+
processed.start_fun = None
13968
optima.append(processed)
14069

14170
sample = [converter.params_from_internal(x) for x in info["exploration_sample"]]

0 commit comments

Comments
 (0)