Skip to content

Commit 2e3b737

Browse files
committed
Remove component tensors
1 parent 7d7c676 commit 2e3b737

File tree

5 files changed

+142
-2
lines changed

5 files changed

+142
-2
lines changed

test/test_derivative.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def test_vector_coefficient_scalar_derivatives(self):
665665
integrand = inner(f, g)
666666

667667
i0, i1, i2, i3, i4 = [Index(count=c) for c in range(5)]
668-
expected = as_tensor(df[i1] * dv, (i1,))[i0] * g[i0]
668+
expected = as_tensor(df[i1], (i1,))[i0] * dv * g[i0]
669669

670670
F = integrand * dx
671671
J = derivative(F, u, dv, cd)
@@ -693,7 +693,7 @@ def test_vector_coefficient_derivatives(self):
693693
integrand = inner(f, g)
694694

695695
i0, i1, i2, i3, i4 = [Index(count=c) for c in range(5)]
696-
expected = as_tensor(df[i2, i1] * dv[i1], (i2,))[i0] * g[i0]
696+
expected = as_tensor(df[i2, i1], (i2,))[i0] * dv[i1] * g[i0]
697697

698698
F = integrand * dx
699699
J = derivative(F, u, dv, cd)

ufl/algebra.py

+6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ufl.core.operator import Operator
1414
from ufl.core.ufl_type import ufl_type
1515
from ufl.index_combination_utils import merge_unique_indices
16+
from ufl.indexed import Indexed
1617
from ufl.precedence import parstr
1718
from ufl.sorting import sorted_expr
1819

@@ -89,6 +90,11 @@ def __init__(self, a, b):
8990
"""Initialise."""
9091
Operator.__init__(self)
9192

93+
def _simplify_indexed(self, multiindex):
94+
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
95+
a, b = self.ufl_operands
96+
return Sum(Indexed(a, multiindex), Indexed(b, multiindex))
97+
9298
def evaluate(self, x, mapping, component, index_values):
9399
"""Evaluate."""
94100
return sum(o.evaluate(x, mapping, component, index_values) for o in self.ufl_operands)

ufl/algorithms/compute_form_data.py

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from ufl.algorithms.analysis import extract_coefficients, extract_sub_elements, unique_tuple
1515
from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering
16+
from ufl.algorithms.remove_component_tensors import remove_component_tensors
1617
from ufl.algorithms.apply_derivatives import apply_coordinate_derivatives, apply_derivatives
1718

