Skip to content

Commit 36c8fa0

Browse files
committed
temp
1 parent 6ac9b9d commit 36c8fa0

7 files changed

+191
-1
lines changed

ufl/algorithms/compute_form_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def compute_form_data(
474474
new_integrals = []
475475
for integral in itg_data.integrals:
476476
integrand = replace(integral.integrand(), self.function_replace_map)
477-
integrand = apply_coefficient_split(integrand, self.coefficient_split)
477+
integrand = integrand.traverse_dag_apply_coefficient_split(self.coefficient_split, cache={})
478478
if not isinstance(integrand, Zero):
479479
new_integrals.append(integral.reconstruct(integrand=integrand))
480480
itg_data.integrals = new_integrals

ufl/coefficient.py

+78
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# Modified by Cecile Daversin-Catty, 2018.
1212
# Modified by Ignacia Fierro-Piccardo 2023.
1313

14+
import numpy as np
1415
from ufl.argument import Argument
1516
from ufl.core.terminal import FormArgument
1617
from ufl.core.ufl_type import ufl_type
@@ -20,6 +21,7 @@
2021
from ufl.split_functions import split
2122
from ufl.utils.counted import Counted
2223

24+
2325
# --- The Coefficient class represents a coefficient in a form ---
2426

2527

@@ -201,6 +203,82 @@ def __repr__(self):
201203
"""Representation."""
202204
return self._repr
203205

206+
def traverse_dag_apply_coefficient_split(
207+
self,
208+
coefficient_split,
209+
reference_value=False,
210+
reference_grad=0,
211+
restricted=None,
212+
cache=None,
213+
):
214+
from ufl.classes import (
215+
ComponentTensor,
216+
MultiIndex,
217+
NegativeRestricted,
218+
PositiveRestricted,
219+
ReferenceGrad,
220+
ReferenceValue,
221+
Zero,
222+
)
223+
from ufl.core.multiindex import indices
224+
from ufl.checks import is_cellwise_constant
225+
from ufl.domain import extract_unique_domain
226+
from ufl.tensors import as_tensor
227+
228+
if self not in coefficient_split:
229+
c = self
230+
if reference_value:
231+
c = ReferenceValue(c)
232+
for _ in range(reference_grad):
233+
# Return zero if expression is trivially constant. This has to
234+
# happen here because ReferenceGrad has no access to the
235+
# topological dimension of a literal zero.
236+
if is_cellwise_constant(c):
237+
dim = extract_unique_domain(subcoeff).topological_dimension()
238+
c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions)
239+
else:
240+
c = ReferenceGrad(c)
241+
if restricted == "+":
242+
c = PositiveRestricted(c)
243+
elif restricted == "-":
244+
c = NegativeRestricted(c)
245+
elif restricted is not None:
246+
raise RuntimeError(f"Got unknown restriction: {restricted}")
247+
return c
248+
# Reference value expected
249+
if not reference_value:
250+
raise RuntimeError(f"ReferenceValue expected: got {o}")
251+
# Derivative indices
252+
beta = indices(reference_grad)
253+
components = []
254+
for subcoeff in coefficient_split[self]:
255+
c = subcoeff
256+
# Apply terminal modifiers onto the subcoefficient
257+
if reference_value:
258+
c = ReferenceValue(c)
259+
for _ in range(reference_grad):
260+
# Return zero if expression is trivially constant. This has to
261+
# happen here because ReferenceGrad has no access to the
262+
# topological dimension of a literal zero.
263+
if is_cellwise_constant(c):
264+
dim = extract_unique_domain(subcoeff).topological_dimension()
265+
c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions)
266+
else:
267+
c = ReferenceGrad(c)
268+
if restricted == "+":
269+
c = PositiveRestricted(c)
270+
elif restricted == "-":
271+
c = NegativeRestricted(c)
272+
elif restricted is not None:
273+
raise RuntimeError(f"Got unknown restriction: {restricted}")
274+
# Collect components of the subcoefficient
275+
for alpha in np.ndindex(subcoeff.ufl_element().reference_value_shape):
276+
# New modified terminal: component[alpha + beta]
277+
components.append(c[alpha + beta])
278+
# Repack derivative indices to shape
279+
i, = indices(1)
280+
return ComponentTensor(as_tensor(components)[i], MultiIndex((i,) + beta))
281+
204282

205283
# --- Helper functions for subfunctions on mixed elements ---
206284

ufl/core/expr.py

+10
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,16 @@ def __round__(self, n=None):
388388
val = NotImplemented
389389
return val
390390

391+
def traverse_dag_apply_coefficient_split(self, *args, **kwargs):
392+
ops = [
393+
op.traverse_dag_apply_coefficient_split(*args, **kwargs)
394+
for op in self.ufl_operands
395+
]
396+
if all(a is b for a, b in zip(self.ufl_operands, ops)):
397+
return self
398+
else:
399+
return self._ufl_expr_reconstruct_(*ops)
400+
391401

