Skip to content

Commit 30c7c6c

Browse files
Fix problems with Python interfaces and expanded docstrings
1 parent a224a13 commit 30c7c6c

File tree

4 files changed

+200
-17
lines changed

4 files changed

+200
-17
lines changed

oif/interfaces/python/oif/interfaces/ivp.py

+102-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
from typing import Callable
1+
"""This module defines the interface for solving initial-value problems
2+
for ordinary differential equations:
3+
4+
.. math::
5+
\\frac{dy}{dt} = f(t, y), \\quad y(t_0) = y_0.
6+
7+
"""
8+
9+
from collections.abc import Callable
10+
from typing import TypeAlias
211

312
import numpy as np
413
from oif.core import (
@@ -13,48 +22,128 @@
1322
unload_impl,
1423
)
1524

16-
rhs_fn_t = Callable[[float, np.ndarray, np.ndarray, object], int]
25+
RhsFn: TypeAlias = Callable[[float, np.ndarray, np.ndarray, object], int]
26+
"""Signature of the right-hand side (RHS) function :math:`f(t, y)`.
27+
28+
The function accepts four arguments:
29+
- `t`: current time,
30+
- `y`: state vector at time :math:`t`,
31+
- `ydot`: output array to which the result of function evalutation is stored,
32+
- `user_data`: additional context (user-defined data) that
33+
must be passed to the function (e.g., parameters of the system).
34+
35+
"""
1736

1837

1938
class IVP:
39+
"""Interface for solving initial value problems.
40+
41+
This class serves as a gateway to the implementations of the
42+
solvers for initial-value problems for ordinary differential equations.
43+
44+
Parameters
45+
----------
46+
impl : str
47+
Name of the desired implementation.
48+
49+
Examples
50+
--------
51+
52+
Let's solve the following initial value problem:
53+
54+
.. math::
55+
y'(t) = -y(t), \\quad y(0) = 1.
56+
57+
First, import the necessary modules:
58+
>>> import numpy as np
59+
>>> from oif.interfaces.ivp import IVP
60+
61+
Define the right-hand side function:
62+
63+
>>> def rhs(t, y, ydot, user_data):
64+
... ydot[0] = -y[0]
65+
... return 0 # No errors, optional
66+
67+
Now define the initial condition:
68+
69+
>>> y0, t0 = np.array([1.0]), 0.0
70+
71+
Create an instance of the IVP solver using the implementation "jl_diffeq",
72+
which is an adapter to the `OrdinaryDiffeq.jl` Julia package:
73+
74+
>>> s = IVP("jl_diffeq")
75+
76+
We set the initial value, the right-hand side function, and the tolerance:
77+
78+
>>> s.set_initial_value(y0, t0)
79+
>>> s.set_rhs_fn(rhs)
80+
>>> s.set_tolerances(1e-6, 1e-12)
81+
82+
Now we integrate to time `t = 1.0` in a loop, outputting the current value
83+
of `y` with time step `0.1`:
84+
85+
>>> t = t0
86+
>>> times = np.linspace(t0, t0 + 1.0, num=11)
87+
>>> for t in times[1:]:
88+
... s.integrate(t)
89+
... print(f"{t:.1f} {s.y[0]:.6f}")
90+
0.1 0.904837
91+
0.2 0.818731
92+
0.3 0.740818
93+
0.4 0.670320
94+
0.5 0.606531
95+
0.6 0.548812
96+
0.7 0.496585
97+
0.8 0.449329
98+
0.9 0.406570
99+
1.0 0.367879
100+
101+
"""
102+
20103
def __init__(self, impl: str):
21104
self._binding: OIFPyBinding = load_impl("ivp", impl, 1, 0)
22-
self.s = None
23-
self.N: int = 0
24-
self.y0: np.ndarray
25105
self.y: np.ndarray
106+
"""Current value of the state vector."""
107+
self._N: int = 0
26108

27109
def set_initial_value(self, y0: np.ndarray, t0: float):
110+
"""Set initial value y(t0) = y0."""
28111
y0 = np.asarray(y0, dtype=np.float64)
29112
self.y0 = y0
30113
self.y = np.empty_like(y0)
31-
self.N = len(self.y0)
114+
self._N = len(self.y0)
32115
t0 = float(t0)
33116
self._binding.call("set_initial_value", (y0, t0), ())
34117

35-
def set_rhs_fn(self, rhs_fn: rhs_fn_t):
36-
if self.N <= 0:
118+
def set_rhs_fn(self, rhs_fn: RhsFn):
119+
"""Specify right-hand side function f."""
120+
if self._N <= 0:
37121
raise RuntimeError("'set_initial_value' must be called before 'set_rhs_fn'")
38122

39123
self.wrapper = make_oif_callback(
40124
rhs_fn, (OIF_FLOAT64, OIF_ARRAY_F64, OIF_ARRAY_F64, OIF_USER_DATA), OIF_INT
41125
)
42126
self._binding.call("set_rhs_fn", (self.wrapper,), ())
43127

128+
def set_tolerances(self, rtol: float, atol: float):
129+
"""Specify relative and absolute tolerances, respectively."""
130+
self._binding.call("set_tolerances", (rtol, atol), ())
131+
44132
def set_user_data(self, user_data: object):
133+
"""Specify additional data that will be used for right-hand side function."""
45134
self.user_data = make_oif_user_data(user_data)
46135
self._binding.call("set_user_data", (self.user_data,), ())
47136