1819
# These are the main symbolic processing steps:
@@ -328,6 +329,8 @@ def compute_form_data(
328329

329330
form = apply_coordinate_derivatives(form)
330331

332+
form = remove_component_tensors(form)
333+
331334
# Propagate restrictions to terminals
332335
if do_apply_restrictions:
333336
form = apply_restrictions(form, apply_default=do_apply_default_restrictions)
+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Remove component tensors.
2+
3+
This module contains classes and functions to remove component tensors.
4+
"""
5+
# Copyright (C) 2008-2016 Martin Sandve Alnæs
6+
#
7+
# This file is part of UFL (https://www.fenicsproject.org)
8+
#
9+
# SPDX-License-Identifier: LGPL-3.0-or-later
10+
11+
from ufl.classes import (
12+
ComponentTensor,
13+
FixedIndex,
14+
Form,
15+
Index,
16+
MultiIndex,
17+
Zero,
18+
)
19+
from ufl.corealg.map_dag import map_expr_dag
20+
from ufl.corealg.multifunction import MultiFunction, memoized_handler
21+
22+
23+
class IndexReplacer(MultiFunction):
24+
"""Replace Indices."""
25+
26+
def __init__(self, fimap: dict):
27+
"""Initialise.
28+
29+
Args:
30+
fimap: map for index replacements.
31+
32+
"""
33+
MultiFunction.__init__(self)
34+
self.fimap = fimap
35+
self._object_cache = {}
36+
37+
expr = MultiFunction.reuse_if_untouched
38+
39+
@memoized_handler
40+
def zero(self, o):
41+
"""Handle Zero."""
42+
free_indices = []
43+
index_dimensions = []
44+
for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions):
45+
if Index(i) in self.fimap:
46+
ind_j = self.fimap[Index(i)]
47+
if not isinstance(ind_j, FixedIndex):
48+
free_indices.append(ind_j.count())
49+
index_dimensions.append(d)
50+
else:
51+
free_indices.append(i)
52+
index_dimensions.append(d)
53+
return Zero(
54+
shape=o.ufl_shape,
55+
free_indices=tuple(free_indices),
56+
index_dimensions=tuple(index_dimensions),
57+
)
58+
59+
@memoized_handler
60+
def multi_index(self, o):
61+
"""Handle MultiIndex."""
62+
return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices()))
63+
64+
65+
class IndexRemover(MultiFunction):
66+
"""Remove Indexed."""
67+
68+
def __init__(self):
69+
"""Initialise."""
70+
MultiFunction.__init__(self)
71+
self._object_cache = {}
72+
73+
expr = MultiFunction.reuse_if_untouched
74+
75+
@memoized_handler
76+
def _zero_simplify(self, o):
77+
"""Apply simplification for Zero()."""
78+
(operand,) = o.ufl_operands
79+
operand = map_expr_dag(self, operand)
80+
if isinstance(operand, Zero):
81+
return Zero(
82+
shape=o.ufl_shape,
83+
free_indices=o.ufl_free_indices,
84+
index_dimensions=o.ufl_index_dimensions,
85+
)
86+
return o._ufl_expr_reconstruct_(operand)
87+
88+
@memoized_handler
89+
def indexed(self, o):
90+
"""Simplify Indexed."""
91+
o1, i1 = o.ufl_operands
92+
if isinstance(o1, ComponentTensor):
93+
# Simplify Indexed ComponentTensor
94+
o2, i2 = o1.ufl_operands
95+
assert len(i2) == len(i1)
96+
fimap = dict(zip(i2, i1))
97+
rule = IndexReplacer(fimap)
98+
v = map_expr_dag(rule, o2)
99+
return map_expr_dag(self, v)
100+
101+
expr = map_expr_dag(self, o1)
102+
if expr is o1:
103+
# Reuse if untouched
104+
return o
105+
return o._ufl_expr_reconstruct_(expr, i1)
106+
107+
# Do something nicer
108+
positive_restricted = _zero_simplify
109+
negative_restricted = _zero_simplify
110+
reference_grad = _zero_simplify
111+
reference_value = _zero_simplify
112+
113+
114+
def remove_component_tensors(o):
115+
"""Remove component tensors."""
116+
if isinstance(o, Form):
117+
integrals = []
118+
for integral in o.integrals():
119+
integrand = remove_component_tensors(integral.integrand())
120+
if not isinstance(integrand, Zero):
121+
integrals.append(integral.reconstruct(integrand=integrand))
122+
return o._ufl_expr_reconstruct_(integrals)
123+
else:
124+
rule = IndexRemover()
125+
return map_expr_dag(rule, o)

ufl/indexsum.py

+6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ufl.core.multiindex import MultiIndex
1212
from ufl.core.operator import Operator
1313
from ufl.core.ufl_type import ufl_type
14+
from ufl.indexed import Indexed
1415
from ufl.precedence import parstr
1516

1617
# --- Sum over an index ---
@@ -69,6 +70,11 @@ def ufl_shape(self):
6970
"""Get UFL shape."""
7071
return self.ufl_operands[0].ufl_shape
7172

73+
def _simplify_indexed(self, multiindex):
74+
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
75+
A, i = self.ufl_operands
76+
return IndexSum(Indexed(A, multiindex), i)
77+
7278
def evaluate(self, x, mapping, component, index_values):
7379
"""Evaluate."""
7480
(i,) = self.ufl_operands[1]

0 commit comments

Comments
 (0)