392402
# Initializing traits here because Expr is not defined in the class
393403
# declaration

ufl/core/terminal.py

+42
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import warnings
1515

16+
import numpy as np
1617
from ufl.core.expr import Expr
1718
from ufl.core.ufl_type import ufl_type
1819

@@ -92,6 +93,47 @@ def __eq__(self, other):
9293
"""Default comparison of terminals just compare repr strings."""
9394
return repr(self) == repr(other)
9495

96+
def traverse_dag_apply_coefficient_split(
97+
self,
98+
coefficient_split,
99+
reference_value=False,
100+
reference_grad=0,
101+
restricted=None,
102+
cache=None,
103+
):
104+
from ufl.classes import (
105+
ComponentTensor,
106+
MultiIndex,
107+
NegativeRestricted,
108+
PositiveRestricted,
109+
ReferenceGrad,
110+
ReferenceValue,
111+
Zero,
112+
)
113+
from ufl.core.multiindex import indices
114+
from ufl.checks import is_cellwise_constant
115+
from ufl.domain import extract_unique_domain
116+
from ufl.tensors import as_tensor
117+
118+
c = self
119+
if reference_value:
120+
c = ReferenceValue(c)
121+
for _ in range(reference_grad):
122+
# Return zero if expression is trivially constant. This has to
123+
# happen here because ReferenceGrad has no access to the
124+
# topological dimension of a literal zero.
125+
if is_cellwise_constant(c):
126+
dim = extract_unique_domain(subcoeff).topological_dimension()
127+
c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions)
128+
else:
129+
c = ReferenceGrad(c)
130+
if restricted == "+":
131+
c = PositiveRestricted(c)
132+
elif restricted == "-":
133+
c = NegativeRestricted(c)
134+
elif restricted is not None:
135+
raise RuntimeError(f"Got unknown restriction: {restricted}")
136+
return c
95137

96138
# --- Subgroups of terminals ---
97139

ufl/differentiation.py

+18
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,24 @@ def __str__(self):
346346
"""Format as a string."""
347347
return "reference_grad(%s)" % self.ufl_operands[0]
348348

349+
def traverse_dag_apply_coefficient_split(
350+
self,
351+
coefficient_split,
352+
reference_value=False,
353+
reference_grad=0,
354+
restricted=None,
355+
cache=None,
356+
):
357+
op, = self.ufl_operands
358+
if not op._ufl_terminal_modifiers_:
359+
raise ValueError(f"Expecting a terminal modifier: got {op!r}.")
360+
return op.traverse_dag_apply_coefficient_split(
361+
coefficient_split,
362+
reference_value=reference_value,
363+
reference_grad=reference_grad + 1,
364+
restricted=restricted,
365+
cache=cache,
366+
)
349367

350368
@ufl_type(num_ops=1, inherit_indices_from_operand=0, is_terminal_modifier=True)
351369
class Div(CompoundDerivative):

ufl/referencevalue.py

+21
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,24 @@ def evaluate(self, x, mapping, component, index_values, derivatives=()):
3434
def __str__(self):
3535
"""Format as a string."""
3636
return f"reference_value({self.ufl_operands[0]})"
37+
38+
def traverse_dag_apply_coefficient_split(
39+
self,
40+
coefficient_split,
41+
reference_value=False,
42+
reference_grad=0,
43+
restricted=None,
44+
cache=None,
45+
):
46+
if reference_value:
47+
raise RuntimeError
48+
op, = self.ufl_operands
49+
if not op._ufl_terminal_modifiers_:
50+
raise ValueError(f"Expecting a terminal modifier: got {op!r}.")
51+
return op.traverse_dag_apply_coefficient_split(
52+
coefficient_split,
53+
reference_value=True,
54+
reference_grad=reference_grad,
55+
restricted=restricted,
56+
cache=cache,
57+
)

ufl/restriction.py

+21
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@ def __str__(self):
4848
"""Format as a string."""
4949
return f"{parstr(self.ufl_operands[0], self)}({self._side})"
5050

51+
def traverse_dag_apply_coefficient_split(
52+
self,
53+
coefficient_split,
54+
reference_value=False,
55+
reference_grad=0,
56+
restricted=None,
57+
cache=None,
58+
):
59+
if restricted is not None:
60+
raise RuntimeError
61+
op, = self.ufl_operands
62+
if not op._ufl_terminal_modifiers_:
63+
raise ValueError(f"Expecting a terminal modifier: got {op!r}.")
64+
return op.traverse_dag_apply_coefficient_split(
65+
coefficient_split,
66+
reference_value=reference_value,
67+
reference_grad=reference_grad,
68+
restricted=self._side,
69+
cache=cache,
70+
)
71+
5172

5273
@ufl_type(is_terminal_modifier=True)
5374
class PositiveRestricted(Restricted):

0 commit comments

Comments
 (0)