Skip to content

Commit 6a67124

Browse files
jorgensdjhalechrisrichardsonmichalhaberagarth-wells
authored
Extend adjoint action to work with mixed spaces (#352)
* Update version number. * Fix * Correct tag ref * Use BaseArgument.__eq__ in Argument (#147) (cherry picked from commit e683148) * Oops! Bump version number. * Fix. * Bump version. * Updated to .md README (#275) * Kebab case in build-wheels.yml (#276) * Fixes for pypa packaging. * Correct documentation links (#277) * Remove unecessary pip pinning (#278) * Bump version. * Update version to 2024.2.0 * Extending ufl.lhs and ufl.rhs to mixed domains variational forms * Add arity optional argument to extract_blocks to be used in compute_form_rhs/lhs * Fix ruff * Fix ruff * Fix lhs / rhs : No need to extract blocks if parts=() * Extend `MixedFunctionSpace with `ufl.rhs`, `ufl.lhs` and thus `ufl.system`. * Add `ufl.action` for MixedFunctionSpace * Extend adjoint action to work with mixed spaces --------- Co-authored-by: Jack S. Hale <mail@jackhale.co.uk> Co-authored-by: Chris Richardson <chris@bpi.cam.ac.uk> Co-authored-by: Michal Habera <michal.habera@gmail.com> Co-authored-by: Garth N. Wells <gnw20@cam.ac.uk> Co-authored-by: Cécile <cecile@simula.no>
1 parent 5d249b0 commit 6a67124

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

test/test_mixed_function_space.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
TestFunctions,
1212
TrialFunctions,
1313
action,
14+
conj,
1415
dx,
1516
grad,
1617
inner,
@@ -23,6 +24,7 @@
2324
)
2425
from ufl.algorithms import expand_derivatives, renumbering
2526
from ufl.algorithms.formsplitter import extract_blocks
27+
from ufl.algorithms.formtransformations import compute_form_adjoint
2628
from ufl.finiteelement import FiniteElement
2729
from ufl.pullback import identity_pullback
2830
from ufl.sobolevspace import H1
@@ -134,7 +136,6 @@ def source2(f, v):
134136
f = Constant(domain)
135137
g = Constant(domain)
136138
F = mass(u, v) + mixed(u, q) + stiffness(p, q) + source1(f, v) + source2(g, q)
137-
138139
a = lhs(F)
139140
a_blocked = extract_blocks(a)
140141
L = rhs(F)
@@ -180,7 +181,6 @@ def source2(f, v):
180181
F_reduced_renumbered = renumbering.renumber_indices(expand_derivatives(F_reduced))
181182
assert len(F_reduced_renumbered.coefficients()) == 1
182183
inserted_coeff = F_reduced_renumbered.coefficients()[0]
183-
184184
# Create reference solution
185185
Fh = mass(inserted_coeff, v) + stiffness(inserted_coeff, q) + source1(f, v) + source2(g, q)
186186

@@ -197,3 +197,32 @@ def source2(f, v):
197197
J_ref = replace(J_exp, {u: inserted_coeff, v: coefficients[0], q: coefficients[1]})
198198
# Verify
199199
assert J_renumbered == J_ref
200+
201+
202+
def test_adjoint():
203+
V = FiniteElement("Lagrange", triangle, 1, (), identity_pullback, H1)
204+
domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1))
205+
space0 = FunctionSpace(domain, V)
206+
space1 = FunctionSpace(domain, V)
207+
mixed_space = MixedFunctionSpace(space0, space1)
208+
209+
u, p = TrialFunctions(mixed_space)
210+
du, dp = TestFunctions(mixed_space)
211+
c = Coefficient(space0)
212+
Jh = (
213+
inner(grad(u), grad(du)) * dx
214+
+ inner(c * dp.dx(0), u) * dx
215+
- inner(du.dx(0), p.dx(1)) * dx
216+
+ inner(p, dp) * dx
217+
)
218+
Jh_adj = compute_form_adjoint(Jh)
219+
blocked_adj = extract_blocks(Jh_adj)
220+
221+
ref_adj = [
222+
[conj(inner(grad(du), grad(u))) * dx, conj(-inner(p.dx(0), du.dx(1))) * dx],
223+
[conj(inner(c * u.dx(0), dp)) * dx, conj(inner(dp, p)) * dx],
224+
]
225+
226+
for i in range(2):
227+
for j in range(2):
228+
assert ref_adj[i][j] == blocked_adj[i][j]

ufl/algorithms/formtransformations.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,39 @@ def compute_form_adjoint(form, reordered_arguments=None):
496496

497497
parts = [arg.part() for arg in arguments]
498498
if set(parts) - {None}:
499-
raise ValueError("compute_form_adjoint cannot handle parts.")
499+
J = extract_blocks(form, arity=2)
500+
num_blocks = len(J)
501+
J_adj = 0
502+
for i in range(num_blocks):
503+
for j in range(num_blocks):
504+
if J[i][j] is None:
505+
continue
506+
v, u = J[i][j].arguments()
507+
if reordered_arguments is None:
508+
reordered_u = Argument(u.ufl_function_space(), number=v.number(), part=v.part())
509+
reordered_v = Argument(v.ufl_function_space(), number=u.number(), part=u.part())
510+
else:
511+
reordered_u, reordered_v = reordered_arguments[i]
512+
513+
if reordered_u.number() >= reordered_v.number():
514+
raise ValueError("Ordering of new arguments is the same as the old arguments!")
515+
516+
if reordered_u.part() != v.part():
517+
raise ValueError("Ordering of new arguments is the same as the old arguments!")
518+
if reordered_v.part() != u.part():
519+
raise ValueError("Ordering of new arguments is the same as the old arguments!")
520+
521+
if reordered_u.ufl_function_space() != u.ufl_function_space():
522+
raise ValueError(
523+
"Element mismatch between new and old arguments (trial functions)."
524+
)
525+
if reordered_v.ufl_function_space() != v.ufl_function_space():
526+
raise ValueError(
527+
"Element mismatch between new and old arguments (test functions)."
528+
)
529+
530+
J_adj += map_integrands(Conj, replace(J[i][j], {v: reordered_v, u: reordered_u}))
531+
return J_adj
500532

501533
if len(arguments) != 2:
502534
raise ValueError("Expecting bilinear form.")

0 commit comments

Comments
 (0)