Skip to content

Commit 45f6d9e

Browse files
authored
PhysicallyMappedElement: implement hand-rolled basis transformation (#115)
* PhysicallyMappedElement: implement hand-rolled basis transformation * comments
1 parent acbd449 commit 45f6d9e

8 files changed

+11
-33
lines changed

finat/aw.py

-8
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ def entity_dofs(self):
7878
1: {0: [0, 1, 2, 3], 1: [4, 5, 6, 7], 2: [8, 9, 10, 11]},
7979
2: {0: [12, 13, 14]}}
8080

81-
@property
82-
def index_shape(self):
83-
return (self.space_dimension(),)
84-
8581
def space_dimension(self):
8682
return 15
8783

@@ -129,9 +125,5 @@ def entity_dofs(self):
129125
1: {0: [9, 10, 11, 12], 1: [13, 14, 15, 16], 2: [17, 18, 19, 20]},
130126
2: {0: [21, 22, 23]}}
131127

132-
@property
133-
def index_shape(self):
134-
return (self.space_dimension(),)
135-
136128
def space_dimension(self):
137129
return 24

finat/bell.py

-4
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,5 @@ def basis_transformation(self, coordinate_mapping):
7373
def entity_dofs(self):
7474
return self._entity_dofs
7575

76-
@property
77-
def index_shape(self):
78-
return (18,)
79-
8076
def space_dimension(self):
8177
return 18

finat/fiat_elements.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def space_dimension(self):
7979

8080
@property
8181
def index_shape(self):
82-
return (self._element.space_dimension(),)
82+
return (self.space_dimension(),)
8383

8484
@property
8585
def value_shape(self):

finat/hct.py

-4
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,5 @@ def basis_transformation(self, coordinate_mapping):
9393
def entity_dofs(self):
9494
return self._entity_dofs
9595

96-
@property
97-
def index_shape(self):
98-
return (9,)
99-
10096
def space_dimension(self):
10197
return 9

finat/mtw.py

-4
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,5 @@ def basis_transformation(self, coordinate_mapping):
4242
def entity_dofs(self):
4343
return self._entity_dofs
4444

45-
@property
46-
def index_shape(self):
47-
return (self._space_dimension,)
48-
4945
def space_dimension(self):
5046
return self._space_dimension

finat/physically_mapped.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -268,15 +268,17 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
268268
assert coordinate_mapping is not None
269269

270270
M = self.basis_transformation(coordinate_mapping)
271-
M, = gem.optimise.constant_fold_zero((M,))
271+
# we expect M to be sparse with O(1) nonzeros per row
272+
# for each row, get the column index of each nonzero entry
273+
csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)]
274+
for i in range(M.shape[0])]
272275

273276
def matvec(table):
274-
table, = gem.optimise.constant_fold_zero((table,))
275-
i, j = gem.indices(2)
276-
value_indices = self.get_value_indices()
277-
table = gem.Indexed(table, (j, ) + value_indices)
278-
val = gem.ComponentTensor(gem.IndexSum(M[i, j]*table, (j,)), (i,) + value_indices)
279-
# Eliminate zeros
277+
# basis recombination using hand-rolled sparse-dense matrix multiplication
278+
table = [gem.partial_indexed(table, (j,)) for j in range(M.shape[1])]
279+
# the sum approach is faster than calling numpy.dot or gem.IndexSum
280+
expressions = [sum(M.array[i, j] * table[j] for j in js) for i, js in enumerate(csr)]
281+
val = gem.ListTensor(expressions)
280282
return gem.optimise.aggressive_unroll(val)
281283

282284
result = super().basis_evaluation(order, ps, entity=entity)

finat/piola_mapped.py

-4
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,6 @@ def __init__(self, fiat_element):
113113
def entity_dofs(self):
114114
return self._entity_dofs
115115

116-
@property
117-
def index_shape(self):
118-
return (self._space_dimension,)
119-
120116
def space_dimension(self):
121117
return self._space_dimension
122118

gem/gem.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ def __new__(cls, i, j, dtype=None):
969969

970970
# Fixed indices
971971
if isinstance(i, int) and isinstance(j, int):
972-
return Literal(int(i == j))
972+
return one if i == j else Zero()
973973

974974
self = super(Delta, cls).__new__(cls)
975975
self.i = i

0 commit comments

Comments
 (0)