Skip to content

Commit 8471cb2

Browse files
committed
temp
1 parent 5da0ab9 commit 8471cb2

6 files changed

+313
-0
lines changed

ufl/coefficient.py

+71
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# Modified by Cecile Daversin-Catty, 2018.
1212
# Modified by Ignacia Fierro-Piccardo 2023.
1313

14+
import numpy as np
1415
from ufl.argument import Argument
16+
from ufl.core.expr import Expr
1517
from ufl.core.terminal import FormArgument
1618
from ufl.core.ufl_type import ufl_type
1719
from ufl.duals import is_dual, is_primal
@@ -20,6 +22,7 @@
2022
from ufl.split_functions import split
2123
from ufl.utils.counted import Counted
2224

25+
2326
# --- The Coefficient class represents a coefficient in a form ---
2427

2528

@@ -201,6 +204,74 @@ def __repr__(self):
201204
"""Representation."""
202205
return self._repr
203206

207+
@Expr.traverse_dag
208+
@Expr.traverse_dag_apply_coefficient_split_cache
209+
def traverse_dag_apply_coefficient_split(
210+
self,
211+
coefficient_split,
212+
reference_value=False,
213+
reference_grad=0,
214+
restricted=None,
215+
cache=None,
216+
):
217+
from ufl.classes import (
218+
Terminal,
219+
ComponentTensor,
220+
MultiIndex,
221+
NegativeRestricted,
222+
PositiveRestricted,
223+
ReferenceGrad,
224+
ReferenceValue,
225+
Zero,
226+
)
227+
from ufl.core.multiindex import indices
228+
from ufl.checks import is_cellwise_constant
229+
from ufl.domain import extract_unique_domain
230+
from ufl.tensors import as_tensor
231+
232+
if self not in coefficient_split:
233+
return Terminal.traverse_dag_apply_coefficient_split(
234+
self,
235+
coefficient_split,
236+
reference_value=reference_value,
237+
reference_grad=reference_grad,
238+
restricted=restricted,
239+
cache=cache,
240+
)
241+
# Reference value expected
242+
if not reference_value:
243+
raise RuntimeError(f"ReferenceValue expected: got {o}")
244+
# Derivative indices
245+
beta = indices(reference_grad)
246+
components = []
247+
for subcoeff in coefficient_split[self]:
248+
c = subcoeff
249+
# Apply terminal modifiers onto the subcoefficient
250+
if reference_value:
251+
c = ReferenceValue(c)
252+
for _ in range(reference_grad):
253+
# Return zero if expression is trivially constant. This has to
254+
# happen here because ReferenceGrad has no access to the
255+
# topological dimension of a literal zero.
256+
if is_cellwise_constant(c):
257+
dim = extract_unique_domain(subcoeff).topological_dimension()
258+
c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions)
259+
else:
260+
c = ReferenceGrad(c)
261+
if restricted == "+":
262+
c = PositiveRestricted(c)
263+
elif restricted == "-":
264+
c = NegativeRestricted(c)
265+
elif restricted is not None:
266+
raise RuntimeError(f"Got unknown restriction: {restricted}")
267+
# Collect components of the subcoefficient
268+
for alpha in np.ndindex(subcoeff.ufl_element().reference_value_shape):
269+
# New modified terminal: component[alpha + beta]
270+
components.append(c[alpha + beta])
271+
# Repack derivative indices to shape
272+
i, = indices(1)
273+
return ComponentTensor(as_tensor(components)[i], MultiIndex((i,) + beta))
274+
204275

205276
# --- Helper functions for subfunctions on mixed elements ---
206277

ufl/core/expr.py

+82
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
# Modified by Anders Logg, 2008
1717
# Modified by Massimiliano Leoni, 2016
1818

19+
import functools
1920
import warnings
2021

2122
from ufl.core.ufl_type import UFLType, update_ufl_type_attributes
2223

2324