48-
def set_tolerances(self, rtol: float, atol: float):
49-
self._binding.call("set_tolerances", (rtol, atol), ())
137+
def set_integrator(self, integrator_name: str, integrator_params: dict = {}):
138+
"""Set integrator, if the name is recognizable."""
139+
self._binding.call("set_integrator", (integrator_name, integrator_params), ())
50140

51141
def integrate(self, t: float):
142+
"""Integrate to time `t` and write solution to `y`."""
52143
self._binding.call("integrate", (t,), (self.y,))
53144

54-
def set_integrator(self, integrator_name: str, integrator_params: dict = {}):
55-
self._binding.call("set_integrator", (integrator_name, integrator_params), ())
56-
57145
def print_stats(self):
146+
"""Print integration statistics."""
58147
self._binding.call("print_stats", (), ())
59148

60149
def __del__(self):

oif/interfaces/python/oif/interfaces/linear_solver.py

+38
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,50 @@
1+
"""This module defines the interface for solving linear systems of equations.
2+
3+
Problems to be solved are of the form:
4+
5+
.. math::
6+
A x = b,
7+
8+
where :math:`A` is a square matrix and :math:`b` is a vector.
9+
"""
10+
111
import numpy as np
212
from oif.core import OIFPyBinding, load_impl, unload_impl
313

414

515
class LinearSolver:
16+
"""Interface for solving linear systems of equations.
17+
18+
This class serves as a gateway to the implementations of the
19+
linear algebraic solvers.
20+
21+
Parameters
22+
----------
23+
impl : str
24+
Name of the desired implementation.
25+
26+
"""
27+
628
def __init__(self, impl: str):
729
self._binding: OIFPyBinding = load_impl("linsolve", impl, 1, 0)
830

931
def solve(self, A: np.ndarray, b: np.ndarray) -> np.ndarray:
32+
"""Solve the linear system of equations :math:`A x = b`.
33+
34+
Parameters
35+
----------
36+
A : np.ndarray of shape (n, n)
37+
Coefficient matrix.
38+
b : np.ndarray of shape (n,)
39+
Right-hand side vector.
40+
41+
Returns
42+
-------
43+
np.ndarray
44+
Result of the linear system solution after the invocation
45+
of the `solve` method.
46+
47+
"""
1048
result = np.empty((A.shape[1]))
1149
self._binding.call("solve_lin", (A, b), (result,))
1250
return result

oif/interfaces/python/oif/interfaces/qeq_solver.py

+52
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,64 @@
1+
"""This module defines the interface for solving a quatratic equation.
2+
3+
The quadratic equation is of the form:
4+
5+
.. math::
6+
a x^2 + b x + c = 0,
7+
8+
where :math:`a`, :math:`b`, and :math:`c` are the coefficients of the equation.
9+
10+
Of course, this is not very useful in scientific context to invoke
11+
such a solver.
12+
13+
It was developed as a prototype to ensure that the envisioned architecture
14+
of Open Interfaces is feasible.
15+
It is used as a simple text case as well.
16+
17+
"""
18+
119
import numpy as np
220
from oif.core import OIFPyBinding, load_impl, unload_impl
321

422

523
class QeqSolver:
24+
"""Interface for solving quadratic equations.
25+
26+
This class serves as a gateway to the implementations of the
27+
of the quadratic-equation solvers.
28+
29+
Example
30+
-------
31+
32+
Let's solve the following quadratic equation:
33+
34+
.. math::
35+
x^2 + 2 x + 1 = 0.
36+
37+
First, import the necessary modules:
38+
39+
>>> from oif.interfaces.qeq_solver import QeqSolver
40+
41+
Define the coefficients of the equation:
42+
43+
>>> a, b, c = 1.0, 2.0, 1.0
44+
45+
Create an instance of the solver:
46+
47+
>>> s = QeqSolver("py_qeq_solver")
48+
49+
Solve the equation:
50+
51+
>>> result = s.solve(a, b, c)
52+
>>> print(result)
53+
[-1. -1.]
54+
55+
"""
56+
657
def __init__(self, impl: str):
758
self._binding: OIFPyBinding = load_impl("qeq", impl, 1, 0)
859

960
def solve(self, a: float, b: float, c: float):
61+
"""Solve the quadratic equation :math:`a x^2 + b x + c = 0`."""
1062
result = np.array([11.0, 22.0])
1163
self._binding.call("solve_qeq", (a, b, c), (result,))
1264
return result

oif_impl/lang_python/oif/impl/ivp.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@ def set_rhs_fn(self, rhs: Callable) -> Union[int, None]:
1717
def set_tolerances(self, rtol: float, atol: float) -> Union[int, None]:
1818
"""Specify relative and absolute tolerances, respectively."""
1919

20-
@abc.abstractmethod
21-
def integrate(self, t: float, y: np.ndarray) -> Union[int, None]:
22-
"""Integrate to time `t` and write solution to `y`."""
23-
2420
@abc.abstractmethod
2521
def set_user_data(self, user_data: object) -> Union[int, None]:
2622
"""Specify additional data that will be used for right-hand side function."""
@@ -30,3 +26,11 @@ def set_integrator(
3026
self, integrator_name: str, integrator_params: Dict
3127
) -> Union[int, None]:
3228
"""Set integrator, if the name is recognizable."""
29+
30+
@abc.abstractmethod
31+
def integrate(self, t: float, y: np.ndarray) -> Union[int, None]:
32+
"""Integrate to time `t` and write solution to `y`."""
33+
34+
@abc.abstractmethod
35+
def print_stats(self):
36+
"""Print integration statistics."""

0 commit comments

Comments
 (0)