25+
PREFIX_TRAVERSE_DAG = "traverse_dag"
26+
27+
2428
class Expr(object, metaclass=UFLType):
2529
"""Base class for all UFL expression types.
2630
@@ -388,6 +392,84 @@ def __round__(self, n=None):
388392
val = NotImplemented
389393
return val
390394

395+
@staticmethod
396+
def traverse_dag(f):
397+
"""Decorate DAG traversal methods."""
398+
@functools.wraps(f)
399+
def wrapper(self, *arg, **kwargs):
400+
if not f.__name__.startswith(PREFIX_TRAVERSE_DAG):
401+
raise RuntimeError(f"""
402+
DAG traversal method name must be prefixed by
403+
'{PREFIX_TRAVERSE_DAG}': got {f.__name__}
404+
""")
405+
if 'cache' not in kwargs:
406+
raise RuntimeError("DAG traversal methods must have cache in kwargs")
407+
if kwargs['cache'] is None:
408+
# Make a new dict for caching if not passed by the caller.
409+
kwargs['cache'] = {}
410+
return f(self, *arg, **kwargs)
411+
return wrapper
412+
413+
@traverse_dag
414+
def traverse_dag_reuse_if_untouched(self, method_name: str, *args, **kwargs) -> "Expr":
415+
"""Reuse if untouched.
416+
417+
Args:
418+
method_name: name of the method to be applied to children.
419+
420+
Returns:
421+
Reconstructed Expr or `self` if children are untouched.
422+
423+
"""
424+
ops = [
425+
getattr(op, method_name)(*args, **kwargs)
426+
for op in self.ufl_operands
427+
]
428+
if all(a is b for a, b in zip(self.ufl_operands, ops)):
429+
return self
430+
else:
431+
return self._ufl_expr_reconstruct_(*ops)
432+
433+
@staticmethod
434+
def traverse_dag_apply_coefficient_split_cache(f):
435+
"""Use method specific key for caching."""
436+
@functools.wraps(f)
437+
def wrapper(
438+
self,
439+
coefficient_split,
440+
reference_value=False,
441+
reference_grad=0,
442+
restricted=None,
443+
cache=None,
444+
):
445+
if cache is None:
446+
raise RuntimeError(f"""
447+
Can not have cache=None.
448+
Must decorate {f} with ``Expr.traverse_dag``.
449+
""")
450+
key = (self, reference_value, reference_grad, restricted)
451+
if key in cache:
452+
return cache[key]
453+
else:
454+
result = f(
455+
self,
456+
coefficient_split,
457+
reference_value=reference_value,
458+
reference_grad=reference_grad,
459+
restricted=restricted,
460+
cache=cache,
461+
)
462+
cache[key] = result
463+
return result
464+
return wrapper
465+
466+
def __getattr__(self, name):
467+
if name.startswith(PREFIX_TRAVERSE_DAG):
468+
# Traverse DAG with reuse_if_untouched by default.
469+
return functools.partial(self.traverse_dag_reuse_if_untouched, name)
470+
else:
471+
raise AttributeError(f"{type(self)} has no attribute {name}")
472+
391473

392474
# Initializing traits here because Expr is not defined in the class
393475
# declaration

ufl/core/terminal.py

+44
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import warnings
1515

16+
import numpy as np
1617
from ufl.core.expr import Expr
1718
from ufl.core.ufl_type import ufl_type
1819

@@ -92,6 +93,49 @@ def __eq__(self, other):
9293
"""Default comparison of terminals just compare repr strings."""
9394
return repr(self) == repr(other)
9495

96+
@Expr.traverse_dag
97+
@Expr.traverse_dag_apply_coefficient_split_cache
98+
def traverse_dag_apply_coefficient_split(
99+
self,
100+
coefficient_split,
101+
reference_value=False,
102+
reference_grad=0,
103+
restricted=None,
104+
cache=None,
105+
):
106+
from ufl.classes import (
107+
ComponentTensor,
108+
MultiIndex,
109+
NegativeRestricted,
110+
PositiveRestricted,
111+
ReferenceGrad,
112+
ReferenceValue,
113+
Zero,
114+
)
115+
from ufl.core.multiindex import indices
116+
from ufl.checks import is_cellwise_constant
117+
from ufl.domain import extract_unique_domain
118+
from ufl.tensors import as_tensor
119+
120+
c = self
121+
if reference_value:
122+
c = ReferenceValue(c)
123+
for _ in range(reference_grad):
124+
# Return zero if expression is trivially constant. This has to
125+
# happen here because ReferenceGrad has no access to the
126+
# topological dimension of a literal zero.
127+
if is_cellwise_constant(c):
128+
dim = extract_unique_domain(subcoeff).topological_dimension()
129+
c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions)
130+
else:
131+
c = ReferenceGrad(c)
132+
if restricted == "+":
133+
c = PositiveRestricted(c)
134+
elif restricted == "-":
135+
c = NegativeRestricted(c)
136+
elif restricted is not None:
137+
raise RuntimeError(f"Got unknown restriction: {restricted}")
138+
return c
95139

96140
# --- Subgroups of terminals ---
97141

ufl/differentiation.py

+36
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,42 @@ def __str__(self):
344344
"""Format as a string."""
345345
return "reference_grad(%s)" % self.ufl_operands[0]
346346

347+
@Expr.traverse_dag
348+
@Expr.traverse_dag_apply_coefficient_split_cache
349+
def traverse_dag_apply_coefficient_split(
350+
self,
351+
coefficient_split,
352+
reference_value=False,
353+
reference_grad=0,
354+
restricted=None,
355+
cache=None,
356+
):
357+
"""Split mixed coefficients.
358+
359+
Args:
360+
coefficient_split: `dict` that maps mixed coefficients to their components.
361+
reference_value: If `ReferenceValue` has been applied.
362+
reference_grad: Number of `ReferenceGrad`s that have been applied.
363+
restricted: '+', '-', or None.
364+
cache: `dict` for caching DAG nodes.
365+
366+
Returns:
367+
This node wrapped with `ReferenceValue` (if ``reference_value``),
368+
`ReferenceGrad` (``reference_grad`` times), and `Restricted` (if
369+
``restricted`` is '+' or '-'). The underlying terminal will be
370+
decomposed into components according to ``coefficient_split``.
371+
372+
"""
373+
op, = self.ufl_operands
374+
if not op._ufl_terminal_modifiers_:
375+
raise ValueError(f"Must be a terminal modifier: {op!r}.")
376+
return op.traverse_dag_apply_coefficient_split(
377+
coefficient_split,
378+
reference_value=reference_value,
379+
reference_grad=reference_grad + 1,
380+
restricted=restricted,
381+
cache=cache,
382+
)
347383

348384
@ufl_type(num_ops=1, inherit_indices_from_operand=0, is_terminal_modifier=True)
349385
class Div(CompoundDerivative):

ufl/referencevalue.py

+40
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# SPDX-License-Identifier: LGPL-3.0-or-later
77

8+
from ufl.core.expr import Expr
89
from ufl.core.operator import Operator
910
from ufl.core.terminal import FormArgument
1011
from ufl.core.ufl_type import ufl_type
@@ -34,3 +35,42 @@ def evaluate(self, x, mapping, component, index_values, derivatives=()):
3435
def __str__(self):
3536
"""Format as a string."""
3637
return f"reference_value({self.ufl_operands[0]})"
38+
39+
@Expr.traverse_dag
40+
@Expr.traverse_dag_apply_coefficient_split_cache
41+
def traverse_dag_apply_coefficient_split(
42+
self,
43+
coefficient_split: dict,
44+
reference_value: bool = False,
45+
reference_grad: int = 0,
46+
restricted: str | None = None,
47+
cache=None,
48+
):
49+
"""Split mixed coefficients.
50+
51+
Args:
52+
coefficient_split: `dict` that maps mixed coefficients to their components.
53+
reference_value: If `ReferenceValue` has been applied.
54+
reference_grad: Number of `ReferenceGrad`s that have been applied.
55+
restricted: '+', '-', or None.
56+
cache: `dict` for caching DAG nodes.
57+
58+
Returns:
59+
This node wrapped with `ReferenceValue` (if ``reference_value``),
60+
`ReferenceGrad` (``reference_grad`` times), and `Restricted` (if
61+
``restricted`` is '+' or '-'). The underlying terminal will be
62+
decomposed into components according to ``coefficient_split``.
63+
64+
"""
65+
if reference_value:
66+
raise RuntimeError(f"Can not apply ReferenceValue twice: {self}")
67+
op, = self.ufl_operands
68+
if not op._ufl_terminal_modifiers_:
69+
raise ValueError(f"Must be a terminal modifier: {op!r}.")
70+
return op.traverse_dag_apply_coefficient_split(
71+
coefficient_split,
72+
reference_value=True,
73+
reference_grad=reference_grad,
74+
restricted=restricted,
75+
cache=cache,
76+
)

ufl/restriction.py

+40
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# SPDX-License-Identifier: LGPL-3.0-or-later
77

88
from ufl.constantvalue import ConstantValue
9+
from ufl.core.expr import Expr
910
from ufl.core.operator import Operator
1011
from ufl.core.ufl_type import ufl_type
1112
from ufl.precedence import parstr
@@ -48,6 +49,45 @@ def __str__(self):
4849
"""Format as a string."""
4950
return f"{parstr(self.ufl_operands[0], self)}({self._side})"
5051

52+
@Expr.traverse_dag
53+
@Expr.traverse_dag_apply_coefficient_split_cache
54+
def traverse_dag_apply_coefficient_split(
55+
self,
56+
coefficient_split,
57+
reference_value=False,
58+
reference_grad=0,
59+
restricted=None,
60+
cache=None,
61+
):
62+
"""Split mixed coefficients.
63+
64+
Args:
65+
coefficient_split: `dict` that maps mixed coefficients to their components.
66+
reference_value: If `ReferenceValue` has been applied.
67+
reference_grad: Number of `ReferenceGrad`s that have been applied.
68+
restricted: '+', '-', or None.
69+
cache: `dict` for caching DAG nodes.
70+
71+
Returns:
72+
This node wrapped with `ReferenceValue` (if ``reference_value``),
73+
`ReferenceGrad` (``reference_grad`` times), and `Restricted` (if
74+
``restricted`` is '+' or '-'). The underlying terminal will be
75+
decomposed into components according to ``coefficient_split``.
76+
77+
"""
78+
if restricted is not None:
79+
raise RuntimeError
80+
op, = self.ufl_operands
81+
if not op._ufl_terminal_modifiers_:
82+
raise ValueError(f"Must be a terminal modifier: {op!r}.")
83+
return op.traverse_dag_apply_coefficient_split(
84+
coefficient_split,
85+
reference_value=reference_value,
86+
reference_grad=reference_grad,
87+
restricted=self._side,
88+
cache=cache,
89+
)
90+
5191

5292
@ufl_type(is_terminal_modifier=True)
5393
class PositiveRestricted(Restricted):

0 commit comments

Comments
 (0)